From 5801a4421147b936dac92762747017c1576ab154 Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Mon, 3 Jun 2019 12:22:48 +0100
Subject: [PATCH] ENH: Implementation of CoefficientField class, and routine to
 convert from coefficient field to displacement field. Check for TOPUP intent
 codes

---
 fsl/data/image.py          |   8 +-
 fsl/transform/nonlinear.py | 269 ++++++++++++++++++++++++++++++++++---
 2 files changed, 255 insertions(+), 22 deletions(-)

diff --git a/fsl/data/image.py b/fsl/data/image.py
index b2f07648e..ac9f514e1 100644
--- a/fsl/data/image.py
+++ b/fsl/data/image.py
@@ -320,9 +320,11 @@ class Nifti(notifier.Notifier, meta.Meta):
         # $FSLDIR/src/fnirt/fnirt_file_writer.cpp
         # and fsl.transform.nonlinear for more
         # details.
-        if intent in (constants.FSL_CUBIC_SPLINE_COEFFICIENTS,
-                      constants.FSL_DCT_COEFFICIENTS,
-                      constants.FSL_QUADRATIC_SPLINE_COEFFICIENTS):
+        if intent in (constants.FSL_DCT_COEFFICIENTS,
+                      constants.FSL_CUBIC_SPLINE_COEFFICIENTS,
+                      constants.FSL_QUADRATIC_SPLINE_COEFFICIENTS,
+                      constants.FSL_TOPUP_CUBIC_SPLINE_COEFFICIENTS,
+                      constants.FSL_TOPUP_QUADRATIC_SPLINE_COEFFICIENTS):
 
             log.debug('FNIRT coefficient field detected - generating affine')
 
diff --git a/fsl/transform/nonlinear.py b/fsl/transform/nonlinear.py
index 9e903912e..4f4022c6d 100644
--- a/fsl/transform/nonlinear.py
+++ b/fsl/transform/nonlinear.py
@@ -5,17 +5,37 @@
 # Author: Paul McCarthy <pauldmccarthy@gmail.com>
 #
 """This module contains data structures and functions for working with
-nonlinear transformations.
+FNIRT-style nonlinear transformations.
+
+
+The :class:`DisplacementField` and :class:`CoefficientField` can be used to
+load and interact with FNIRT transformation images. The following utility
+functions are also available:
+
+.. autosummary::
+   :nosignatures:
+
+   detectDisplacementType
+   convertDisplacementType
+   convertDisplacementSpace
+   coefficientFieldToDisplacementField
 """
 
 
+import              logging
+import itertools as it
+
 import numpy as np
 
-import fsl.data.image as fslimage
+import fsl.data.constants as constants
+import fsl.data.image     as fslimage
 
 from . import affine
 
 
+log = logging.getLogger(__name__)
+
+
 class NonLinearTransform(fslimage.Image):
     """Class which represents a nonlinear transformation. This is just a base
     class for the :class:`DisplacementField` and :class:`CoefficientField`
@@ -27,10 +47,10 @@ class NonLinearTransform(fslimage.Image):
     coordinate system.
 
 
-    In FSL, non-linear transformations are defined in the same space as the
-    reference image. At a given location in the reference image space, the
-    non-linear mapping at that location can be used to calculate the
-    corresponding location in the source image space. Therefore, these
+    In FSL, non-linear transformations are defined in terms of the reference
+    image coordinate system.  At a given location in the reference image
+    space, the non-linear mapping at that location can be used to calculate
+    the corresponding location in the source image space. Therefore, these
     non-linear transformation effectively encode a transformation *from* the
     reference image *to* the source image.
     """
@@ -39,7 +59,7 @@ class NonLinearTransform(fslimage.Image):
     def __init__(self,
                  image,
                  src,
-                 ref=None,
+                 ref,
                  srcSpace=None,
                  refSpace=None,
                  **kwargs):
@@ -52,9 +72,6 @@ class NonLinearTransform(fslimage.Image):
         :arg src:      :class:`.Nifti` representing the source image.
 
         :arg ref:      :class:`.Nifti` representing the reference image.
-                       If not provided, it is assumed that this
-                       ``NonLinearTransform`` is defined in the same
-                       space as the reference.
 
         :arg srcSpace: Coordinate system in the source image that this
                        ``NonLinearTransform`` maps from. Defaults to ``'fsl'``.
@@ -65,7 +82,6 @@ class NonLinearTransform(fslimage.Image):
         All other arguments are passed through to :meth:`.Image.__init__`.
         """
 
-        if ref      is None: ref      = self
         if srcSpace is None: srcSpace = 'fsl'
         if refSpace is None: refSpace = 'fsl'
 
@@ -76,12 +92,6 @@ class NonLinearTransform(fslimage.Image):
 
         fslimage.Image.__init__(self, image, **kwargs)
 
-        # Displacement fields must be
-        # defined in the same space
-        # as the reference image
-        if not self.sameSpace(ref):
-            raise ValueError('Invalid reference image: {}'.format(ref))
-
         self.__src      = fslimage.Nifti(src.header.copy())
         self.__ref      = fslimage.Nifti(ref.header.copy())
         self.__srcSpace = srcSpace
@@ -120,6 +130,20 @@ class NonLinearTransform(fslimage.Image):
         return self.__refSpace
 
 
+    def transform(self, coords, from_=None, to=None):
+        """Transform coordinates from the reference image space to the source
+        image space. Implemented by sub-classes.
+
+        :arg coords: A sequence of XYZ coordinates, or ``numpy`` array of shape
+                     ``(n, 3)`` containing ``n`` sets of coordinates in the
+                     reference space.
+        :arg from_:  Reference image space that ``coords`` are defined in
+        :arg to:     Source image space to transform ``coords`` into
+        :returns    ``coords``, transformed into the source image space
+        """
+        raise NotImplementedError()
+
+
 class DisplacementField(NonLinearTransform):
     """Class which represents a displacement field which, at each voxel,
     contains an absolute or relative displacement between a source space and a
@@ -127,9 +151,12 @@ class DisplacementField(NonLinearTransform):
     """
 
 
-    def __init__(self, *args, **kwargs):
+    def __init__(self, image, src, ref=None, **kwargs):
         """Create a ``DisplacementField``.
 
+        :arg ref:      Optional. If not provided, it is assumed that the
+                       reference is defined in the same space as ``image``.
+
         :arg dispType: Either ``'absolute'`` or ``'relative'``, indicating
                        the type of this displacement field. If not provided,
                        will be inferred via the :func:`detectDisplacementType`
@@ -139,12 +166,18 @@ class DisplacementField(NonLinearTransform):
         :meth:`NonLinearTransform.__init__`.
         """
 
+        if ref is None:
+            ref = self
+
         dispType = kwargs.pop('dispType',  None)
 
         if dispType not in (None, 'relative', 'absolute'):
             raise ValueError('Invalid value for dispType: {}'.format(dispType))
 
-        NonLinearTransform.__init__(self, *args, **kwargs)
+        NonLinearTransform.__init__(self, image, src, ref, **kwargs)
+
+        if not self.sameSpace(self.ref):
+            raise ValueError('Invalid reference image: {}'.format(self.ref))
 
         self.__dispType = dispType
 
@@ -237,6 +270,184 @@ class DisplacementField(NonLinearTransform):
         return outcoords
 
 
+class CoefficientField(NonLinearTransform):
+    """Class which represents a quadratic or cubic B-spline coefficient field
+    generated by FNIRT.
+    """
+
+    def __init__(self, image, src, ref, **kwargs):
+        """Create a ``CoefficientField``.
+
+        :arg image:
+        :arg src:
+        :arg ref:
+        """
+
+        NonLinearTransform.__init__(self, image, src, ref, **kwargs)
+
+        # FNIRT uses NIFTI header fields in
+        # non-standard ways to store some
+        # additional information about the
+        # coefficient field. See
+        # $FSLDIR/src/fnirt/fnirt_file_writer.cpp
+        # for more details.
+
+        # The field type (quadratic, cubic,
+        # or discrete-cosine-transform) is
+        # inferred from the intent. There is
+        # no support in this implementation
+        # for DCT fields
+        if self.intent == constants.FSL_CUBIC_SPLINE_COEFFICIENTS:
+            self.__fieldType = 'cubic'
+        elif self.intent == constants.FSL_QUADRATIC_SPLINE_COEFFICIENTS:
+            self.__fieldType = 'quadratic'
+        else:
+            self.__fieldType = 'cubic'
+            log.warning('Unrecognised/unsupported coefficient field type '
+                        '(assuming cubic): {}'.format(self.intent))
+
+        # Knot spacing (in voxels) is
+        # stored in the pixdims
+        kx, ky, kz = self.pixdim[:3]
+        self.__knotSpacing = (kx, ky, kz)
+
+        # The sform contains an initial
+        # global src-to-ref affine
+        # (the starting point for the
+        # non-linear registration)
+        self.__srcToRefMat = self.header.get_sform()
+
+        # The fieldToRefMat affine allows us
+        # to transform coefficient field voxel
+        # coordinates into displacement field/
+        # reference image voxel coordinates.
+        self.__fieldToRefMat = affine.scaleOffsetXform((kx, ky, kz), 0)
+        self.__refToFieldMat = affine.invert(self.__fieldToRefMat)
+
+
+    @property
+    def fieldType(self):
+        """Return the type of the coefficient field, either ``'cubic'`` or
+        ``'quadratic'``.
+        """
+        return self.__fieldType
+
+
+    @property
+    def srcToRefMat(self):
+        """Return an initial global affine transformation from the source
+        image to the reference image.
+        """
+        return np.copy(self.__srcToRefMat)
+
+
+    @property
+    def knotSpacing(self):
+        """Return a tuple containing spline knot spacings along the x, y, and
+        z axes.
+        """
+        return self.__knotSpacing
+
+    @property
+    def fieldToRefMat(self):
+        """Return an affine transformation which can transform coefficient
+        field voxel coordinates into reference image voxel coordinates.
+        """
+        return np.copy(self.__fieldToRefMat)
+
+
+    @property
+    def refToFieldMat(self):
+        """Return an affine transformation which can transform reference
+        image voxel coordinates into coefficient field voxel coordinates.
+        """
+        return np.copy(self.__refToFieldMat)
+
+
+    def transform(self, coords, from_=None, to=None):
+        raise NotImplementedError()
+
+
+    def displacements(self, coords):
+        """Calculate the relative displacemenets for the given coordinates.
+
+        :arg coords: ``(N, 3)`` array of reference image voxel coordinates.
+        :return:      A ``(N, 3)`` array  of relative displacements to the
+                      source image for ``coords``
+        """
+
+        if self.fieldType != 'cubic':
+            raise NotImplementedError()
+
+        # See
+        #   https://www.cs.jhu.edu/~cis/cista/746/papers/RueckertFreeFormBreastMRI.pdf
+        #   https://www.fmrib.ox.ac.uk/datasets/techrep/tr07ja2/tr07ja2.pdf
+
+        # Cubic b-spline basis functions
+        def b0(u):
+            return ((1 - u) ** 3) / 6
+
+        def b1(u):
+            return (3 * (u ** 3) - 6 * (u ** 2) + 4) / 6
+
+        def b2(u):
+            return (-3 * (u ** 3) + 3 * (u ** 2)  + 3 * u + 1) / 6
+
+        def b3(u):
+            return (u ** 3) / 6
+
+        b = [b0, b1, b2, b3]
+
+        fdata      = self.data
+        nx, ny, nz = self.shape[:3]
+        ix, iy, iz = self.ref.shape[:3]
+
+        # Convert the given voxel coordinates
+        # into the corresponding coefficient
+        # field voxel coordinates
+        x, y, z    = coords.T
+        i, j, k    = affine.transform(coords, self.refToFieldMat).T
+
+        # i, j, k: coefficient field indices
+        # u, v, w: position of the ref voxel
+        #          on the current spline
+        u = np.remainder(i, 1)
+        v = np.remainder(j, 1)
+        w = np.remainder(k, 1)
+        i = np.floor(i).astype(np.int)
+        j = np.floor(j).astype(np.int)
+        k = np.floor(k).astype(np.int)
+
+        disps = np.zeros(coords.shape)
+
+        for l, m, n in it.product(range(4), range(4), range(4)):
+            il   = i + l
+            jm   = j + m
+            kn   = k + n
+            mask = ((il >= 0)  &
+                    (il <  nx) &
+                    (jm >= 0)  &
+                    (jm <  ny) &
+                    (kn >= 0)  &
+                    (kn <  nz))
+
+            il = il[mask]
+            jm = jm[mask]
+            kn = kn[mask]
+            uu = u[ mask]
+            vv = v[ mask]
+            ww = w[ mask]
+
+            cx, cy, cz = fdata[il, jm, kn, :].T
+            c          = b[l](uu) * b[m](vv) * b[n](ww)
+
+            disps[mask, 0] += c * cx
+            disps[mask, 1] += c * cy
+            disps[mask, 2] += c * cz
+
+        return disps
+
+
 def detectDisplacementType(field):
     """Attempt to automatically determine whether a displacement field is
     specified in absolute or relative coordinates.
@@ -366,3 +577,23 @@ def convertDisplacementSpace(field, from_, to):
         srcSpace=to,
         refSpace=from_,
         dispType=field.displacementType)
+
+
+def coefficientFieldToDisplacementField(field):
+    """Convert a FNIRT quadratic or cubic B-spline coefficient field into
+    a relative displacement field.
+
+    :arg field: :class:`CoefficientField` to convert
+    :return:    :class:`DisplacementField` calculated from ``field``
+    """
+    ix, iy, iz = field.ref.shape[:3]
+    x,  y,  z  = np.meshgrid(np.arange(ix),
+                             np.arange(iy),
+                             np.arange(iz), indexing='ij')
+    x          = x.flatten()
+    y          = y.flatten()
+    z          = z.flatten()
+    xyz        = np.vstack((x, y, z)).T
+    disps      = field.displacements(xyz).reshape((ix, iy, iz, 3))
+
+    return DisplacementField(disps, field.src, field.ref, dispType='relative')
-- 
GitLab