From 8053b4c7adc28715f7bf0cd92ffa4ccd08f0fe1e Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Sat, 18 May 2019 17:16:05 +0100
Subject: [PATCH] ENH: DisplacementField.transform method

---
 fsl/data/image.py          |  2 +-
 fsl/transform/nonlinear.py | 71 +++++++++++++++++++++++++++++++++++---
 2 files changed, 68 insertions(+), 5 deletions(-)

diff --git a/fsl/data/image.py b/fsl/data/image.py
index b706ee100..637e745af 100644
--- a/fsl/data/image.py
+++ b/fsl/data/image.py
@@ -227,7 +227,7 @@ class Nifti(notifier.Notifier, meta.Meta):
     ``'header       A header field has changed. This will occur when the
                     :meth:`intent` is changed.
     =============== ========================================================
-    """
+    """  # noqa
 
 
     def __init__(self, header):
diff --git a/fsl/transform/nonlinear.py b/fsl/transform/nonlinear.py
index 7ddcf36dd..dbe18d34b 100644
--- a/fsl/transform/nonlinear.py
+++ b/fsl/transform/nonlinear.py
@@ -23,8 +23,13 @@ class NonLinearTransform(fslimage.Image):
 
 
     A nonlinear transformation is an :class:`.Image` which contains
-    some mapping from a source image coordinate system to a reference image
+    some mapping between a source image coordinate system and a reference image
     coordinate system.
+
+    In FSL, non-linear transformations are defined in the same space as the
+    reference image. At a given location in the reference image space, the
+    non-linear mapping at that location can be used to calculate the
+    corresponding location in the source image space.
     """
 
 
@@ -113,7 +118,7 @@ class NonLinearTransform(fslimage.Image):
 
 class DisplacementField(NonLinearTransform):
     """Class which represents a displacement field which, at each voxel,
-    contains an absolute or relative displacement from a source space to a
+    contains an absolute or relative displacement between a source space and a
     reference space.
     """
 
@@ -166,8 +171,66 @@ class DisplacementField(NonLinearTransform):
         return self.displacementType == 'relative'
 
 
-    def transform(self, coords):
-        raise NotImplementedError()
+    def transform(self, coords, from_=None, to=None):
+        """Transform the given XYZ coordinates from the reference to the source.
+
+        :arg coords: A sequence of XYZ coordinates, or ``numpy`` array of shape
+                     ``(n, 3)`` containing ``n`` sets of coordinates in the
+                     reference space.
+        :arg from_:  Reference image space that ``coords`` are defined in
+        :arg to:     Source image space to transform ``coords`` into
+        :returns    ``coords``, transformed into the source image space
+        """
+
+        if from_ is None: from_ = self.refSpace
+        if to    is None: to    = self.srcSpace
+
+        coords = np.asanyarray(coords)
+
+        # We may need to pre-transform the
+        # coordinates so they are in the
+        # same reference image space as the
+        # displacements
+        if from_ != self.refSpace:
+            coords = affine.transform(
+                coords, self.ref.getAffine(from_, self.refSpace))
+
+        # We also need to make sure that the
+        # coordinates are in voxels, so we
+        # can look up the displacements
+        if self.refSpace != 'voxel':
+            voxels = affine.transform(
+                coords, self.ref.getAffine(self.refSpace, 'voxel'))
+        else:
+            voxels = coords
+
+        voxels = np.round(voxels).astype(np.int)
+
+        # Mask out the coordinates
+        # that are out of bounds
+        voxmask = (voxels >= [0, 0, 0]) & (voxels < self.shape[:3])
+        voxmask = voxmask.all(axis=1)
+        voxels  = voxels[voxmask]
+
+        xs, ys, zs = voxels.T
+
+        if self.absolute: disps = self.data[xs, ys, zs, :]
+        else:             disps = self.data[xs, ys, zs, :] + coords
+
+        # Make sure the coordinates
+        # are in the requested
+        # source image space
+        if to != self.srcSpace:
+            disps = affine.transform(
+                disps, self.src.getAffine(self.srcSpace, to))
+
+        # Nans for input coordinates
+        # which were outside of the
+        # field
+        outcoords          = np.full(coords.shape, np.nan)
+        outcoords[voxmask] = disps
+
+        return outcoords
 
 
 def detectDisplacementType(field):
-- 
GitLab