Skip to content
Snippets Groups Projects
Forked from FSL / fslpy
2684 commits behind the upstream repository.
imagewrapper.py 12.71 KiB
#!/usr/bin/env python
#
# imagewrapper.py - The ImageWrapper class.
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
"""This module provides the :class:`ImageWrapper` class, which can be used
to manage data access to ``nibabel`` NIFTI images.
"""


import logging
import collections

import numpy   as np
import nibabel as nib

import fsl.utils.notifier as notifier
import fsl.utils.memoize  as memoize


log = logging.getLogger(__name__)


class ImageWrapper(notifier.Notifier):
    """The ``ImageWrapper`` class is a convenience class which manages data
    access to ``nibabel`` NIFTI images. The ``ImageWrapper`` class can be
    used to:
    
      - Control whether the image is loaded into memory, or kept on disk
    
      - Incrementally update the known image  data range, as more image
        data is read in.


    The ``ImageWrapper`` implements the :class:`.Notifier` interface.
    Listeners can register to be notified whenever the known image data range
    is updated. The data range can be accessed via the :attr:`dataRange`
    property.


    The ``ImageWrapper`` class uses the following functions (also defined in 
    this module) to keep track of the portion of the image that has currently
    been included in the data range calculation:

    .. autosummary::
       :nosignatures:

       sliceObjToSliceTuple
       sliceTupleToSliceObj
       sliceCovered
       calcSliceExpansion
       adjustCoverage
    """

    
    def __init__(self, image, name=None, loadData=False):
        """Create an ``ImageWrapper``.

        :arg image:    A ``nibabel.Nifti1Image``.

        :arg name:     A name for this ``ImageWrapper``, solely used for debug
                       log messages.

        :arg loadData: If ``True``, the image data is loaded into memory.
                       Otherwise it is kept on disk (and data access is
                       performed through the ``nibabel.Nifti1Image.dataobj``
                       array proxy).
        """
        
        self.__image = image
        self.__name  = name

        # Save the number of 'real' dimensions,
        # that is the number of dimensions minus
        # any trailing dimensions of length 1
        self.__numRealDims = len(image.shape)
        for d in reversed(image.shape):
            if d == 1: self.__numRealDims -= 1
            else:      break

        # And save the number of
        # 'padding' dimensions too.
        self.__numPadDims = len(image.shape) - self.__numRealDims

        hdr = image.get_header()

        # The current known image data range. This
        # gets updated as more image data gets read.
        # We default to whatever is stored in the
        # header (which may or may not contain useful
        # values).
        self.__range = (float(hdr['cal_min']), float(hdr['cal_max']))

        # The coverage array is used to keep track of
        # the portions of the image which have been
        # considered in the data range calculation.
        # We use this coverage to avoid unnecessarily
        # re-calculating the data range on the same
        # part of the image.
        #
        # First of all, we're going to store a separate
        # 'coverage' for each 2D slice in the 3D image
        # (or 3D volume for 4D images). This effectively
        # means a seaprate coverage for each index in the
        # last 'real' image dimension (see above).
        # 
        # For each slice/volume, the the coverage is
        # stored as sequences of (low, high) indices, one
        # for each dimension in the slice/volume (e.g.
        # row/column for a slice, or row/column/depth
        # for a volume).
        #
        # All of these indices are stored in a big numpy
        # array:
        #   - first dimension:  low/high index
        #   - second dimension: image dimension
        #   - third dimension:  slice/volume index
        self.__coverage = np.zeros(
            (2, self.__numRealDims - 1, image.shape[self.__numRealDims - 1]),
            dtype=np.uint32)

        self.__coverage[:] = np.nan

        if loadData:
            self.loadData()

        
    @property
    def dataRange(self):
        """Returns the currently known data range as a tuple of ``(min, max)``
        values.
        """
        return tuple(self.__range)


    def loadData(self):
        """Forces all of the image data to be loaded into memory.

        .. note:: This method will be called by :meth:`__init__` if its
                  ``loadData`` parameter is ``True``.
        """

        # If the data is not already
        # loaded, this will cause 
        # nibabel to load and cache it
        self.__image.get_data()


    def __getData(self, sliceobj, isTuple=False):
        """
        """

        if isTuple:
            sliceobj = sliceTupleToSliceObj(sliceobj)

        # If the image has not been loaded
        # into memory,  we can use the nibabel
        # ArrayProxy. Otheriwse if it is in
        # memory, we can access it directly.
        #
        # Furthermore, if it is in memory and
        # has been modified, the ArrayProxy
        # will give us out-of-date values (as
        # the ArrayProxy reads from disk). So
        # we have to read from the in-memory
        # array to get changed values.
        if self.__image.in_memory: return self.__image.get_data()[sliceobj]
        else:                      return self.__image.dataobj[   sliceobj] 


    @memoize.Instanceify(memoize.memoize(args=[0]))
    def __updateDataRangeOnRead(self, slices, data):
        """Called by :meth:`__getitem__`. Calculates the minimum/maximum
        values of the given data (which has been extracted from the portion of
        the image specified by ``slices``), and updates the known data range
        of the image.

        :arg slices: A sequence of ``(low, high)`` index pairs, one for each
                     dimension in the image. Tuples are used instead of
                     ``slice`` objects, because this method is memoized (and
                     ``slice`` objects are unhashable).
        
        :arg data:   The image data at the given ``slices`` (as a ``numpy``
                     array).
        """

        oldmin, oldmax = self.__range

        log.debug('Updating image {} data range (current range: '
                  '[{}, {}]; current coverage: {})'.format(
                      self.__name,
                      self.__range[0],
                      self.__range[1],
                      self.__coverage))

        volumes, expansions = calcSliceExpansion(slices,
                                                 self.__coverage,
                                                 self.__numRealDims,
                                                 self.__numPadDims)

        newmin = oldmin
        newmax = oldmax

        for vol, exp in zip(volumes, expansions):

            data = self.__getData(exp, isTuple=True)
            dmin = float(np.nanmin(data))
            dmax = float(np.nanmax(data))

            if newmin is None or dmin < newmin: newmin = dmin
            if newmax is None or dmax > newmax: newmax = dmax

        self.__range = (newmin, newmax)

        for vol, exp in zip(volumes, expansions):
            self.__coverage[..., vol] = adjustCoverage(
                self.__coverage[..., vol], exp)

        # TODO floating point error
        if newmin != oldmin or newmax != oldmax:
            log.debug('Image {} range changed: [{}, {}] -> [{}, {}]'.format(
                self.__name,
                oldmin,
                oldmax,
                newmin,
                newmax))
            self.notify()

            
    def __getitem__(self, sliceobj):
        """Returns the image data for the given ``sliceobj``, and updates
        the known image data range if necessary.

        .. note:: If the image data is in memory, it is accessed 
                  directly, via the ``nibabel.Nifti1Image.get_data`` 
                  method. Otherwise the image data is accessed through 
                  the ``nibabel.Nifti1Image.dataobj`` array proxy.

        :arg sliceobj: Something which can slice the image data.
        """

        log.debug('Getting image data: {}'.format(sliceobj))

        sliceobj = nib.fileslice.canonical_slicers(
            sliceobj, self.__image.shape)

        # TODO Cache 3D images for large 4D volumes, 
        #      so you don't have to hit the disk?

        data = self.__getData(sliceobj)

        # TODO If full range is 
        #      known, return now.

        slices = sliceObjToSliceTuple(sliceobj, self.__image.shape)

        if not sliceCovered(slices,
                            self.__coverage,
                            self.__image.shape,
                            self.__numRealDims):
            
            self.__updateDataRangeOnRead(slices, data)

        return data


def sliceObjToSliceTuple(sliceobj, shape):
    """Turns an array slice object into a tuple of (low, high) index
    pairs, one pair for each dimension in the given shape

    :arg sliceobj: Something which can be used to slice an array of shape
                   ``shape``.

    :arg shape:    Shape of the array being sliced.
    """

    indices = []

    if not isinstance(sliceobj, collections.Sequence):
        sliceobj = [sliceobj]

    if len(sliceobj) != len(shape):
        missing  = len(shape) - len(sliceobj)
        sliceobj = list(sliceobj) + [slice(None) for i in range(missing)]

    for dim, s in enumerate(sliceobj):

        # each element in the slices tuple should 
        # be a slice object or an integer
        if isinstance(s, slice): i = [s.start, s.stop]
        else:                    i = [s,       s + 1]

        if i[0] is None: i[0] = 0
        if i[1] is None: i[1] = shape[dim]

        indices.append(tuple(i))

    return tuple(indices)


def sliceTupleToSliceObj(slices):
    """Turns a sequence of (low, high) index pairs into a tuple of array
    slice objects.

    :arg slices: A sequence of (low, high) index pairs.
    """

    sliceobj = []

    for lo, hi in slices:
        sliceobj.append(slice(lo, hi, 1))

    return tuple(sliceobj)


def adjustCoverage(oldCoverage, slices): 
    """Adjusts/expands the given ``oldCoverage`` so that it covers the
    given set of ``slices``.

    :arg oldCoverage: A ``numpy`` array of shape ``(2, n)`` containing
                      the (low, high) index pairs for a single slice/volume
                      in the image.
    
    :arg slices:      A sequence of (low, high) index pairs. If ``slices``
                      contains more dimensions than are specified in
                      ``oldCoverage``, the trailing dimensions are ignored.

    :return: A ``numpy`` array containing the adjusted/expanded coverage.
    """

    newCoverage = np.zeros(oldCoverage.shape, dtype=np.uint32)

    for dim in range(oldCoverage.shape[1]):

        low,      high      = slices[        dim]
        lowCover, highCover = oldCoverage[:, dim]

        if lowCover  is None or low  < lowCover:  lowCover  = low
        if highCover is None or high > highCover: highCover = high

        newCoverage[:, dim] = lowCover, highCover

    return newCoverage


def sliceCovered(slices, coverage, shape, realDims):
    """Returns ``True`` if the portion of the image data calculated by
    the given ``slices` has already been calculated, ``False`` otherwise.

    :arg slices:
    :arg coverage:
    :arg shape:
    :arg volDim:
    """

    lowVol, highVol = slices[realDims - 1]
    shape           = shape[:realDims - 1]

    for vol in range(lowVol, highVol):

        for dim, size in enumerate(shape):

            lowCover, highCover = coverage[:, dim, vol]
            lowSlice, highSlice = slices[     dim] 

            if lowCover is None or highCover is None:
                return False

            if lowSlice  is None: lowSlice  = 0
            if highSlice is None: highSlice = size

            if lowSlice  < lowCover:  return False
            if highSlice > highCover: return False

    return True


def calcSliceExpansion(slices, coverage, realDims, padDims):
    """
    """

    # One per volume
    lowVol, highVol = slices[realDims - 1] 

    expansions = []
    volumes    = list(range(lowVol, highVol))

    # TODO Reduced slice duplication.
    #      You know what this means.

    for vol in volumes:

        expansion = []

        for dim in range(realDims - 1):

            lowCover, highCover = coverage[:, dim, vol]
            lowSlice, highSlice = slices[     dim]

            if lowCover  is None: lowCover  = lowSlice
            if highCover is None: highCover = highSlice

            expansion.append((min(lowCover,  lowSlice),
                              max(highCover, highSlice)))

        expansion.append((vol, vol + 1))
        for i in range(padDims):
            expansion.append((0, 1))

        expansions.append(expansion)

    return volumes, expansions