From db21c20b3de5e145a56359dd3aef4a77c21620b0 Mon Sep 17 00:00:00 2001 From: Michiel Cottaar <MichielCottaar@protonmail.com> Date: Fri, 1 May 2020 11:22:02 +0100 Subject: [PATCH] BUG: set undefined axes to None --- fsl/data/cifti.py | 88 +++++++++++++++++++++++++++++------------------ 1 file changed, 54 insertions(+), 34 deletions(-) diff --git a/fsl/data/cifti.py b/fsl/data/cifti.py index 5e6e97170..d615fb87a 100644 --- a/fsl/data/cifti.py +++ b/fsl/data/cifti.py @@ -7,7 +7,7 @@ The data can be read from NIFTI, GIFTI, or CIFTI files. Non-sparse volumetric or surface representations can be extracte. """ from nibabel.cifti2 import cifti2_axes -from typing import Sequence +from typing import Sequence, Optional import numpy as np from fsl.data.image import Image import nibabel as nib @@ -40,45 +40,59 @@ class Cifti: - :py:class:`BrainModelAxis <cifti2_axes.BrainModelAxis>` - :py:class:`ParcelsAxis <cifti2_axes.ParcelsAxis>` """ - def __init__(self, arr: np.ndarray, axes: Sequence[cifti2_axes.Axis]): + def __init__(self, arr: np.ndarray, axes: Sequence[Optional[cifti2_axes.Axis], ...]): """ Defines a new dataset in greyordinate space - :param data: (..., N) array for N greyordinates or parcels + :param data: (..., N) array for N greyordinates or parcels; can contain Nones for undefined axes :param axes: sequence of CIFTI axes describing the data along each dimension """ self.arr = arr + axes = tuple(axes) + while self.arr.ndim > len(axes): + axes = (None, ) + axes self.axes = axes - if arr.shape[-len(axes):] != tuple(len(ax) for ax in axes): - raise ValueError(f"Shape of axes {tuple(len(ax) for ax in axes)} does not match shape of array {self.arr.shape}") + if not all(ax is None or len(ax) == sz for ax, sz in zip(axes, self.arr.shape)): + raise ValueError(f"Shape of axes {tuple(-1 if ax is None else len(ax) for ax in axes)} does not " + f"match shape of array {self.arr.shape}") - def to_cifti(self, other_axes=None): + def to_cifti(self, default_axis=None): """ Create a CIFTI image from the data - :param other_axes: overwrites the :mod:`cifti2_axes` to be used to write to create the CIFTI image + :param default_axis: What to use as an axis along any undefined dimensions + + - By default an error is raised + - if set to "scalar" a ScalarAxis is used with names of "default {index}" + - if set to "series" a SeriesAxis is used + :return: nibabel CIFTI image """ - if other_axes is None: - if len(self.axes) != self.data.ndim: - raise ValueError("Can not store to CIFTI without defining what is stored along the other dimensions") - other_axes = self.axes[:-1] + if any(ax is None for ax in self.axes): + if default_axis is None: + raise ValueError("Can not store to CIFTI without defining what is stored along each dimension") + elif default_axis == 'scalar': + def get_axis(n: int): + return cifti2_axes.ScalarAxis([f'default {idx + 1}' for idx in range(n)]) + elif default_axis == 'series': + def get_axis(n: int): + return cifti2_axes.SeriesAxis(0, 1, n) + else: + raise ValueError(f"default_axis should be set to None, 'scalar', or 'series', not {default_axis}") + new_axes = [ + get_axis(sz) if ax is None else ax + for ax, sz in zip(self.axes, self.arr.shape) + ] else: - if len(other_axes) != self.data.ndim - 1: - raise ValueError("Number of axis does not match dimensionality of the data") - if tuple(len(ax) for ax in other_axes) != self.data.shape[:-1]: - raise ValueError("Size of other axes does not match data size") + new_axes = list(self.axes) data = self.data if data.ndim == 1: # CIFTI axes are always at least 2D data = data[None, :] - other_axes = [cifti2_axes.ScalarAxis(['default'])] + new_axes.insert(0, cifti2_axes.ScalarAxis(['default'])) - return cifti2_axes.Cifti2Image( - data, - header=list(other_axes) + [self.axes[-1]] - ) + return cifti2_axes.Cifti2Image(data, header=new_axes) @classmethod def from_cifti(cls, filename, writable=False): @@ -110,25 +124,30 @@ class Cifti: return ParcelCifti(data, axes) raise ValueError("Last axis of CIFTI object should be a BrainModelAxis or ParcelsAxis") - def write(self, cifti_filename, other_axes=None): + def save(self, cifti_filename, default_axis=None): """ Writes this sparse representation to/from a filename :param cifti_filename: output filename - :param other_axes: overwrites the :mod:`cifti2_axes` to be used to write to the file + :param default_axis: What to use as an axis along any undefined dimensions + + - By default an error is raised + - if set to "scalar" a ScalarAxis is used with names of "default {index}" + - if set to "series" a SeriesAxis is used :return: """ - self.to_cifti(other_axes).to_filename(addExt(cifti_filename, defaultExt=self.extension)) + self.to_cifti(default_axis).to_filename(addExt(cifti_filename, defaultExt=self.extension)) - def read(cls, filename, mask_values=(0, np.nan), writable=False): + @classmethod + def load(cls, filename, mask_values=(0, np.nan), writable=False): """ Reads greyordinate data from the given file File can be: - - NIFTI mask - - GIFTI mask - - CIFTI file + - NIFTI mask + - GIFTI mask + - CIFTI file :param filename: input filename :param mask_values: which values are outside of the mask for NIFTI or GIFTI input @@ -140,17 +159,19 @@ class Cifti: else: img = filename - if isinstance(img, nib.Nifti1Image): - if writable: - raise ValueError("Can not open NIFTI file in writable mode") - return cls.from_image(Image(img), mask_values) if isinstance(img, nib.Cifti2Image): return cls.from_cifti(img, writable=writable) if isinstance(img, nib.GiftiImage): if writable: raise ValueError("Can not open GIFTI file in writable mode") return cls.from_gifti(img, mask_values) - raise ValueError(f"I do not know how to convert {type(img)} into greyordinates (from {filename})") + try: + vol_img = Image(img) + except ValueError: + raise ValueError(f"I do not know how to convert {type(img)} into greyordinates (from {filename})") + if writable: + raise ValueError("Can not open NIFTI file in writable mode") + return cls.from_image(vol_img, mask_values) @classmethod def from_gifti(cls, filename, mask_values=(0, np.nan)): @@ -191,7 +212,7 @@ class Cifti: """ Creates a new greyordinate object from a NIFTI file - :param filename: NIFTI filename or Image object + :param image: FSL :class:`image.Image` object :param mask_values: which values to mask out :return: greyordinate object representing the unmasked voxels """ @@ -211,7 +232,6 @@ class Cifti: return cifti2_axes.GreyOrdinates(inverted_data, [bm_axes]) - class DenseCifti(Cifti): """ Represents sparse data defined for a subset of voxels and vertices (i.e., greyordinates) -- GitLab