Commit db21c20b authored by Michiel Cottaar's avatar Michiel Cottaar
Browse files

BUG: set undefined axes to None

parent c95dacf1
......@@ -7,7 +7,7 @@ The data can be read from NIFTI, GIFTI, or CIFTI files.
Non-sparse volumetric or surface representations can be extracte.
"""
from nibabel.cifti2 import cifti2_axes
from typing import Sequence
from typing import Sequence, Optional
import numpy as np
from fsl.data.image import Image
import nibabel as nib
......@@ -40,45 +40,59 @@ class Cifti:
- :py:class:`BrainModelAxis <cifti2_axes.BrainModelAxis>`
- :py:class:`ParcelsAxis <cifti2_axes.ParcelsAxis>`
"""
def __init__(self, arr: np.ndarray, axes: Sequence[cifti2_axes.Axis]):
def __init__(self, arr: np.ndarray, axes: Sequence[Optional[cifti2_axes.Axis], ...]):
"""
Defines a new dataset in greyordinate space
:param data: (..., N) array for N greyordinates or parcels
:param data: (..., N) array for N greyordinates or parcels; can contain Nones for undefined axes
:param axes: sequence of CIFTI axes describing the data along each dimension
"""
self.arr = arr
axes = tuple(axes)
while self.arr.ndim > len(axes):
axes = (None, ) + axes
self.axes = axes
if arr.shape[-len(axes):] != tuple(len(ax) for ax in axes):
raise ValueError(f"Shape of axes {tuple(len(ax) for ax in axes)} does not match shape of array {self.arr.shape}")
if not all(ax is None or len(ax) == sz for ax, sz in zip(axes, self.arr.shape)):
raise ValueError(f"Shape of axes {tuple(-1 if ax is None else len(ax) for ax in axes)} does not "
f"match shape of array {self.arr.shape}")
def to_cifti(self, other_axes=None):
def to_cifti(self, default_axis=None):
"""
Create a CIFTI image from the data
:param other_axes: overwrites the :mod:`cifti2_axes` to be used to write to create the CIFTI image
:param default_axis: What to use as an axis along any undefined dimensions
- By default an error is raised
- if set to "scalar" a ScalarAxis is used with names of "default {index}"
- if set to "series" a SeriesAxis is used
:return: nibabel CIFTI image
"""
if other_axes is None:
if len(self.axes) != self.data.ndim:
raise ValueError("Can not store to CIFTI without defining what is stored along the other dimensions")
other_axes = self.axes[:-1]
if any(ax is None for ax in self.axes):
if default_axis is None:
raise ValueError("Can not store to CIFTI without defining what is stored along each dimension")
elif default_axis == 'scalar':
def get_axis(n: int):
return cifti2_axes.ScalarAxis([f'default {idx + 1}' for idx in range(n)])
elif default_axis == 'series':
def get_axis(n: int):
return cifti2_axes.SeriesAxis(0, 1, n)
else:
raise ValueError(f"default_axis should be set to None, 'scalar', or 'series', not {default_axis}")
new_axes = [
get_axis(sz) if ax is None else ax
for ax, sz in zip(self.axes, self.arr.shape)
]
else:
if len(other_axes) != self.data.ndim - 1:
raise ValueError("Number of axis does not match dimensionality of the data")
if tuple(len(ax) for ax in other_axes) != self.data.shape[:-1]:
raise ValueError("Size of other axes does not match data size")
new_axes = list(self.axes)
data = self.data
if data.ndim == 1:
# CIFTI axes are always at least 2D
data = data[None, :]
other_axes = [cifti2_axes.ScalarAxis(['default'])]
new_axes.insert(0, cifti2_axes.ScalarAxis(['default']))
return cifti2_axes.Cifti2Image(
data,
header=list(other_axes) + [self.axes[-1]]
)
return cifti2_axes.Cifti2Image(data, header=new_axes)
@classmethod
def from_cifti(cls, filename, writable=False):
......@@ -110,17 +124,22 @@ class Cifti:
return ParcelCifti(data, axes)
raise ValueError("Last axis of CIFTI object should be a BrainModelAxis or ParcelsAxis")
def write(self, cifti_filename, other_axes=None):
def save(self, cifti_filename, default_axis=None):
"""
Writes this sparse representation to/from a filename
:param cifti_filename: output filename
:param other_axes: overwrites the :mod:`cifti2_axes` to be used to write to the file
:param default_axis: What to use as an axis along any undefined dimensions
- By default an error is raised
- if set to "scalar" a ScalarAxis is used with names of "default {index}"
- if set to "series" a SeriesAxis is used
:return:
"""
self.to_cifti(other_axes).to_filename(addExt(cifti_filename, defaultExt=self.extension))
self.to_cifti(default_axis).to_filename(addExt(cifti_filename, defaultExt=self.extension))
def read(cls, filename, mask_values=(0, np.nan), writable=False):
@classmethod
def load(cls, filename, mask_values=(0, np.nan), writable=False):
"""
Reads greyordinate data from the given file
......@@ -140,17 +159,19 @@ class Cifti:
else:
img = filename
if isinstance(img, nib.Nifti1Image):
if writable:
raise ValueError("Can not open NIFTI file in writable mode")
return cls.from_image(Image(img), mask_values)
if isinstance(img, nib.Cifti2Image):
return cls.from_cifti(img, writable=writable)
if isinstance(img, nib.GiftiImage):
if writable:
raise ValueError("Can not open GIFTI file in writable mode")
return cls.from_gifti(img, mask_values)
try:
vol_img = Image(img)
except ValueError:
raise ValueError(f"I do not know how to convert {type(img)} into greyordinates (from {filename})")
if writable:
raise ValueError("Can not open NIFTI file in writable mode")
return cls.from_image(vol_img, mask_values)
@classmethod
def from_gifti(cls, filename, mask_values=(0, np.nan)):
......@@ -191,7 +212,7 @@ class Cifti:
"""
Creates a new greyordinate object from a NIFTI file
:param filename: NIFTI filename or Image object
:param image: FSL :class:`image.Image` object
:param mask_values: which values to mask out
:return: greyordinate object representing the unmasked voxels
"""
......@@ -211,7 +232,6 @@ class Cifti:
return cifti2_axes.GreyOrdinates(inverted_data, [bm_axes])
class DenseCifti(Cifti):
"""
Represents sparse data defined for a subset of voxels and vertices (i.e., greyordinates)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment