Skip to content
Snippets Groups Projects
Commit bca18bf7 authored by Michiel Cottaar's avatar Michiel Cottaar
Browse files

BUG: test IO for CIFTI

parent 6910c699
No related branches found
No related tags found
No related merge requests found
...@@ -86,13 +86,13 @@ class Cifti: ...@@ -86,13 +86,13 @@ class Cifti:
else: else:
new_axes = list(self.axes) new_axes = list(self.axes)
data = self.data data = self.arr
if data.ndim == 1: if data.ndim == 1:
# CIFTI axes are always at least 2D # CIFTI axes are always at least 2D
data = data[None, :] data = data[None, :]
new_axes.insert(0, cifti2_axes.ScalarAxis(['default'])) new_axes.insert(0, cifti2_axes.ScalarAxis(['default']))
return cifti2_axes.Cifti2Image(data, header=new_axes) return nib.Cifti2Image(data, header=new_axes)
@classmethod @classmethod
def from_cifti(cls, filename, writable=False): def from_cifti(cls, filename, writable=False):
...@@ -136,7 +136,7 @@ class Cifti: ...@@ -136,7 +136,7 @@ class Cifti:
- if set to "series" a SeriesAxis is used - if set to "series" a SeriesAxis is used
:return: :return:
""" """
self.to_cifti(default_axis).to_filename(addExt(cifti_filename, defaultExt=self.extension)) self.to_cifti(default_axis).to_filename(addExt(cifti_filename, defaultExt=self.extension, mustExist=False))
@classmethod @classmethod
def from_gifti(cls, filename, mask_values=(0, np.nan)): def from_gifti(cls, filename, mask_values=(0, np.nan)):
...@@ -214,6 +214,8 @@ class DenseCifti(Cifti): ...@@ -214,6 +214,8 @@ class DenseCifti(Cifti):
@property @property
def extension(self, ): def extension(self, ):
if self.arr.ndim == 1:
return dense_extensions[cifti2_axes.ScalarAxis]
return dense_extensions[type(self.axes[-2])] return dense_extensions[type(self.axes[-2])]
def to_image(self, fill=0) -> image.Image: def to_image(self, fill=0) -> image.Image:
...@@ -268,12 +270,14 @@ class ParcelCifti(Cifti): ...@@ -268,12 +270,14 @@ class ParcelCifti(Cifti):
Represents sparse data defined at specific parcels Represents sparse data defined at specific parcels
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(self, *args, **kwargs) super().__init__(*args, **kwargs)
if not isinstance(self.parcel_axis, cifti2_axes.BrainModelAxis): if not isinstance(self.parcel_axis, cifti2_axes.ParcelsAxis):
raise ValueError(f"ParcelCifti expects a ParcelsAxis as last axes object, not {type(self.parcel_axis)}") raise ValueError(f"ParcelCifti expects a ParcelsAxis as last axes object, not {type(self.parcel_axis)}")
@property @property
def extension(self, ): def extension(self, ):
if self.arr.ndim == 1:
return parcel_extensions[cifti2_axes.ScalarAxis]
return parcel_extensions[type(self.axes[-2])] return parcel_extensions[type(self.axes[-2])]
@property @property
......
import pytest
from fsl.data import cifti from fsl.data import cifti
import os.path as op import os.path as op
import numpy as np import numpy as np
import nibabel as nib import nibabel as nib
from numpy import testing from numpy import testing
import tests import tests
from nibabel.cifti2 import cifti2_axes
def volumetric_brain_model():
mask = np.random.randint(2, size=(10, 10, 10)) > 0
return cifti2_axes.BrainModelAxis.from_mask(mask, affine=np.eye(4))
def surface_brain_model():
mask = np.random.randint(2, size=100) > 0
return cifti2_axes.BrainModelAxis.from_mask(mask, name='cortex')
def volumetric_parcels():
mask = np.random.randint(5, size=(10, 10, 10))
return cifti2_axes.ParcelsAxis(
[f'vol_{idx}' for idx in range(1, 5)],
voxels=[np.stack(np.where(mask == idx), axis=-1) for idx in range(1, 5)],
vertices=[{} for _ in range(1, 5)],
)
def surface_parcels():
mask = np.random.randint(5, size=100)
return cifti2_axes.ParcelsAxis(
[f'surf_{idx}' for idx in range(1, 5)],
voxels=[np.zeros((0, 3), dtype=int) for _ in range(1, 5)],
vertices=[{'CIFTI_STRUCTURE_CORTEX': np.where(mask == idx)[0]} for idx in range(1, 5)],
nvertices={'CIFTI_STRUCTURE_CORTEX': 100},
)
def gen_data(axes):
return np.random.randn(*(5 if ax is None else len(ax) for ax in axes))
def test_read_gifti(): def test_read_gifti():
...@@ -45,3 +78,53 @@ def test_read_nifti(): ...@@ -45,3 +78,53 @@ def test_read_nifti():
testing.assert_equal(data.arr, values[mask]) testing.assert_equal(data.arr, values[mask])
testing.assert_allclose(data.brain_model_axis.affine, affine) testing.assert_allclose(data.brain_model_axis.affine, affine)
assert len(data.brain_model_axis.nvertices) == 0 assert len(data.brain_model_axis.nvertices) == 0
def check_io(data: cifti.Cifti, extension):
with tests.testdir():
data.save("test")
assert op.isfile(f'test.{extension}.nii')
loaded = cifti.load("test")
if data.arr.ndim == 1:
testing.assert_equal(data.arr, loaded.arr[0])
assert data.axes == loaded.axes[1:]
else:
testing.assert_equal(data.arr, loaded.arr)
assert data.axes == loaded.axes
def test_io_cifti():
for cifti_class, cifti_type, main_axis_options in (
(cifti.DenseCifti, 'd', (volumetric_brain_model(), surface_brain_model(),
volumetric_brain_model() + surface_brain_model())),
(cifti.ParcelCifti, 'p', (volumetric_parcels(), surface_parcels(),
volumetric_parcels() + surface_parcels())),
):
for main_axis in main_axis_options:
with tests.testdir():
data_1d = cifti_class(gen_data([main_axis]), [main_axis])
check_io(data_1d, f'{cifti_type}scalar')
connectome = cifti_class(gen_data([main_axis, main_axis]), (main_axis, main_axis))
check_io(connectome, f'{cifti_type}conn')
scalar_axis = cifti2_axes.ScalarAxis(['A', 'B', 'C'])
scalar = cifti_class(gen_data([scalar_axis, main_axis]), (scalar_axis, main_axis))
check_io(scalar, f'{cifti_type}scalar')
label_axis = cifti2_axes.LabelAxis(['A', 'B', 'C'], {1: ('some parcel', (1, 0, 0, 1))})
label = cifti_class(gen_data([label_axis, main_axis]), (label_axis, main_axis))
check_io(label, f'{cifti_type}label')
series_axis = cifti2_axes.SeriesAxis(10, 3, 50, unit='HERTZ')
series = cifti_class(gen_data([series_axis, main_axis]), (series_axis, main_axis))
check_io(series, f'{cifti_type}tseries')
if cifti_type == 'd':
parcel_axis = surface_parcels()
dpconn = cifti_class(gen_data([parcel_axis, main_axis]), (parcel_axis, main_axis))
check_io(dpconn, 'dpconn')
else:
dense_axis = surface_brain_model()
pdconn = cifti_class(gen_data([dense_axis, main_axis]), (dense_axis, main_axis))
check_io(pdconn, 'pdconn')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment