Skip to content
Snippets Groups Projects
Commit c428f937 authored by Paul McCarthy's avatar Paul McCarthy
Browse files

That last ImageWrapper fix didn't fix anything. This should.

parent 22a51496
No related branches found
No related tags found
No related merge requests found
...@@ -161,26 +161,17 @@ class Nifti(object): ...@@ -161,26 +161,17 @@ class Nifti(object):
- A sequence/tuple containing the zooms/pixdims. - A sequence/tuple containing the zooms/pixdims.
""" """
# The canonicalShape method figures out
# the data shape that we should use.
origShape = list(header.get_data_shape()) origShape = list(header.get_data_shape())
shape = list(origShape) shape = imagewrapper.canonicalShape(origShape)
pixdims = list(header.get_zooms()) pixdims = list(header.get_zooms())
# Squeeze out empty dimensions, as # if get_zooms() doesn't return at
# 3D image can sometimes be listed # least len(shape) values, we'll
# as having 4 or more dimensions # fall back to the pixdim field in
for i in reversed(range(len(shape))): # the header.
if shape[i] == 1: shape = shape[:i] if len(pixdims) < len(shape):
else: break
# But make sure the shape
# has at 3 least dimensions
if len(shape) < 3:
shape = shape + [1] * (3 - len(shape))
# The same goes for the pixdim - if get_zooms()
# doesn't return at least 3 values, we'll fall
# back to the pixdim field in the header.
if len(pixdims) < 3:
pixdims = header['pixdim'][1:] pixdims = header['pixdim'][1:]
pixdims = pixdims[:len(shape)] pixdims = pixdims[:len(shape)]
......
...@@ -168,6 +168,15 @@ class ImageWrapper(notifier.Notifier): ...@@ -168,6 +168,15 @@ class ImageWrapper(notifier.Notifier):
# 'padding' dimensions too. # 'padding' dimensions too.
self.__numPadDims = len(image.shape) - self.__numRealDims self.__numPadDims = len(image.shape) - self.__numRealDims
# Too many shapes! Figure out
# what shape we should present
# the data as (e.g. at least 3
# dimensions). This is used in
# __getitem__ to force the
# result to have the correct
# dimensionality.
self.__canonicalShape = canonicalShape(image.shape)
# The internal state is stored # The internal state is stored
# in these attributes - they're # in these attributes - they're
# initialised in the reset method. # initialised in the reset method.
...@@ -335,14 +344,6 @@ class ImageWrapper(notifier.Notifier): ...@@ -335,14 +344,6 @@ class ImageWrapper(notifier.Notifier):
if isTuple: if isTuple:
sliceobj = sliceTupleToSliceObj(sliceobj) sliceobj = sliceTupleToSliceObj(sliceobj)
# Truncate some dimensions from the
# slice object if it has too many
# (e.g. trailing dims of length 1).
ndims = len(self.__image.shape)
if len(sliceobj) > ndims:
sliceobj = sliceobj[:ndims]
# If the image has not been loaded # If the image has not been loaded
# into memory, we can use the nibabel # into memory, we can use the nibabel
# ArrayProxy. Otheriwse if it is in # ArrayProxy. Otheriwse if it is in
...@@ -541,6 +542,12 @@ class ImageWrapper(notifier.Notifier): ...@@ -541,6 +542,12 @@ class ImageWrapper(notifier.Notifier):
if not isinstance(sliceobj, tuple): if not isinstance(sliceobj, tuple):
sliceobj = (sliceobj,) sliceobj = (sliceobj,)
# Figure out the number of dimensions
# that the result should have, given
# this slice object.
expNdims = len(self.__canonicalShape) - \
len([s for s in sliceobj if isinstance(s, int)])
# Truncate some dimensions from the # Truncate some dimensions from the
# slice object if it has too many # slice object if it has too many
# (e.g. trailing dims of length 1). # (e.g. trailing dims of length 1).
...@@ -550,20 +557,26 @@ class ImageWrapper(notifier.Notifier): ...@@ -550,20 +557,26 @@ class ImageWrapper(notifier.Notifier):
if len(sliceobj) > ndims: if len(sliceobj) > ndims:
sliceobj = sliceobj[:ndims] sliceobj = sliceobj[:ndims]
sliceobj = nib.fileslice.canonical_slicers(sliceobj, shape)
# TODO Cache 3D images for large 4D volumes, # TODO Cache 3D images for large 4D volumes,
# so you don't have to hit the disk? # so you don't have to hit the disk?
data = self.__getData(sliceobj) sliceobj = nib.fileslice.canonical_slicers(sliceobj, shape)
data = self.__getData(sliceobj)
if not self.__covered: if not self.__covered:
slices = sliceObjToSliceTuple(sliceobj, self.__image.shape) slices = sliceObjToSliceTuple(sliceobj, shape)
if not sliceCovered(slices, self.__coverage): if not sliceCovered(slices, self.__coverage):
self.__updateDataRangeOnRead(slices, data) self.__updateDataRangeOnRead(slices, data)
# Make sure that the result has
# the shape that the caller is
# expecting.
ndims = len(data.shape)
if ndims < expNdims:
data = data.reshape(list(data.shape) + [1] * (expNdims - ndims))
return data return data
...@@ -594,6 +607,29 @@ class ImageWrapper(notifier.Notifier): ...@@ -594,6 +607,29 @@ class ImageWrapper(notifier.Notifier):
self.__updateDataRangeOnWrite(slices, values) self.__updateDataRangeOnWrite(slices, values)
def canonicalShape(shape):
"""Calculates a *canonical* shape, how the given ``shape`` should
be presented. The shape is forced to be at least three dimensions,
with any other trailing dimensions of length 1 ignored.
"""
shape = list(shape)
# Squeeze out empty dimensions, as
# 3D image can sometimes be listed
# as having 4 or more dimensions
for i in reversed(range(len(shape))):
if shape[i] == 1: shape = shape[:i]
else: break
# But make sure the shape
# has at 3 least dimensions
if len(shape) < 3:
shape = shape + [1] * (3 - len(shape))
return shape
def sliceObjToSliceTuple(sliceobj, shape): def sliceObjToSliceTuple(sliceobj, shape):
"""Turns an array slice object into a tuple of (low, high) index """Turns an array slice object into a tuple of (low, high) index
pairs, one pair for each dimension in the given shape pairs, one pair for each dimension in the given shape
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment