From 4e6e03ca1cbd346ca3147f140d1174d9b5bc3d39 Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Thu, 18 Jul 2019 13:23:46 +0100
Subject: [PATCH] RF: Minor refactorings; the only real change is that a
 deformation field is no longer assumed to be aligned with the reference in
 voxels - it is only assumed to be aligned to the reference in world
 coordinates.

---
 fsl/transform/nonlinear.py | 96 +++++++++++++++++++++++---------------
 1 file changed, 59 insertions(+), 37 deletions(-)

diff --git a/fsl/transform/nonlinear.py b/fsl/transform/nonlinear.py
index 6cc516e55..650badc7c 100644
--- a/fsl/transform/nonlinear.py
+++ b/fsl/transform/nonlinear.py
@@ -9,8 +9,8 @@ FNIRT-style nonlinear transformations.
 
 
 The :class:`DeformationField` and :class:`CoefficientField` can be used to
-load and interact with FNIRT transformation images. The following utility
-functions are also available:
+load and interact with FNIRT-style transformation images. The following
+utility functions are also available:
 
 
 .. autosummary::
@@ -53,7 +53,7 @@ class NonLinearTransform(fslimage.Image):
     space, the non-linear mapping at that location can be used to calculate
     the corresponding location in the source image space. Therefore, these
     non-linear transformation effectively encode a transformation *from* the
-    reference image *to* the source image.
+    reference image *to* the (unwarped) source image.
     """
 
 
@@ -63,9 +63,10 @@ class NonLinearTransform(fslimage.Image):
                  ref,
                  srcSpace=None,
                  refSpace=None,
-
                  **kwargs):
-        """Create a ``NonLinearTransform``.
+        """Create a ``NonLinearTransform``. See the :meth:`.Nifti.getAffine`
+        method for an overview of the values that ``srcSpace`` and ``refSpace``
+        may take.
 
         :arg image:       A string containing the name of an image file to
                           load, or a :mod:`numpy` array, or a :mod:`nibabel`
@@ -86,6 +87,12 @@ class NonLinearTransform(fslimage.Image):
         All other arguments are passed through to :meth:`.Image.__init__`.
         """
 
+        # TODO Could make more general by replacing
+        # srcSpace and refSpace with src/ref affines,
+        # which transform tofrom (e.g.) source/ref
+        # voxels to/from the space required by the
+        # deformation field
+
         if srcSpace is None: srcSpace = 'fsl'
         if refSpace is None: refSpace = 'fsl'
 
@@ -96,10 +103,10 @@ class NonLinearTransform(fslimage.Image):
 
         fslimage.Image.__init__(self, image, **kwargs)
 
-        self.__src         = fslimage.Nifti(src.header.copy())
-        self.__ref         = fslimage.Nifti(ref.header.copy())
-        self.__srcSpace    = srcSpace
-        self.__refSpace    = refSpace
+        self.__src      = fslimage.Nifti(src.header.copy())
+        self.__ref      = fslimage.Nifti(ref.header.copy())
+        self.__srcSpace = srcSpace
+        self.__refSpace = refSpace
 
 
     @property
@@ -138,6 +145,9 @@ class NonLinearTransform(fslimage.Image):
         """Transform coordinates from the reference image space to the source
         image space. Implemented by sub-classes.
 
+        See the :meth:`.Nifti.getAffine` method for an overview of the values
+        that ``from_`` and ``to`` may take.
+
         :arg coords: A sequence of XYZ coordinates, or ``numpy`` array of shape
                      ``(n, 3)`` containing ``n`` sets of coordinates in the
                      reference space.
@@ -146,7 +156,7 @@ class NonLinearTransform(fslimage.Image):
 
         :arg to:     Source image space to transform ``coords`` into
 
-        :returns:    ``coords``, transformed into the source image space
+        :returns:    The corresponding coordinates in the source image space.
         """
         raise NotImplementedError()
 
@@ -155,9 +165,14 @@ class DeformationField(NonLinearTransform):
     """Class which represents a deformation (a.k.a. warp) field which, at each
     voxel, contains either:
 
-      - a relative displacement from the reference space to the source space,
-        or
+      - a relative displacement from the reference image space to the source
+        image space, or
       - absolute coordinates in the source space
+
+
+    It is assumed that the a ``DeformationField`` is aligned with the
+    reference image in their world coordinate systems (i.e. their ``sform``
+    affines project both images into alignment).
     """
 
 
@@ -246,17 +261,23 @@ class DeformationField(NonLinearTransform):
             xform  = self.ref.getAffine(from_, self.refSpace)
             coords = affine.transform(coords, xform)
 
-        # We also need to make sure that the
-        # coordinates are in voxels, so we
-        # can look up the displacements
-        if self.refSpace != 'voxel':
-            xform  = self.ref.getAffine(self.refSpace, 'voxel')
-            voxels = affine.transform(coords, xform)
-        else:
+        # We also need to get the coordinates
+        # in field voxels, so we can look up
+        # the displacements/coordinates. We
+        # can get this through the assumption
+        # that field and ref are aligned in
+        # the world coordinate system
+        xform = affine.concat(self    .getAffine('world', 'voxel'),
+                              self.ref.getAffine(from_,   'world'))
+
+        if np.all(np.isclose(xform, np.eye(4))):
             voxels = coords
+        else:
+            voxels = affine.transform(coords, xform)
 
         # Mask out the coordinates
-        # that are out of bounds
+        # that are out of bounds of
+        # the deformation field
         voxels  = np.round(voxels).astype(np.int)
         voxmask = (voxels >= [0, 0, 0]) & (voxels < self.shape[:3])
         voxmask = voxmask.all(axis=1)
@@ -304,16 +325,16 @@ class CoefficientField(NonLinearTransform):
                  image,
                  src,
                  ref,
-                 srcSpace,
-                 refSpace,
-                 fieldType,
-                 knotSpacing,
-                 fieldToRefMat,
+                 srcSpace=None,
+                 refSpace=None,
+                 fieldType='cubic',
+                 knotSpacing=None,
+                 fieldToRefMat=None,
                  srcToRefMat=None,
                  **kwargs):
         """Create a ``CoefficientField``.
 
-        :arg fieldType:     Must be ``'cubic'``
+        :arg fieldType:     Must currently be ``'cubic'``
 
         :arg knotSpacing:   A tuple containing the spline knot spacings along
                             each axis.
@@ -337,8 +358,8 @@ class CoefficientField(NonLinearTransform):
         if fieldType not in ('cubic',):
             raise ValueError('Unsupported field type: {}'.format(fieldType))
 
-        if srcToRefMat is not None:
-            srcToRefMat = np.copy(srcToRefMat)
+        if srcToRefMat   is not None: srcToRefMat   = np.copy(srcToRefMat)
+        if fieldToRefMat is     None: fieldToRefMat = np.eye(4)
 
         NonLinearTransform.__init__(self,
                                     image,
@@ -410,6 +431,9 @@ class CoefficientField(NonLinearTransform):
     @memoize.Instanceify(memoize.memoize)
     def asDeformationField(self, defType='relative', premat=True):
         """Convert this ``CoefficientField`` to a :class:`DeformationField`.
+
+        This method is a wrapper around
+        :func:`coefficientFieldToDeformationField`
         """
         return coefficientFieldToDeformationField(self, defType, premat)
 
@@ -558,7 +582,7 @@ def convertDeformationType(field, defType=None):
 
     # Regardless of the conversion direction,
     # we need the coordinates of every voxel
-    # in the reference FSL coordinate system.
+    # in the reference coordinate system.
     dx, dy, dz = field.shape[:3]
     xform      = field.getAffine('voxel', field.refSpace)
 
@@ -570,9 +594,9 @@ def convertDeformationType(field, defType=None):
     coords     = coords.reshape((dx, dy, dz, 3))
 
     # If converting from relative to absolute,
-    # we just add the voxel coordinates to
-    # (what is assumed to be) the relative shift.
-    # Or, to convert from absolute to relative,
+    # we just add the coordinates to (what is
+    # assumed to be) the relative shift. Or,
+    # to convert from absolute to relative,
     # we subtract the reference image voxels.
     if   defType == 'absolute': return field.data + coords
     elif defType == 'relative': return field.data - coords
@@ -648,9 +672,7 @@ def convertDeformationSpace(field, from_, to):
         defType=field.deformationType)
 
 
-def coefficientFieldToDeformationField(field,
-                                       defType='relative',
-                                       premat=True):
+def coefficientFieldToDeformationField(field, defType='relative', premat=True):
     """Convert a :class:`CoefficientField` into a :class:`DeformationField`.
 
     :arg field:   :class:`CoefficientField` to convert
@@ -695,11 +717,11 @@ def coefficientFieldToDeformationField(field,
     # space.
     disps   = field.displacements(xyz).reshape((ix, iy, iz, 3))
     rdfield = DeformationField(disps,
+                               header=field.ref.header,
                                src=field.src,
                                ref=field.ref,
                                srcSpace=field.srcSpace,
                                refSpace=field.refSpace,
-                               header=field.ref.header,
                                defType='relative')
 
     if (defType == 'relative') and (not premat):
@@ -736,11 +758,11 @@ def coefficientFieldToDeformationField(field,
         #   disps = affine.transform(disps, refToSrc)
 
     adfield = DeformationField(disps,
+                               header=field.ref.header,
                                src=field.src,
                                ref=field.ref,
                                srcSpace=field.srcSpace,
                                refSpace=field.refSpace,
-                               header=field.ref.header,
                                defType='absolute')
 
     # Not either return absolute displacements,
-- 
GitLab