Skip to content
Snippets Groups Projects
Commit bca18bf7 authored by Michiel Cottaar's avatar Michiel Cottaar
Browse files

BUG: test IO for CIFTI

parent 6910c699
No related branches found
No related tags found
No related merge requests found
......@@ -86,13 +86,13 @@ class Cifti:
else:
new_axes = list(self.axes)
data = self.data
data = self.arr
if data.ndim == 1:
# CIFTI axes are always at least 2D
data = data[None, :]
new_axes.insert(0, cifti2_axes.ScalarAxis(['default']))
return cifti2_axes.Cifti2Image(data, header=new_axes)
return nib.Cifti2Image(data, header=new_axes)
@classmethod
def from_cifti(cls, filename, writable=False):
......@@ -136,7 +136,7 @@ class Cifti:
- if set to "series" a SeriesAxis is used
:return:
"""
self.to_cifti(default_axis).to_filename(addExt(cifti_filename, defaultExt=self.extension))
self.to_cifti(default_axis).to_filename(addExt(cifti_filename, defaultExt=self.extension, mustExist=False))
@classmethod
def from_gifti(cls, filename, mask_values=(0, np.nan)):
......@@ -214,6 +214,8 @@ class DenseCifti(Cifti):
@property
def extension(self, ):
if self.arr.ndim == 1:
return dense_extensions[cifti2_axes.ScalarAxis]
return dense_extensions[type(self.axes[-2])]
def to_image(self, fill=0) -> image.Image:
......@@ -268,12 +270,14 @@ class ParcelCifti(Cifti):
Represents sparse data defined at specific parcels
"""
def __init__(self, *args, **kwargs):
super().__init__(self, *args, **kwargs)
if not isinstance(self.parcel_axis, cifti2_axes.BrainModelAxis):
super().__init__(*args, **kwargs)
if not isinstance(self.parcel_axis, cifti2_axes.ParcelsAxis):
raise ValueError(f"ParcelCifti expects a ParcelsAxis as last axes object, not {type(self.parcel_axis)}")
@property
def extension(self, ):
if self.arr.ndim == 1:
return parcel_extensions[cifti2_axes.ScalarAxis]
return parcel_extensions[type(self.axes[-2])]
@property
......
import pytest
from fsl.data import cifti
import os.path as op
import numpy as np
import nibabel as nib
from numpy import testing
import tests
from nibabel.cifti2 import cifti2_axes
def volumetric_brain_model():
mask = np.random.randint(2, size=(10, 10, 10)) > 0
return cifti2_axes.BrainModelAxis.from_mask(mask, affine=np.eye(4))
def surface_brain_model():
mask = np.random.randint(2, size=100) > 0
return cifti2_axes.BrainModelAxis.from_mask(mask, name='cortex')
def volumetric_parcels():
mask = np.random.randint(5, size=(10, 10, 10))
return cifti2_axes.ParcelsAxis(
[f'vol_{idx}' for idx in range(1, 5)],
voxels=[np.stack(np.where(mask == idx), axis=-1) for idx in range(1, 5)],
vertices=[{} for _ in range(1, 5)],
)
def surface_parcels():
mask = np.random.randint(5, size=100)
return cifti2_axes.ParcelsAxis(
[f'surf_{idx}' for idx in range(1, 5)],
voxels=[np.zeros((0, 3), dtype=int) for _ in range(1, 5)],
vertices=[{'CIFTI_STRUCTURE_CORTEX': np.where(mask == idx)[0]} for idx in range(1, 5)],
nvertices={'CIFTI_STRUCTURE_CORTEX': 100},
)
def gen_data(axes):
return np.random.randn(*(5 if ax is None else len(ax) for ax in axes))
def test_read_gifti():
......@@ -45,3 +78,53 @@ def test_read_nifti():
testing.assert_equal(data.arr, values[mask])
testing.assert_allclose(data.brain_model_axis.affine, affine)
assert len(data.brain_model_axis.nvertices) == 0
def check_io(data: cifti.Cifti, extension):
with tests.testdir():
data.save("test")
assert op.isfile(f'test.{extension}.nii')
loaded = cifti.load("test")
if data.arr.ndim == 1:
testing.assert_equal(data.arr, loaded.arr[0])
assert data.axes == loaded.axes[1:]
else:
testing.assert_equal(data.arr, loaded.arr)
assert data.axes == loaded.axes
def test_io_cifti():
for cifti_class, cifti_type, main_axis_options in (
(cifti.DenseCifti, 'd', (volumetric_brain_model(), surface_brain_model(),
volumetric_brain_model() + surface_brain_model())),
(cifti.ParcelCifti, 'p', (volumetric_parcels(), surface_parcels(),
volumetric_parcels() + surface_parcels())),
):
for main_axis in main_axis_options:
with tests.testdir():
data_1d = cifti_class(gen_data([main_axis]), [main_axis])
check_io(data_1d, f'{cifti_type}scalar')
connectome = cifti_class(gen_data([main_axis, main_axis]), (main_axis, main_axis))
check_io(connectome, f'{cifti_type}conn')
scalar_axis = cifti2_axes.ScalarAxis(['A', 'B', 'C'])
scalar = cifti_class(gen_data([scalar_axis, main_axis]), (scalar_axis, main_axis))
check_io(scalar, f'{cifti_type}scalar')
label_axis = cifti2_axes.LabelAxis(['A', 'B', 'C'], {1: ('some parcel', (1, 0, 0, 1))})
label = cifti_class(gen_data([label_axis, main_axis]), (label_axis, main_axis))
check_io(label, f'{cifti_type}label')
series_axis = cifti2_axes.SeriesAxis(10, 3, 50, unit='HERTZ')
series = cifti_class(gen_data([series_axis, main_axis]), (series_axis, main_axis))
check_io(series, f'{cifti_type}tseries')
if cifti_type == 'd':
parcel_axis = surface_parcels()
dpconn = cifti_class(gen_data([parcel_axis, main_axis]), (parcel_axis, main_axis))
check_io(dpconn, 'dpconn')
else:
dense_axis = surface_brain_model()
pdconn = cifti_class(gen_data([dense_axis, main_axis]), (dense_axis, main_axis))
check_io(pdconn, 'pdconn')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment