From 8660298fc44f6ccbb75d5debb054b113d3e927e2 Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauld.mccarthy@gmail.com>
Date: Tue, 13 Dec 2016 12:30:37 +0000
Subject: [PATCH] ImageWrapper bugfixes - more robust support for images with
 dodgy/misreported dimensionality

---
 fsl/data/imagewrapper.py | 223 ++++++++++++++++++++++++++++++---------
 1 file changed, 172 insertions(+), 51 deletions(-)

diff --git a/fsl/data/imagewrapper.py b/fsl/data/imagewrapper.py
index a3b7657fa..d9bacb798 100644
--- a/fsl/data/imagewrapper.py
+++ b/fsl/data/imagewrapper.py
@@ -15,21 +15,25 @@ Terminology
 There are some confusing terms used in this module, so it may be useful to
 get their definitions straight:
 
-  - *Coverage*:  The portion of an image that has been covered in the data
-                 range calculation. The ``ImageWrapper`` keeps track of
-                 the coverage for individual volumes within a 4D image (or
-                 slices in a 3D image).
-
-  - *Slice*:     Portion of the image data which is being accessed. A slice
-                 comprises either a tuple of ``slice`` objects (or integers),
-                 or a sequence of ``(low, high)`` tuples, specifying the
-                 index range into each image dimension that is covered by
-                 the slice.
-
-  - *Expansion*: A sequence of ``(low, high)`` tuples, specifying an
-                 index range into each image dimension, that is used to
-                 *expand* the *coverage* of an image, based on a given set of
-                 *slices*.
+  - *Coverage*:    The portion of an image that has been covered in the data
+                   range calculation. The ``ImageWrapper`` keeps track of
+                   the coverage for individual volumes within a 4D image (or
+                   slices in a 3D image).
+
+  - *Slice*:       Portion of the image data which is being accessed. A slice
+                   comprises either a tuple of ``slice`` objects (or integers),
+                   or a sequence of ``(low, high)`` tuples, specifying the
+                   index range into each image dimension that is covered by
+                   the slice.
+
+  - *Expansion*:   A sequence of ``(low, high)`` tuples, specifying an
+                   index range into each image dimension, that is used to
+                   *expand* the *coverage* of an image, based on a given set 
+                   of *slices*.
+
+  - *Fancy slice*: Any object which is used to slice an array, and is not
+                   an ``int``, ``slice``, or ``Ellipsis``, or sequence of
+                   these.
 """
 
 
@@ -86,6 +90,21 @@ class ImageWrapper(notifier.Notifier):
     need provide an index of 0 for that dimensions, for all data accesses.
 
 
+    *Data access*
+
+    
+    The ``ImageWrapper`` can be indexed in one of two ways:
+
+       - With basic ``numpy``-like multi-dimensional array slicing (with step
+         sizes of 1)
+    
+       - With boolean array indexing, where the boolean/mask array has the
+         same shape as the image data.
+
+    See https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html for
+    more details on numpy indexing.
+ 
+
     *Data range*
 
     
@@ -302,6 +321,15 @@ class ImageWrapper(notifier.Notifier):
         return self.__covered
 
 
+    @property
+    def shape(self):
+        """Returns the shape that the image data is presented as. This is
+        the same as the underlying image shape, but with trailing dimensions
+        of length 1 removed, and at least three dimensions.
+        """
+        return self.__canonicalShape
+
+
     def coverage(self, vol):
         """Returns the current image data coverage for the specified volume
         (for a 4D image, slice for a 3D image, or vector for a 2D images).
@@ -547,52 +575,53 @@ class ImageWrapper(notifier.Notifier):
 
         :arg sliceobj: Something which can slice the image data.
         """
-
+        
         log.debug('Getting image data: {}'.format(sliceobj))
         
-        image = self.__image
-        shape = image.shape
-        ndims = len(shape)
-
-        fancy = isValidFancySliceObj(sliceobj, shape)
+        shape              = self.__canonicalShape
+        realShape          = self.__image.shape
+        sliceobj           = canonicalSliceObj(   sliceobj, shape)
+        fancy              = isValidFancySliceObj(sliceobj, shape)
+        expNdims, expShape = expectedShape(       sliceobj, shape)
 
-        if fancy:
-            expNdims = ndims
-        else:
-            
-            if not isinstance(sliceobj, tuple):
-                sliceobj = (sliceobj,)
-
-            # Figure out the number of dimensions
-            # that the result should have, given
-            # this slice object.
-            expNdims = len(self.__canonicalShape) - \
-                       len([s for s in sliceobj if isinstance(s, int)])
-
-            # Truncate some dimensions from the
-            # slice object if it has too many
-            # (e.g. trailing dims of length 1).
-            if len(sliceobj) > ndims:
-                sliceobj = sliceobj[:ndims]
-                
         # TODO Cache 3D images for large 4D volumes, 
         #      so you don't have to hit the disk?
 
-        sliceobj = canonicalSliceObj(sliceobj, shape)
+        # Make the slice object compatible with the
+        # actual image shape, and retrieve the data.
+        sliceobj = canonicalSliceObj(sliceobj, realShape)
         data     = self.__getData(sliceobj)
 
+        # Update data range for the 
+        # data that we just read in
         if not self.__covered:
 
-            slices = sliceObjToSliceTuple(sliceobj, shape)
+            slices = sliceObjToSliceTuple(sliceobj, realShape)
 
             if not sliceCovered(slices, self.__coverage):
                 self.__updateDataRangeOnRead(slices, data)
 
         # Make sure that the result has the
         # shape that the caller is expecting.
-        if not fancy and ndims != expNdims:
-            data = data.reshape(list(data.shape) + [1] * (expNdims - ndims))
-
+        if fancy: data = data.reshape((data.size, ))
+        else:     data = data.reshape(expShape)
+
+        # If expNdims == 0, we should
+        # return a scalar. If expNdims
+        # == 0, but data.size != 1,
+        # something is wrong somewhere
+        # (and is not being handled
+        # here).
+        if expNdims == 0 and data.size == 1:
+
+            # Funny behaviour with numpy scalar arrays.
+            # data[()] returns a numpy scalar (which is
+            # what we want). But data.item() returns a
+            # python scalar. And if the data is a
+            # ndarray with 0 dims, data[0] will raise
+            # an error!
+            data = data[()]
+                
         return data
 
 
@@ -608,8 +637,37 @@ class ImageWrapper(notifier.Notifier):
                   loaded into memory. 
         """
 
-        sliceobj = canonicalSliceObj(   sliceobj, self.__image.shape)
-        slices   = sliceObjToSliceTuple(sliceobj, self.__image.shape)
+        realShape = self.__image.shape
+        sliceobj  = canonicalSliceObj(   sliceobj, realShape)
+        slices    = sliceObjToSliceTuple(sliceobj, realShape)
+
+        # If the image shape does not match its
+        # 'display' shape (either less three
+        # dims, or  has trailing dims of length
+        # 1), we might need to re-shape the
+        # values to prevent numpy from raising
+        # an error in the assignment below.
+        if realShape != self.__canonicalShape:
+            
+            expNdims, expShape = expectedShape(sliceobj, realShape)
+
+            # If we are slicing a scalar, the
+            # assigned value has to be scalar.
+            if expNdims == 0 and isinstance(values, collections.Sequence):
+
+                if len(values) > 1:
+                    raise IndexError('Invalid assignment: [{}] = {}'.format(
+                        sliceobj, len(values)))
+        
+                values = values[0]
+
+            # Make sure that the values 
+            # have a compatible shape.
+            else:
+                
+                values = np.array(values)
+                if values.shape != expShape:
+                    values = values.reshape(expShape)
 
         # The image data has to be in memory
         # for the data to be changed. If it's
@@ -681,7 +739,7 @@ def isValidFancySliceObj(sliceobj, shape):
     # which have the same shape as the image
     return (isinstance(sliceobj, np.ndarray) and
             sliceobj.dtype == np.bool        and
-            sliceobj.shape == shape)
+            np.prod(sliceobj.shape) == np.prod(shape))
 
 
 def canonicalSliceObj(sliceobj, shape):
@@ -689,10 +747,20 @@ def canonicalSliceObj(sliceobj, shape):
     ``nibabel.fileslice.canonical_slicers` function.
     """
 
-    if not isValidFancySliceObj(sliceobj, shape):
-        sliceobj = nib.fileslice.canonical_slicers(sliceobj, shape)
+    # Fancy slice objects must have 
+    # the same shape as the data
+    if isValidFancySliceObj(sliceobj, shape):
+        return sliceobj.reshape(shape)
 
-    return sliceobj
+    else:
+
+        if not isinstance(sliceobj, tuple):
+            sliceobj = (sliceobj,)
+        
+        if len(sliceobj) > len(shape):
+            sliceobj = sliceobj[:len(shape)]
+
+        return nib.fileslice.canonical_slicers(sliceobj, shape)
     
 
 def canonicalShape(shape):
@@ -718,6 +786,59 @@ def canonicalShape(shape):
     return shape
 
 
+def expectedShape(sliceobj, shape):
+    """Given a slice object, and the shape of an array to which
+    that slice object is going to be applied, returns the expected
+    shape of the result.
+
+    .. note:: It is assumed that the ``sliceobj`` has been passed through
+              the :func:`canonicalSliceObj` function.
+
+    :arg sliceobj: Something which can be used to slice an array
+                   of shape ``shape``.
+
+    :arg shape:    Shape of the array being sliced.
+
+    :returns:      A tuple containing:
+    
+                     - Expected number of dimensions of the result
+    
+                     - Expected shape of the result (or ``None`` if
+                       ``sliceobj`` is fancy).
+    """
+    
+    if isValidFancySliceObj(sliceobj, shape):
+        return 1, None
+
+    # Truncate some dimensions from the
+    # slice object if it has too many
+    # (e.g. trailing dims of length 1).
+    elif len(sliceobj) > len(shape):
+        sliceobj = sliceobj[:len(shape)]
+
+    # Figure out the number of dimensions
+    # that the result should have, given
+    # this slice object. 
+    expShape = []
+
+    for i in range(len(sliceobj)):
+
+        # Each dimension which has an 
+        # int slice will be collapsed
+        if isinstance(sliceobj[i], int):
+            continue
+
+        start = sliceobj[i].start
+        stop  = sliceobj[i].stop
+        
+        if start is None: start = 0
+        if stop  is None: stop  = shape[i]
+
+        expShape.append(stop - start)
+
+    return len(expShape), expShape
+
+
 def sliceObjToSliceTuple(sliceobj, shape):
     """Turns an array slice object into a tuple of (low, high) index
     pairs, one pair for each dimension in the given shape
-- 
GitLab