Skip to content
Snippets Groups Projects
Commit 6154a5b6 authored by Paul McCarthy's avatar Paul McCarthy :mountain_bicyclist:
Browse files

ENH: New fnirt module for reading/writing fnirt transforms. Expand/clean up

nonlinear module. Expand x5 module to handle non-linear. Rename x5 functions
to be more generic.
parent 1165eeb9
No related branches found
No related tags found
No related merge requests found
...@@ -34,7 +34,14 @@ from .flirt import ( # noqa ...@@ -34,7 +34,14 @@ from .flirt import ( # noqa
flirtMatrixToSform, flirtMatrixToSform,
sformToFlirtMatrix) sformToFlirtMatrix)
from .fnirt import ( # noqa
readFnirt,
writeFnirt,
toFnirt,
fromFnirt)
from .x5 import ( # noqa from .x5 import ( # noqa
readFlirtX5, readLinearX5,
writeFlirtX5 writeLinearX5,
) readNonLinearX5,
writeNonLinearX5)
...@@ -10,6 +10,8 @@ matrices. The following functions are available: ...@@ -10,6 +10,8 @@ matrices. The following functions are available:
.. autosummary:: .. autosummary::
:nosignatures: :nosignatures:
readFlirt
writeFlirt
fromFlirt fromFlirt
toFlirt toFlirt
flirtMatrixToSform flirtMatrixToSform
......
#!/usr/bin/env python
#
# fnirt.py - Functions for working with FNIRT non-linear transformations.
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
"""This module contains functions for working with FNIRT non-linear
transformation matrices. The following functions are available:
.. autosummary::
:nosignatures:
readFnirt
writeFnirt
"""
import logging
import fsl.data.constants as constants
log = logging.getLogger(__name__)
def readFnirt(fname, src, ref, dispType=None):
"""
"""
# Figure out whether the file
# is a displacement field or
# a coefficient field
import fsl.data.image as fslimage
from . import nonlinear
img = fslimage.Image(fname, loadData=False)
dispfields = (constants.FSL_FNIRT_DISPLACEMENT_FIELD,
constants.FSL_TOPUP_FIELD)
coeffields = (constants.FSL_CUBIC_SPLINE_COEFFICIENTS,
constants.FSL_DCT_COEFFICIENTS,
constants.FSL_QUADRATIC_SPLINE_COEFFICIENTS,
constants.FSL_TOPUP_CUBIC_SPLINE_COEFFICIENTS,
constants.FSL_TOPUP_QUADRATIC_SPLINE_COEFFICIENTS)
kwargs = {
'src' : src,
'ref' : ref,
'srcSpace' : 'fsl',
'refSpace' : 'fsl',
'dispType' : None,
}
if img.intent in dispfields:
return nonlinear.DisplacementField(fname, **kwargs)
elif img.intent in coeffields:
pass # return nonlinear.CoefficientField(fname, **kwargs)
else:
raise ValueError('Cannot determine type of nonlinear '
'file {}'.format(fname))
def writeFnirt(field, fname):
"""
"""
field.save(fname)
def toFnirt(field):
pass
def fromFnirt(field, from_='voxel', to='world'):
"""
"""
from . import nonlinear
return nonlinear.convertDisplacementSpace(field, from_=from_, to=to)
...@@ -4,6 +4,9 @@ ...@@ -4,6 +4,9 @@
# #
# Author: Paul McCarthy <pauldmccarthy@gmail.com> # Author: Paul McCarthy <pauldmccarthy@gmail.com>
# #
"""This module contains data structures and functions for working with
nonlinear transformations.
"""
import numpy as np import numpy as np
...@@ -14,61 +17,124 @@ from . import affine ...@@ -14,61 +17,124 @@ from . import affine
class NonLinearTransform(fslimage.Image): class NonLinearTransform(fslimage.Image):
"""Class which represents a FNIRT non-linear transformation """Class which represents a nonlinear transformation. This is just a base
class for the :class:`DisplacementField` and :class:`CoefficientField`
classes.
A nonlinear transformation is an :class:`.Image` which contains
some mapping from a source image coordinate system to a reference image
coordinate system.
""" """
def __init__(self, *args, **kwargs):
""" def __init__(self,
image,
src,
ref=None,
srcSpace=None,
refSpace=None,
**kwargs):
"""Create a ``NonLinearTransform``.
:arg image: A string containing the name of an image file to load,
or a :mod:`numpy` array, or a :mod:`nibabel` image
object.
:arg src: :class:`.Nifti` representing the sourceimage
:arg ref: :class:`.Nifti` representing the reference image.
If not provided, it is assumed that this
``NonLinearTransform`` is defined in the same
space as the reference.
:arg srcSpace: Coordinate system in the source image that this
``NonLinearTransform`` maps from. Defaults to ``'fsl'``.
:arg refSpace: Coordinate system in the reference image that this
``NonLinearTransform`` maps to. Defaults to ``'fsl'``.
All other arguments are passed through to :meth:`.Image.__init__`.
""" """
src = kwargs.pop('src', None)
ref = kwargs.pop('ref', None)
srcSpace = kwargs.pop('srceSpace', 'fsl')
refSpace = kwargs.pop('refSpace', 'fsl')
fslimage.Image.__init__(self, *args, **kwargs) if ref is None: ref = self
if srcSpace is None: srcSpace = 'fsl'
if refSpace is None: refSpace = 'fsl'
if not (isinstance(src, (fslimage.Nifti, type(None))) and
isinstance(ref, fslimage.Nifti)):
raise ValueError('Invalid source/reference: {} -> {}'.format(
src, ref))
if src is not None: src = src .header.copy() if srcSpace not in ('fsl', 'voxel', 'world') or \
if ref is not None: ref = ref .header.copy() refSpace not in ('fsl', 'voxel', 'world'):
else: ref = self.header.copy() raise ValueError('Invalid source/reference space: {} -> {}'.format(
srcSpace, refSpace))
self.__src = src fslimage.Image.__init__(self, image, **kwargs)
self.__ref = ref
self.__src = fslimage.Nifti(src.header.copy())
self.__ref = fslimage.Nifti(ref.header.copy())
self.__srcSpace = srcSpace self.__srcSpace = srcSpace
self.__refSpace = refSpace self.__refSpace = refSpace
@property @property
def src(self): def src(self):
"""Return a reference to the :class:`.Nifti` instance representing
the source image.
"""
return self.__src return self.__src
@property @property
def ref(self): def ref(self):
"""Return a reference to the :class:`.Nifti` instance representing
the reference image.
"""
return self.__ref return self.__ref
@property @property
def srcSpace(self): def srcSpace(self):
"""Return the source image coordinate system this
``NonLinearTransform`` maps from - see :meth:`.Nifti.getAffine`.
"""
return self.__srcSpace return self.__srcSpace
@property @property
def refSpace(self): def refSpace(self):
"""Return the reference image coordinate system this
``NonLinearTransform`` maps to - see :meth:`.Nifti.getAffine`.
"""
return self.__refSpace return self.__refSpace
class DisplacementField(NonLinearTransform): class DisplacementField(NonLinearTransform):
"""Class which represents a FNIRT displacement field which, at each voxel, """Class which represents a displacement field which, at each voxel,
contains an absolute or relative displacement from a source space to a contains an absolute or relative displacement from a source space to a
reference space. reference space.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
""" """Create a ``DisplacementField``.
:arg dispType: Either ``'absolute'`` or ``'relative'``, indicating
the type of this displacement field. If not provided,
will be inferred via the :func:`detectDisplacementType`
function.
All other arguments are passed through to
:meth:`NonLinearTransform.__init__`.
""" """
dispType = kwargs.pop('dispType', None) dispType = kwargs.pop('dispType', None)
if dispType not in (None, 'relative', 'absolute'):
raise ValueError('Invalid value for dispType: {}'.format(dispType))
NonLinearTransform.__init__(self, *args, **kwargs) NonLinearTransform.__init__(self, *args, **kwargs)
self.__dispType = dispType self.__dispType = dispType
...@@ -76,6 +142,9 @@ class DisplacementField(NonLinearTransform): ...@@ -76,6 +142,9 @@ class DisplacementField(NonLinearTransform):
@property @property
def displacementType(self): def displacementType(self):
"""The type of this ``DisplacementField`` - ``'absolute'`` or
``'relative'``.
"""
if self.__dispType is None: if self.__dispType is None:
self.__dispType = detectDisplacementType(self) self.__dispType = detectDisplacementType(self)
return self.__dispType return self.__dispType
...@@ -83,14 +152,23 @@ class DisplacementField(NonLinearTransform): ...@@ -83,14 +152,23 @@ class DisplacementField(NonLinearTransform):
@property @property
def absolute(self): def absolute(self):
"""``True`` if this ``DisplacementField`` contains absolute
displacements.
"""
return self.displacementType == 'absolute' return self.displacementType == 'absolute'
@property @property
def relative(self): def relative(self):
"""``True`` if this ``DisplacementField`` contains relative
displacements.
"""
return self.displacementType == 'relative' return self.displacementType == 'relative'
def transform(self, coords):
raise NotImplementedError()
def detectDisplacementType(field): def detectDisplacementType(field):
"""Attempt to automatically determine whether a displacement field is """Attempt to automatically determine whether a displacement field is
...@@ -104,9 +182,9 @@ def detectDisplacementType(field): ...@@ -104,9 +182,9 @@ def detectDisplacementType(field):
# This test is based on the assumption # This test is based on the assumption
# that a displacement field containing # that a displacement field containing
# absolute oordinates will have a greater # absolute coordinates will have a
# standard deviation than one which # greater standard deviation than one
# contains relative coordinates. # which contains relative coordinates.
absdata = field[:] absdata = field[:]
reldata = convertDisplacementType(field, 'relative') reldata = convertDisplacementType(field, 'relative')
stdabs = absdata.std(axis=(0, 1, 2)).sum() stdabs = absdata.std(axis=(0, 1, 2)).sum()
...@@ -119,6 +197,12 @@ def detectDisplacementType(field): ...@@ -119,6 +197,12 @@ def detectDisplacementType(field):
def convertDisplacementType(field, dispType=None): def convertDisplacementType(field, dispType=None):
"""Convert a displacement field between storing absolute and relative """Convert a displacement field between storing absolute and relative
displacements. displacements.
:arg field: A :class:`DisplacementField` instance
:arg dispType: Either ``'absolute'`` or ``'relative'``. If not provided,
the opposite type to ``field.displacementType`` is used.
:returns: A ``numpy.array`` containing the adjusted displacement
field.
""" """
if dispType is None: if dispType is None:
...@@ -146,13 +230,19 @@ def convertDisplacementType(field, dispType=None): ...@@ -146,13 +230,19 @@ def convertDisplacementType(field, dispType=None):
elif dispType == 'relative': return field.data - coords elif dispType == 'relative': return field.data - coords
def convertDisplacementSpace(field, src, from_, to, ref=None, dispType=None): def convertDisplacementSpace(field, from_, to):
"""Adjust the source and/or reference spaces of the given displacement """Adjust the source and/or reference spaces of the given displacement
field. field. See the :meth:`.Nifti.getAffine` method for the valid values for
""" the ``from_`` and ``to`` arguments.
:arg field: A :class:`DisplacementField` instance
:arg from_: New source image coordinate system
:arg to: New reference image coordinate system
if ref is None: ref = field :returns: A new :class:`DisplacementField` which transforms from
if dispType is None: dispType = field.displacementType the source ``from_`` coordinate system to the reference ``to``
coordinate system.
"""
# Get the field in absolute # Get the field in absolute
# coordinates if necessary # coordinates if necessary
...@@ -161,10 +251,10 @@ def convertDisplacementSpace(field, src, from_, to, ref=None, dispType=None): ...@@ -161,10 +251,10 @@ def convertDisplacementSpace(field, src, from_, to, ref=None, dispType=None):
else: srccoords = fieldcoords else: srccoords = fieldcoords
# Now transform those source # Now transform those source
# coordinates from the original # coordinates from the original
# source space to the source # source space to the source
# space specified by "from_" # space specified by "from_"
srcmat = src.getAffine(field.srcSpace, from_) srcmat = field.src.getAffine(field.srcSpace, from_)
srccoords = srccoords.reshape((-1, 3)) srccoords = srccoords.reshape((-1, 3))
srccoords = affine.transform(srccoords, srcmat) srccoords = affine.transform(srccoords, srcmat)
...@@ -172,7 +262,7 @@ def convertDisplacementSpace(field, src, from_, to, ref=None, dispType=None): ...@@ -172,7 +262,7 @@ def convertDisplacementSpace(field, src, from_, to, ref=None, dispType=None):
# an absolute displacement, the # an absolute displacement, the
# reference "to" coordinate system # reference "to" coordinate system
# is irrelevant - we're done. # is irrelevant - we're done.
if dispType == 'absolute': if field.absolute:
fieldcoords = srccoords fieldcoords = srccoords
# Otherwise our displacement field # Otherwise our displacement field
...@@ -184,7 +274,7 @@ def convertDisplacementSpace(field, src, from_, to, ref=None, dispType=None): ...@@ -184,7 +274,7 @@ def convertDisplacementSpace(field, src, from_, to, ref=None, dispType=None):
# displacements from source "from_" # displacements from source "from_"
# space into reference "to" space. # space into reference "to" space.
else: else:
refmat = ref.getAffine(field.refSpace, to) refmat = field.ref.getAffine(field.refSpace, to)
refcoords = fieldcoords.reshape((-1, 3)) refcoords = fieldcoords.reshape((-1, 3))
refcoords = affine.transform(refcoords, refmat) refcoords = affine.transform(refcoords, refmat)
fieldcoords = srccoords - refcoords fieldcoords = srccoords - refcoords
...@@ -192,8 +282,8 @@ def convertDisplacementSpace(field, src, from_, to, ref=None, dispType=None): ...@@ -192,8 +282,8 @@ def convertDisplacementSpace(field, src, from_, to, ref=None, dispType=None):
return DisplacementField( return DisplacementField(
fieldcoords.reshape(field.shape), fieldcoords.reshape(field.shape),
header=field.header, header=field.header,
src=src, src=field.src,
ref=ref, ref=field.ref,
srcSpace=from_, srcSpace=from_,
refSpace=to, refSpace=to,
dispType=dispType) dispType=field.displacementType)
...@@ -17,13 +17,13 @@ import numpy.linalg as npla ...@@ -17,13 +17,13 @@ import numpy.linalg as npla
import nibabel as nib import nibabel as nib
import h5py import h5py
from . import flirt import fsl.version as version
def _writeLinearTransform(group, xform): def _writeMetadata(group):
group.attrs['Type'] = 'linear' group.attrs['Format'] = 'X5'
group.create_dataset('Transform', data=xform) group.attrs['Version'] = '0.0.1'
group.create_dataset('Inverse', data=npla.inv(xform)) group.attrs['Metadata'] = json.dumps({'fslpy' : version.__version__})
def _readLinearTransform(group): def _readLinearTransform(group):
...@@ -32,13 +32,15 @@ def _readLinearTransform(group): ...@@ -32,13 +32,15 @@ def _readLinearTransform(group):
return np.array(group['Transform']) return np.array(group['Transform'])
def _writeLinearMapping(group, img): def _writeLinearTransform(group, xform):
group.attrs['Type'] = 'image'
group.attrs['Size'] = img.shape[ :3] xform = np.asarray(xform, dtype=np.float32)
group.attrs['Scales'] = img.pixdim[:3] inv = np.asarray(npla.inv(xform), dtype=np.float32)
group.attrs['Type'] = 'linear'
group.create_dataset('Transform', data=xform)
group.create_dataset('Inverse', data=inv)
mapping = group.create_group('Mapping')
_writeLinearTransform(mapping, img.getAffine('voxel', 'world'))
def _readLinearMapping(group): def _readLinearMapping(group):
...@@ -58,17 +60,71 @@ def _readLinearMapping(group): ...@@ -58,17 +60,71 @@ def _readLinearMapping(group):
return fslimage.Nifti(hdr) return fslimage.Nifti(hdr)
def writeFlirtX5(fname, xform, src, ref): def _writeLinearMapping(group, img):
group.attrs['Type'] = 'image'
group.attrs['Size'] = np.asarray(img.shape[ :3], np.uint32)
group.attrs['Scales'] = np.asarray(img.pixdim[:3], np.float32)
mapping = group.create_group('Mapping')
_writeLinearTransform(mapping, img.getAffine('voxel', 'world'))
def _readNonLinearTransform(group):
if group.attrs['Type'] != 'nonlinear':
raise ValueError('Not a nonlinear transform')
return np.array(group['Transform'])
def _writeNonLinearTransform(group, field):
""" """
""" """
group.attrs['Type'] = 'nonlinear'
group.create_dataset('Transform', data=field, dtype=np.float32)
xform = flirt.fromFlirt(xform, src, ref, 'world', 'world')
with h5py.File(fname, 'w') as f: def readLinearX5(fname):
f.attrs['Format'] = 'X5' """
f.attrs['Version'] = '0.0.1' """
f.attrs['Metadata'] = json.dumps({'software' : 'flirt'}) with h5py.File(fname, 'r') as f:
xform = _readLinearTransform(f['/'])
src = _readLinearMapping( f['/From'])
ref = _readLinearMapping( f['/To'])
return xform, src, ref
def writeLinearX5(fname, xform, src, ref):
"""
::
/Format # "X5"
/Version # "0.0.1"
/Metadata # json string containing unstructured metadata
/Type # "linear"
/Transform # the transform itself
/Inverse # optional pre-calculated inverse
/From/Type # "image" - could in principle be something other than
# "image" (e.g. "surface"), in which case the "Size" and
# "Scales" entries might be replaced with something else
/From/Size # voxel dimensions
/From/Scales # voxel pixdims
/From/Mapping/Type # "linear" - could be also be "nonlinear"
/From/Mapping/Transform # source voxel-to-world sform
/From/Mapping/Inverse # optional inverse
/To/Type # "image"
/To/Size # voxel dimensions
/To/Scales # voxel pixdims
/To/Mapping/Type # "linear"
/To/Mapping/Transform # reference voxel-to-world sform
/To/Mapping/Inverse # optional inverse
"""
with h5py.File(fname, 'w') as f:
_writeMetadata(f)
_writeLinearTransform(f, xform) _writeLinearTransform(f, xform)
from_ = f.create_group('/From') from_ = f.create_group('/From')
...@@ -78,12 +134,59 @@ def writeFlirtX5(fname, xform, src, ref): ...@@ -78,12 +134,59 @@ def writeFlirtX5(fname, xform, src, ref):
_writeLinearMapping(to, ref) _writeLinearMapping(to, ref)
def readFlirtX5(fname): def readNonLinearX5(fname):
""" """
""" """
from . import nonlinear
with h5py.File(fname, 'r') as f: with h5py.File(fname, 'r') as f:
xform = _readLinearTransform(f['/']) field = _readNonLinearTransform(f['/'])
src = _readLinearMapping( f['/From']) src = _readLinearMapping(f['/From'])
ref = _readLinearMapping( f['/To']) ref = _readLinearMapping(f['/To'])
return xform, src, ref # TODO coefficient fields
return nonlinear.DisplacementField(field,
src=src,
ref=ref,
srcSpace='world',
refSpace='world')
def writeNonLinearX5(fname, field):
"""
::
/Format # "X5"
/Version # "0.0.1"
/Metadata # json string containing unstructured metadata
/Type # "nonlinear"
/Transform # the displacement/coefficient field itself
/Inverse # optional pre-calculated inverse
/From/Type # "image"
/From/Size # voxel dimensions
/From/Scales # voxel pixdims
/From/Mapping/Type # "linear"
/From/Mapping/Transform # source voxel-to-world sform
/From/Mapping/Inverse # optional inverse
/To/Type # "image"
/To/Size # voxel dimensions
/To/Scales # voxel pixdims
/To/Mapping/Type # "linear"
/To/Mapping/Transform # reference voxel-to-world sform
/To/Mapping/Inverse # optional inverse
"""
# TODO coefficient fields
with h5py.File(fname, 'w') as f:
_writeMetadata(f)
_writeNonLinearTransform(f, field.data)
from_ = f.create_group('/From')
to = f.create_group('/To')
_writeLinearMapping(from_, field.src)
_writeLinearMapping(to, field.ref)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment