Skip to content
Snippets Groups Projects
Commit e52fc826 authored by Paul McCarthy's avatar Paul McCarthy :mountain_bicyclist:
Browse files

Merge branch 'rf/resample_image' into 'master'

Rf/resample image

See merge request fsl/fslpy!142
parents 434b82f6 9e8ecb14
No related branches found
No related tags found
No related merge requests found
...@@ -2,6 +2,18 @@ This document contains the ``fslpy`` release history in reverse chronological ...@@ -2,6 +2,18 @@ This document contains the ``fslpy`` release history in reverse chronological
order. order.
2.4.0 (Under development)
-------------------------
Changed
^^^^^^^
* The :mod:`.resample_image` script has been updated to support resampling of
images with more than 3 dimensions.
2.3.1 (Friday July 5th 2019) 2.3.1 (Friday July 5th 2019)
---------------------------- ----------------------------
......
...@@ -20,6 +20,40 @@ import fsl.utils.image.resample as resample ...@@ -20,6 +20,40 @@ import fsl.utils.image.resample as resample
import fsl.data.image as fslimage import fsl.data.image as fslimage
def intlist(val):
"""Turn a string of comma-separated ints into a list of ints. """
return [int(v) for v in val.split(',')]
def floatlist(val):
"""Turn a string of comma-separated floats into a list of floats. """
return [float(v) for v in val.split(',')]
def sanitiseList(parser, vals, img, arg):
"""Make sure that ``vals`` has the same number of elements as ``img`` has
dimensions. Used to sanitise the ``--shape`` and ``--dim`` options.
"""
if vals is None:
return vals
nvals = len(vals)
if nvals < 3:
parser.error('At least three values are '
'required for {}'.format(arg))
if nvals > img.ndim:
parser.error('Input only has {} dimensions - too many values '
'specified for {}'.format(img.ndim, arg))
if nvals < img.ndim:
vals = list(vals) + list(img.shape[nvals:])
return vals
ARGS = { ARGS = {
'input' : ('input',), 'input' : ('input',),
'output' : ('output',), 'output' : ('output',),
...@@ -36,8 +70,8 @@ OPTS = { ...@@ -36,8 +70,8 @@ OPTS = {
'input' : dict(type=parse_data.Image), 'input' : dict(type=parse_data.Image),
'output' : dict(type=parse_data.ImageOut), 'output' : dict(type=parse_data.ImageOut),
'reference' : dict(type=parse_data.Image, metavar='IMAGE'), 'reference' : dict(type=parse_data.Image, metavar='IMAGE'),
'shape' : dict(type=int, nargs=3, metavar=('X', 'Y', 'Z')), 'shape' : dict(type=intlist, metavar=('X,Y,Z,...')),
'dim' : dict(type=float, nargs=3, metavar=('X', 'Y', 'Z')), 'dim' : dict(type=floatlist, metavar=('X,Y,Z,...')),
'interp' : dict(choices=('nearest', 'linear', 'cubic'), 'interp' : dict(choices=('nearest', 'linear', 'cubic'),
default='linear'), default='linear'),
'origin' : dict(choices=('centre', 'corner'), default='centre'), 'origin' : dict(choices=('centre', 'corner'), default='centre'),
...@@ -110,6 +144,14 @@ def parseArgs(argv): ...@@ -110,6 +144,14 @@ def parseArgs(argv):
args = parser.parse_args(argv) args = parser.parse_args(argv)
args.interp = INTERPS[ args.interp] args.interp = INTERPS[ args.interp]
args.dtype = DTYPES.get(args.dtype, args.input.dtype) args.dtype = DTYPES.get(args.dtype, args.input.dtype)
args.shape = sanitiseList(parser, args.shape, args.input, 'shape')
args.dim = sanitiseList(parser, args.dim, args.input, 'dim')
if (args.reference is not None) and \
(args.input.ndim > 3) and \
(args.reference.ndim > 3):
print('Reference and image are both >3D - only '
'resampling along the spatial dimensions.')
return args return args
...@@ -154,6 +196,19 @@ def main(argv=None): ...@@ -154,6 +196,19 @@ def main(argv=None):
xform = None xform = None
resampled = fslimage.Image(resampled, xform=xform, header=hdr) resampled = fslimage.Image(resampled, xform=xform, header=hdr)
# Adjust the pixdims of the
# higher dimensions if they
# have been resampled
if len(resampled.shape) > 3:
oldPixdim = args.input.pixdim[3:]
oldShape = args.input.shape[ 3:]
newShape = resampled .shape[ 3:]
for i, (p, o, n) in enumerate(zip(oldPixdim, oldShape, newShape), 4):
resampled.header['pixdim'][i] = p * o / n
resampled.save(args.output) resampled.save(args.output)
return 0 return 0
......
...@@ -45,15 +45,43 @@ def resampleToReference(image, reference, **kwargs): ...@@ -45,15 +45,43 @@ def resampleToReference(image, reference, **kwargs):
This is a wrapper around :func:`resample` - refer to its documenttion This is a wrapper around :func:`resample` - refer to its documenttion
for details on the other arguments and the return values. for details on the other arguments and the return values.
When resampling to a reference image, resampling will only be applied
along the spatial (first three) dimensions.
:arg image: :class:`.Image` to resample :arg image: :class:`.Image` to resample
:arg reference: :class:`.Nifti` defining the space to resample ``image`` :arg reference: :class:`.Nifti` defining the space to resample ``image``
into into
""" """
oldShape = list(image.shape)
newShape = list(reference.shape[:3])
# If the input image is >3D, pad the
# new shape so that we only resample
# along the first 3 dimensions.
if len(newShape) < len(oldShape):
newShape = newShape + oldShape[len(newShape):]
# Align the two images together
# via their vox-to-world affines.
matrix = transform.concat(image.worldToVoxMat, reference.voxToWorldMat)
# If the input image is >3D, we
# have to adjust the resampling
# matrix to take into account the
# additional dimensions (see scipy.
# ndimage.affine_transform)
if len(newShape) > 3:
rotmat = matrix[:3, :3]
offsets = matrix[:3, 3]
matrix = np.eye(len(newShape) + 1)
matrix[:3, :3] = rotmat
matrix[:3, -1] = offsets
kwargs['mode'] = kwargs.get('mode', 'constant') kwargs['mode'] = kwargs.get('mode', 'constant')
kwargs['newShape'] = reference.shape kwargs['newShape'] = newShape
kwargs['matrix'] = transform.concat(image.worldToVoxMat, kwargs['matrix'] = matrix
reference.voxToWorldMat)
return resample(image, **kwargs) return resample(image, **kwargs)
...@@ -182,7 +210,12 @@ def resample(image, ...@@ -182,7 +210,12 @@ def resample(image,
# might not return a 4x4 matrix, so we # might not return a 4x4 matrix, so we
# make sure it is valid. # make sure it is valid.
if matrix.shape != (4, 4): if matrix.shape != (4, 4):
matrix = np.vstack((matrix[:3, :4], [0, 0, 0, 1])) rotmat = matrix[:3, :3]
offsets = matrix[:3, -1]
matrix = np.eye(4)
matrix[:3, :3] = rotmat
matrix[:3, -1] = offsets
matrix = transform.concat(image.voxToWorldMat, matrix) matrix = transform.concat(image.voxToWorldMat, matrix)
return data, matrix return data, matrix
......
...@@ -6,12 +6,21 @@ import numpy as np ...@@ -6,12 +6,21 @@ import numpy as np
import pytest import pytest
import scipy.ndimage as ndimage
import fsl.data.image as fslimage import fsl.data.image as fslimage
import fsl.utils.transform as transform import fsl.utils.transform as transform
import fsl.utils.image.resample as resample import fsl.utils.image.resample as resample
from . import make_random_image from . import make_random_image
def random_affine():
return transform.compose(
0.25 + 4.75 * np.random.random(3),
-50 + 100 * np.random.random(3),
-np.pi + 2 * np.pi * np.random.random(3))
def test_resample(seed): def test_resample(seed):
...@@ -183,13 +192,7 @@ def test_resampleToPixdims(): ...@@ -183,13 +192,7 @@ def test_resampleToPixdims():
def test_resampleToReference(): def test_resampleToReference1():
def random_v2w():
return transform.compose(
0.25 + 4.75 * np.random.random(3),
-50 + 100 * np.random.random(3),
-np.pi + 2 * np.pi * np.random.random(3))
# Basic test - output has same # Basic test - output has same
# dimensions/space as reference # dimensions/space as reference
...@@ -197,8 +200,8 @@ def test_resampleToReference(): ...@@ -197,8 +200,8 @@ def test_resampleToReference():
ishape = np.random.randint(5, 50, 3) ishape = np.random.randint(5, 50, 3)
rshape = np.random.randint(5, 50, 3) rshape = np.random.randint(5, 50, 3)
iv2w = random_v2w() iv2w = random_affine()
rv2w = random_v2w() rv2w = random_affine()
img = fslimage.Image(make_random_image(dims=ishape, xform=iv2w)) img = fslimage.Image(make_random_image(dims=ishape, xform=iv2w))
ref = fslimage.Image(make_random_image(dims=rshape, xform=rv2w)) ref = fslimage.Image(make_random_image(dims=rshape, xform=rv2w))
res = resample.resampleToReference(img, ref) res = resample.resampleToReference(img, ref)
...@@ -206,10 +209,12 @@ def test_resampleToReference(): ...@@ -206,10 +209,12 @@ def test_resampleToReference():
assert res.sameSpace(ref) assert res.sameSpace(ref)
def test_resampleToReference2():
# More specific test - output # More specific test - output
# data gets transformed correctly # data gets transformed correctly
# into reference space # into reference space
img = np.zeros((5, 5, 5), dtype=np.float) img = np.zeros((5, 5, 5), dtype=np.float)
img[1, 1, 1] = 1 img[1, 1, 1] = 1
img = fslimage.Image(img) img = fslimage.Image(img)
...@@ -223,3 +228,38 @@ def test_resampleToReference(): ...@@ -223,3 +228,38 @@ def test_resampleToReference():
exp[2, 2, 2] = 1 exp[2, 2, 2] = 1
assert np.all(np.isclose(res[0], exp)) assert np.all(np.isclose(res[0], exp))
def test_resampleToReference3():
# Test resampling image to ref
# with mismatched dimensions
imgdata = np.random.randint(0, 65536, (5, 5, 5))
img = fslimage.Image(imgdata, xform=transform.scaleOffsetXform(
(2, 2, 2), (0.5, 0.5, 0.5)))
# reference/expected data when
# resampled to ref (using nn interp).
# Same as image, upsampled by a
# factor of 2
refdata = np.repeat(np.repeat(np.repeat(imgdata, 2, 0), 2, 1), 2, 2)
refdata = np.array([refdata] * 8).transpose((1, 2, 3, 0))
ref = fslimage.Image(refdata)
# We should be able to use a 4D reference
resampled, xform = resample.resampleToReference(img, ref, order=0, mode='nearest')
assert np.all(resampled == ref.data[..., 0])
# If resampling a 4D image with a 3D reference,
# the fourth dimension should be passed through
resampled, xform = resample.resampleToReference(ref, img, order=0, mode='nearest')
exp = np.array([imgdata] * 8).transpose((1, 2, 3, 0))
assert np.all(resampled == exp)
# When resampling 4D to 4D, only the
# first 3 dimensions should be resampled
imgdata = np.array([imgdata] * 15).transpose((1, 2, 3, 0))
img = fslimage.Image(imgdata, xform=img.voxToWorldMat)
exp = np.array([refdata[..., 0]] * 15).transpose((1, 2, 3, 0))
resampled, xform = resample.resampleToReference(img, ref, order=0, mode='nearest')
assert np.all(resampled == exp)
...@@ -15,10 +15,11 @@ from fsl.data.image import Image ...@@ -15,10 +15,11 @@ from fsl.data.image import Image
from .. import make_random_image from .. import make_random_image
def test_resample_image_shape(): def test_resample_image_shape():
with tempdir(): with tempdir():
img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10))) img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10)))
resample_image.main('image resampled -s 20 20 20'.split()) resample_image.main('image resampled -s 20,20,20'.split())
res = Image('resampled') res = Image('resampled')
expv2w = transform.concat( expv2w = transform.concat(
...@@ -32,18 +33,36 @@ def test_resample_image_shape(): ...@@ -32,18 +33,36 @@ def test_resample_image_shape():
np.array(transform.axisBounds(res.shape, res.voxToWorldMat)) - 0.25, np.array(transform.axisBounds(res.shape, res.voxToWorldMat)) - 0.25,
transform.axisBounds(img.shape, img.voxToWorldMat))) transform.axisBounds(img.shape, img.voxToWorldMat)))
resample_image.main('image resampled -s 20 20 20 -o corner'.split()) resample_image.main('image resampled -s 20,20,20 -o corner'.split())
res = Image('resampled') res = Image('resampled')
assert np.all(np.isclose( assert np.all(np.isclose(
transform.axisBounds(res.shape, res.voxToWorldMat), transform.axisBounds(res.shape, res.voxToWorldMat),
transform.axisBounds(img.shape, img.voxToWorldMat))) transform.axisBounds(img.shape, img.voxToWorldMat)))
def test_resample_image_shape_4D():
with tempdir():
# Can specify three dims
img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10, 10)))
resample_image.main('image resampled -s 20,20,20'.split())
res = Image('resampled')
assert np.all(np.isclose(res.shape, (20, 20, 20, 10)))
assert np.all(np.isclose(res.pixdim, (0.5, 0.5, 0.5, 1)))
# Or resample along the higher dims
resample_image.main('image resampled -s 20,20,20,20'.split())
res = Image('resampled')
assert np.all(np.isclose(res.shape, (20, 20, 20, 20)))
assert np.all(np.isclose(res.pixdim, (0.5, 0.5, 0.5, 0.5)))
def test_resample_image_dim(): def test_resample_image_dim():
with tempdir(): with tempdir():
img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10))) img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10)))
resample_image.main('image resampled -d 0.5 0.5 0.5'.split()) resample_image.main('image resampled -d 0.5,0.5,0.5'.split())
res = Image('resampled') res = Image('resampled')
expv2w = transform.concat( expv2w = transform.concat(
...@@ -70,12 +89,55 @@ def test_resample_image_ref(): ...@@ -70,12 +89,55 @@ def test_resample_image_ref():
assert np.all(np.isclose(res.pixdim, (0.5, 0.5, 0.5))) assert np.all(np.isclose(res.pixdim, (0.5, 0.5, 0.5)))
assert np.all(np.isclose(res.voxToWorldMat, expv2w)) assert np.all(np.isclose(res.voxToWorldMat, expv2w))
# 3D / 4D
img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10)))
ref = Image(make_random_image('ref.nii.gz', dims=(20, 20, 20, 20),
pixdims=(0.5, 0.5, 0.5, 1)))
resample_image.main('image resampled -r ref'.split())
res = Image('resampled')
assert np.all(np.isclose(res.shape, (20, 20, 20)))
assert np.all(np.isclose(res.pixdim, (0.5, 0.5, 0.5)))
# 4D / 3D
img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10, 10)))
ref = Image(make_random_image('ref.nii.gz', dims=(20, 20, 20),
pixdims=(0.5, 0.5, 0.5)))
resample_image.main('image resampled -r ref'.split())
res = Image('resampled')
assert np.all(np.isclose(res.shape, (20, 20, 20, 10)))
assert np.all(np.isclose(res.pixdim, (0.5, 0.5, 0.5, 1)))
# 4D / 4D - no resampling along fourth dim
img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10, 10)))
ref = Image(make_random_image('ref.nii.gz', dims=(20, 20, 20, 20),
pixdims=(0.5, 0.5, 0.5, 1)))
resample_image.main('image resampled -r ref'.split())
res = Image('resampled')
assert np.all(np.isclose(res.shape, (20, 20, 20, 10)))
assert np.all(np.isclose(res.pixdim, (0.5, 0.5, 0.5, 1)))
def test_resample_image_bad_options(): def test_resample_image_bad_options():
with tempdir(): with tempdir():
img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10))) img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10)))
# No args - should print help and exit(0)
with pytest.raises(SystemExit) as e:
resample_image.main([])
assert e.value.code == 0
with pytest.raises(SystemExit) as e:
resample_image.main('image resampled -d 0.5,0.5,0.5 '
'-s 20,20,20'.split())
assert e.value.code != 0
with pytest.raises(SystemExit) as e:
resample_image.main('image resampled -s 20,20'.split())
assert e.value.code != 0
with pytest.raises(SystemExit) as e: with pytest.raises(SystemExit) as e:
resample_image.main('image resampled -d 0.5 0.5 0.5 ' resample_image.main('image resampled -s 20,20,20,20'.split())
'-s 20 20 20'.split())
assert e.value.code != 0 assert e.value.code != 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment