From bca18bf7311bd68801c0261e9e21696536c46919 Mon Sep 17 00:00:00 2001 From: Michiel Cottaar <MichielCottaar@protonmail.com> Date: Fri, 1 May 2020 13:42:39 +0100 Subject: [PATCH] BUG: test IO for CIFTI --- fsl/data/cifti.py | 14 +++++--- tests/test_cifti.py | 85 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 93 insertions(+), 6 deletions(-) diff --git a/fsl/data/cifti.py b/fsl/data/cifti.py index 96b8507ad..eea43b665 100644 --- a/fsl/data/cifti.py +++ b/fsl/data/cifti.py @@ -86,13 +86,13 @@ class Cifti: else: new_axes = list(self.axes) - data = self.data + data = self.arr if data.ndim == 1: # CIFTI axes are always at least 2D data = data[None, :] new_axes.insert(0, cifti2_axes.ScalarAxis(['default'])) - return cifti2_axes.Cifti2Image(data, header=new_axes) + return nib.Cifti2Image(data, header=new_axes) @classmethod def from_cifti(cls, filename, writable=False): @@ -136,7 +136,7 @@ class Cifti: - if set to "series" a SeriesAxis is used :return: """ - self.to_cifti(default_axis).to_filename(addExt(cifti_filename, defaultExt=self.extension)) + self.to_cifti(default_axis).to_filename(addExt(cifti_filename, defaultExt=self.extension, mustExist=False)) @classmethod def from_gifti(cls, filename, mask_values=(0, np.nan)): @@ -214,6 +214,8 @@ class DenseCifti(Cifti): @property def extension(self, ): + if self.arr.ndim == 1: + return dense_extensions[cifti2_axes.ScalarAxis] return dense_extensions[type(self.axes[-2])] def to_image(self, fill=0) -> image.Image: @@ -268,12 +270,14 @@ class ParcelCifti(Cifti): Represents sparse data defined at specific parcels """ def __init__(self, *args, **kwargs): - super().__init__(self, *args, **kwargs) - if not isinstance(self.parcel_axis, cifti2_axes.BrainModelAxis): + super().__init__(*args, **kwargs) + if not isinstance(self.parcel_axis, cifti2_axes.ParcelsAxis): raise ValueError(f"ParcelCifti expects a ParcelsAxis as last axes object, not {type(self.parcel_axis)}") @property def extension(self, ): + if self.arr.ndim == 1: + return parcel_extensions[cifti2_axes.ScalarAxis] return parcel_extensions[type(self.axes[-2])] @property diff --git a/tests/test_cifti.py b/tests/test_cifti.py index f61afe393..67851a3e3 100644 --- a/tests/test_cifti.py +++ b/tests/test_cifti.py @@ -1,10 +1,43 @@ -import pytest from fsl.data import cifti import os.path as op import numpy as np import nibabel as nib from numpy import testing import tests +from nibabel.cifti2 import cifti2_axes + + +def volumetric_brain_model(): + mask = np.random.randint(2, size=(10, 10, 10)) > 0 + return cifti2_axes.BrainModelAxis.from_mask(mask, affine=np.eye(4)) + + +def surface_brain_model(): + mask = np.random.randint(2, size=100) > 0 + return cifti2_axes.BrainModelAxis.from_mask(mask, name='cortex') + + +def volumetric_parcels(): + mask = np.random.randint(5, size=(10, 10, 10)) + return cifti2_axes.ParcelsAxis( + [f'vol_{idx}' for idx in range(1, 5)], + voxels=[np.stack(np.where(mask == idx), axis=-1) for idx in range(1, 5)], + vertices=[{} for _ in range(1, 5)], + ) + + +def surface_parcels(): + mask = np.random.randint(5, size=100) + return cifti2_axes.ParcelsAxis( + [f'surf_{idx}' for idx in range(1, 5)], + voxels=[np.zeros((0, 3), dtype=int) for _ in range(1, 5)], + vertices=[{'CIFTI_STRUCTURE_CORTEX': np.where(mask == idx)[0]} for idx in range(1, 5)], + nvertices={'CIFTI_STRUCTURE_CORTEX': 100}, + ) + + +def gen_data(axes): + return np.random.randn(*(5 if ax is None else len(ax) for ax in axes)) def test_read_gifti(): @@ -45,3 +78,53 @@ def test_read_nifti(): testing.assert_equal(data.arr, values[mask]) testing.assert_allclose(data.brain_model_axis.affine, affine) assert len(data.brain_model_axis.nvertices) == 0 + + +def check_io(data: cifti.Cifti, extension): + with tests.testdir(): + data.save("test") + assert op.isfile(f'test.{extension}.nii') + loaded = cifti.load("test") + if data.arr.ndim == 1: + testing.assert_equal(data.arr, loaded.arr[0]) + assert data.axes == loaded.axes[1:] + else: + testing.assert_equal(data.arr, loaded.arr) + assert data.axes == loaded.axes + + +def test_io_cifti(): + for cifti_class, cifti_type, main_axis_options in ( + (cifti.DenseCifti, 'd', (volumetric_brain_model(), surface_brain_model(), + volumetric_brain_model() + surface_brain_model())), + (cifti.ParcelCifti, 'p', (volumetric_parcels(), surface_parcels(), + volumetric_parcels() + surface_parcels())), + ): + for main_axis in main_axis_options: + with tests.testdir(): + data_1d = cifti_class(gen_data([main_axis]), [main_axis]) + check_io(data_1d, f'{cifti_type}scalar') + + connectome = cifti_class(gen_data([main_axis, main_axis]), (main_axis, main_axis)) + check_io(connectome, f'{cifti_type}conn') + + scalar_axis = cifti2_axes.ScalarAxis(['A', 'B', 'C']) + scalar = cifti_class(gen_data([scalar_axis, main_axis]), (scalar_axis, main_axis)) + check_io(scalar, f'{cifti_type}scalar') + + label_axis = cifti2_axes.LabelAxis(['A', 'B', 'C'], {1: ('some parcel', (1, 0, 0, 1))}) + label = cifti_class(gen_data([label_axis, main_axis]), (label_axis, main_axis)) + check_io(label, f'{cifti_type}label') + + series_axis = cifti2_axes.SeriesAxis(10, 3, 50, unit='HERTZ') + series = cifti_class(gen_data([series_axis, main_axis]), (series_axis, main_axis)) + check_io(series, f'{cifti_type}tseries') + + if cifti_type == 'd': + parcel_axis = surface_parcels() + dpconn = cifti_class(gen_data([parcel_axis, main_axis]), (parcel_axis, main_axis)) + check_io(dpconn, 'dpconn') + else: + dense_axis = surface_brain_model() + pdconn = cifti_class(gen_data([dense_axis, main_axis]), (dense_axis, main_axis)) + check_io(pdconn, 'pdconn') -- GitLab