diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 2790ec5d8549a8ae2352c5aef6cf6c8d862d0842..b579ea69f50c35f8e5060ead812ad7e4c131401e 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,18 @@ This document contains the ``fslpy`` release history in reverse chronological order. +2.4.0 (Under development) +------------------------- + + +Changed +^^^^^^^ + + +* The :mod:`.resample_image` script has been updated to support resampling of + images with more than 3 dimensions. + + 2.3.1 (Friday July 5th 2019) ---------------------------- diff --git a/fsl/scripts/resample_image.py b/fsl/scripts/resample_image.py index 151c301bf3af108485a67f3d21fd65df9efedca2..b2ef4b9c079b5b01a0bf9df298e11633b5508edf 100644 --- a/fsl/scripts/resample_image.py +++ b/fsl/scripts/resample_image.py @@ -20,6 +20,40 @@ import fsl.utils.image.resample as resample import fsl.data.image as fslimage +def intlist(val): + """Turn a string of comma-separated ints into a list of ints. """ + return [int(v) for v in val.split(',')] + + +def floatlist(val): + """Turn a string of comma-separated floats into a list of floats. """ + return [float(v) for v in val.split(',')] + + +def sanitiseList(parser, vals, img, arg): + """Make sure that ``vals`` has the same number of elements as ``img`` has + dimensions. Used to sanitise the ``--shape`` and ``--dim`` options. + """ + + if vals is None: + return vals + + nvals = len(vals) + + if nvals < 3: + parser.error('At least three values are ' + 'required for {}'.format(arg)) + + if nvals > img.ndim: + parser.error('Input only has {} dimensions - too many values ' + 'specified for {}'.format(img.ndim, arg)) + + if nvals < img.ndim: + vals = list(vals) + list(img.shape[nvals:]) + + return vals + + ARGS = { 'input' : ('input',), 'output' : ('output',), @@ -36,8 +70,8 @@ OPTS = { 'input' : dict(type=parse_data.Image), 'output' : dict(type=parse_data.ImageOut), 'reference' : dict(type=parse_data.Image, metavar='IMAGE'), - 'shape' : dict(type=int, nargs=3, metavar=('X', 'Y', 'Z')), - 'dim' : dict(type=float, nargs=3, metavar=('X', 'Y', 'Z')), + 'shape' : dict(type=intlist, metavar=('X,Y,Z,...')), + 'dim' : dict(type=floatlist, metavar=('X,Y,Z,...')), 'interp' : dict(choices=('nearest', 'linear', 'cubic'), default='linear'), 'origin' : dict(choices=('centre', 'corner'), default='centre'), @@ -110,6 +144,14 @@ def parseArgs(argv): args = parser.parse_args(argv) args.interp = INTERPS[ args.interp] args.dtype = DTYPES.get(args.dtype, args.input.dtype) + args.shape = sanitiseList(parser, args.shape, args.input, 'shape') + args.dim = sanitiseList(parser, args.dim, args.input, 'dim') + + if (args.reference is not None) and \ + (args.input.ndim > 3) and \ + (args.reference.ndim > 3): + print('Reference and image are both >3D - only ' + 'resampling along the spatial dimensions.') return args @@ -154,6 +196,19 @@ def main(argv=None): xform = None resampled = fslimage.Image(resampled, xform=xform, header=hdr) + + # Adjust the pixdims of the + # higher dimensions if they + # have been resampled + if len(resampled.shape) > 3: + + oldPixdim = args.input.pixdim[3:] + oldShape = args.input.shape[ 3:] + newShape = resampled .shape[ 3:] + + for i, (p, o, n) in enumerate(zip(oldPixdim, oldShape, newShape), 4): + resampled.header['pixdim'][i] = p * o / n + resampled.save(args.output) return 0 diff --git a/fsl/utils/image/resample.py b/fsl/utils/image/resample.py index cc49a621dca6a5759421e8353578fbb7038eca8e..91e79e6c0416cc01187bc4c636f59ec56aa23668 100644 --- a/fsl/utils/image/resample.py +++ b/fsl/utils/image/resample.py @@ -45,15 +45,43 @@ def resampleToReference(image, reference, **kwargs): This is a wrapper around :func:`resample` - refer to its documenttion for details on the other arguments and the return values. + When resampling to a reference image, resampling will only be applied + along the spatial (first three) dimensions. + :arg image: :class:`.Image` to resample :arg reference: :class:`.Nifti` defining the space to resample ``image`` into """ + oldShape = list(image.shape) + newShape = list(reference.shape[:3]) + + # If the input image is >3D, pad the + # new shape so that we only resample + # along the first 3 dimensions. + if len(newShape) < len(oldShape): + newShape = newShape + oldShape[len(newShape):] + + # Align the two images together + # via their vox-to-world affines. + matrix = transform.concat(image.worldToVoxMat, reference.voxToWorldMat) + + # If the input image is >3D, we + # have to adjust the resampling + # matrix to take into account the + # additional dimensions (see scipy. + # ndimage.affine_transform) + if len(newShape) > 3: + rotmat = matrix[:3, :3] + offsets = matrix[:3, 3] + matrix = np.eye(len(newShape) + 1) + matrix[:3, :3] = rotmat + matrix[:3, -1] = offsets + kwargs['mode'] = kwargs.get('mode', 'constant') - kwargs['newShape'] = reference.shape - kwargs['matrix'] = transform.concat(image.worldToVoxMat, - reference.voxToWorldMat) + kwargs['newShape'] = newShape + kwargs['matrix'] = matrix + return resample(image, **kwargs) @@ -182,7 +210,12 @@ def resample(image, # might not return a 4x4 matrix, so we # make sure it is valid. if matrix.shape != (4, 4): - matrix = np.vstack((matrix[:3, :4], [0, 0, 0, 1])) + rotmat = matrix[:3, :3] + offsets = matrix[:3, -1] + matrix = np.eye(4) + matrix[:3, :3] = rotmat + matrix[:3, -1] = offsets + matrix = transform.concat(image.voxToWorldMat, matrix) return data, matrix diff --git a/tests/test_image_resample.py b/tests/test_image_resample.py index 8802ded1eeee128e0d4cc39f833bb27e213527da..d9f474bf45330e5691535a545faa7ddcac9159a4 100644 --- a/tests/test_image_resample.py +++ b/tests/test_image_resample.py @@ -6,12 +6,21 @@ import numpy as np import pytest +import scipy.ndimage as ndimage + import fsl.data.image as fslimage import fsl.utils.transform as transform import fsl.utils.image.resample as resample from . import make_random_image +def random_affine(): + return transform.compose( + 0.25 + 4.75 * np.random.random(3), + -50 + 100 * np.random.random(3), + -np.pi + 2 * np.pi * np.random.random(3)) + + def test_resample(seed): @@ -183,13 +192,7 @@ def test_resampleToPixdims(): -def test_resampleToReference(): - - def random_v2w(): - return transform.compose( - 0.25 + 4.75 * np.random.random(3), - -50 + 100 * np.random.random(3), - -np.pi + 2 * np.pi * np.random.random(3)) +def test_resampleToReference1(): # Basic test - output has same # dimensions/space as reference @@ -197,8 +200,8 @@ def test_resampleToReference(): ishape = np.random.randint(5, 50, 3) rshape = np.random.randint(5, 50, 3) - iv2w = random_v2w() - rv2w = random_v2w() + iv2w = random_affine() + rv2w = random_affine() img = fslimage.Image(make_random_image(dims=ishape, xform=iv2w)) ref = fslimage.Image(make_random_image(dims=rshape, xform=rv2w)) res = resample.resampleToReference(img, ref) @@ -206,10 +209,12 @@ def test_resampleToReference(): assert res.sameSpace(ref) + +def test_resampleToReference2(): + # More specific test - output # data gets transformed correctly # into reference space - img = np.zeros((5, 5, 5), dtype=np.float) img[1, 1, 1] = 1 img = fslimage.Image(img) @@ -223,3 +228,38 @@ def test_resampleToReference(): exp[2, 2, 2] = 1 assert np.all(np.isclose(res[0], exp)) + + +def test_resampleToReference3(): + + # Test resampling image to ref + # with mismatched dimensions + imgdata = np.random.randint(0, 65536, (5, 5, 5)) + img = fslimage.Image(imgdata, xform=transform.scaleOffsetXform( + (2, 2, 2), (0.5, 0.5, 0.5))) + + # reference/expected data when + # resampled to ref (using nn interp). + # Same as image, upsampled by a + # factor of 2 + refdata = np.repeat(np.repeat(np.repeat(imgdata, 2, 0), 2, 1), 2, 2) + refdata = np.array([refdata] * 8).transpose((1, 2, 3, 0)) + ref = fslimage.Image(refdata) + + # We should be able to use a 4D reference + resampled, xform = resample.resampleToReference(img, ref, order=0, mode='nearest') + assert np.all(resampled == ref.data[..., 0]) + + # If resampling a 4D image with a 3D reference, + # the fourth dimension should be passed through + resampled, xform = resample.resampleToReference(ref, img, order=0, mode='nearest') + exp = np.array([imgdata] * 8).transpose((1, 2, 3, 0)) + assert np.all(resampled == exp) + + # When resampling 4D to 4D, only the + # first 3 dimensions should be resampled + imgdata = np.array([imgdata] * 15).transpose((1, 2, 3, 0)) + img = fslimage.Image(imgdata, xform=img.voxToWorldMat) + exp = np.array([refdata[..., 0]] * 15).transpose((1, 2, 3, 0)) + resampled, xform = resample.resampleToReference(img, ref, order=0, mode='nearest') + assert np.all(resampled == exp) diff --git a/tests/test_scripts/test_resample_image.py b/tests/test_scripts/test_resample_image.py index 7fc0cf9765fa4dbba615f88e20e5068b6786e8fd..ed7e944b57e2e4196bf45c10efe341ada0243c84 100644 --- a/tests/test_scripts/test_resample_image.py +++ b/tests/test_scripts/test_resample_image.py @@ -15,10 +15,11 @@ from fsl.data.image import Image from .. import make_random_image + def test_resample_image_shape(): with tempdir(): img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10))) - resample_image.main('image resampled -s 20 20 20'.split()) + resample_image.main('image resampled -s 20,20,20'.split()) res = Image('resampled') expv2w = transform.concat( @@ -32,18 +33,36 @@ def test_resample_image_shape(): np.array(transform.axisBounds(res.shape, res.voxToWorldMat)) - 0.25, transform.axisBounds(img.shape, img.voxToWorldMat))) - resample_image.main('image resampled -s 20 20 20 -o corner'.split()) + resample_image.main('image resampled -s 20,20,20 -o corner'.split()) res = Image('resampled') assert np.all(np.isclose( transform.axisBounds(res.shape, res.voxToWorldMat), transform.axisBounds(img.shape, img.voxToWorldMat))) +def test_resample_image_shape_4D(): + with tempdir(): + + # Can specify three dims + img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10, 10))) + resample_image.main('image resampled -s 20,20,20'.split()) + res = Image('resampled') + + assert np.all(np.isclose(res.shape, (20, 20, 20, 10))) + assert np.all(np.isclose(res.pixdim, (0.5, 0.5, 0.5, 1))) + + # Or resample along the higher dims + resample_image.main('image resampled -s 20,20,20,20'.split()) + res = Image('resampled') + assert np.all(np.isclose(res.shape, (20, 20, 20, 20))) + assert np.all(np.isclose(res.pixdim, (0.5, 0.5, 0.5, 0.5))) + + def test_resample_image_dim(): with tempdir(): img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10))) - resample_image.main('image resampled -d 0.5 0.5 0.5'.split()) + resample_image.main('image resampled -d 0.5,0.5,0.5'.split()) res = Image('resampled') expv2w = transform.concat( @@ -70,12 +89,55 @@ def test_resample_image_ref(): assert np.all(np.isclose(res.pixdim, (0.5, 0.5, 0.5))) assert np.all(np.isclose(res.voxToWorldMat, expv2w)) + # 3D / 4D + img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10))) + ref = Image(make_random_image('ref.nii.gz', dims=(20, 20, 20, 20), + pixdims=(0.5, 0.5, 0.5, 1))) + + resample_image.main('image resampled -r ref'.split()) + res = Image('resampled') + assert np.all(np.isclose(res.shape, (20, 20, 20))) + assert np.all(np.isclose(res.pixdim, (0.5, 0.5, 0.5))) + + # 4D / 3D + img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10, 10))) + ref = Image(make_random_image('ref.nii.gz', dims=(20, 20, 20), + pixdims=(0.5, 0.5, 0.5))) + + resample_image.main('image resampled -r ref'.split()) + res = Image('resampled') + assert np.all(np.isclose(res.shape, (20, 20, 20, 10))) + assert np.all(np.isclose(res.pixdim, (0.5, 0.5, 0.5, 1))) + + # 4D / 4D - no resampling along fourth dim + img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10, 10))) + ref = Image(make_random_image('ref.nii.gz', dims=(20, 20, 20, 20), + pixdims=(0.5, 0.5, 0.5, 1))) + + resample_image.main('image resampled -r ref'.split()) + res = Image('resampled') + assert np.all(np.isclose(res.shape, (20, 20, 20, 10))) + assert np.all(np.isclose(res.pixdim, (0.5, 0.5, 0.5, 1))) + def test_resample_image_bad_options(): with tempdir(): img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10))) + # No args - should print help and exit(0) + with pytest.raises(SystemExit) as e: + resample_image.main([]) + assert e.value.code == 0 + + with pytest.raises(SystemExit) as e: + resample_image.main('image resampled -d 0.5,0.5,0.5 ' + '-s 20,20,20'.split()) + assert e.value.code != 0 + + with pytest.raises(SystemExit) as e: + resample_image.main('image resampled -s 20,20'.split()) + assert e.value.code != 0 + with pytest.raises(SystemExit) as e: - resample_image.main('image resampled -d 0.5 0.5 0.5 ' - '-s 20 20 20'.split()) + resample_image.main('image resampled -s 20,20,20,20'.split()) assert e.value.code != 0