From bca18bf7311bd68801c0261e9e21696536c46919 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <MichielCottaar@protonmail.com>
Date: Fri, 1 May 2020 13:42:39 +0100
Subject: [PATCH] BUG: test IO for CIFTI

---
 fsl/data/cifti.py   | 14 +++++---
 tests/test_cifti.py | 85 ++++++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 93 insertions(+), 6 deletions(-)

diff --git a/fsl/data/cifti.py b/fsl/data/cifti.py
index 96b8507ad..eea43b665 100644
--- a/fsl/data/cifti.py
+++ b/fsl/data/cifti.py
@@ -86,13 +86,13 @@ class Cifti:
         else:
             new_axes = list(self.axes)
 
-        data = self.data
+        data = self.arr
         if data.ndim == 1:
             # CIFTI axes are always at least 2D
             data = data[None, :]
             new_axes.insert(0, cifti2_axes.ScalarAxis(['default']))
 
-        return cifti2_axes.Cifti2Image(data, header=new_axes)
+        return nib.Cifti2Image(data, header=new_axes)
 
     @classmethod
     def from_cifti(cls, filename, writable=False):
@@ -136,7 +136,7 @@ class Cifti:
             - if set to "series" a SeriesAxis is used
         :return:
         """
-        self.to_cifti(default_axis).to_filename(addExt(cifti_filename, defaultExt=self.extension))
+        self.to_cifti(default_axis).to_filename(addExt(cifti_filename, defaultExt=self.extension, mustExist=False))
 
     @classmethod
     def from_gifti(cls, filename, mask_values=(0, np.nan)):
@@ -214,6 +214,8 @@ class DenseCifti(Cifti):
 
     @property
     def extension(self, ):
+        if self.arr.ndim == 1:
+            return dense_extensions[cifti2_axes.ScalarAxis]
         return dense_extensions[type(self.axes[-2])]
 
     def to_image(self, fill=0) -> image.Image:
@@ -268,12 +270,14 @@ class ParcelCifti(Cifti):
     Represents sparse data defined at specific parcels
     """
     def __init__(self, *args, **kwargs):
-        super().__init__(self, *args, **kwargs)
-        if not isinstance(self.parcel_axis, cifti2_axes.BrainModelAxis):
+        super().__init__(*args, **kwargs)
+        if not isinstance(self.parcel_axis, cifti2_axes.ParcelsAxis):
             raise ValueError(f"ParcelCifti expects a ParcelsAxis as last axes object, not {type(self.parcel_axis)}")
 
     @property
     def extension(self, ):
+        if self.arr.ndim == 1:
+            return parcel_extensions[cifti2_axes.ScalarAxis]
         return parcel_extensions[type(self.axes[-2])]
 
     @property
diff --git a/tests/test_cifti.py b/tests/test_cifti.py
index f61afe393..67851a3e3 100644
--- a/tests/test_cifti.py
+++ b/tests/test_cifti.py
@@ -1,10 +1,43 @@
-import pytest
 from fsl.data import cifti
 import os.path as op
 import numpy as np
 import nibabel as nib
 from numpy import testing
 import tests
+from nibabel.cifti2 import cifti2_axes
+
+
+def volumetric_brain_model():
+    mask = np.random.randint(2, size=(10, 10, 10)) > 0
+    return cifti2_axes.BrainModelAxis.from_mask(mask, affine=np.eye(4))
+
+
+def surface_brain_model():
+    mask = np.random.randint(2, size=100) > 0
+    return cifti2_axes.BrainModelAxis.from_mask(mask, name='cortex')
+
+
+def volumetric_parcels():
+    mask = np.random.randint(5, size=(10, 10, 10))
+    return cifti2_axes.ParcelsAxis(
+        [f'vol_{idx}' for idx in range(1, 5)],
+        voxels=[np.stack(np.where(mask == idx), axis=-1) for idx in range(1, 5)],
+        vertices=[{} for _ in range(1, 5)],
+    )
+
+
+def surface_parcels():
+    mask = np.random.randint(5, size=100)
+    return cifti2_axes.ParcelsAxis(
+        [f'surf_{idx}' for idx in range(1, 5)],
+        voxels=[np.zeros((0, 3), dtype=int) for _ in range(1, 5)],
+        vertices=[{'CIFTI_STRUCTURE_CORTEX': np.where(mask == idx)[0]} for idx in range(1, 5)],
+        nvertices={'CIFTI_STRUCTURE_CORTEX': 100},
+    )
+
+
+def gen_data(axes):
+    return np.random.randn(*(5 if ax is None else len(ax) for ax in axes))
 
 
 def test_read_gifti():
@@ -45,3 +78,53 @@ def test_read_nifti():
             testing.assert_equal(data.arr, values[mask])
             testing.assert_allclose(data.brain_model_axis.affine, affine)
             assert len(data.brain_model_axis.nvertices) == 0
+
+
+def check_io(data: cifti.Cifti, extension):
+    with tests.testdir():
+        data.save("test")
+        assert op.isfile(f'test.{extension}.nii')
+        loaded = cifti.load("test")
+        if data.arr.ndim == 1:
+            testing.assert_equal(data.arr, loaded.arr[0])
+            assert data.axes == loaded.axes[1:]
+        else:
+            testing.assert_equal(data.arr, loaded.arr)
+            assert data.axes == loaded.axes
+
+
+def test_io_cifti():
+    for cifti_class, cifti_type, main_axis_options in (
+        (cifti.DenseCifti, 'd', (volumetric_brain_model(), surface_brain_model(),
+                                 volumetric_brain_model() + surface_brain_model())),
+        (cifti.ParcelCifti, 'p', (volumetric_parcels(), surface_parcels(),
+                                  volumetric_parcels() + surface_parcels())),
+    ):
+        for main_axis in main_axis_options:
+            with tests.testdir():
+                data_1d = cifti_class(gen_data([main_axis]), [main_axis])
+                check_io(data_1d, f'{cifti_type}scalar')
+
+                connectome = cifti_class(gen_data([main_axis, main_axis]), (main_axis, main_axis))
+                check_io(connectome, f'{cifti_type}conn')
+
+                scalar_axis = cifti2_axes.ScalarAxis(['A', 'B', 'C'])
+                scalar = cifti_class(gen_data([scalar_axis, main_axis]), (scalar_axis, main_axis))
+                check_io(scalar, f'{cifti_type}scalar')
+
+                label_axis = cifti2_axes.LabelAxis(['A', 'B', 'C'], {1: ('some parcel', (1, 0, 0, 1))})
+                label = cifti_class(gen_data([label_axis, main_axis]), (label_axis, main_axis))
+                check_io(label, f'{cifti_type}label')
+
+                series_axis = cifti2_axes.SeriesAxis(10, 3, 50, unit='HERTZ')
+                series = cifti_class(gen_data([series_axis, main_axis]), (series_axis, main_axis))
+                check_io(series, f'{cifti_type}tseries')
+
+                if cifti_type == 'd':
+                    parcel_axis = surface_parcels()
+                    dpconn = cifti_class(gen_data([parcel_axis, main_axis]), (parcel_axis, main_axis))
+                    check_io(dpconn, 'dpconn')
+                else:
+                    dense_axis = surface_brain_model()
+                    pdconn = cifti_class(gen_data([dense_axis, main_axis]), (dense_axis, main_axis))
+                    check_io(pdconn, 'pdconn')
-- 
GitLab