From 4bbd521a88ef2db2221d4ff6b670e8f1a9be4fbd Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Fri, 31 May 2019 13:10:53 +0100
Subject: [PATCH] RF,BF: Fixes/adjustments to nonlinear module, and in
 particular to convertDisplacementSpace

---
 fsl/transform/nonlinear.py | 98 ++++++++++++++++++++++----------------
 1 file changed, 57 insertions(+), 41 deletions(-)

diff --git a/fsl/transform/nonlinear.py b/fsl/transform/nonlinear.py
index 19f41c034..9e903912e 100644
--- a/fsl/transform/nonlinear.py
+++ b/fsl/transform/nonlinear.py
@@ -26,10 +26,13 @@ class NonLinearTransform(fslimage.Image):
     some mapping between a source image coordinate system and a reference image
     coordinate system.
 
+
     In FSL, non-linear transformations are defined in the same space as the
     reference image. At a given location in the reference image space, the
     non-linear mapping at that location can be used to calculate the
-    corresponding location in the source image space.
+    corresponding location in the source image space. Therefore, these
+    non-linear transformation effectively encode a transformation *from* the
+    reference image *to* the source image.
     """
 
 
@@ -46,7 +49,7 @@ class NonLinearTransform(fslimage.Image):
                        or a :mod:`numpy` array, or a :mod:`nibabel` image
                        object.
 
-        :arg src:      :class:`.Nifti` representing the sourceimage
+        :arg src:      :class:`.Nifti` representing the source image.
 
         :arg ref:      :class:`.Nifti` representing the reference image.
                        If not provided, it is assumed that this
@@ -66,11 +69,6 @@ class NonLinearTransform(fslimage.Image):
         if srcSpace is None: srcSpace = 'fsl'
         if refSpace is None: refSpace = 'fsl'
 
-        if not (isinstance(src, (fslimage.Nifti, type(None))) and
-                isinstance(ref,  fslimage.Nifti)):
-            raise ValueError('Invalid source/reference: {} -> {}'.format(
-                src, ref))
-
         if srcSpace not in ('fsl', 'voxel', 'world') or \
            refSpace not in ('fsl', 'voxel', 'world'):
             raise ValueError('Invalid source/reference space: {} -> {}'.format(
@@ -78,6 +76,12 @@ class NonLinearTransform(fslimage.Image):
 
         fslimage.Image.__init__(self, image, **kwargs)
 
+        # Displacement fields must be
+        # defined in the same space
+        # as the reference image
+        if not self.sameSpace(ref):
+            raise ValueError('Invalid reference image: {}'.format(ref))
+
         self.__src      = fslimage.Nifti(src.header.copy())
         self.__ref      = fslimage.Nifti(ref.header.copy())
         self.__srcSpace = srcSpace
@@ -172,7 +176,8 @@ class DisplacementField(NonLinearTransform):
 
 
     def transform(self, coords, from_=None, to=None):
-        """Transform the given XYZ coordinates from the reference to the source.
+        """Transform the given XYZ coordinates from the reference image space
+        to the source image space.
 
         :arg coords: A sequence of XYZ coordinates, or ``numpy`` array of shape
                      ``(n, 3)`` containing ``n`` sets of coordinates in the
@@ -192,22 +197,21 @@ class DisplacementField(NonLinearTransform):
         # same reference image space as the
         # displacements
         if from_ != self.refSpace:
-            coords = affine.transform(
-                coords, self.ref.getAffine(from_, self.refSpace))
+            xform  = self.ref.getAffine(from_, self.refSpace)
+            coords = affine.transform(coords, xform)
 
         # We also need to make sure that the
         # coordinates are in voxels, so we
         # can look up the displacements
         if self.refSpace != 'voxel':
-            voxels = affine.transform(
-                coords, self.ref.getAffine(self.refSpace, 'voxel'))
+            xform  = self.ref.getAffine(self.refSpace, 'voxel')
+            voxels = affine.transform(coords, xform)
         else:
             voxels = coords
 
-        voxels = np.round(voxels).astype(np.int)
-
         # Mask out the coordinates
         # that are out of bounds
+        voxels  = np.round(voxels).astype(np.int)
         voxmask = (voxels >= [0, 0, 0]) & (voxels < self.shape[:3])
         voxmask = voxmask.all(axis=1)
         voxels  = voxels[voxmask]
@@ -221,8 +225,8 @@ class DisplacementField(NonLinearTransform):
         # are in the requested
         # source image space
         if to != self.srcSpace:
-            disps = affine.transform(
-                disps, self.src.getAffine(self.srcSpace, to))
+            xform = self.src.getAffine(self.srcSpace, to)
+            disps = affine.transform(disps, xform)
 
         # Nans for input coordinates
         # which were outside of the
@@ -276,12 +280,13 @@ def convertDisplacementType(field, dispType=None):
     # we need the coordinates of every voxel
     # in the reference FSL coordinate system.
     dx, dy, dz = field.shape[:3]
-    v2fsl      = field.getAffine('voxel', field.srcSpace)
+    xform      = field.getAffine('voxel', field.refSpace)
+
     coords     = np.meshgrid(np.arange(dx),
                              np.arange(dy),
                              np.arange(dz), indexing='ij')
     coords     = np.array(coords).transpose((1, 2, 3, 0))
-    coords     = affine.transform(coords.reshape((-1, 3)), v2fsl)
+    coords     = affine.transform(coords.reshape((-1, 3)), xform)
     coords     = coords.reshape((dx, dy, dz, 3))
 
     # If converting from relative to absolute,
@@ -299,47 +304,58 @@ def convertDisplacementSpace(field, from_, to):
     the ``from_`` and ``to`` arguments.
 
     :arg field: A :class:`DisplacementField` instance
-    :arg from_: New source image coordinate system
-    :arg to:    New reference image coordinate system
+    :arg from_: New reference image coordinate system
+    :arg to:    New source image coordinate system
 
     :returns:   A new :class:`DisplacementField` which transforms between
-                the source ``from_`` coordinate system and the reference ``to``
+                the reference ``from_`` coordinate system and the source ``to``
                 coordinate system.
     """
 
-    # Get the field in absolute
-    # coordinates if necessary
+    # Get the field in absolute coordinates
+    # if necessary - these are our source
+    # coordinates in the original "to" space.
     fieldcoords = field.data
     if field.relative: srccoords = convertDisplacementType(field)
     else:              srccoords = fieldcoords
 
-    # Now transform those source
-    # coordinates from the original
-    # source space to the source
-    # space specified by "from_"
-    srcmat    = field.src.getAffine(field.srcSpace, from_)
     srccoords = srccoords.reshape((-1, 3))
-    srccoords = affine.transform(srccoords, srcmat)
+
+    # Now transform those source coordinates
+    # from the original source space to the
+    # source space specified by "to"
+    if to != field.srcSpace:
+
+        srcmat    = field.src.getAffine(field.srcSpace, to)
+        srccoords = affine.transform(srccoords, srcmat)
 
     # If we have been asked to return
     # an absolute displacement, the
-    # reference "to" coordinate system
-    #  is irrelevant - we're done.
+    # reference "from_" coordinate
+    # system is irrelevant - we're done.
     if field.absolute:
         fieldcoords = srccoords
 
     # Otherwise our displacement field
     # will contain relative displacements
-    # between the reference image "to"
+    # between the reference image "from_"
     # coordinate system and the source
-    # image "from_" coordinate system.
-    # We need to re-calculate the relative
-    # displacements between source "from_"
-    # space and reference "to" space.
+    # image "to" coordinate system. We
+    # need to re-calculate the relative
+    # displacements between the new
+    # reference "from_" space and source
+    # "to" space.
     else:
-        refmat      = field.ref.getAffine(field.refSpace, to)
-        refcoords   = fieldcoords.reshape((-1, 3))
-        refcoords   = affine.transform(refcoords, refmat)
+        refcoords = np.meshgrid(np.arange(field.shape[0]),
+                                np.arange(field.shape[1]),
+                                np.arange(field.shape[2]), indexing='ij')
+        refcoords = np.array(refcoords)
+        refcoords = refcoords.transpose((1, 2, 3, 0)).reshape((-1, 3))
+
+        if from_ != 'voxel':
+            refmat    = field.ref.getAffine('voxel', from_)
+            refcoords = affine.transform(refcoords, refmat)
+
         fieldcoords = srccoords - refcoords
 
     return DisplacementField(
@@ -347,6 +363,6 @@ def convertDisplacementSpace(field, from_, to):
         header=field.header,
         src=field.src,
         ref=field.ref,
-        srcSpace=from_,
-        refSpace=to,
+        srcSpace=to,
+        refSpace=from_,
         dispType=field.displacementType)
-- 
GitLab