From c428f937281546c309eca6dba982885f34b13cbc Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauld.mccarthy@gmail.com>
Date: Thu, 22 Sep 2016 14:38:23 +0100
Subject: [PATCH] That last ImageWrapper fix didn't fix anything. This should.

---
 fsl/data/image.py        | 25 ++++++-----------
 fsl/data/imagewrapper.py | 60 ++++++++++++++++++++++++++++++++--------
 2 files changed, 56 insertions(+), 29 deletions(-)

diff --git a/fsl/data/image.py b/fsl/data/image.py
index 90a1c8e63..84f4ea32a 100644
--- a/fsl/data/image.py
+++ b/fsl/data/image.py
@@ -161,26 +161,17 @@ class Nifti(object):
          - A sequence/tuple containing the zooms/pixdims.
         """
 
+        # The canonicalShape method figures out
+        # the data shape that we should use.
         origShape = list(header.get_data_shape())
-        shape     = list(origShape)
+        shape     = imagewrapper.canonicalShape(origShape)
         pixdims   = list(header.get_zooms())
 
-        # Squeeze out empty dimensions, as
-        # 3D image can sometimes be listed
-        # as having 4 or more dimensions 
-        for i in reversed(range(len(shape))):
-            if shape[i] == 1: shape = shape[:i]
-            else:             break
-
-        # But make sure the shape 
-        # has at 3 least dimensions
-        if len(shape) < 3:
-            shape = shape + [1] * (3 - len(shape))
-
-        # The same goes for the pixdim - if get_zooms()
-        # doesn't return at least 3 values, we'll fall
-        # back to the pixdim field in the header.
-        if len(pixdims) < 3:
+        # if get_zooms() doesn't return at
+        # least len(shape) values, we'll
+        # fall back to the pixdim field in
+        # the header.
+        if len(pixdims) < len(shape):
             pixdims = header['pixdim'][1:]
 
         pixdims = pixdims[:len(shape)]
diff --git a/fsl/data/imagewrapper.py b/fsl/data/imagewrapper.py
index 79a5b0c37..b05510d8a 100644
--- a/fsl/data/imagewrapper.py
+++ b/fsl/data/imagewrapper.py
@@ -168,6 +168,15 @@ class ImageWrapper(notifier.Notifier):
         # 'padding' dimensions too.
         self.__numPadDims = len(image.shape) - self.__numRealDims
 
+        # Too many shapes! Figure out
+        # what shape we should present
+        # the data as (e.g. at least 3
+        # dimensions). This is used in
+        # __getitem__ to force the
+        # result to have the correct
+        # dimensionality.
+        self.__canonicalShape = canonicalShape(image.shape)
+
         # The internal state is stored
         # in these attributes - they're
         # initialised in the reset method.
@@ -335,14 +344,6 @@ class ImageWrapper(notifier.Notifier):
         if isTuple:
             sliceobj = sliceTupleToSliceObj(sliceobj)
 
-        # Truncate some dimensions from the
-        # slice object if it has too many
-        # (e.g. trailing dims of length 1).
-        ndims = len(self.__image.shape)
-
-        if len(sliceobj) > ndims:
-            sliceobj = sliceobj[:ndims]
-
         # If the image has not been loaded
         # into memory, we can use the nibabel
         # ArrayProxy. Otheriwse if it is in
@@ -541,6 +542,12 @@ class ImageWrapper(notifier.Notifier):
         if not isinstance(sliceobj, tuple):
             sliceobj = (sliceobj,)
 
+        # Figure out the number of dimensions
+        # that the result should have, given
+        # this slice object.
+        expNdims = len(self.__canonicalShape) - \
+                   len([s for s in sliceobj if isinstance(s, int)])
+
         # Truncate some dimensions from the
         # slice object if it has too many
         # (e.g. trailing dims of length 1).
@@ -550,20 +557,26 @@ class ImageWrapper(notifier.Notifier):
         if len(sliceobj) > ndims:
             sliceobj = sliceobj[:ndims]
 
-        sliceobj = nib.fileslice.canonical_slicers(sliceobj, shape)
-
         # TODO Cache 3D images for large 4D volumes, 
         #      so you don't have to hit the disk?
 
-        data = self.__getData(sliceobj)
+        sliceobj = nib.fileslice.canonical_slicers(sliceobj, shape)
+        data     = self.__getData(sliceobj)
 
         if not self.__covered:
 
-            slices = sliceObjToSliceTuple(sliceobj, self.__image.shape)
+            slices = sliceObjToSliceTuple(sliceobj, shape)
 
             if not sliceCovered(slices, self.__coverage):
                 self.__updateDataRangeOnRead(slices, data)
 
+        # Make sure that the result has
+        # the shape that the caller is
+        # expecting.
+        ndims = len(data.shape)
+        if ndims < expNdims:
+            data = data.reshape(list(data.shape) + [1] * (expNdims - ndims))
+
         return data
 
 
@@ -594,6 +607,29 @@ class ImageWrapper(notifier.Notifier):
         self.__updateDataRangeOnWrite(slices, values)
 
 
+def canonicalShape(shape):
+    """Calculates a *canonical* shape, how the given ``shape`` should
+    be presented. The shape is forced to be at least three dimensions,
+    with any other trailing dimensions of length 1 ignored.
+    """
+
+    shape = list(shape)
+
+    # Squeeze out empty dimensions, as
+    # 3D image can sometimes be listed
+    # as having 4 or more dimensions 
+    for i in reversed(range(len(shape))):
+        if shape[i] == 1: shape = shape[:i]
+        else:             break
+
+    # But make sure the shape 
+    # has at 3 least dimensions
+    if len(shape) < 3:
+        shape = shape + [1] * (3 - len(shape))
+
+    return shape
+
+
 def sliceObjToSliceTuple(sliceobj, shape):
     """Turns an array slice object into a tuple of (low, high) index
     pairs, one pair for each dimension in the given shape
-- 
GitLab