From 4cd6c30bb8ed296f9601c7d3570b97b3daedfbde Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <MichielCottaar@protonmail.com>
Date: Fri, 1 May 2020 14:44:16 +0100
Subject: [PATCH] BUG: add tests for extracting images/surface data

---
 fsl/data/cifti.py   |  30 ++++++------
 tests/test_cifti.py | 117 ++++++++++++++++++++++++++++++++++++++++++--
 2 files changed, 128 insertions(+), 19 deletions(-)

diff --git a/fsl/data/cifti.py b/fsl/data/cifti.py
index eea43b665..a1521d56f 100644
--- a/fsl/data/cifti.py
+++ b/fsl/data/cifti.py
@@ -224,10 +224,10 @@ class DenseCifti(Cifti):
         """
         if self.brain_model_axis.volume_mask.sum() == 0:
             raise ValueError(f"Can not create volume without voxels in {self}")
-        data = np.full(self.brain_model_axis.volume_shape + self.data.shape[:-1], fill,
-                       dtype=self.data.dtype)
+        data = np.full(self.brain_model_axis.volume_shape + self.arr.shape[:-1], fill,
+                       dtype=self.arr.dtype)
         voxels = self.brain_model_axis.voxel[self.brain_model_axis.volume_mask]
-        data[tuple(voxels.T)] = np.transpose(self.data, (-1,) + tuple(range(self.data.ndim - 1)))[
+        data[tuple(voxels.T)] = np.transpose(self.arr, (-1,) + tuple(range(self.arr.ndim - 1)))[
             self.brain_model_axis.volume_mask]
         return image.Image(data, xform=self.brain_model_axis.affine)
 
@@ -249,8 +249,8 @@ class DenseCifti(Cifti):
         if anatomy.cifti not in self.brain_model_axis.name:
             raise ValueError(f"No surface data for {anatomy.cifti} found")
         slc, bm = None, None
-        arr = np.full(self.data.shape[:-1] + (self.brain_model_axis.nvertices[anatomy.cifti],), fill,
-                      dtype=self.data.dtype)
+        arr = np.full(self.arr.shape[:-1] + (self.brain_model_axis.nvertices[anatomy.cifti],), fill,
+                      dtype=self.arr.dtype)
         for name, slc_try, bm_try in self.brain_model_axis.iter_structures():
             if name == anatomy.cifti:
                 if partial:
@@ -258,11 +258,11 @@ class DenseCifti(Cifti):
                         raise ValueError(f"Surface {anatomy} does not form a contiguous block")
                     slc, bm = slc_try, bm_try
                 else:
-                    arr[..., bm_try.vertex] = self.data[..., slc_try]
+                    arr[..., bm_try.vertex] = self.arr[..., slc_try]
         if not partial:
             return arr
         else:
-            return bm.vertex, self.data[..., slc]
+            return bm.vertex, self.arr[..., slc]
 
 
 class ParcelCifti(Cifti):
@@ -290,14 +290,14 @@ class ParcelCifti(Cifti):
         """
         data = np.full(self.parcel_axis.volume_shape + self.arr.shape[:-1], fill, dtype=self.arr.dtype)
         written = np.zeros(self.parcel_axis.volume_shape, dtype='bool')
-        for idx, write_to in enumerate(self.parcel_axis).voxels:
-            if written[write_to].any():
+        for idx, write_to in enumerate(self.parcel_axis.voxels):
+            if written[tuple(write_to.T)].any():
                 raise ValueError("Duplicate voxels in different parcels")
-            data[write_to] = self.arr[np.newaxis, ..., idx]
-            written[write_to] = True
+            data[tuple(write_to.T)] = self.arr[np.newaxis, ..., idx]
+            written[tuple(write_to.T)] = True
         if not written.any():
             raise ValueError("Parcellation does not contain any volumetric data")
