From 8660298fc44f6ccbb75d5debb054b113d3e927e2 Mon Sep 17 00:00:00 2001 From: Paul McCarthy <pauld.mccarthy@gmail.com> Date: Tue, 13 Dec 2016 12:30:37 +0000 Subject: [PATCH] ImageWrapper bugfixes - more robust support for images with dodgy/misreported dimensionality --- fsl/data/imagewrapper.py | 223 ++++++++++++++++++++++++++++++--------- 1 file changed, 172 insertions(+), 51 deletions(-) diff --git a/fsl/data/imagewrapper.py b/fsl/data/imagewrapper.py index a3b7657fa..d9bacb798 100644 --- a/fsl/data/imagewrapper.py +++ b/fsl/data/imagewrapper.py @@ -15,21 +15,25 @@ Terminology There are some confusing terms used in this module, so it may be useful to get their definitions straight: - - *Coverage*: The portion of an image that has been covered in the data - range calculation. The ``ImageWrapper`` keeps track of - the coverage for individual volumes within a 4D image (or - slices in a 3D image). - - - *Slice*: Portion of the image data which is being accessed. A slice - comprises either a tuple of ``slice`` objects (or integers), - or a sequence of ``(low, high)`` tuples, specifying the - index range into each image dimension that is covered by - the slice. - - - *Expansion*: A sequence of ``(low, high)`` tuples, specifying an - index range into each image dimension, that is used to - *expand* the *coverage* of an image, based on a given set of - *slices*. + - *Coverage*: The portion of an image that has been covered in the data + range calculation. The ``ImageWrapper`` keeps track of + the coverage for individual volumes within a 4D image (or + slices in a 3D image). + + - *Slice*: Portion of the image data which is being accessed. A slice + comprises either a tuple of ``slice`` objects (or integers), + or a sequence of ``(low, high)`` tuples, specifying the + index range into each image dimension that is covered by + the slice. + + - *Expansion*: A sequence of ``(low, high)`` tuples, specifying an + index range into each image dimension, that is used to + *expand* the *coverage* of an image, based on a given set + of *slices*. + + - *Fancy slice*: Any object which is used to slice an array, and is not + an ``int``, ``slice``, or ``Ellipsis``, or sequence of + these. """ @@ -86,6 +90,21 @@ class ImageWrapper(notifier.Notifier): need provide an index of 0 for that dimensions, for all data accesses. + *Data access* + + + The ``ImageWrapper`` can be indexed in one of two ways: + + - With basic ``numpy``-like multi-dimensional array slicing (with step + sizes of 1) + + - With boolean array indexing, where the boolean/mask array has the + same shape as the image data. + + See https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html for + more details on numpy indexing. + + *Data range* @@ -302,6 +321,15 @@ class ImageWrapper(notifier.Notifier): return self.__covered + @property + def shape(self): + """Returns the shape that the image data is presented as. This is + the same as the underlying image shape, but with trailing dimensions + of length 1 removed, and at least three dimensions. + """ + return self.__canonicalShape + + def coverage(self, vol): """Returns the current image data coverage for the specified volume (for a 4D image, slice for a 3D image, or vector for a 2D images). @@ -547,52 +575,53 @@ class ImageWrapper(notifier.Notifier): :arg sliceobj: Something which can slice the image data. """ - + log.debug('Getting image data: {}'.format(sliceobj)) - image = self.__image - shape = image.shape - ndims = len(shape) - - fancy = isValidFancySliceObj(sliceobj, shape) + shape = self.__canonicalShape + realShape = self.__image.shape + sliceobj = canonicalSliceObj( sliceobj, shape) + fancy = isValidFancySliceObj(sliceobj, shape) + expNdims, expShape = expectedShape( sliceobj, shape) - if fancy: - expNdims = ndims - else: - - if not isinstance(sliceobj, tuple): - sliceobj = (sliceobj,) - - # Figure out the number of dimensions - # that the result should have, given - # this slice object. - expNdims = len(self.__canonicalShape) - \ - len([s for s in sliceobj if isinstance(s, int)]) - - # Truncate some dimensions from the - # slice object if it has too many - # (e.g. trailing dims of length 1). - if len(sliceobj) > ndims: - sliceobj = sliceobj[:ndims] - # TODO Cache 3D images for large 4D volumes, # so you don't have to hit the disk? - sliceobj = canonicalSliceObj(sliceobj, shape) + # Make the slice object compatible with the + # actual image shape, and retrieve the data. + sliceobj = canonicalSliceObj(sliceobj, realShape) data = self.__getData(sliceobj) + # Update data range for the + # data that we just read in if not self.__covered: - slices = sliceObjToSliceTuple(sliceobj, shape) + slices = sliceObjToSliceTuple(sliceobj, realShape) if not sliceCovered(slices, self.__coverage): self.__updateDataRangeOnRead(slices, data) # Make sure that the result has the # shape that the caller is expecting. - if not fancy and ndims != expNdims: - data = data.reshape(list(data.shape) + [1] * (expNdims - ndims)) - + if fancy: data = data.reshape((data.size, )) + else: data = data.reshape(expShape) + + # If expNdims == 0, we should + # return a scalar. If expNdims + # == 0, but data.size != 1, + # something is wrong somewhere + # (and is not being handled + # here). + if expNdims == 0 and data.size == 1: + + # Funny behaviour with numpy scalar arrays. + # data[()] returns a numpy scalar (which is + # what we want). But data.item() returns a + # python scalar. And if the data is a + # ndarray with 0 dims, data[0] will raise + # an error! + data = data[()] + return data @@ -608,8 +637,37 @@ class ImageWrapper(notifier.Notifier): loaded into memory. """ - sliceobj = canonicalSliceObj( sliceobj, self.__image.shape) - slices = sliceObjToSliceTuple(sliceobj, self.__image.shape) + realShape = self.__image.shape + sliceobj = canonicalSliceObj( sliceobj, realShape) + slices = sliceObjToSliceTuple(sliceobj, realShape) + + # If the image shape does not match its + # 'display' shape (either less three + # dims, or has trailing dims of length + # 1), we might need to re-shape the + # values to prevent numpy from raising + # an error in the assignment below. + if realShape != self.__canonicalShape: + + expNdims, expShape = expectedShape(sliceobj, realShape) + + # If we are slicing a scalar, the + # assigned value has to be scalar. + if expNdims == 0 and isinstance(values, collections.Sequence): + + if len(values) > 1: + raise IndexError('Invalid assignment: [{}] = {}'.format( + sliceobj, len(values))) + + values = values[0] + + # Make sure that the values + # have a compatible shape. + else: + + values = np.array(values) + if values.shape != expShape: + values = values.reshape(expShape) # The image data has to be in memory # for the data to be changed. If it's @@ -681,7 +739,7 @@ def isValidFancySliceObj(sliceobj, shape): # which have the same shape as the image return (isinstance(sliceobj, np.ndarray) and sliceobj.dtype == np.bool and - sliceobj.shape == shape) + np.prod(sliceobj.shape) == np.prod(shape)) def canonicalSliceObj(sliceobj, shape): @@ -689,10 +747,20 @@ def canonicalSliceObj(sliceobj, shape): ``nibabel.fileslice.canonical_slicers` function. """ - if not isValidFancySliceObj(sliceobj, shape): - sliceobj = nib.fileslice.canonical_slicers(sliceobj, shape) + # Fancy slice objects must have + # the same shape as the data + if isValidFancySliceObj(sliceobj, shape): + return sliceobj.reshape(shape) - return sliceobj + else: + + if not isinstance(sliceobj, tuple): + sliceobj = (sliceobj,) + + if len(sliceobj) > len(shape): + sliceobj = sliceobj[:len(shape)] + + return nib.fileslice.canonical_slicers(sliceobj, shape) def canonicalShape(shape): @@ -718,6 +786,59 @@ def canonicalShape(shape): return shape +def expectedShape(sliceobj, shape): + """Given a slice object, and the shape of an array to which + that slice object is going to be applied, returns the expected + shape of the result. + + .. note:: It is assumed that the ``sliceobj`` has been passed through + the :func:`canonicalSliceObj` function. + + :arg sliceobj: Something which can be used to slice an array + of shape ``shape``. + + :arg shape: Shape of the array being sliced. + + :returns: A tuple containing: + + - Expected number of dimensions of the result + + - Expected shape of the result (or ``None`` if + ``sliceobj`` is fancy). + """ + + if isValidFancySliceObj(sliceobj, shape): + return 1, None + + # Truncate some dimensions from the + # slice object if it has too many + # (e.g. trailing dims of length 1). + elif len(sliceobj) > len(shape): + sliceobj = sliceobj[:len(shape)] + + # Figure out the number of dimensions + # that the result should have, given + # this slice object. + expShape = [] + + for i in range(len(sliceobj)): + + # Each dimension which has an + # int slice will be collapsed + if isinstance(sliceobj[i], int): + continue + + start = sliceobj[i].start + stop = sliceobj[i].stop + + if start is None: start = 0 + if stop is None: stop = shape[i] + + expShape.append(stop - start) + + return len(expShape), expShape + + def sliceObjToSliceTuple(sliceobj, shape): """Turns an array slice object into a tuple of (low, high) index pairs, one pair for each dimension in the given shape -- GitLab