Skip to content
Snippets Groups Projects
Commit 84300e0c authored by Paul McCarthy's avatar Paul McCarthy :mountain_bicyclist:
Browse files

BF,RF: FIxes to fnirt/x5

parent f57b0a92
No related branches found
No related tags found
No related merge requests found
...@@ -24,21 +24,24 @@ import fsl.data.constants as constants ...@@ -24,21 +24,24 @@ import fsl.data.constants as constants
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
def _readFnirtDisplacementField(fname, img, src, ref): def _readFnirtDisplacementField(fname, img, src, ref, dispType=None):
"""Loads ``fname``, assumed to be a FNIRT displacement field. """Loads ``fname``, assumed to be a FNIRT displacement field.
:arg fname: File name of FNIRT displacement field :arg fname: File name of FNIRT displacement field
:arg img: ``fname`` loaded as an :class:`.Image` :arg img: ``fname`` loaded as an :class:`.Image`
:arg src: Source image :arg src: Source image
:arg ref: Reference image :arg ref: Reference image
:return: A :class:`.DisplacementField` :arg dispType: Displacement type - either ``'absolute'`` or ``'relative'``.
If not provided, is automatically inferred from the data.
:return: A :class:`.DisplacementField`
""" """
from . import nonlinear from . import nonlinear
return nonlinear.DisplacementField(fname, return nonlinear.DisplacementField(fname,
src, src,
ref, ref,
srcSpace='fsl', srcSpace='fsl',
refSpace='fsl') refSpace='fsl',
dispType=dispType)
def _readFnirtCoefficientField(fname, img, src, ref): def _readFnirtCoefficientField(fname, img, src, ref):
...@@ -114,14 +117,16 @@ def _readFnirtCoefficientField(fname, img, src, ref): ...@@ -114,14 +117,16 @@ def _readFnirtCoefficientField(fname, img, src, ref):
fieldToRefMat=fieldToRefMat) fieldToRefMat=fieldToRefMat)
def readFnirt(fname, src, ref): def readFnirt(fname, src, ref, dispType=None):
"""Reads a non-linear FNIRT transformation image, returning """Reads a non-linear FNIRT transformation image, returning
a :class:`.DisplacementField` or :class:`.CoefficientField` depending a :class:`.DisplacementField` or :class:`.CoefficientField` depending
on the file type. on the file type.
:arg fname: File name of FNIRT transformation :arg fname: File name of FNIRT transformation
:arg src: Source image :arg src: Source image
:arg ref: Reference image :arg ref: Reference image
:arg dispType: Displacement type - either ``'absolute'`` or ``'relative'``.
If not provided, is automatically inferred from the data.
""" """
# Figure out whether the file # Figure out whether the file
...@@ -139,7 +144,7 @@ def readFnirt(fname, src, ref): ...@@ -139,7 +144,7 @@ def readFnirt(fname, src, ref):
constants.FSL_TOPUP_QUADRATIC_SPLINE_COEFFICIENTS) constants.FSL_TOPUP_QUADRATIC_SPLINE_COEFFICIENTS)
if img.intent in disps: if img.intent in disps:
return _readFnirtDisplacementField(fname, img, src, ref) return _readFnirtDisplacementField(fname, img, src, ref, dispType)
elif img.intent in coefs: elif img.intent in coefs:
return _readFnirtCoefficientField(fname, img, src, ref) return _readFnirtCoefficientField(fname, img, src, ref)
else: else:
......
...@@ -402,10 +402,10 @@ def readNonLinearX5(fname): ...@@ -402,10 +402,10 @@ def readNonLinearX5(fname):
root = f['/'] root = f['/']
_validateNonLinearTransform(root) _validateNonLinearTransform(root)
if root['SubType'] == 'displacement': subtype = root.attrs['SubType']
field = _readDisplacementField(root)
elif root['SubType'] == 'coefficient': if subtype == 'displacement': field = _readDisplacementField(root)
field = _readCoefficientField(root) elif subtype == 'coefficient': field = _readCoefficientField(root)
return field return field
...@@ -484,8 +484,8 @@ def _writeAffine(group, xform): ...@@ -484,8 +484,8 @@ def _writeAffine(group, xform):
:arg xform: ``numpy`` array containing a ``(4, 4)`` affine transformation. :arg xform: ``numpy`` array containing a ``(4, 4)`` affine transformation.
""" """
xform = np.asarray(xform, dtype=np.float64) xform = np.asarray(xform, dtype=np.float64)
inv = np.asarray(affine.inverse(xform), dtype=np.float64) inv = np.asarray(affine.invert(xform), dtype=np.float64)
group.attrs['Type'] = 'linear' group.attrs['Type'] = 'linear'
group.create_dataset('Transform', data=xform) group.create_dataset('Transform', data=xform)
...@@ -624,7 +624,7 @@ def _writeNonLinearCommon(group, field): ...@@ -624,7 +624,7 @@ def _writeNonLinearCommon(group, field):
_writeAffine(group.create_group('Post'), post) _writeAffine(group.create_group('Post'), post)
if field.srcToRefMat is not None: if field.srcToRefMat is not None:
_writeAffine(group.create_group('InitialAlignment', field.srcToRefMat)) _writeAffine(group.create_group('InitialAlignment'), field.srcToRefMat)
def _readDisplacementField(group): def _readDisplacementField(group):
...@@ -674,15 +674,15 @@ def _readCoefficientField(group): ...@@ -674,15 +674,15 @@ def _readCoefficientField(group):
src, ref, pre, post, init, srcSpace, refSpace = _readNonLinearCommon(group) src, ref, pre, post, init, srcSpace, refSpace = _readNonLinearCommon(group)
field = np.array(group['Transform']) field = np.array(group['Transform'])
ftype = group['Representation'] ftype = group.attrs['Representation']
spacing = group['Parameters/Spacing'] spacing = group['Parameters'].attrs['Spacing']
refToField = _readAffine(group['Parameters/ReferenceToField']) refToField = _readAffine(group['Parameters/ReferenceToField'])
fieldToRef = affine.invert(refToField) fieldToRef = affine.invert(refToField)
if ftype == 'quadratic bspline': ftype = 'quadratic' if ftype == 'quadratic bspline': ftype = 'quadratic'
elif ftype == 'cubic bspline': ftype = 'cubic' elif ftype == 'cubic bspline': ftype = 'cubic'
if spacing.shape != 3: if spacing.shape != (3,):
raise X5Error('Invalid spacing: {}'.format(spacing)) raise X5Error('Invalid spacing: {}'.format(spacing))
field = nonlinear.CoefficientField(field, field = nonlinear.CoefficientField(field,
...@@ -714,7 +714,7 @@ def _writeCoefficientField(group, field): ...@@ -714,7 +714,7 @@ def _writeCoefficientField(group, field):
elif field.fieldType == 'quadratic': elif field.fieldType == 'quadratic':
group.attrs['Representation'] = 'quadratic bspline' group.attrs['Representation'] = 'quadratic bspline'
xform = np.field.data.astype(np.float64) xform = field.data.astype(np.float64)
group.create_dataset('Transform', data=xform) group.create_dataset('Transform', data=xform)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment