Skip to content
Snippets Groups Projects
Commit 64862743 authored by Paul McCarthy's avatar Paul McCarthy
Browse files

Refactorings/documentations/fixes to image wrapper. More to come.

parent 944c6390
No related branches found
No related tags found
No related merge requests found
...@@ -9,7 +9,8 @@ to manage data access to ``nibabel`` NIFTI images. ...@@ -9,7 +9,8 @@ to manage data access to ``nibabel`` NIFTI images.
""" """
import logging import logging
import collections
import numpy as np import numpy as np
import nibabel as nib import nibabel as nib
...@@ -38,9 +39,21 @@ class ImageWrapper(notifier.Notifier): ...@@ -38,9 +39,21 @@ class ImageWrapper(notifier.Notifier):
property. property.
.. todo:: Figure out if NIFTI2 can be supported as well. The ``ImageWrapper`` class uses the following functions (also defined in
this module) to keep track of the portion of the image that has currently
been included in the data range calculation:
.. autosummary::
:nosignatures:
sliceObjToSliceTuple
sliceTupleToSliceObj
sliceCovered
calcSliceExpansion
adjustCoverage
""" """
def __init__(self, image, name=None, loadData=False): def __init__(self, image, name=None, loadData=False):
"""Create an ``ImageWrapper``. """Create an ``ImageWrapper``.
...@@ -85,7 +98,11 @@ class ImageWrapper(notifier.Notifier): ...@@ -85,7 +98,11 @@ class ImageWrapper(notifier.Notifier):
# been included in the data range calculation, so # been included in the data range calculation, so
# we do not unnecessarily re-calculate ranges on # we do not unnecessarily re-calculate ranges on
# the same part of the image. # the same part of the image.
self.__sliceCoverage = [] self.__coverage = []
# TODO Use a numpy array of size
# (2, [numRealDims - 1], [shape[numRealDims - 1]])
# instead of a list of lists
# This is a list of lists of (low, high) pairs, # This is a list of lists of (low, high) pairs,
# one list for each entry in the last dimension # one list for each entry in the last dimension
...@@ -96,7 +113,7 @@ class ImageWrapper(notifier.Notifier): ...@@ -96,7 +113,7 @@ class ImageWrapper(notifier.Notifier):
for i in range(image.shape[self.__numRealDims - 1]): for i in range(image.shape[self.__numRealDims - 1]):
cov = [[None, None] for i in range(self.__numRealDims - 1)] cov = [[None, None] for i in range(self.__numRealDims - 1)]
self.__sliceCoverage.append(cov) self.__coverage.append(cov)
if loadData: if loadData:
self.loadData() self.loadData()
...@@ -128,7 +145,7 @@ class ImageWrapper(notifier.Notifier): ...@@ -128,7 +145,7 @@ class ImageWrapper(notifier.Notifier):
""" """
if isTuple: if isTuple:
sliceobj = sliceTupletoSliceObj(sliceobj) sliceobj = sliceTupleToSliceObj(sliceobj)
# If the image has not been loaded # If the image has not been loaded
# into memory, we can use the nibabel # into memory, we can use the nibabel
...@@ -168,10 +185,10 @@ class ImageWrapper(notifier.Notifier): ...@@ -168,10 +185,10 @@ class ImageWrapper(notifier.Notifier):
self.__name, self.__name,
self.__range[0], self.__range[0],
self.__range[1], self.__range[1],
self.__sliceCoverage)) self.__coverage))
volumes, expansions = calcSliceExpansion(slices, volumes, expansions = calcSliceExpansion(slices,
self.__sliceCoverage, self.__coverage,
self.__numRealDims, self.__numRealDims,
self.__numPadDims) self.__numPadDims)
...@@ -190,10 +207,7 @@ class ImageWrapper(notifier.Notifier): ...@@ -190,10 +207,7 @@ class ImageWrapper(notifier.Notifier):
self.__range = (newmin, newmax) self.__range = (newmin, newmax)
for vol, exp in zip(volumes, expansions): for vol, exp in zip(volumes, expansions):
self.__sliceCoverage[vol] = adjustSliceCoverage( self.__coverage[vol] = adjustCoverage(self.__coverage[vol], exp)
self.__sliceCoverage[vol],
exp,
self.__numRealDims)
# TODO floating point error # TODO floating point error
if newmin != oldmin or newmax != oldmax: if newmin != oldmin or newmax != oldmax:
...@@ -234,7 +248,7 @@ class ImageWrapper(notifier.Notifier): ...@@ -234,7 +248,7 @@ class ImageWrapper(notifier.Notifier):
slices = sliceObjToSliceTuple(sliceobj, self.__image.shape) slices = sliceObjToSliceTuple(sliceobj, self.__image.shape)
if not sliceCovered(slices, if not sliceCovered(slices,
self.__sliceCoverage, self.__coverage,
self.__image.shape, self.__image.shape,
self.__numRealDims): self.__numRealDims):
...@@ -243,14 +257,25 @@ class ImageWrapper(notifier.Notifier): ...@@ -243,14 +257,25 @@ class ImageWrapper(notifier.Notifier):
return data return data
def sliceObjToSliceTuple(sliceobj, shape): def sliceObjToSliceTuple(sliceobj, shape):
"""Turns an array slice object into a tuple of (low, high) index """Turns an array slice object into a tuple of (low, high) index
pairs, one pair for each dimension in the given shape pairs, one pair for each dimension in the given shape
:arg sliceobj: Something which can be used to slice an array of shape
``shape``.
:arg shape: Shape of the array being sliced.
""" """
indices = [] indices = []
if not isinstance(sliceobj, collections.Sequence):
sliceobj = [sliceobj]
if len(sliceobj) != len(shape):
missing = len(shape) - len(sliceobj)
sliceobj = list(sliceobj) + [slice(None) for i in range(missing)]
for dim, s in enumerate(sliceobj): for dim, s in enumerate(sliceobj):
# each element in the slices tuple should # each element in the slices tuple should
...@@ -266,8 +291,11 @@ def sliceObjToSliceTuple(sliceobj, shape): ...@@ -266,8 +291,11 @@ def sliceObjToSliceTuple(sliceobj, shape):
return tuple(indices) return tuple(indices)
def sliceTupletoSliceObj(slices): def sliceTupleToSliceObj(slices):
""" """Turns a sequence of (low, high) index pairs into a tuple of array
slice objects.
:arg slices: A sequence of (low, high) index pairs.
""" """
sliceobj = [] sliceobj = []
...@@ -278,10 +306,41 @@ def sliceTupletoSliceObj(slices): ...@@ -278,10 +306,41 @@ def sliceTupletoSliceObj(slices):
return tuple(sliceobj) return tuple(sliceobj)
def adjustCoverage(oldCoverage, slices):
"""Adjusts/expands the given ``oldCoverage`` so that it covers the
given set of ``slices``.
:arg oldCoverage: A sequence of (low, high) index pairs
:arg slices: A sequence of (low, high) index pairs. If ``slices``
contains more dimensions than are specified in
``oldCoverage``, the trailing dimensions are ignored.
:return: A list of (low, high) tuples containing the adjusted coverage.
"""
newCoverage = []
for dim in range(len(oldCoverage)):
low, high = slices[ dim]
lowCover, highCover = oldCoverage[dim]
if lowCover is None or low < lowCover: lowCover = low
if highCover is None or high > highCover: highCover = high
newCoverage.append((lowCover, highCover))
return newCoverage
def sliceCovered(slices, sliceCoverage, shape, realDims): def sliceCovered(slices, coverage, shape, realDims):
"""Returns ``True`` if the portion of the image data calculated by """Returns ``True`` if the portion of the image data calculated by
the given ``slices` has already been calculated, ``False`` otherwise. the given ``slices` has already been calculated, ``False`` otherwise.
:arg slices:
:arg coverage:
:arg shape:
:arg volDim:
""" """
lowVol, highVol = slices[realDims - 1] lowVol, highVol = slices[realDims - 1]
...@@ -289,11 +348,11 @@ def sliceCovered(slices, sliceCoverage, shape, realDims): ...@@ -289,11 +348,11 @@ def sliceCovered(slices, sliceCoverage, shape, realDims):
for vol in range(lowVol, highVol): for vol in range(lowVol, highVol):
coverage = sliceCoverage[vol] volCoverage = coverage[vol]
for dim, size in enumerate(shape): for dim, size in enumerate(shape):
lowCover, highCover = coverage[dim] lowCover, highCover = volCoverage[dim]
if lowCover is None or highCover is None: if lowCover is None or highCover is None:
return False return False
...@@ -309,7 +368,7 @@ def sliceCovered(slices, sliceCoverage, shape, realDims): ...@@ -309,7 +368,7 @@ def sliceCovered(slices, sliceCoverage, shape, realDims):
return True return True
def calcSliceExpansion(slices, sliceCoverage, realDims, padDims): def calcSliceExpansion(slices, coverage, realDims, padDims):
""" """
""" """
...@@ -324,13 +383,13 @@ def calcSliceExpansion(slices, sliceCoverage, realDims, padDims): ...@@ -324,13 +383,13 @@ def calcSliceExpansion(slices, sliceCoverage, realDims, padDims):
for vol in volumes: for vol in volumes:
coverage = sliceCoverage[vol] volCoverage = coverage[vol]
expansion = [] expansion = []
for dim in range(realDims - 1): for dim in range(realDims - 1):
lowCover, highCover = coverage[dim] lowCover, highCover = volCoverage[dim]
lowSlice, highSlice = slices[ dim] lowSlice, highSlice = slices[ dim]
if lowCover is None: lowCover = lowSlice if lowCover is None: lowCover = lowSlice
if highCover is None: highCover = highSlice if highCover is None: highCover = highSlice
...@@ -345,22 +404,3 @@ def calcSliceExpansion(slices, sliceCoverage, realDims, padDims): ...@@ -345,22 +404,3 @@ def calcSliceExpansion(slices, sliceCoverage, realDims, padDims):
expansions.append(expansion) expansions.append(expansion)
return volumes, expansions return volumes, expansions
def adjustSliceCoverage(oldCoverage, slices, realDims):
"""
"""
newCoverage = []
for dim in range(realDims - 1):
low, high = slices[ dim]
lowCover, highCover = oldCoverage[dim]
if lowCover is None or low < lowCover: lowCover = low
if highCover is None or high < highCover: highCover = high
newCoverage.append((lowCover, highCover))
return newCoverage
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment