From fc2a1e61f7f121817f2374bcdd4d704b24b6dcf4 Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Mon, 1 Jul 2019 15:04:10 +0930
Subject: [PATCH] ENH: Finish coefficientFieldToDisplacementField
 implementation. Add some docs

---
 fsl/transform/affine.py    |   2 +-
 fsl/transform/fnirt.py     |  16 +++-
 fsl/transform/nonlinear.py | 192 ++++++++++++++++++++++++++++++-------
 3 files changed, 170 insertions(+), 40 deletions(-)

diff --git a/fsl/transform/affine.py b/fsl/transform/affine.py
index 5e4f272e5..817ef9be3 100644
--- a/fsl/transform/affine.py
+++ b/fsl/transform/affine.py
@@ -5,7 +5,7 @@
 # Author: Paul McCarthy <pauldmccarthy@gmail.com>
 #
 """This module contains utility functions for working with affine
-transformations. The following funcyions are available:
+transformations. The following functions are available:
 
 .. autosummary::
    :nosignatures:
diff --git a/fsl/transform/fnirt.py b/fsl/transform/fnirt.py
index 43ce0dd04..962599a27 100644
--- a/fsl/transform/fnirt.py
+++ b/fsl/transform/fnirt.py
@@ -88,13 +88,19 @@ def _readFnirtCoefficientField(fname, img, src, ref):
     # The sform contains an initial
     # global src-to-ref affine
     # (the starting point for the
-    # non-linear registration)
+    # non-linear registration). This
+    # is encoded as a flirt matrix,
+    # i.e. it transforms from
+    # source-scaled-voxels to
+    # ref-scaled-voxels
     srcToRefMat = img.header.get_sform()
 
-    # The fieldToRefMat affine allows us
-    # to transform coefficient field voxel
-    # coordinates into displacement field/
-    # reference image voxel coordinates.
+    # The fieldToRefMat affine tells
+    # the CoefficientField class how
+    # to transform coefficient field
+    # voxel coordinates into
+    # displacement field/reference
+    # image voxel coordinates.
     fieldToRefMat = affine.scaleOffsetXform(knotSpacing, 0)
 
     return nonlinear.CoefficientField(fname,
diff --git a/fsl/transform/nonlinear.py b/fsl/transform/nonlinear.py
index 3ca34374e..45b544491 100644
--- a/fsl/transform/nonlinear.py
+++ b/fsl/transform/nonlinear.py
@@ -1,6 +1,6 @@
 #!/usr/bin/env python
 #
-# nonlinear.py -
+# nonlinear.py - Functions/classes for non-linear transformations.
 #
 # Author: Paul McCarthy <pauldmccarthy@gmail.com>
 #
@@ -216,7 +216,7 @@ class DisplacementField(NonLinearTransform):
                      reference space.
         :arg from_:  Reference image space that ``coords`` are defined in
         :arg to:     Source image space to transform ``coords`` into
-        :returns    ``coords``, transformed into the source image space
+        :returns     ``coords``, transformed into the source image space
         """
 
         if from_ is None: from_ = self.refSpace
@@ -270,10 +270,24 @@ class DisplacementField(NonLinearTransform):
 
 
 class CoefficientField(NonLinearTransform):
-    """Class which represents a quadratic or cubic B-spline coefficient field
-    generated by FNIRT.
+    """Class which represents a cubic B-spline coefficient field generated by
+    FNIRT.
+
+    The :meth:`displacements` method can be used to calculate relative
+    displacements for a set of reference space voxel coordinates.
+
+
+    A FNIRT coefficient field typically contains a *premat*, a global affine
+    transformation from the source space to the reference space, which was
+    used as the starting point for the non-linear optimisation performed by
+    FNIRT.
+
+    This affine must be provided when creating a ``CoefficientField``, and is
+    subsequently accessed via the :meth:`srcToRefMat` or :meth:`premat`
+    attributes.
     """
 
+
     def __init__(self,
                  image,
                  src,
@@ -287,13 +301,27 @@ class CoefficientField(NonLinearTransform):
                  **kwargs):
         """Create a ``CoefficientField``.
 
-        :arg fieldType:
-        :arg knotSpacing:
-        :arg srcToRefMat:
-        :arg fieldToRefMat:
+        :arg fieldType:     Must be ``'cubic'``
+
+        :arg knotSpacing:   A tuple containing the spline knot spacings along
+                            each axis.
+
+        :arg srcToRefMat:   Initial global affine transformation from the
+                            source image to the reference image. This is
+                            assumed to be a FLIRT-style matrix, i.e. it
+                            transforms from source image FSL coordinates
+                            into reference image FSL coordinates (scaled
+                            voxels).
+
+        :arg fieldToRefMat: Affine transformation which can transform reference
+                            image voxel coordinates into coefficient field
+                            voxel coordinates.
+
+        See the :class:`NonLinearTransform` class for details on the other
+        arguments.
         """
 
-        if fieldType not in ('quadratic', 'cubic'):
+        if fieldType not in ('cubic',):
             raise ValueError('Unsupported field type: {}'.format(fieldType))
 
         NonLinearTransform.__init__(self,
@@ -302,7 +330,6 @@ class CoefficientField(NonLinearTransform):
                                     ref,
                                     srcSpace,
                                     refSpace,
-                                    refSpace='voxel',
                                     **kwargs)
 
         self.__fieldType     = fieldType
@@ -314,8 +341,8 @@ class CoefficientField(NonLinearTransform):
 
     @property
     def fieldType(self):
-        """Return the type of the coefficient field, either ``'cubic'`` or
-        ``'quadratic'``.
+        """Return the type of the coefficient field, currently always
+        ``'cubic'``.
         """
         return self.__fieldType
 
@@ -352,7 +379,20 @@ class CoefficientField(NonLinearTransform):
         return np.copy(self.__refToFieldMat)
 
 
-    def transform(self, coords, from_=None, to=None):
+    def transform(self, coords, from_=None, to=None, premat=True):
+        """Overrides :meth:`NonLinearTransform.transform`. Transforms the
+        given ``coords`` from the reference image space into the source image
+        space.
+
+        :arg coords: A sequence of XYZ coordinates, or ``numpy`` array of shape
+                     ``(n, 3)`` containing ``n`` sets of coordinates in the
+                     reference space.
+        :arg from_:  Reference image space that ``coords`` are defined in
+        :arg to:     Source image space to transform ``coords`` into
+        :returns    ``coords``, transformed into the source image space
+        :arg premat: If ``True``, the inverse :meth:`srcToRefMat` is applied
+                     to
+        """
         raise NotImplementedError()
 
 
@@ -368,7 +408,8 @@ class CoefficientField(NonLinearTransform):
             raise NotImplementedError()
 
         # See
-        #   https://www.cs.jhu.edu/~cis/cista/746/papers/RueckertFreeFormBreastMRI.pdf
+        #   https://www.cs.jhu.edu/~cis/cista/746/papers/\
+        #     RueckertFreeFormBreastMRI.pdf
         #   https://www.fmrib.ox.ac.uk/datasets/techrep/tr07ja2/tr07ja2.pdf
 
         # Cubic b-spline basis functions
@@ -412,12 +453,12 @@ class CoefficientField(NonLinearTransform):
             il   = i + l
             jm   = j + m
             kn   = k + n
-            mask = ((il >= 0)  &
-                    (il <  nx) &
-                    (jm >= 0)  &
-                    (jm <  ny) &
-                    (kn >= 0)  &
-                    (kn <  nz))
+            mask = (il >= 0)  & \
+                   (il <  nx) & \
+                   (jm >= 0)  & \
+                   (jm <  ny) & \
+                   (kn >= 0)  & \
+                   (kn <  nz)
 
             il = il[mask]
             jm = jm[mask]
@@ -567,14 +608,24 @@ def convertDisplacementSpace(field, from_, to):
         dispType=field.displacementType)
 
 
-def coefficientFieldToDisplacementField(field):
-    """Convert a :class:`CoefficientField` into a relative
-    :class:`DisplacementField`.
+def coefficientFieldToDisplacementField(field,
+                                        dispType='relative',
+                                        premat=True):
+    """Convert a :class:`CoefficientField` into a :class:`DisplacementField`.
+
+    :arg field:    :class:`CoefficientField` to convert
+
+    :arg dispType: The type of displcaement field - either ``'relative'`` (the
+                   default) or ``'absolute'``.
 
-    :arg field: :class:`CoefficientField` to convert
-    :return:    :class:`DisplacementField` calculated from ``field``
+    :arg premat:   If ``True`` (the default), the :meth:`srcToRefMat` is
+                   encoded into the displacements.
+
+    :return:       :class:`DisplacementField` calculated from ``field``.
     """
 
+    # Generate coordinates for every
+    # voxel in the reference image
     ix, iy, iz = field.ref.shape[:3]
     x,  y,  z  = np.meshgrid(np.arange(ix),
                              np.arange(iy),
@@ -583,12 +634,85 @@ def coefficientFieldToDisplacementField(field):
     y          = y.flatten()
     z          = z.flatten()
     xyz        = np.vstack((x, y, z)).T
-    disps      = field.displacements(xyz).reshape((ix, iy, iz, 3))
-
-    return DisplacementField(disps,
-                             src=field.src,
-                             ref=field.ref,
-                             srcSpace=field.srcSpace,
-                             refSpace=field.refSpace,
-                             header=field.ref.header,
-                             dispType='relative')
+
+    # There are three spaces to consider here:
+    #
+    #  - ref space:         Reference image scaled voxels ("fsl" space)
+    #
+    #  - aligned-src space: Source image scaled voxels, after the
+    #                       source image has been linearly aligned to
+    #                       the reference via the field.srcToRefMat
+    #                       This will typically be equivalent to ref
+    #                       space
+    #
+    #  - orig-src space:    Source image scaled voxels, in the coordinate
+    #                       system of the original source image, without
+    #                       linear alignment to the reference image
+
+    # The displacements method will
+    # return relative displacements
+    # from ref space to aligned-src
+    # space.
+    disps   = field.displacements(xyz).reshape((ix, iy, iz, 3))
+    rdfield = DisplacementField(disps,
+                                src=field.src,
+                                ref=field.ref,
+                                srcSpace=field.srcSpace,
+                                refSpace=field.refSpace,
+                                header=field.ref.header,
+                                dispType='relative')
+
+    if (dispType == 'relative') and (not premat):
+        return rdfield
+
+    # Convert to absolute - the
+    # displacements will now be
+    # absolute coordinates in
+    # aligned-src space
+    disps = convertDisplacementType(rdfield)
+
+    # Apply the premat if requested -
+    # this will transform the coordinates
+    # from aligned-src to orig-src space.
+    if premat:
+
+        # We apply the premat in the same way
+        # that fnirtfileutils does - applying
+        # the inverse affine to every ref space
+        # voxel coordinate, then adding it to
+        # the existing displacements.
+        shape    = disps.shape
+        disps    = disps.reshape(-1, 3)
+        refToSrc = affine.invert(field.srcToRefMat)
+        premat   = affine.concat(refToSrc - np.eye(4),
+                                 field.ref.getAffine('voxel', 'fsl'))
+        disps    = disps + affine.transform(xyz, premat)
+        disps    = disps.reshape(shape)
+
+        # note that convertwarp applies a premat
+        # differently - its method is equivalent
+        # to directly transforming the existing
+        # absolute displacements, i.e.:
+        #
+        #   disps = affine.transform(disps, refToSrc)
+
+    adfield = DisplacementField(disps,
+                                src=field.src,
+                                ref=field.ref,
+                                srcSpace=field.srcSpace,
+                                refSpace=field.refSpace,
+                                header=field.ref.header,
+                                dispType='absolute')
+
+    # Not either return absolute displacements,
+    # or convert back to relative displacements
+    if dispType == 'absolute':
+        return adfield
+    else:
+        return DisplacementField(convertDisplacementType(adfield),
+                                 src=field.src,
+                                 ref=field.ref,
+                                 srcSpace=field.srcSpace,
+                                 refSpace=field.refSpace,
+                                 header=field.ref.header,
+                                 dispType='relative')
-- 
GitLab