-        return image.Image(data, xform=self.brain_model_axis.affine)
+        return image.Image(data, xform=self.parcel_axis.affine)
 
     def surface(self, anatomy, fill=np.nan, partial=False):
         """
@@ -315,9 +315,9 @@ class ParcelCifti(Cifti):
         if anatomy.cifti not in self.parcel_axis.nvertices:
             raise ValueError(f"No surface data for {anatomy.cifti} found")
 
-        arr = np.full(self.data.shape[:-1] + (self.parcel_axis.nvertices[anatomy.cifti],), fill,
-                      dtype=self.data.dtype)
-        written = np.zeros(self.parcel_axis.nvertices[anatomy.cifti])
+        arr = np.full(self.arr.shape[:-1] + (self.parcel_axis.nvertices[anatomy.cifti],), fill,
+                      dtype=self.arr.dtype)
+        written = np.zeros(self.parcel_axis.nvertices[anatomy.cifti], dtype='bool')
         for idx, vertices in enumerate(self.parcel_axis.vertices):
             if anatomy.cifti not in vertices:
                 continue
diff --git a/tests/test_cifti.py b/tests/test_cifti.py
index 67851a3e3..1a292582b 100644
--- a/tests/test_cifti.py
+++ b/tests/test_cifti.py
@@ -17,23 +17,33 @@ def surface_brain_model():
     return cifti2_axes.BrainModelAxis.from_mask(mask, name='cortex')
 
 
-def volumetric_parcels():
+def volumetric_parcels(return_mask=False):
     mask = np.random.randint(5, size=(10, 10, 10))
-    return cifti2_axes.ParcelsAxis(
+    axis = cifti2_axes.ParcelsAxis(
         [f'vol_{idx}' for idx in range(1, 5)],
         voxels=[np.stack(np.where(mask == idx), axis=-1) for idx in range(1, 5)],
         vertices=[{} for _ in range(1, 5)],
+        volume_shape=mask.shape,
+        affine=np.eye(4),
     )
+    if return_mask:
+        return axis, mask
+    else:
+        return axis
 
 
-def surface_parcels():
+def surface_parcels(return_mask=False):
     mask = np.random.randint(5, size=100)
-    return cifti2_axes.ParcelsAxis(
+    axis = cifti2_axes.ParcelsAxis(
         [f'surf_{idx}' for idx in range(1, 5)],
         voxels=[np.zeros((0, 3), dtype=int) for _ in range(1, 5)],
         vertices=[{'CIFTI_STRUCTURE_CORTEX': np.where(mask == idx)[0]} for idx in range(1, 5)],
         nvertices={'CIFTI_STRUCTURE_CORTEX': 100},
     )
+    if return_mask:
+        return axis, mask
+    else:
+        return axis
 
 
 def gen_data(axes):
@@ -128,3 +138,102 @@ def test_io_cifti():
                     dense_axis = surface_brain_model()
                     pdconn = cifti_class(gen_data([dense_axis, main_axis]), (dense_axis, main_axis))
                     check_io(pdconn, 'pdconn')
+
+
+def test_extract_dense():
+    vol_bm = volumetric_brain_model()
+    surf_bm = surface_brain_model()
+    for bm in (vol_bm + surf_bm, surf_bm + vol_bm):
+        for ndim, no_other_axis in ((1, True), (2, False), (2, True)):
+            if ndim == 1:
+                data = cifti.DenseCifti(gen_data([bm]), [bm])
+            else:
+                scl = cifti2_axes.ScalarAxis(['A', 'B', 'C'])
+                data = cifti.DenseCifti(gen_data([scl, bm]),
+                                        [None if no_other_axis else scl, bm])
+
+            # extract volume
+            ref_arr = data.arr[..., data.brain_model_axis.volume_mask]
+            vol_image = data.to_image(fill=np.nan)
+            if ndim == 1:
+                assert vol_image.shape == data.brain_model_axis.volume_shape
+            else:
+                assert vol_image.shape == data.brain_model_axis.volume_shape + (3, )
+            assert np.isfinite(vol_image.data).sum() == len(vol_bm) * (3 if ndim == 2 else 1)
+            testing.assert_equal(vol_image.data[tuple(vol_bm.voxel.T)], ref_arr.T)
+
+            from_image = cifti.DenseCifti.from_image(vol_image)
+            assert from_image.brain_model_axis == vol_bm
+            testing.assert_equal(from_image.arr, ref_arr)
+
+            # extract surface
+            ref_arr = data.arr[..., data.brain_model_axis.surface_mask]
+            mask, surf_data = data.surface('cortex', partial=True)
+            assert surf_data.shape[-1] < 100
+            testing.assert_equal(ref_arr, surf_data)
+            testing.assert_equal(surf_bm.vertex, mask)
+
+            surf_data_full = data.surface('cortex', fill=np.nan)
+            assert surf_data_full.shape[-1] == 100
+            mask_full = np.isfinite(surf_data_full)
+            if ndim == 2:
+                assert (mask_full.any(0) == mask_full.all(0)).all()
+                mask_full = mask_full[0]
+            assert mask_full.sum() == len(surf_bm)
+            assert mask_full[..., mask].sum() == len(surf_bm)
+            testing.assert_equal(surf_data_full[..., mask_full], ref_arr)
+
+
+def test_extract_parcel():
+    vol_parcel, vol_mask = volumetric_parcels(return_mask=True)
+    surf_parcel, surf_mask = surface_parcels(return_mask=True)
+    parcel = vol_parcel + surf_parcel
+    for ndim, no_other_axis in ((1, True), (2, False), (2, True)):
+        if ndim == 1:
+            data = cifti.ParcelCifti(gen_data([parcel]), [parcel])
+        else:
+            scl = cifti2_axes.ScalarAxis(['A', 'B', 'C'])
+            data = cifti.ParcelCifti(gen_data([scl, parcel]),
+                                     [None if no_other_axis else scl, parcel])
+
+        # extract volume
+        vol_image = data.to_image(fill=np.nan)
+        if ndim == 1:
+            assert vol_image.shape == data.parcel_axis.volume_shape
+        else:
+            assert vol_image.shape == data.parcel_axis.volume_shape + (3, )
+        assert np.isfinite(vol_image.data).sum() == np.sum(vol_mask != 0) * (3 if ndim == 2 else 1)
+        if ndim == 1:
+            testing.assert_equal(vol_mask != 0, np.isfinite(vol_image.data))
+            for idx in range(1, 5):
+                testing.assert_allclose(vol_image.data[vol_mask == idx], data.arr[..., idx - 1])
+        else:
+            for idx in range(3):
+                testing.assert_equal(vol_mask != 0, np.isfinite(vol_image.data[..., idx]))
+                for idx2 in range(1, 5):
+                    testing.assert_allclose(vol_image.data[vol_mask == idx2, idx], data.arr[idx, idx2 - 1])
+
+        # extract surface
+        mask, surf_data = data.surface('cortex', partial=True)
+        assert surf_data.shape[-1] == (surf_mask != 0).sum()
+        assert (surf_mask[mask] != 0).all()
+        print(data.arr)
+        for idx in range(1, 5):
+            if ndim == 1:
+                testing.assert_equal(surf_data.T[surf_mask[mask] == idx], data.arr[idx + 3])
+            else:
+                for idx2 in range(3):
+                    testing.assert_equal(surf_data.T[surf_mask[mask] == idx, idx2], data.arr[idx2, idx + 3])
+
+        surf_data_full = data.surface('cortex', partial=False)
+        assert surf_data_full.shape[-1] == 100
+        if ndim == 1:
+            testing.assert_equal(np.isfinite(surf_data_full), surf_mask != 0)
+            for idx in range(1, 5):
+                testing.assert_equal(surf_data_full.T[surf_mask == idx], data.arr[idx + 3])
+        else:
+            for idx2 in range(3):
+                testing.assert_equal(np.isfinite(surf_data_full)[idx2], (surf_mask != 0))
+                for idx in range(1, 5):
+                    testing.assert_equal(surf_data_full.T[surf_mask == idx, idx2], data.arr[idx2, idx + 3])
+
-- 
GitLab