Skip to content
Snippets Groups Projects
Commit 5801a442 authored by Paul McCarthy's avatar Paul McCarthy :mountain_bicyclist:
Browse files

ENH: Implementation of CoefficientField class, and routine to convert from

coefficient field to displacement field. Check for TOPUP intent codes
parent 2f1e421b
No related branches found
No related tags found
No related merge requests found
...@@ -320,9 +320,11 @@ class Nifti(notifier.Notifier, meta.Meta): ...@@ -320,9 +320,11 @@ class Nifti(notifier.Notifier, meta.Meta):
# $FSLDIR/src/fnirt/fnirt_file_writer.cpp # $FSLDIR/src/fnirt/fnirt_file_writer.cpp
# and fsl.transform.nonlinear for more # and fsl.transform.nonlinear for more
# details. # details.
if intent in (constants.FSL_CUBIC_SPLINE_COEFFICIENTS, if intent in (constants.FSL_DCT_COEFFICIENTS,
constants.FSL_DCT_COEFFICIENTS, constants.FSL_CUBIC_SPLINE_COEFFICIENTS,
constants.FSL_QUADRATIC_SPLINE_COEFFICIENTS): constants.FSL_QUADRATIC_SPLINE_COEFFICIENTS,
constants.FSL_TOPUP_CUBIC_SPLINE_COEFFICIENTS,
constants.FSL_TOPUP_QUADRATIC_SPLINE_COEFFICIENTS):
log.debug('FNIRT coefficient field detected - generating affine') log.debug('FNIRT coefficient field detected - generating affine')
......
...@@ -5,17 +5,37 @@ ...@@ -5,17 +5,37 @@
# Author: Paul McCarthy <pauldmccarthy@gmail.com> # Author: Paul McCarthy <pauldmccarthy@gmail.com>
# #
"""This module contains data structures and functions for working with """This module contains data structures and functions for working with
nonlinear transformations. FNIRT-style nonlinear transformations.
The :class:`DisplacementField` and :class:`CoefficientField` can be used to
load and interact with FNIRT transformation images. The following utility
functions are also available:
.. autosummary::
:nosignatures:
detectDisplacementType
convertDisplacementType
convertDisplacementSpace
coefficientFieldToDisplacementField
""" """
import logging
import itertools as it
import numpy as np import numpy as np
import fsl.data.image as fslimage import fsl.data.constants as constants
import fsl.data.image as fslimage
from . import affine from . import affine
log = logging.getLogger(__name__)
class NonLinearTransform(fslimage.Image): class NonLinearTransform(fslimage.Image):
"""Class which represents a nonlinear transformation. This is just a base """Class which represents a nonlinear transformation. This is just a base
class for the :class:`DisplacementField` and :class:`CoefficientField` class for the :class:`DisplacementField` and :class:`CoefficientField`
...@@ -27,10 +47,10 @@ class NonLinearTransform(fslimage.Image): ...@@ -27,10 +47,10 @@ class NonLinearTransform(fslimage.Image):
coordinate system. coordinate system.
In FSL, non-linear transformations are defined in the same space as the In FSL, non-linear transformations are defined in terms of the reference
reference image. At a given location in the reference image space, the image coordinate system. At a given location in the reference image
non-linear mapping at that location can be used to calculate the space, the non-linear mapping at that location can be used to calculate
corresponding location in the source image space. Therefore, these the corresponding location in the source image space. Therefore, these
non-linear transformation effectively encode a transformation *from* the non-linear transformation effectively encode a transformation *from* the
reference image *to* the source image. reference image *to* the source image.
""" """
...@@ -39,7 +59,7 @@ class NonLinearTransform(fslimage.Image): ...@@ -39,7 +59,7 @@ class NonLinearTransform(fslimage.Image):
def __init__(self, def __init__(self,
image, image,
src, src,
ref=None, ref,
srcSpace=None, srcSpace=None,
refSpace=None, refSpace=None,
**kwargs): **kwargs):
...@@ -52,9 +72,6 @@ class NonLinearTransform(fslimage.Image): ...@@ -52,9 +72,6 @@ class NonLinearTransform(fslimage.Image):
:arg src: :class:`.Nifti` representing the source image. :arg src: :class:`.Nifti` representing the source image.
:arg ref: :class:`.Nifti` representing the reference image. :arg ref: :class:`.Nifti` representing the reference image.
If not provided, it is assumed that this
``NonLinearTransform`` is defined in the same
space as the reference.
:arg srcSpace: Coordinate system in the source image that this :arg srcSpace: Coordinate system in the source image that this
``NonLinearTransform`` maps from. Defaults to ``'fsl'``. ``NonLinearTransform`` maps from. Defaults to ``'fsl'``.
...@@ -65,7 +82,6 @@ class NonLinearTransform(fslimage.Image): ...@@ -65,7 +82,6 @@ class NonLinearTransform(fslimage.Image):
All other arguments are passed through to :meth:`.Image.__init__`. All other arguments are passed through to :meth:`.Image.__init__`.
""" """
if ref is None: ref = self
if srcSpace is None: srcSpace = 'fsl' if srcSpace is None: srcSpace = 'fsl'
if refSpace is None: refSpace = 'fsl' if refSpace is None: refSpace = 'fsl'
...@@ -76,12 +92,6 @@ class NonLinearTransform(fslimage.Image): ...@@ -76,12 +92,6 @@ class NonLinearTransform(fslimage.Image):
fslimage.Image.__init__(self, image, **kwargs) fslimage.Image.__init__(self, image, **kwargs)
# Displacement fields must be
# defined in the same space
# as the reference image
if not self.sameSpace(ref):
raise ValueError('Invalid reference image: {}'.format(ref))
self.__src = fslimage.Nifti(src.header.copy()) self.__src = fslimage.Nifti(src.header.copy())
self.__ref = fslimage.Nifti(ref.header.copy()) self.__ref = fslimage.Nifti(ref.header.copy())
self.__srcSpace = srcSpace self.__srcSpace = srcSpace
...@@ -120,6 +130,20 @@ class NonLinearTransform(fslimage.Image): ...@@ -120,6 +130,20 @@ class NonLinearTransform(fslimage.Image):
return self.__refSpace return self.__refSpace
def transform(self, coords, from_=None, to=None):
"""Transform coordinates from the reference image space to the source
image space. Implemented by sub-classes.
:arg coords: A sequence of XYZ coordinates, or ``numpy`` array of shape
``(n, 3)`` containing ``n`` sets of coordinates in the
reference space.
:arg from_: Reference image space that ``coords`` are defined in
:arg to: Source image space to transform ``coords`` into
:returns ``coords``, transformed into the source image space
"""
raise NotImplementedError()
class DisplacementField(NonLinearTransform): class DisplacementField(NonLinearTransform):
"""Class which represents a displacement field which, at each voxel, """Class which represents a displacement field which, at each voxel,
contains an absolute or relative displacement between a source space and a contains an absolute or relative displacement between a source space and a
...@@ -127,9 +151,12 @@ class DisplacementField(NonLinearTransform): ...@@ -127,9 +151,12 @@ class DisplacementField(NonLinearTransform):
""" """
def __init__(self, *args, **kwargs): def __init__(self, image, src, ref=None, **kwargs):
"""Create a ``DisplacementField``. """Create a ``DisplacementField``.
:arg ref: Optional. If not provided, it is assumed that the
reference is defined in the same space as ``image``.
:arg dispType: Either ``'absolute'`` or ``'relative'``, indicating :arg dispType: Either ``'absolute'`` or ``'relative'``, indicating
the type of this displacement field. If not provided, the type of this displacement field. If not provided,
will be inferred via the :func:`detectDisplacementType` will be inferred via the :func:`detectDisplacementType`
...@@ -139,12 +166,18 @@ class DisplacementField(NonLinearTransform): ...@@ -139,12 +166,18 @@ class DisplacementField(NonLinearTransform):
:meth:`NonLinearTransform.__init__`. :meth:`NonLinearTransform.__init__`.
""" """
if ref is None:
ref = self
dispType = kwargs.pop('dispType', None) dispType = kwargs.pop('dispType', None)
if dispType not in (None, 'relative', 'absolute'): if dispType not in (None, 'relative', 'absolute'):
raise ValueError('Invalid value for dispType: {}'.format(dispType)) raise ValueError('Invalid value for dispType: {}'.format(dispType))
NonLinearTransform.__init__(self, *args, **kwargs) NonLinearTransform.__init__(self, image, src, ref, **kwargs)
if not self.sameSpace(self.ref):
raise ValueError('Invalid reference image: {}'.format(self.ref))
self.__dispType = dispType self.__dispType = dispType
...@@ -237,6 +270,184 @@ class DisplacementField(NonLinearTransform): ...@@ -237,6 +270,184 @@ class DisplacementField(NonLinearTransform):
return outcoords return outcoords
class CoefficientField(NonLinearTransform):
"""Class which represents a quadratic or cubic B-spline coefficient field
generated by FNIRT.
"""
def __init__(self, image, src, ref, **kwargs):
"""Create a ``CoefficientField``.
:arg image:
:arg src:
:arg ref:
"""
NonLinearTransform.__init__(self, image, src, ref, **kwargs)
# FNIRT uses NIFTI header fields in
# non-standard ways to store some
# additional information about the
# coefficient field. See
# $FSLDIR/src/fnirt/fnirt_file_writer.cpp
# for more details.
# The field type (quadratic, cubic,
# or discrete-cosine-transform) is
# inferred from the intent. There is
# no support in this implementation
# for DCT fields
if self.intent == constants.FSL_CUBIC_SPLINE_COEFFICIENTS:
self.__fieldType = 'cubic'
elif self.intent == constants.FSL_QUADRATIC_SPLINE_COEFFICIENTS:
self.__fieldType = 'quadratic'
else:
self.__fieldType = 'cubic'
log.warning('Unrecognised/unsupported coefficient field type '
'(assuming cubic): {}'.format(self.intent))
# Knot spacing (in voxels) is
# stored in the pixdims
kx, ky, kz = self.pixdim[:3]
self.__knotSpacing = (kx, ky, kz)
# The sform contains an initial
# global src-to-ref affine
# (the starting point for the
# non-linear registration)
self.__srcToRefMat = self.header.get_sform()
# The fieldToRefMat affine allows us
# to transform coefficient field voxel
# coordinates into displacement field/
# reference image voxel coordinates.
self.__fieldToRefMat = affine.scaleOffsetXform((kx, ky, kz), 0)
self.__refToFieldMat = affine.invert(self.__fieldToRefMat)
@property
def fieldType(self):
"""Return the type of the coefficient field, either ``'cubic'`` or
``'quadratic'``.
"""
return self.__fieldType
@property
def srcToRefMat(self):
"""Return an initial global affine transformation from the source
image to the reference image.
"""
return np.copy(self.__srcToRefMat)
@property
def knotSpacing(self):
"""Return a tuple containing spline knot spacings along the x, y, and
z axes.
"""
return self.__knotSpacing
@property
def fieldToRefMat(self):
"""Return an affine transformation which can transform coefficient
field voxel coordinates into reference image voxel coordinates.
"""
return np.copy(self.__fieldToRefMat)
@property
def refToFieldMat(self):
"""Return an affine transformation which can transform reference
image voxel coordinates into coefficient field voxel coordinates.
"""
return np.copy(self.__refToFieldMat)
def transform(self, coords, from_=None, to=None):
raise NotImplementedError()
def displacements(self, coords):
"""Calculate the relative displacemenets for the given coordinates.
:arg coords: ``(N, 3)`` array of reference image voxel coordinates.
:return: A ``(N, 3)`` array of relative displacements to the
source image for ``coords``
"""
if self.fieldType != 'cubic':
raise NotImplementedError()
# See
# https://www.cs.jhu.edu/~cis/cista/746/papers/RueckertFreeFormBreastMRI.pdf
# https://www.fmrib.ox.ac.uk/datasets/techrep/tr07ja2/tr07ja2.pdf
# Cubic b-spline basis functions
def b0(u):
return ((1 - u) ** 3) / 6
def b1(u):
return (3 * (u ** 3) - 6 * (u ** 2) + 4) / 6
def b2(u):
return (-3 * (u ** 3) + 3 * (u ** 2) + 3 * u + 1) / 6
def b3(u):
return (u ** 3) / 6
b = [b0, b1, b2, b3]
fdata = self.data
nx, ny, nz = self.shape[:3]
ix, iy, iz = self.ref.shape[:3]
# Convert the given voxel coordinates
# into the corresponding coefficient
# field voxel coordinates
x, y, z = coords.T
i, j, k = affine.transform(coords, self.refToFieldMat).T
# i, j, k: coefficient field indices
# u, v, w: position of the ref voxel
# on the current spline
u = np.remainder(i, 1)
v = np.remainder(j, 1)
w = np.remainder(k, 1)
i = np.floor(i).astype(np.int)
j = np.floor(j).astype(np.int)
k = np.floor(k).astype(np.int)
disps = np.zeros(coords.shape)
for l, m, n in it.product(range(4), range(4), range(4)):
il = i + l
jm = j + m
kn = k + n
mask = ((il >= 0) &
(il < nx) &
(jm >= 0) &
(jm < ny) &
(kn >= 0) &
(kn < nz))
il = il[mask]
jm = jm[mask]
kn = kn[mask]
uu = u[ mask]
vv = v[ mask]
ww = w[ mask]
cx, cy, cz = fdata[il, jm, kn, :].T
c = b[l](uu) * b[m](vv) * b[n](ww)
disps[mask, 0] += c * cx
disps[mask, 1] += c * cy
disps[mask, 2] += c * cz
return disps
def detectDisplacementType(field): def detectDisplacementType(field):
"""Attempt to automatically determine whether a displacement field is """Attempt to automatically determine whether a displacement field is
specified in absolute or relative coordinates. specified in absolute or relative coordinates.
...@@ -366,3 +577,23 @@ def convertDisplacementSpace(field, from_, to): ...@@ -366,3 +577,23 @@ def convertDisplacementSpace(field, from_, to):
srcSpace=to, srcSpace=to,
refSpace=from_, refSpace=from_,
dispType=field.displacementType) dispType=field.displacementType)
def coefficientFieldToDisplacementField(field):
"""Convert a FNIRT quadratic or cubic B-spline coefficient field into
a relative displacement field.
:arg field: :class:`CoefficientField` to convert
:return: :class:`DisplacementField` calculated from ``field``
"""
ix, iy, iz = field.ref.shape[:3]
x, y, z = np.meshgrid(np.arange(ix),
np.arange(iy),
np.arange(iz), indexing='ij')
x = x.flatten()
y = y.flatten()
z = z.flatten()
xyz = np.vstack((x, y, z)).T
disps = field.displacements(xyz).reshape((ix, iy, iz, 3))
return DisplacementField(disps, field.src, field.ref, dispType='relative')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment