Skip to content
Snippets Groups Projects
Commit 8660298f authored by Paul McCarthy's avatar Paul McCarthy
Browse files

ImageWrapper bugfixes - more robust support for images with dodgy/misreported

dimensionality
parent d6177a78
No related branches found
No related tags found
No related merge requests found
......@@ -15,21 +15,25 @@ Terminology
There are some confusing terms used in this module, so it may be useful to
get their definitions straight:
- *Coverage*: The portion of an image that has been covered in the data
range calculation. The ``ImageWrapper`` keeps track of
the coverage for individual volumes within a 4D image (or
slices in a 3D image).
- *Slice*: Portion of the image data which is being accessed. A slice
comprises either a tuple of ``slice`` objects (or integers),
or a sequence of ``(low, high)`` tuples, specifying the
index range into each image dimension that is covered by
the slice.
- *Expansion*: A sequence of ``(low, high)`` tuples, specifying an
index range into each image dimension, that is used to
*expand* the *coverage* of an image, based on a given set of
*slices*.
- *Coverage*: The portion of an image that has been covered in the data
range calculation. The ``ImageWrapper`` keeps track of
the coverage for individual volumes within a 4D image (or
slices in a 3D image).
- *Slice*: Portion of the image data which is being accessed. A slice
comprises either a tuple of ``slice`` objects (or integers),
or a sequence of ``(low, high)`` tuples, specifying the
index range into each image dimension that is covered by
the slice.
- *Expansion*: A sequence of ``(low, high)`` tuples, specifying an
index range into each image dimension, that is used to
*expand* the *coverage* of an image, based on a given set
of *slices*.
- *Fancy slice*: Any object which is used to slice an array, and is not
an ``int``, ``slice``, or ``Ellipsis``, or sequence of
these.
"""
......@@ -86,6 +90,21 @@ class ImageWrapper(notifier.Notifier):
need provide an index of 0 for that dimensions, for all data accesses.
*Data access*
The ``ImageWrapper`` can be indexed in one of two ways:
- With basic ``numpy``-like multi-dimensional array slicing (with step
sizes of 1)
- With boolean array indexing, where the boolean/mask array has the
same shape as the image data.
See https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html for
more details on numpy indexing.
*Data range*
......@@ -302,6 +321,15 @@ class ImageWrapper(notifier.Notifier):
return self.__covered
@property
def shape(self):
"""Returns the shape that the image data is presented as. This is
the same as the underlying image shape, but with trailing dimensions
of length 1 removed, and at least three dimensions.
"""
return self.__canonicalShape
def coverage(self, vol):
"""Returns the current image data coverage for the specified volume
(for a 4D image, slice for a 3D image, or vector for a 2D images).
......@@ -547,52 +575,53 @@ class ImageWrapper(notifier.Notifier):
:arg sliceobj: Something which can slice the image data.
"""
log.debug('Getting image data: {}'.format(sliceobj))
image = self.__image
shape = image.shape
ndims = len(shape)
fancy = isValidFancySliceObj(sliceobj, shape)
shape = self.__canonicalShape
realShape = self.__image.shape
sliceobj = canonicalSliceObj( sliceobj, shape)
fancy = isValidFancySliceObj(sliceobj, shape)
expNdims, expShape = expectedShape( sliceobj, shape)
if fancy:
expNdims = ndims
else:
if not isinstance(sliceobj, tuple):
sliceobj = (sliceobj,)
# Figure out the number of dimensions
# that the result should have, given
# this slice object.
expNdims = len(self.__canonicalShape) - \
len([s for s in sliceobj if isinstance(s, int)])
# Truncate some dimensions from the
# slice object if it has too many
# (e.g. trailing dims of length 1).
if len(sliceobj) > ndims:
sliceobj = sliceobj[:ndims]
# TODO Cache 3D images for large 4D volumes,
# so you don't have to hit the disk?
sliceobj = canonicalSliceObj(sliceobj, shape)
# Make the slice object compatible with the
# actual image shape, and retrieve the data.
sliceobj = canonicalSliceObj(sliceobj, realShape)
data = self.__getData(sliceobj)
# Update data range for the
# data that we just read in
if not self.__covered:
slices = sliceObjToSliceTuple(sliceobj, shape)
slices = sliceObjToSliceTuple(sliceobj, realShape)
if not sliceCovered(slices, self.__coverage):
self.__updateDataRangeOnRead(slices, data)
# Make sure that the result has the
# shape that the caller is expecting.
if not fancy and ndims != expNdims:
data = data.reshape(list(data.shape) + [1] * (expNdims - ndims))
if fancy: data = data.reshape((data.size, ))
else: data = data.reshape(expShape)
# If expNdims == 0, we should
# return a scalar. If expNdims
# == 0, but data.size != 1,
# something is wrong somewhere
# (and is not being handled
# here).
if expNdims == 0 and data.size == 1:
# Funny behaviour with numpy scalar arrays.
# data[()] returns a numpy scalar (which is
# what we want). But data.item() returns a
# python scalar. And if the data is a
# ndarray with 0 dims, data[0] will raise
# an error!
data = data[()]
return data
......@@ -608,8 +637,37 @@ class ImageWrapper(notifier.Notifier):
loaded into memory.
"""
sliceobj = canonicalSliceObj( sliceobj, self.__image.shape)
slices = sliceObjToSliceTuple(sliceobj, self.__image.shape)
realShape = self.__image.shape
sliceobj = canonicalSliceObj( sliceobj, realShape)
slices = sliceObjToSliceTuple(sliceobj, realShape)
# If the image shape does not match its
# 'display' shape (either less three
# dims, or has trailing dims of length
# 1), we might need to re-shape the
# values to prevent numpy from raising
# an error in the assignment below.
if realShape != self.__canonicalShape:
expNdims, expShape = expectedShape(sliceobj, realShape)
# If we are slicing a scalar, the
# assigned value has to be scalar.
if expNdims == 0 and isinstance(values, collections.Sequence):
if len(values) > 1:
raise IndexError('Invalid assignment: [{}] = {}'.format(
sliceobj, len(values)))
values = values[0]
# Make sure that the values
# have a compatible shape.
else:
values = np.array(values)
if values.shape != expShape:
values = values.reshape(expShape)
# The image data has to be in memory
# for the data to be changed. If it's
......@@ -681,7 +739,7 @@ def isValidFancySliceObj(sliceobj, shape):
# which have the same shape as the image
return (isinstance(sliceobj, np.ndarray) and
sliceobj.dtype == np.bool and
sliceobj.shape == shape)
np.prod(sliceobj.shape) == np.prod(shape))
def canonicalSliceObj(sliceobj, shape):
......@@ -689,10 +747,20 @@ def canonicalSliceObj(sliceobj, shape):
``nibabel.fileslice.canonical_slicers` function.
"""
if not isValidFancySliceObj(sliceobj, shape):
sliceobj = nib.fileslice.canonical_slicers(sliceobj, shape)
# Fancy slice objects must have
# the same shape as the data
if isValidFancySliceObj(sliceobj, shape):
return sliceobj.reshape(shape)
return sliceobj
else:
if not isinstance(sliceobj, tuple):
sliceobj = (sliceobj,)
if len(sliceobj) > len(shape):
sliceobj = sliceobj[:len(shape)]
return nib.fileslice.canonical_slicers(sliceobj, shape)
def canonicalShape(shape):
......@@ -718,6 +786,59 @@ def canonicalShape(shape):
return shape
def expectedShape(sliceobj, shape):
"""Given a slice object, and the shape of an array to which
that slice object is going to be applied, returns the expected
shape of the result.
.. note:: It is assumed that the ``sliceobj`` has been passed through
the :func:`canonicalSliceObj` function.
:arg sliceobj: Something which can be used to slice an array
of shape ``shape``.
:arg shape: Shape of the array being sliced.
:returns: A tuple containing:
- Expected number of dimensions of the result
- Expected shape of the result (or ``None`` if
``sliceobj`` is fancy).
"""
if isValidFancySliceObj(sliceobj, shape):
return 1, None
# Truncate some dimensions from the
# slice object if it has too many
# (e.g. trailing dims of length 1).
elif len(sliceobj) > len(shape):
sliceobj = sliceobj[:len(shape)]
# Figure out the number of dimensions
# that the result should have, given
# this slice object.
expShape = []
for i in range(len(sliceobj)):
# Each dimension which has an
# int slice will be collapsed
if isinstance(sliceobj[i], int):
continue
start = sliceobj[i].start
stop = sliceobj[i].stop
if start is None: start = 0
if stop is None: stop = shape[i]
expShape.append(stop - start)
return len(expShape), expShape
def sliceObjToSliceTuple(sliceobj, shape):
"""Turns an array slice object into a tuple of (low, high) index
pairs, one pair for each dimension in the given shape
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment