Skip to content
Snippets Groups Projects
nonlinear.py 29.5 KiB
Newer Older
# nonlinear.py - Functions/classes for non-linear transformations.
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
"""This module contains data structures and functions for working with
The :class:`DeformationField` and :class:`CoefficientField` can be used to
load and interact with FNIRT-style transformation images. The following
utility functions are also available:
   detectDeformationType
   convertDeformationType
   convertDeformationSpace
   coefficientFieldToDeformationField
import numpy                       as np
import scipy.ndimage.interpolation as ndinterp
import fsl.utils.memoize           as memoize
import fsl.data.image              as fslimage
import fsl.utils.image.resample    as resample
from . import                         affine
class NonLinearTransform(fslimage.Image):
    """Class which represents a nonlinear transformation. This is just a base
    class for the :class:`DeformationField` and :class:`CoefficientField`
    classes.


    A nonlinear transformation is an :class:`.Image` which contains
    some mapping between a source image coordinate system and a reference image
    In FSL, non-linear transformations are defined in terms of the reference
    image coordinate system.  At a given location in the reference image
    space, the non-linear mapping at that location can be used to calculate
    the corresponding location in the source image space. Therefore, these
    non-linear transformation effectively encode a transformation *from* the
    reference image *to* the (unwarped) source image.
        """Create a ``NonLinearTransform``. See the :meth:`.Nifti.getAffine`
        method for an overview of the values that ``srcSpace`` and ``refSpace``
        may take.
        :arg image:       A string containing the name of an image file to
                          load, or a :mod:`numpy` array, or a :mod:`nibabel`
                          image object.
        :arg src:         :class:`.Nifti` representing the source image.
        :arg ref:         :class:`.Nifti` representing the reference image.
        :arg srcSpace:    Coordinate system in the source image that this
                          ``NonLinearTransform`` maps to. Defaults to
                          ``'fsl'``.
        :arg refSpace:    Coordinate system in the reference image that this
                          ``NonLinearTransform`` maps from. Defaults to
                          ``'fsl'``.

        All other arguments are passed through to :meth:`.Image.__init__`.
        # TODO Could make more general by replacing
        # srcSpace and refSpace with src/ref affines,
        # which transform tofrom (e.g.) source/ref
        # voxels to/from the space required by the
        # deformation field

        if srcSpace is None: srcSpace = 'fsl'
        if refSpace is None: refSpace = 'fsl'

        if srcSpace not in ('fsl', 'voxel', 'world') or \
           refSpace not in ('fsl', 'voxel', 'world'):
            raise ValueError('Invalid source/reference space: {} -> {}'.format(
                srcSpace, refSpace))
        fslimage.Image.__init__(self, image, **kwargs)

        self.__src      = fslimage.Nifti(src.header.copy())
        self.__ref      = fslimage.Nifti(ref.header.copy())
        self.__srcSpace = srcSpace
        self.__refSpace = refSpace
        """Return a reference to the :class:`.Nifti` instance representing
        the source image.
        """
        """Return a reference to the :class:`.Nifti` instance representing
        the reference image.
        """
        """Return the source image coordinate system this
        ``NonLinearTransform`` maps from - see :meth:`.Nifti.getAffine`.
        """
        """Return the reference image coordinate system this
        ``NonLinearTransform`` maps to - see :meth:`.Nifti.getAffine`.
        """
    def transform(self, coords, from_=None, to=None):
        """Transform coordinates from the reference image space to the source
        image space. Implemented by sub-classes.

        See the :meth:`.Nifti.getAffine` method for an overview of the values
        that ``from_`` and ``to`` may take.

        :arg coords: A sequence of XYZ coordinates, or ``numpy`` array of shape
                     ``(n, 3)`` containing ``n`` sets of coordinates in the
                     reference space.
        :arg from_:  Reference image space that ``coords`` are defined in
        :arg to:     Source image space to transform ``coords`` into
        :returns:    The corresponding coordinates in the source image space.
class DeformationField(NonLinearTransform):
    """Class which represents a deformation (a.k.a. warp) field which, at each
    voxel, contains either:

      - a relative displacement from the reference image space to the source
        image space, or
      - absolute coordinates in the source space


    It is assumed that the a ``DeformationField`` is aligned with the
    reference image in their world coordinate systems (i.e. their ``sform``
Paul McCarthy's avatar
Paul McCarthy committed
    affines project the reference image and the deformation field into
    alignment).
    def __init__(self, image, src, ref=None, **kwargs):
        :arg ref:     Optional. If not provided, it is assumed that the
                      reference is defined in the same space as ``image``.
        :arg defType: Either ``'absolute'`` or ``'relative'``, indicating
                      the type of this displacement field. If not provided,
                      will be inferred via the :func:`detectDeformationType`
                      function.

        All other arguments are passed through to
        :meth:`NonLinearTransform.__init__`.
        defType = kwargs.pop('defType', None)
        if defType not in (None, 'relative', 'absolute'):
            raise ValueError('Invalid value for defType: {}'.format(defType))
        NonLinearTransform.__init__(self, image, src, ref, **kwargs)

        if not self.sameSpace(self.ref):
            raise ValueError('Invalid reference image: {}'.format(self.ref))
    def deformationType(self):
        """The type of this ``DeformationField`` - ``'absolute'`` or
        if self.__defType is None:
            self.__defType = detectDeformationType(self)
        return self.__defType
        """``True`` if this ``DeformationField`` contains absolute
        coordinates.
        return self.deformationType == 'absolute'
        """``True`` if this ``DeformationField`` contains relative
        return self.deformationType == 'relative'
    def transform(self, coords, from_=None, to=None):
        """Transform the given XYZ coordinates from the reference image space
        to the source image space.

        :arg coords: A sequence of XYZ coordinates, or ``numpy`` array of shape
                     ``(n, 3)`` containing ``n`` sets of coordinates in the
                     reference space.
        :arg from_:  Reference image space that ``coords`` are defined in
        :arg to:     Source image space to transform ``coords`` into
Paul McCarthy's avatar
Paul McCarthy committed
        :returns:    ``coords``, transformed into the source image space
        """

        if from_ is None: from_ = self.refSpace
        if to    is None: to    = self.srcSpace

        coords = np.asanyarray(coords)

        # We may need to pre-transform the
        # coordinates so they are in the
        # same reference image space as the
        # displacements
        if from_ != self.refSpace:
            xform  = self.ref.getAffine(from_, self.refSpace)
            coords = affine.transform(coords, xform)
        # We also need to get the coordinates
        # in field voxels, so we can look up
        # the displacements/coordinates. We
        # can get this through the assumption
        # that field and ref are aligned in
        # the world coordinate system
        xform = affine.concat(self    .getAffine('world',       'voxel'),
                              self.ref.getAffine(self.refSpace, 'world'))

        if np.all(np.isclose(xform, np.eye(4))):
            voxels = coords
        else:
            voxels = affine.transform(coords, xform)

        # Mask out the coordinates
        # that are out of bounds of
        # the deformation field
        voxels  = np.round(voxels).astype(np.int)
        voxmask = (voxels >= [0, 0, 0]) & (voxels < self.shape[:3])
        voxmask = voxmask.all(axis=1)
        voxels  = voxels[voxmask]

        xs, ys, zs = voxels.T

        if self.absolute: disps = self.data[xs, ys, zs, :]
        else:             disps = self.data[xs, ys, zs, :] + coords[voxmask]
        # Make sure the coordinates are in
        # the requested source image space
        if to != self.srcSpace:
            xform = self.src.getAffine(self.srcSpace, to)
            disps = affine.transform(disps, xform)
        # Nans for input coordinates which
        # were outside of the field
        outcoords          = np.full(coords.shape, np.nan)
        outcoords[voxmask] = disps

        return outcoords
class CoefficientField(NonLinearTransform):
    """Class which represents a B-spline coefficient field generated by FNIRT.

    The :meth:`displacements` method can be used to calculate relative
    displacements for a set of reference space voxel coordinates.


    A FNIRT nonlinear transformation often contains a *premat*, a global
    affine transformation from the source space to the reference space, which
    was calculated with FLIRT, and used as the starting point for the
    non-linear optimisation performed by FNIRT.


    This affine may be provided when creating a ``CoefficientField`` as the
    ``srcToRefMat`` argument to :meth:`__init__`, and is subsequently accessed
    via the :meth:`srcToRefMat` attribute.
                 srcSpace=None,
                 refSpace=None,
                 fieldType='cubic',
                 knotSpacing=None,
                 fieldToRefMat=None,
        :arg fieldType:     Must currently be ``'cubic'``

        :arg knotSpacing:   A tuple containing the spline knot spacings along
                            each axis.

        :arg fieldToRefMat: Affine transformation which can transform reference
                            image voxel coordinates into coefficient field
                            voxel coordinates.

        :arg srcToRefMat:   Optional initial global affine transformation from
                            the source image to the reference image. This is
                            assumed to be a FLIRT-style matrix, i.e. it
                            transforms from source image ``srcSpace``
                            coordinates into reference image ``refSpace``
                            coordinates (typically ``'fsl'`` coordinates, i.e.
                            scaled voxels potentially with a left-right flip).

        See the :class:`NonLinearTransform` class for details on the other
        arguments.
            raise ValueError('Unsupported field type: {}'.format(fieldType))

        if srcToRefMat   is not None: srcToRefMat   = np.copy(srcToRefMat)
        if fieldToRefMat is     None: fieldToRefMat = np.eye(4)
        NonLinearTransform.__init__(self,
                                    image,
                                    src,
                                    ref,
                                    srcSpace,
                                    refSpace,
                                    **kwargs)

        self.__fieldType     = fieldType
        self.__knotSpacing   = tuple(knotSpacing)
        self.__refToSrcMat   = None
        self.__srcToRefMat   = srcToRefMat
        self.__fieldToRefMat = np.copy(fieldToRefMat)
        self.__refToFieldMat = affine.invert(self.__fieldToRefMat)

        if srcToRefMat is not None:
            self.__refToSrcMat = affine.invert(srcToRefMat)

        """Return the type of the coefficient field, currently always
        ``'cubic'``.
        """
        return self.__fieldType


    @property
    def knotSpacing(self):
        """Return a tuple containing spline knot spacings along the x, y, and
        z axes.
        """
        return self.__knotSpacing

    @property
    def fieldToRefMat(self):
        """Return an affine transformation which can transform coefficient
        field voxel coordinates into reference image voxel coordinates.
        """
        return np.copy(self.__fieldToRefMat)


    @property
    def refToFieldMat(self):
        """Return an affine transformation which can transform reference
        image voxel coordinates into coefficient field voxel coordinates.
        """
        return np.copy(self.__refToFieldMat)


    @property
    def srcToRefMat(self):
        """Return the initial source-to-reference affine, or ``None`` if
        there isn't one.
        """
        return self.__srcToRefMat


    @property
    def refToSrcMat(self):
        """Return the inverse of the initial source-to-reference affine, or
        ``None`` if there isn't one.
        """
        return self.__refToSrcMat


    @memoize.Instanceify(memoize.memoize)
    def asDeformationField(self, defType='relative', premat=True):
        """Convert this ``CoefficientField`` to a :class:`DeformationField`.

        This method is a wrapper around
        :func:`coefficientFieldToDeformationField`
        return coefficientFieldToDeformationField(self, defType, premat)
    def transform(self, coords, from_=None, to=None, premat=True):
        """Overrides :meth:`NonLinearTransform.transform`. Transforms the
        given ``coords`` from the reference image space into the source image
        space.

        :arg coords: A sequence of XYZ coordinates, or ``numpy`` array of shape
                     ``(n, 3)`` containing ``n`` sets of coordinates in the
                     reference space.
        :arg from_:  Reference image space that ``coords`` are defined in
        :arg to:     Source image space to transform ``coords`` into
        :arg premat: If ``True``, the inverse :meth:`srcToRefMat` is applied
                     to the coordinates after the displacements have been
                     addd.
Paul McCarthy's avatar
Paul McCarthy committed

        :returns:    ``coords``, transformed into the source image space
        df = self.asDeformationField(premat=premat)
        return df.transform(coords, from_, to)


    def displacements(self, coords):
        """Calculate the relative displacemenets for the given coordinates.

        :arg coords: ``(N, 3)`` array of reference image voxel coordinates.
        :return:      A ``(N, 3)`` array  of relative displacements to the
                      source image for ``coords``
        """

        if self.fieldType != 'cubic':
            raise NotImplementedError()

        # See
        #   https://www.cs.jhu.edu/~cis/cista/746/papers/\
        #     RueckertFreeFormBreastMRI.pdf
        #   https://www.fmrib.ox.ac.uk/datasets/techrep/tr07ja2/tr07ja2.pdf

        # Cubic b-spline basis functions
        def b0(u):
            return ((1 - u) ** 3) / 6

        def b1(u):
            return (3 * (u ** 3) - 6 * (u ** 2) + 4) / 6

        def b2(u):
            return (-3 * (u ** 3) + 3 * (u ** 2)  + 3 * u + 1) / 6

        def b3(u):
            return (u ** 3) / 6

        b = [b0, b1, b2, b3]

        fdata      = self.data
        nx, ny, nz = self.shape[:3]
        ix, iy, iz = self.ref.shape[:3]

        # Convert the given voxel coordinates
        # into the corresponding coefficient
        # field voxel coordinates
        x, y, z    = coords.T
        i, j, k    = affine.transform(coords, self.refToFieldMat).T

        # i, j, k: coefficient field indices
        # u, v, w: position of the ref voxel
        #          on the current spline
        u = np.remainder(i, 1)
        v = np.remainder(j, 1)
        w = np.remainder(k, 1)
        i = np.floor(i).astype(np.int)
        j = np.floor(j).astype(np.int)
        k = np.floor(k).astype(np.int)

        disps = np.zeros(coords.shape)

        for l, m, n in it.product(range(4), range(4), range(4)):
            il   = i + l
            jm   = j + m
            kn   = k + n
            mask = (il >= 0)  & \
                   (il <  nx) & \
                   (jm >= 0)  & \
                   (jm <  ny) & \
                   (kn >= 0)  & \
                   (kn <  nz)

            il = il[mask]
            jm = jm[mask]
            kn = kn[mask]
            uu = u[ mask]
            vv = v[ mask]
            ww = w[ mask]

            cx, cy, cz = fdata[il, jm, kn, :].T
            c          = b[l](uu) * b[m](vv) * b[n](ww)

            disps[mask, 0] += c * cx
            disps[mask, 1] += c * cy
            disps[mask, 2] += c * cz

        return disps


def detectDeformationType(field):
    """Attempt to automatically determine whether a deformation field is
    specified in absolute or relative coordinates.

    :arg field: A :class:`DeformationField`

    :returns:   ``'absolute'`` if it looks like ``field`` contains absolute
                coordinates, ``'relative'`` otherwise.
    """

    # This test is based on the assumption
    # that a deformation field containing
    # absolute coordinates will have a
    # greater standard deviation than one
    # which contains relative coordinates.
    reldata = convertDeformationType(field, 'relative')
    stdabs  = absdata.std(axis=(0, 1, 2)).sum()
    stdrel  = reldata.std(axis=(0, 1, 2)).sum()

    if stdabs > stdrel: return 'absolute'
    else:               return 'relative'


def convertDeformationType(field, defType=None):
    """Convert a deformation field between storing absolute coordinates or
    relative displacements.
    :arg field:   A :class:`DeformationField` instance
    :arg defType: Either ``'absolute'`` or ``'relative'``. If not provided,
                  the opposite type to ``field.deformationType`` is used.
    :returns:     A ``numpy.array`` containing the adjusted deformation field.
    if defType is None:
        if field.deformationType == 'absolute': defType = 'relative'
        else:                                   defType = 'absolute'

    # Regardless of the conversion direction,
    # we need the coordinates of every voxel
    # in the reference coordinate system.
    xform      = field.getAffine('voxel', field.refSpace)

    coords     = np.meshgrid(np.arange(dx),
                             np.arange(dy),
                             np.arange(dz), indexing='ij')
    coords     = np.array(coords).transpose((1, 2, 3, 0))
    coords     = affine.transform(coords.reshape((-1, 3)), xform)
    coords     = coords.reshape((dx, dy, dz, 3))

    # If converting from relative to absolute,
    # we just add the coordinates to (what is
    # assumed to be) the relative shift. Or,
    # to convert from absolute to relative,
    # we subtract the reference image voxels.
    if   defType == 'absolute': return field.data + coords
    elif defType == 'relative': return field.data - coords
def convertDeformationSpace(field, from_, to):
    """Adjust the source and/or reference spaces of the given deformation
    field. See the :meth:`.Nifti.getAffine` method for the valid values for
    the ``from_`` and ``to`` arguments.

    :arg field: A :class:`DeformationField` instance
    :arg from_: New reference image coordinate system
    :arg to:    New source image coordinate system
    :returns:   A new :class:`DeformationField` which transforms between
                the reference ``from_`` coordinate system and the source ``to``
    if field.srcSpace == to and field.refSpace == from_:
        return field

    # Get the field in absolute coordinates
    # if necessary - these are our source
    # coordinates in the original "to" space.
    if field.relative: srccoords = convertDeformationType(field)
    else:              srccoords = fieldcoords

    srccoords = srccoords.reshape((-1, 3))

    # Now transform those source coordinates
    # from the original source space to the
    # source space specified by "to"
    if to != field.srcSpace:

        srcmat    = field.src.getAffine(field.srcSpace, to)
        srccoords = affine.transform(srccoords, srcmat)

    # If we have been asked to return
    # reference "from_" coordinate
    # system is irrelevant - we're done.
    # Otherwise our deformation field
    # will contain relative displacements
    # between the reference image "from_"
    # coordinate system and the source
    # image "to" coordinate system. We
    # need to re-calculate the relative
    # displacements between the new
    # reference "from_" space and source
    # "to" space.
        refcoords = np.meshgrid(np.arange(field.shape[0]),
                                np.arange(field.shape[1]),
                                np.arange(field.shape[2]), indexing='ij')
        refcoords = np.array(refcoords)
        refcoords = refcoords.transpose((1, 2, 3, 0)).reshape((-1, 3))

        if from_ != 'voxel':
            refmat    = field.ref.getAffine('voxel', from_)
            refcoords = affine.transform(refcoords, refmat)

        fieldcoords = srccoords - refcoords

        fieldcoords.reshape(field.shape),
        header=field.header,
        defType=field.deformationType)
def applyDeformation(image, field, ref=None, order=1, mode=None, cval=None):
    """Applies a :class:`DeformationField` to an :class:`.Image`.

    The image is transformed into the space of the field's reference image
    space. See the ``scipy.ndimage.interpolation.map_coordinates`` function
    for details on the ``order``, ``mode`` and ``cval`` options.

    If an alternate reference image is provided via the ``ref`` argument,
    the deformation field is resampled into its space, and then applied to
Paul McCarthy's avatar
Paul McCarthy committed
    the input image. It is therefore assumed that an alternate ``ref`` is
    aligned in world coordinates with the field's actual reference image.

    :arg image: :class:`.Image` to be transformed

    :arg field: :class:`DeformationField` to use

    :arg ref:   Alternate reference image - if not provided, ``field.ref``
                is used

    :arg order: Spline interpolation order, passed through to the
                ``scipy.ndimage.affine_transform`` function - ``0``
                corresponds to nearest neighbour interpolation, ``1``
                (the default) to linear interpolation, and ``3`` to
                cubic interpolation.

    :arg mode:  How to handle regions which are outside of the image FOV.
                Defaults to `''nearest'``.

    :arg cval:  Constant value to use when ``mode='constant'``.

    :return:    ``numpy.array`` containing the transformed image data.
    """

    if order is None: order = 1
    if mode  is None: mode  = 'nearest'
    if cval  is None: cval  = 0

    # We need the field to contain
    # absolute source image voxel
    # coordinates
    field = convertDeformationSpace(field, 'voxel', 'voxel')
    if field.deformationType != 'absolute':
        field = DeformationField(convertDeformationType(field, 'absolute'),
                                 header=field.header,
                                 src=field.src,
                                 ref=field.ref,
                                 srcSpace='voxel',
                                 refSpace='voxel',
                                 defType='absolute')

    # Resample to alternate reference image
    # space if provided - regions of the
    # field outside of the reference image
    # space will contain -1s, so will be
    # detected as out of bounds by
    # map_coordinates
    if ref is not None:
        field = resample.resampleToReference(field,
                                             ref,
                                             order=1,
                                             mode='constant',
                                             cval=-1)[0]

    field = field.transpose((3, 0, 1, 2))
    return ndinterp.map_coordinates(image.data,
                                    field,
                                    order=order,
                                    mode=mode,
                                    cval=cval)


def coefficientFieldToDeformationField(field, defType='relative', premat=True):
    """Convert a :class:`CoefficientField` into a :class:`DeformationField`.
    :arg field:   :class:`CoefficientField` to convert
    :arg defType: The type of deformation field - either ``'relative'`` (the
                  default) or ``'absolute'``.
    :arg premat:  If ``True`` (the default), the :meth:`srcToRefMat` is
                  encoded into the deformation field.
    :return:      :class:`DeformationField` calculated from ``field``.
    # Generate coordinates for every
    # voxel in the reference image
    ix, iy, iz = field.ref.shape[:3]
    x,  y,  z  = np.meshgrid(np.arange(ix),
                             np.arange(iy),
                             np.arange(iz), indexing='ij')
    x          = x.flatten()
    y          = y.flatten()
    z          = z.flatten()
    xyz        = np.vstack((x, y, z)).T

    # There are three spaces to consider here:
    #
    #  - ref space:         Reference image scaled voxels ("fsl" space)
    #
    #  - aligned-src space: Source image scaled voxels, after the
    #                       source image has been linearly aligned to
    #                       the reference via the field.srcToRefMat
    #                       This will typically be equivalent to ref
    #                       space
    #
    #  - orig-src space:    Source image scaled voxels, in the coordinate
    #                       system of the original source image, without
    #                       linear alignment to the reference image

    # The displacements method will
    # return relative displacements
    # from ref space to aligned-src
    # space.
    disps   = field.displacements(xyz).reshape((ix, iy, iz, 3))
    rdfield = DeformationField(disps,
                               src=field.src,
                               ref=field.ref,
                               srcSpace=field.srcSpace,
                               refSpace=field.refSpace,
                               defType='relative')

    if (defType == 'relative') and (not premat):
    # absolute coordinates in
    # aligned-src space
    disps = convertDeformationType(rdfield)

    # Apply the premat if requested -
    # this will transform the coordinates
    # from aligned-src to orig-src space.
    if premat and field.srcToRefMat is not None:

        # We apply the premat in the same way
        # that fnirtfileutils does - applying
        # the inverse affine to every ref space
        # voxel coordinate, then adding it to
        # the existing displacements.
        shape  = disps.shape
        disps  = disps.reshape(-1, 3)
        premat = affine.concat(field.refToSrcMat - np.eye(4),
                               field.ref.getAffine('voxel', 'fsl'))
        disps  = disps + affine.transform(xyz, premat)
        disps  = disps.reshape(shape)

        # note that convertwarp applies a premat
        # differently - its method is equivalent
        # to directly transforming the existing
        # absolute displacements, i.e.:
        #
        #   disps = affine.transform(disps, refToSrc)

    adfield = DeformationField(disps,
                               src=field.src,
                               ref=field.ref,
                               srcSpace=field.srcSpace,
                               refSpace=field.refSpace,
                               defType='absolute')

    # Not either return absolute displacements,
    # or convert back to relative displacements
        return DeformationField(convertDeformationType(adfield),
                                src=field.src,
                                ref=field.ref,
                                srcSpace=field.srcSpace,
                                refSpace=field.refSpace,
                                header=field.ref.header,
                                defType='relative')