From c428f937281546c309eca6dba982885f34b13cbc Mon Sep 17 00:00:00 2001 From: Paul McCarthy <pauld.mccarthy@gmail.com> Date: Thu, 22 Sep 2016 14:38:23 +0100 Subject: [PATCH] That last ImageWrapper fix didn't fix anything. This should. --- fsl/data/image.py | 25 ++++++----------- fsl/data/imagewrapper.py | 60 ++++++++++++++++++++++++++++++++-------- 2 files changed, 56 insertions(+), 29 deletions(-) diff --git a/fsl/data/image.py b/fsl/data/image.py index 90a1c8e63..84f4ea32a 100644 --- a/fsl/data/image.py +++ b/fsl/data/image.py @@ -161,26 +161,17 @@ class Nifti(object): - A sequence/tuple containing the zooms/pixdims. """ + # The canonicalShape method figures out + # the data shape that we should use. origShape = list(header.get_data_shape()) - shape = list(origShape) + shape = imagewrapper.canonicalShape(origShape) pixdims = list(header.get_zooms()) - # Squeeze out empty dimensions, as - # 3D image can sometimes be listed - # as having 4 or more dimensions - for i in reversed(range(len(shape))): - if shape[i] == 1: shape = shape[:i] - else: break - - # But make sure the shape - # has at 3 least dimensions - if len(shape) < 3: - shape = shape + [1] * (3 - len(shape)) - - # The same goes for the pixdim - if get_zooms() - # doesn't return at least 3 values, we'll fall - # back to the pixdim field in the header. - if len(pixdims) < 3: + # if get_zooms() doesn't return at + # least len(shape) values, we'll + # fall back to the pixdim field in + # the header. + if len(pixdims) < len(shape): pixdims = header['pixdim'][1:] pixdims = pixdims[:len(shape)] diff --git a/fsl/data/imagewrapper.py b/fsl/data/imagewrapper.py index 79a5b0c37..b05510d8a 100644 --- a/fsl/data/imagewrapper.py +++ b/fsl/data/imagewrapper.py @@ -168,6 +168,15 @@ class ImageWrapper(notifier.Notifier): # 'padding' dimensions too. self.__numPadDims = len(image.shape) - self.__numRealDims + # Too many shapes! Figure out + # what shape we should present + # the data as (e.g. at least 3 + # dimensions). This is used in + # __getitem__ to force the + # result to have the correct + # dimensionality. + self.__canonicalShape = canonicalShape(image.shape) + # The internal state is stored # in these attributes - they're # initialised in the reset method. @@ -335,14 +344,6 @@ class ImageWrapper(notifier.Notifier): if isTuple: sliceobj = sliceTupleToSliceObj(sliceobj) - # Truncate some dimensions from the - # slice object if it has too many - # (e.g. trailing dims of length 1). - ndims = len(self.__image.shape) - - if len(sliceobj) > ndims: - sliceobj = sliceobj[:ndims] - # If the image has not been loaded # into memory, we can use the nibabel # ArrayProxy. Otheriwse if it is in @@ -541,6 +542,12 @@ class ImageWrapper(notifier.Notifier): if not isinstance(sliceobj, tuple): sliceobj = (sliceobj,) + # Figure out the number of dimensions + # that the result should have, given + # this slice object. + expNdims = len(self.__canonicalShape) - \ + len([s for s in sliceobj if isinstance(s, int)]) + # Truncate some dimensions from the # slice object if it has too many # (e.g. trailing dims of length 1). @@ -550,20 +557,26 @@ class ImageWrapper(notifier.Notifier): if len(sliceobj) > ndims: sliceobj = sliceobj[:ndims] - sliceobj = nib.fileslice.canonical_slicers(sliceobj, shape) - # TODO Cache 3D images for large 4D volumes, # so you don't have to hit the disk? - data = self.__getData(sliceobj) + sliceobj = nib.fileslice.canonical_slicers(sliceobj, shape) + data = self.__getData(sliceobj) if not self.__covered: - slices = sliceObjToSliceTuple(sliceobj, self.__image.shape) + slices = sliceObjToSliceTuple(sliceobj, shape) if not sliceCovered(slices, self.__coverage): self.__updateDataRangeOnRead(slices, data) + # Make sure that the result has + # the shape that the caller is + # expecting. + ndims = len(data.shape) + if ndims < expNdims: + data = data.reshape(list(data.shape) + [1] * (expNdims - ndims)) + return data @@ -594,6 +607,29 @@ class ImageWrapper(notifier.Notifier): self.__updateDataRangeOnWrite(slices, values) +def canonicalShape(shape): + """Calculates a *canonical* shape, how the given ``shape`` should + be presented. The shape is forced to be at least three dimensions, + with any other trailing dimensions of length 1 ignored. + """ + + shape = list(shape) + + # Squeeze out empty dimensions, as + # 3D image can sometimes be listed + # as having 4 or more dimensions + for i in reversed(range(len(shape))): + if shape[i] == 1: shape = shape[:i] + else: break + + # But make sure the shape + # has at 3 least dimensions + if len(shape) < 3: + shape = shape + [1] * (3 - len(shape)) + + return shape + + def sliceObjToSliceTuple(sliceobj, shape): """Turns an array slice object into a tuple of (low, high) index pairs, one pair for each dimension in the given shape -- GitLab