From db21c20b3de5e145a56359dd3aef4a77c21620b0 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <MichielCottaar@protonmail.com>
Date: Fri, 1 May 2020 11:22:02 +0100
Subject: [PATCH] BUG: set undefined axes to None

---
 fsl/data/cifti.py | 88 +++++++++++++++++++++++++++++------------------
 1 file changed, 54 insertions(+), 34 deletions(-)

diff --git a/fsl/data/cifti.py b/fsl/data/cifti.py
index 5e6e97170..d615fb87a 100644
--- a/fsl/data/cifti.py
+++ b/fsl/data/cifti.py
@@ -7,7 +7,7 @@ The data can be read from NIFTI, GIFTI, or CIFTI files.
 Non-sparse volumetric or surface representations can be extracte.
 """
 from nibabel.cifti2 import cifti2_axes
-from typing import Sequence
+from typing import Sequence, Optional
 import numpy as np
 from fsl.data.image import Image
 import nibabel as nib
@@ -40,45 +40,59 @@ class Cifti:
     - :py:class:`BrainModelAxis <cifti2_axes.BrainModelAxis>`
     - :py:class:`ParcelsAxis <cifti2_axes.ParcelsAxis>`
     """
-    def __init__(self, arr: np.ndarray, axes: Sequence[cifti2_axes.Axis]):
+    def __init__(self, arr: np.ndarray, axes: Sequence[Optional[cifti2_axes.Axis], ...]):
         """
         Defines a new dataset in greyordinate space
 
-        :param data: (..., N) array for N greyordinates or parcels
+        :param data: (..., N) array for N greyordinates or parcels; can contain Nones for undefined axes
         :param axes: sequence of CIFTI axes describing the data along each dimension
         """
         self.arr = arr
+        axes = tuple(axes)
+        while self.arr.ndim > len(axes):
+            axes = (None, ) + axes
         self.axes = axes
-        if arr.shape[-len(axes):] != tuple(len(ax) for ax in axes):
-            raise ValueError(f"Shape of axes {tuple(len(ax) for ax in axes)} does not match shape of array {self.arr.shape}")
+        if not all(ax is None or len(ax) == sz for ax, sz in zip(axes, self.arr.shape)):
+            raise ValueError(f"Shape of axes {tuple(-1 if ax is None else len(ax) for ax in axes)} does not "
+                             f"match shape of array {self.arr.shape}")
 
-    def to_cifti(self, other_axes=None):
+    def to_cifti(self, default_axis=None):
         """
         Create a CIFTI image from the data
 
-        :param other_axes: overwrites the :mod:`cifti2_axes` to be used to write to create the CIFTI image
+        :param default_axis: What to use as an axis along any undefined dimensions
+
+            - By default an error is raised
+            - if set to "scalar" a ScalarAxis is used with names of "default {index}"
+            - if set to "series" a SeriesAxis is used
+
         :return: nibabel CIFTI image
         """
-        if other_axes is None:
-            if len(self.axes) != self.data.ndim:
-                raise ValueError("Can not store to CIFTI without defining what is stored along the other dimensions")
-            other_axes = self.axes[:-1]
+        if any(ax is None for ax in self.axes):
+            if default_axis is None:
+                raise ValueError("Can not store to CIFTI without defining what is stored along each dimension")
+            elif default_axis == 'scalar':
+                def get_axis(n: int):
+                    return cifti2_axes.ScalarAxis([f'default {idx + 1}' for idx in range(n)])
+            elif default_axis == 'series':
+                def get_axis(n: int):
+                    return cifti2_axes.SeriesAxis(0, 1, n)
+            else:
+                raise ValueError(f"default_axis should be set to None, 'scalar', or 'series', not {default_axis}")
+            new_axes = [
+                get_axis(sz) if ax is None else ax
+                for ax, sz in zip(self.axes, self.arr.shape)
+            ]
         else:
-            if len(other_axes) != self.data.ndim - 1:
-                raise ValueError("Number of axis does not match dimensionality of the data")
-            if tuple(len(ax) for ax in other_axes) != self.data.shape[:-1]:
-                raise ValueError("Size of other axes does not match data size")
+            new_axes = list(self.axes)
 
         data = self.data
         if data.ndim == 1:
             # CIFTI axes are always at least 2D
             data = data[None, :]
-            other_axes = [cifti2_axes.ScalarAxis(['default'])]
+            new_axes.insert(0, cifti2_axes.ScalarAxis(['default']))
 
-        return cifti2_axes.Cifti2Image(
-            data,
-            header=list(other_axes) + [self.axes[-1]]
-        )
+        return cifti2_axes.Cifti2Image(data, header=new_axes)
 
     @classmethod
     def from_cifti(cls, filename, writable=False):
@@ -110,25 +124,30 @@ class Cifti:
             return ParcelCifti(data, axes)
         raise ValueError("Last axis of CIFTI object should be a BrainModelAxis or ParcelsAxis")
 
-    def write(self, cifti_filename, other_axes=None):
+    def save(self, cifti_filename, default_axis=None):
         """
         Writes this sparse representation to/from a filename
 
         :param cifti_filename: output filename
-        :param other_axes: overwrites the :mod:`cifti2_axes` to be used to write to the file
+        :param default_axis: What to use as an axis along any undefined dimensions
+
+            - By default an error is raised
+            - if set to "scalar" a ScalarAxis is used with names of "default {index}"
+            - if set to "series" a SeriesAxis is used
         :return:
         """
-        self.to_cifti(other_axes).to_filename(addExt(cifti_filename, defaultExt=self.extension))
+        self.to_cifti(default_axis).to_filename(addExt(cifti_filename, defaultExt=self.extension))
 
-    def read(cls, filename, mask_values=(0, np.nan), writable=False):
+    @classmethod
+    def load(cls, filename, mask_values=(0, np.nan), writable=False):
         """
         Reads greyordinate data from the given file
 
         File can be:
 
-        - NIFTI mask
-        - GIFTI mask
-        - CIFTI file
+            - NIFTI mask
+            - GIFTI mask
+            - CIFTI file
 
         :param filename: input filename
         :param mask_values: which values are outside of the mask for NIFTI or GIFTI input
@@ -140,17 +159,19 @@ class Cifti:
         else:
             img = filename
 
-        if isinstance(img, nib.Nifti1Image):
-            if writable:
-                raise ValueError("Can not open NIFTI file in writable mode")
-            return cls.from_image(Image(img), mask_values)
         if isinstance(img, nib.Cifti2Image):
             return cls.from_cifti(img, writable=writable)
         if isinstance(img, nib.GiftiImage):
             if writable:
                 raise ValueError("Can not open GIFTI file in writable mode")
             return cls.from_gifti(img, mask_values)
-        raise ValueError(f"I do not know how to convert {type(img)} into greyordinates (from {filename})")
+        try:
+            vol_img = Image(img)
+        except ValueError:
+            raise ValueError(f"I do not know how to convert {type(img)} into greyordinates (from {filename})")
+        if writable:
+            raise ValueError("Can not open NIFTI file in writable mode")
+        return cls.from_image(vol_img, mask_values)
 
     @classmethod
     def from_gifti(cls, filename, mask_values=(0, np.nan)):
@@ -191,7 +212,7 @@ class Cifti:
         """
         Creates a new greyordinate object from a NIFTI file
 
-        :param filename: NIFTI filename or Image object
+        :param image: FSL :class:`image.Image` object
         :param mask_values: which values to mask out
         :return: greyordinate object representing the unmasked voxels
         """
@@ -211,7 +232,6 @@ class Cifti:
         return cifti2_axes.GreyOrdinates(inverted_data, [bm_axes])
 
 
-
 class DenseCifti(Cifti):
     """
     Represents sparse data defined for a subset of voxels and vertices (i.e., greyordinates)
-- 
GitLab