Skip to content
Snippets Groups Projects
Commit 4cd6c30b authored by Michiel Cottaar's avatar Michiel Cottaar
Browse files

BUG: add tests for extracting images/surface data

parent bca18bf7
No related branches found
No related tags found
No related merge requests found
...@@ -224,10 +224,10 @@ class DenseCifti(Cifti): ...@@ -224,10 +224,10 @@ class DenseCifti(Cifti):
""" """
if self.brain_model_axis.volume_mask.sum() == 0: if self.brain_model_axis.volume_mask.sum() == 0:
raise ValueError(f"Can not create volume without voxels in {self}") raise ValueError(f"Can not create volume without voxels in {self}")
data = np.full(self.brain_model_axis.volume_shape + self.data.shape[:-1], fill, data = np.full(self.brain_model_axis.volume_shape + self.arr.shape[:-1], fill,
dtype=self.data.dtype) dtype=self.arr.dtype)
voxels = self.brain_model_axis.voxel[self.brain_model_axis.volume_mask] voxels = self.brain_model_axis.voxel[self.brain_model_axis.volume_mask]
data[tuple(voxels.T)] = np.transpose(self.data, (-1,) + tuple(range(self.data.ndim - 1)))[ data[tuple(voxels.T)] = np.transpose(self.arr, (-1,) + tuple(range(self.arr.ndim - 1)))[
self.brain_model_axis.volume_mask] self.brain_model_axis.volume_mask]
return image.Image(data, xform=self.brain_model_axis.affine) return image.Image(data, xform=self.brain_model_axis.affine)
...@@ -249,8 +249,8 @@ class DenseCifti(Cifti): ...@@ -249,8 +249,8 @@ class DenseCifti(Cifti):
if anatomy.cifti not in self.brain_model_axis.name: if anatomy.cifti not in self.brain_model_axis.name:
raise ValueError(f"No surface data for {anatomy.cifti} found") raise ValueError(f"No surface data for {anatomy.cifti} found")
slc, bm = None, None slc, bm = None, None
arr = np.full(self.data.shape[:-1] + (self.brain_model_axis.nvertices[anatomy.cifti],), fill, arr = np.full(self.arr.shape[:-1] + (self.brain_model_axis.nvertices[anatomy.cifti],), fill,
dtype=self.data.dtype) dtype=self.arr.dtype)
for name, slc_try, bm_try in self.brain_model_axis.iter_structures(): for name, slc_try, bm_try in self.brain_model_axis.iter_structures():
if name == anatomy.cifti: if name == anatomy.cifti:
if partial: if partial:
...@@ -258,11 +258,11 @@ class DenseCifti(Cifti): ...@@ -258,11 +258,11 @@ class DenseCifti(Cifti):
raise ValueError(f"Surface {anatomy} does not form a contiguous block") raise ValueError(f"Surface {anatomy} does not form a contiguous block")
slc, bm = slc_try, bm_try slc, bm = slc_try, bm_try
else: else:
arr[..., bm_try.vertex] = self.data[..., slc_try] arr[..., bm_try.vertex] = self.arr[..., slc_try]
if not partial: if not partial:
return arr return arr
else: else:
return bm.vertex, self.data[..., slc] return bm.vertex, self.arr[..., slc]
class ParcelCifti(Cifti): class ParcelCifti(Cifti):
...@@ -290,14 +290,14 @@ class ParcelCifti(Cifti): ...@@ -290,14 +290,14 @@ class ParcelCifti(Cifti):
""" """
data = np.full(self.parcel_axis.volume_shape + self.arr.shape[:-1], fill, dtype=self.arr.dtype) data = np.full(self.parcel_axis.volume_shape + self.arr.shape[:-1], fill, dtype=self.arr.dtype)
written = np.zeros(self.parcel_axis.volume_shape, dtype='bool') written = np.zeros(self.parcel_axis.volume_shape, dtype='bool')
for idx, write_to in enumerate(self.parcel_axis).voxels: for idx, write_to in enumerate(self.parcel_axis.voxels):
if written[write_to].any(): if written[tuple(write_to.T)].any():
raise ValueError("Duplicate voxels in different parcels") raise ValueError("Duplicate voxels in different parcels")
data[write_to] = self.arr[np.newaxis, ..., idx] data[tuple(write_to.T)] = self.arr[np.newaxis, ..., idx]
written[write_to] = True written[tuple(write_to.T)] = True
if not written.any(): if not written.any():
raise ValueError("Parcellation does not contain any volumetric data") raise ValueError("Parcellation does not contain any volumetric data")
return image.Image(data, xform=self.brain_model_axis.affine) return image.Image(data, xform=self.parcel_axis.affine)
def surface(self, anatomy, fill=np.nan, partial=False): def surface(self, anatomy, fill=np.nan, partial=False):
""" """
...@@ -315,9 +315,9 @@ class ParcelCifti(Cifti): ...@@ -315,9 +315,9 @@ class ParcelCifti(Cifti):
if anatomy.cifti not in self.parcel_axis.nvertices: if anatomy.cifti not in self.parcel_axis.nvertices:
raise ValueError(f"No surface data for {anatomy.cifti} found") raise ValueError(f"No surface data for {anatomy.cifti} found")
arr = np.full(self.data.shape[:-1] + (self.parcel_axis.nvertices[anatomy.cifti],), fill, arr = np.full(self.arr.shape[:-1] + (self.parcel_axis.nvertices[anatomy.cifti],), fill,
dtype=self.data.dtype) dtype=self.arr.dtype)
written = np.zeros(self.parcel_axis.nvertices[anatomy.cifti]) written = np.zeros(self.parcel_axis.nvertices[anatomy.cifti], dtype='bool')
for idx, vertices in enumerate(self.parcel_axis.vertices): for idx, vertices in enumerate(self.parcel_axis.vertices):
if anatomy.cifti not in vertices: if anatomy.cifti not in vertices:
continue continue
......
...@@ -17,23 +17,33 @@ def surface_brain_model(): ...@@ -17,23 +17,33 @@ def surface_brain_model():
return cifti2_axes.BrainModelAxis.from_mask(mask, name='cortex') return cifti2_axes.BrainModelAxis.from_mask(mask, name='cortex')
def volumetric_parcels(): def volumetric_parcels(return_mask=False):
mask = np.random.randint(5, size=(10, 10, 10)) mask = np.random.randint(5, size=(10, 10, 10))
return cifti2_axes.ParcelsAxis( axis = cifti2_axes.ParcelsAxis(
[f'vol_{idx}' for idx in range(1, 5)], [f'vol_{idx}' for idx in range(1, 5)],
voxels=[np.stack(np.where(mask == idx), axis=-1) for idx in range(1, 5)], voxels=[np.stack(np.where(mask == idx), axis=-1) for idx in range(1, 5)],
vertices=[{} for _ in range(1, 5)], vertices=[{} for _ in range(1, 5)],
volume_shape=mask.shape,
affine=np.eye(4),
) )
if return_mask:
return axis, mask
else:
return axis
def surface_parcels(): def surface_parcels(return_mask=False):
mask = np.random.randint(5, size=100) mask = np.random.randint(5, size=100)
return cifti2_axes.ParcelsAxis( axis = cifti2_axes.ParcelsAxis(
[f'surf_{idx}' for idx in range(1, 5)], [f'surf_{idx}' for idx in range(1, 5)],
voxels=[np.zeros((0, 3), dtype=int) for _ in range(1, 5)], voxels=[np.zeros((0, 3), dtype=int) for _ in range(1, 5)],
vertices=[{'CIFTI_STRUCTURE_CORTEX': np.where(mask == idx)[0]} for idx in range(1, 5)], vertices=[{'CIFTI_STRUCTURE_CORTEX': np.where(mask == idx)[0]} for idx in range(1, 5)],
nvertices={'CIFTI_STRUCTURE_CORTEX': 100}, nvertices={'CIFTI_STRUCTURE_CORTEX': 100},
) )
if return_mask:
return axis, mask
else:
return axis
def gen_data(axes): def gen_data(axes):
...@@ -128,3 +138,102 @@ def test_io_cifti(): ...@@ -128,3 +138,102 @@ def test_io_cifti():
dense_axis = surface_brain_model() dense_axis = surface_brain_model()
pdconn = cifti_class(gen_data([dense_axis, main_axis]), (dense_axis, main_axis)) pdconn = cifti_class(gen_data([dense_axis, main_axis]), (dense_axis, main_axis))
check_io(pdconn, 'pdconn') check_io(pdconn, 'pdconn')
def test_extract_dense():
vol_bm = volumetric_brain_model()
surf_bm = surface_brain_model()
for bm in (vol_bm + surf_bm, surf_bm + vol_bm):
for ndim, no_other_axis in ((1, True), (2, False), (2, True)):
if ndim == 1:
data = cifti.DenseCifti(gen_data([bm]), [bm])
else:
scl = cifti2_axes.ScalarAxis(['A', 'B', 'C'])
data = cifti.DenseCifti(gen_data([scl, bm]),
[None if no_other_axis else scl, bm])
# extract volume
ref_arr = data.arr[..., data.brain_model_axis.volume_mask]
vol_image = data.to_image(fill=np.nan)
if ndim == 1:
assert vol_image.shape == data.brain_model_axis.volume_shape
else:
assert vol_image.shape == data.brain_model_axis.volume_shape + (3, )
assert np.isfinite(vol_image.data).sum() == len(vol_bm) * (3 if ndim == 2 else 1)
testing.assert_equal(vol_image.data[tuple(vol_bm.voxel.T)], ref_arr.T)
from_image = cifti.DenseCifti.from_image(vol_image)
assert from_image.brain_model_axis == vol_bm
testing.assert_equal(from_image.arr, ref_arr)
# extract surface
ref_arr = data.arr[..., data.brain_model_axis.surface_mask]
mask, surf_data = data.surface('cortex', partial=True)
assert surf_data.shape[-1] < 100
testing.assert_equal(ref_arr, surf_data)
testing.assert_equal(surf_bm.vertex, mask)
surf_data_full = data.surface('cortex', fill=np.nan)
assert surf_data_full.shape[-1] == 100
mask_full = np.isfinite(surf_data_full)
if ndim == 2:
assert (mask_full.any(0) == mask_full.all(0)).all()
mask_full = mask_full[0]
assert mask_full.sum() == len(surf_bm)
assert mask_full[..., mask].sum() == len(surf_bm)
testing.assert_equal(surf_data_full[..., mask_full], ref_arr)
def test_extract_parcel():
vol_parcel, vol_mask = volumetric_parcels(return_mask=True)
surf_parcel, surf_mask = surface_parcels(return_mask=True)
parcel = vol_parcel + surf_parcel
for ndim, no_other_axis in ((1, True), (2, False), (2, True)):
if ndim == 1:
data = cifti.ParcelCifti(gen_data([parcel]), [parcel])
else:
scl = cifti2_axes.ScalarAxis(['A', 'B', 'C'])
data = cifti.ParcelCifti(gen_data([scl, parcel]),
[None if no_other_axis else scl, parcel])
# extract volume
vol_image = data.to_image(fill=np.nan)
if ndim == 1:
assert vol_image.shape == data.parcel_axis.volume_shape
else:
assert vol_image.shape == data.parcel_axis.volume_shape + (3, )
assert np.isfinite(vol_image.data).sum() == np.sum(vol_mask != 0) * (3 if ndim == 2 else 1)
if ndim == 1:
testing.assert_equal(vol_mask != 0, np.isfinite(vol_image.data))
for idx in range(1, 5):
testing.assert_allclose(vol_image.data[vol_mask == idx], data.arr[..., idx - 1])
else:
for idx in range(3):
testing.assert_equal(vol_mask != 0, np.isfinite(vol_image.data[..., idx]))
for idx2 in range(1, 5):
testing.assert_allclose(vol_image.data[vol_mask == idx2, idx], data.arr[idx, idx2 - 1])
# extract surface
mask, surf_data = data.surface('cortex', partial=True)
assert surf_data.shape[-1] == (surf_mask != 0).sum()
assert (surf_mask[mask] != 0).all()
print(data.arr)
for idx in range(1, 5):
if ndim == 1:
testing.assert_equal(surf_data.T[surf_mask[mask] == idx], data.arr[idx + 3])
else:
for idx2 in range(3):
testing.assert_equal(surf_data.T[surf_mask[mask] == idx, idx2], data.arr[idx2, idx + 3])
surf_data_full = data.surface('cortex', partial=False)
assert surf_data_full.shape[-1] == 100
if ndim == 1:
testing.assert_equal(np.isfinite(surf_data_full), surf_mask != 0)
for idx in range(1, 5):
testing.assert_equal(surf_data_full.T[surf_mask == idx], data.arr[idx + 3])
else:
for idx2 in range(3):
testing.assert_equal(np.isfinite(surf_data_full)[idx2], (surf_mask != 0))
for idx in range(1, 5):
testing.assert_equal(surf_data_full.T[surf_mask == idx, idx2], data.arr[idx2, idx + 3])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment