From fce649f84ce77fc6a52a7af2537d08e74932de78 Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Mon, 3 Jun 2019 14:47:59 +0100
Subject: [PATCH] RF: Move FSL-specific coefficient field stuff from
 nonlinear.py into fnirt.py - CoefficientField can now represent a generic
 b-spline coefficient field. Cleaned up fnirt.py, other tweaks and comments

---
 fsl/transform/__init__.py  |   1 -
 fsl/transform/fnirt.py     | 179 +++++++++++++++++++++++++++++--------
 fsl/transform/nonlinear.py |  95 ++++++++++----------
 fsl/transform/x5.py        |  25 ++++--
 4 files changed, 208 insertions(+), 92 deletions(-)

diff --git a/fsl/transform/__init__.py b/fsl/transform/__init__.py
index 8a93bf594..8961c5182 100644
--- a/fsl/transform/__init__.py
+++ b/fsl/transform/__init__.py
@@ -36,7 +36,6 @@ from .flirt  import (  # noqa
 
 from .fnirt import (  # noqa
     readFnirt,
-    writeFnirt,
     toFnirt,
     fromFnirt)
 
diff --git a/fsl/transform/fnirt.py b/fsl/transform/fnirt.py
index d64e0589d..43ce0dd04 100644
--- a/fsl/transform/fnirt.py
+++ b/fsl/transform/fnirt.py
@@ -11,7 +11,8 @@ transformation matrices. The following functions are available:
    :nosignatures:
 
    readFnirt
-   writeFnirt
+   toFnirt
+   fromFnirt
 """
 
 
@@ -23,59 +24,167 @@ import fsl.data.constants as constants
 log = logging.getLogger(__name__)
 
 
-def readFnirt(fname, src, ref, dispType=None):
+def _readFnirtDisplacementField(fname, img, src, ref):
+    """Loads ``fname``, assumed to be a FNIRT displacement field.
+
+    :arg fname: File name of FNIRT displacement field
+    :arg img:   ``fname`` loaded as an :class:`.Image`
+    :arg src:   Source image
+    :arg ref:   Reference image
+    :return:    A :class:`.DisplacementField`
     """
+    from . import nonlinear
+    return nonlinear.DisplacementField(fname,
+                                       src,
+                                       ref,
+                                       srcSpace='fsl',
+                                       refSpace='fsl')
+
+
+def _readFnirtCoefficientField(fname, img, src, ref):
+    """Loads ``fname``, assumed to be a FNIRT coefficient field.
+
+    :arg fname: File name of FNIRT coefficient field
+    :arg img:   ``fname`` loaded as an :class:`.Image`
+    :arg src:   Source image
+    :arg ref:   Reference image
+    :return:    A :class:`.CoefficientField`
+    """
+
+    from . import affine
+    from . import nonlinear
+
+    # FNIRT uses NIFTI header fields in
+    # non-standard ways to store some
+    # additional information about the
+    # coefficient field. See
+    # $FSLDIR/src/fnirt/fnirt_file_writer.cpp
+    # for more details.
+
+    # The field type (quadratic, cubic,
+    # or discrete-cosine-transform) is
+    # inferred from the intent. There is
+    # no support in this implementation
+    # for DCT fields
+    cubics = (constants.FSL_CUBIC_SPLINE_COEFFICIENTS,
+              constants.FSL_TOPUP_CUBIC_SPLINE_COEFFICIENTS)
+    quads  = (constants.FSL_QUADRATIC_SPLINE_COEFFICIENTS,
+              constants.FSL_TOPUP_QUADRATIC_SPLINE_COEFFICIENTS)
+
+    if img.intent in cubics:
+        fieldType = 'cubic'
+    elif img.intent in quads:
+        fieldType = 'quadratic'
+    else:
+        fieldType = 'cubic'
+        log.warning('Unrecognised/unsupported coefficient '
+                    'field type (assuming cubic b-spline): '
+                    '{}'.format(img.intent))
+
+    # Knot spacing (in voxels) is
+    # stored in the pixdims
+    knotSpacing = img.pixdim[:3]
+
+    # The sform contains an initial
+    # global src-to-ref affine
+    # (the starting point for the
+    # non-linear registration)
+    srcToRefMat = img.header.get_sform()
+
+    # The fieldToRefMat affine allows us
+    # to transform coefficient field voxel
+    # coordinates into displacement field/
+    # reference image voxel coordinates.
+    fieldToRefMat = affine.scaleOffsetXform(knotSpacing, 0)
+
+    return nonlinear.CoefficientField(fname,
+                                      src,
+                                      ref,
+                                      srcSpace='fsl',
+                                      refSpace='fsl',
+                                      fieldType=fieldType,
+                                      knotSpacing=knotSpacing,
+                                      srcToRefMat=srcToRefMat,
+                                      fieldToRefMat=fieldToRefMat)
+
+
+def readFnirt(fname, src, ref):
+    """Reads a non-linear FNIRT transformation image, returning
+    a :class:`.DisplacementField` or :class:`.CoefficientField` depending
+    on the file type.
+
+    :arg fname: File name of FNIRT transformation
+    :arg src:   Source image
+    :arg ref:   Reference image
     """
 
     # Figure out whether the file
     # is a displacement field or
     # a coefficient field
-    import fsl.data.image     as fslimage
-    from   .              import nonlinear
+    import fsl.data.image as fslimage
+
+    img   = fslimage.Image(fname, loadData=False)
+    disps = (constants.FSL_FNIRT_DISPLACEMENT_FIELD,
+             constants.FSL_TOPUP_FIELD)
+    coefs = (constants.FSL_CUBIC_SPLINE_COEFFICIENTS,
+             constants.FSL_DCT_COEFFICIENTS,
+             constants.FSL_QUADRATIC_SPLINE_COEFFICIENTS,
+             constants.FSL_TOPUP_CUBIC_SPLINE_COEFFICIENTS,
+             constants.FSL_TOPUP_QUADRATIC_SPLINE_COEFFICIENTS)
+
+    if img.intent in disps:
+        return _readFnirtDisplacementField(fname, img, src, ref)
+    elif img.intent in coefs:
+        return _readFnirtCoefficientField(fname, img, src, ref)
+    else:
+        raise ValueError('Cannot determine type of nonlinear '
+                         'file {}'.format(fname))
 
-    img = fslimage.Image(fname, loadData=False)
 
-    dispfields = (constants.FSL_FNIRT_DISPLACEMENT_FIELD,
-                  constants.FSL_TOPUP_FIELD)
-    coeffields = (constants.FSL_CUBIC_SPLINE_COEFFICIENTS,
-                  constants.FSL_DCT_COEFFICIENTS,
-                  constants.FSL_QUADRATIC_SPLINE_COEFFICIENTS,
-                  constants.FSL_TOPUP_CUBIC_SPLINE_COEFFICIENTS,
-                  constants.FSL_TOPUP_QUADRATIC_SPLINE_COEFFICIENTS)
+def toFnirt(field):
+    """Convert a :class:`.NonLinearTransform` to a FNIRT-compatible
+    :class:`.DisplacementField`.
 
-    kwargs = {
-        'src'      : src,
-        'ref'      : ref,
-        'srcSpace' : 'fsl',
-        'refSpace' : 'fsl',
-        'dispType' : None,
-    }
+    :arg field: :class:`.NonLinearTransform` to convert
+    :return:    A FNIRT-compatible :class:`.DisplacementField`.
+    """
 
-    if img.intent in dispfields:
-        return nonlinear.DisplacementField(fname, **kwargs)
+    from . import nonlinear
 
-    elif img.intent in coeffields:
-        pass  # return nonlinear.CoefficientField(fname, **kwargs)
+    # We can't convert a CoefficientField,
+    # because the coefficients will have
+    # been calculated between some other
+    # source/reference coordinate systems,
+    # and we can't adjust the coefficients
+    # to encode an FSL->FSL deformation.
+    if isinstance(field, nonlinear.CoefficientField):
+        field = nonlinear.coefficientFieldToDisplacementField(field)
 
-    else:
-        raise ValueError('Cannot determine type of nonlinear '
-                         'file {}'.format(fname))
+    field = nonlinear.convertDisplacementSpace(field, from_='fsl', to='fsl')
+    field.header['intent_code'] = constants.FSL_FNIRT_DISPLACEMENT_FIELD
 
+    return field
 
-def writeFnirt(field, fname):
-    """
-    """
-    field.save(fname)
 
+def fromFnirt(field, from_='world', to='world'):
+    """Convert a FNIRT-style :class:`.NonLinearTransform` to a generic
+    :class:`.DisplacementField`.
 
-def toFnirt(field):
-    pass
+    :arg field: A FNIRT-style :class:`.CoefficientField` or
+                :class:`.DisplacementField`
 
+    :arg from_: Desired reference image coordinate system
 
-def fromFnirt(field, from_='voxel', to='world'):
-    """
-    """
+    :arg to:    Desired source image coordinate system
 
+    :return:    A :class:`.DisplacementField` which contains displacements
+                from the reference image ``from_`` cordinate system to the
+                source image ``to`` coordinate syste.
+    """
     from . import nonlinear
 
+    # see comments in toFnirt
+    if isinstance(field, nonlinear.CoefficientField):
+        field = nonlinear.coefficientFieldToDisplacementField(field)
+
     return nonlinear.convertDisplacementSpace(field, from_=from_, to=to)
diff --git a/fsl/transform/nonlinear.py b/fsl/transform/nonlinear.py
index 4f4022c6d..3ca34374e 100644
--- a/fsl/transform/nonlinear.py
+++ b/fsl/transform/nonlinear.py
@@ -27,8 +27,7 @@ import itertools as it
 
 import numpy as np
 
-import fsl.data.constants as constants
-import fsl.data.image     as fslimage
+import fsl.data.image as fslimage
 
 from . import affine
 
@@ -74,10 +73,10 @@ class NonLinearTransform(fslimage.Image):
         :arg ref:      :class:`.Nifti` representing the reference image.
 
         :arg srcSpace: Coordinate system in the source image that this
-                       ``NonLinearTransform`` maps from. Defaults to ``'fsl'``.
+                       ``NonLinearTransform`` maps to. Defaults to ``'fsl'``.
 
         :arg refSpace: Coordinate system in the reference image that this
-                       ``NonLinearTransform`` maps to. Defaults to ``'fsl'``.
+                       ``NonLinearTransform`` maps from. Defaults to ``'fsl'``.
 
         All other arguments are passed through to :meth:`.Image.__init__`.
         """
@@ -275,53 +274,41 @@ class CoefficientField(NonLinearTransform):
     generated by FNIRT.
     """
 
-    def __init__(self, image, src, ref, **kwargs):
+    def __init__(self,
+                 image,
+                 src,
+                 ref,
+                 srcSpace,
+                 refSpace,
+                 fieldType,
+                 knotSpacing,
+                 srcToRefMat,
+                 fieldToRefMat,
+                 **kwargs):
         """Create a ``CoefficientField``.
 
-        :arg image:
-        :arg src:
-        :arg ref:
+        :arg fieldType:
+        :arg knotSpacing:
+        :arg srcToRefMat:
+        :arg fieldToRefMat:
         """
 
-        NonLinearTransform.__init__(self, image, src, ref, **kwargs)
-
-        # FNIRT uses NIFTI header fields in
-        # non-standard ways to store some
-        # additional information about the
-        # coefficient field. See
-        # $FSLDIR/src/fnirt/fnirt_file_writer.cpp
-        # for more details.
-
-        # The field type (quadratic, cubic,
-        # or discrete-cosine-transform) is
-        # inferred from the intent. There is
-        # no support in this implementation
-        # for DCT fields
-        if self.intent == constants.FSL_CUBIC_SPLINE_COEFFICIENTS:
-            self.__fieldType = 'cubic'
-        elif self.intent == constants.FSL_QUADRATIC_SPLINE_COEFFICIENTS:
-            self.__fieldType = 'quadratic'
-        else:
-            self.__fieldType = 'cubic'
-            log.warning('Unrecognised/unsupported coefficient field type '
-                        '(assuming cubic): {}'.format(self.intent))
-
-        # Knot spacing (in voxels) is
-        # stored in the pixdims
-        kx, ky, kz = self.pixdim[:3]
-        self.__knotSpacing = (kx, ky, kz)
-
-        # The sform contains an initial
-        # global src-to-ref affine
-        # (the starting point for the
-        # non-linear registration)
-        self.__srcToRefMat = self.header.get_sform()
-
-        # The fieldToRefMat affine allows us
-        # to transform coefficient field voxel
-        # coordinates into displacement field/
-        # reference image voxel coordinates.
-        self.__fieldToRefMat = affine.scaleOffsetXform((kx, ky, kz), 0)
+        if fieldType not in ('quadratic', 'cubic'):
+            raise ValueError('Unsupported field type: {}'.format(fieldType))
+
+        NonLinearTransform.__init__(self,
+                                    image,
+                                    src,
+                                    ref,
+                                    srcSpace,
+                                    refSpace,
+                                    refSpace='voxel',
+                                    **kwargs)
+
+        self.__fieldType     = fieldType
+        self.__knotSpacing   = tuple(knotSpacing)
+        self.__srcToRefMat   = np.copy(srcToRefMat)
+        self.__fieldToRefMat = np.copy(fieldToRefMat)
         self.__refToFieldMat = affine.invert(self.__fieldToRefMat)
 
 
@@ -348,6 +335,7 @@ class CoefficientField(NonLinearTransform):
         """
         return self.__knotSpacing
 
+
     @property
     def fieldToRefMat(self):
         """Return an affine transformation which can transform coefficient
@@ -580,12 +568,13 @@ def convertDisplacementSpace(field, from_, to):
 
 
 def coefficientFieldToDisplacementField(field):
-    """Convert a FNIRT quadratic or cubic B-spline coefficient field into
-    a relative displacement field.
+    """Convert a :class:`CoefficientField` into a relative
+    :class:`DisplacementField`.
 
     :arg field: :class:`CoefficientField` to convert
     :return:    :class:`DisplacementField` calculated from ``field``
     """
+
     ix, iy, iz = field.ref.shape[:3]
     x,  y,  z  = np.meshgrid(np.arange(ix),
                              np.arange(iy),
@@ -596,4 +585,10 @@ def coefficientFieldToDisplacementField(field):
     xyz        = np.vstack((x, y, z)).T
     disps      = field.displacements(xyz).reshape((ix, iy, iz, 3))
 
-    return DisplacementField(disps, field.src, field.ref, dispType='relative')
+    return DisplacementField(disps,
+                             src=field.src,
+                             ref=field.ref,
+                             srcSpace=field.srcSpace,
+                             refSpace=field.refSpace,
+                             header=field.ref.header,
+                             dispType='relative')
diff --git a/fsl/transform/x5.py b/fsl/transform/x5.py
index d6659a5e7..1ebd1daab 100644
--- a/fsl/transform/x5.py
+++ b/fsl/transform/x5.py
@@ -111,7 +111,7 @@ def writeLinearX5(fname, xform, src, ref):
                                       # "Scales" entries might be replaced with something else
         /From/Size                    # voxel dimensions
         /From/Scales                  # voxel pixdims
-        /From/Mapping/Type            # "linear" - could be also be "nonlinear"
+        /From/Mapping/Type            # "linear"
         /From/Mapping/Transform       # source voxel-to-world sform
         /From/Mapping/Inverse         # optional inverse
 
@@ -121,7 +121,7 @@ def writeLinearX5(fname, xform, src, ref):
         /To/Mapping/Type              # "linear"
         /To/Mapping/Transform         # reference voxel-to-world sform
         /To/Mapping/Inverse           # optional inverse
-    """
+    """  # noqa
 
     with h5py.File(fname, 'w') as f:
         _writeMetadata(f)
@@ -160,24 +160,37 @@ def writeNonLinearX5(fname, field):
         /Version                      # "0.0.1"
         /Metadata                     # json string containing unstructured metadata
 
-        /Type                         # "nonlinear"
         /Transform                    # the displacement/coefficient field itself
+        /Type                         # "nonlinear"
+        /SubType                      # "displacement" / "deformation"
+        /Representation               # "cubic bspline" / "quadratic bspline"
         /Inverse                      # optional pre-calculated inverse
 
+        /Pre/Type                     # "linear"
+        /Pre/Transform                # ref world-to-[somewhere], to prepare ref
+                                      # world coordinates as inputs to the nonlinear
+                                      # transform
+        /Pre/Inverse                  # optional pre-calculated inverse
+        /Post/Type                    #  "linear"
+        /Post/Transform               # source [somewhere]-to-world, to transform
+                                      # source coordinates produced by the nonlinear
+                                      # transform into world coordinates
+        /Post/Inverse                 # optional pre-calculated inverse
+
         /From/Type                    # "image"
         /From/Size                    # voxel dimensions
         /From/Scales                  # voxel pixdims
-        /From/Mapping/Type            # "linear"
         /From/Mapping/Transform       # source voxel-to-world sform
+        /From/Mapping/Type            # "linear"
         /From/Mapping/Inverse         # optional inverse
 
         /To/Type                      # "image"
         /To/Size                      # voxel dimensions
         /To/Scales                    # voxel pixdims
-        /To/Mapping/Type              # "linear"
         /To/Mapping/Transform         # reference voxel-to-world sform
+        /To/Mapping/Type              # "linear"
         /To/Mapping/Inverse           # optional inverse
-    """
+    """  # noqa
 
     # TODO coefficient fields
 
-- 
GitLab