diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index 2790ec5d8549a8ae2352c5aef6cf6c8d862d0842..b579ea69f50c35f8e5060ead812ad7e4c131401e 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -2,6 +2,18 @@ This document contains the ``fslpy`` release history in reverse chronological
 order.
 
 
+2.4.0 (Under development)
+-------------------------
+
+
+Changed
+^^^^^^^
+
+
+* The :mod:`.resample_image` script has been updated to support resampling of
+  images with more than 3 dimensions.
+
+
 2.3.1 (Friday July 5th 2019)
 ----------------------------
 
diff --git a/fsl/scripts/resample_image.py b/fsl/scripts/resample_image.py
index 151c301bf3af108485a67f3d21fd65df9efedca2..b2ef4b9c079b5b01a0bf9df298e11633b5508edf 100644
--- a/fsl/scripts/resample_image.py
+++ b/fsl/scripts/resample_image.py
@@ -20,6 +20,40 @@ import fsl.utils.image.resample as resample
 import fsl.data.image           as fslimage
 
 
+def intlist(val):
+    """Turn a string of comma-separated ints into a list of ints. """
+    return [int(v) for v in val.split(',')]
+
+
+def floatlist(val):
+    """Turn a string of comma-separated floats into a list of floats. """
+    return [float(v) for v in val.split(',')]
+
+
+def sanitiseList(parser, vals, img, arg):
+    """Make sure that ``vals`` has the same number of elements as ``img`` has
+    dimensions. Used to sanitise the ``--shape`` and ``--dim`` options.
+    """
+
+    if vals is None:
+        return vals
+
+    nvals = len(vals)
+
+    if nvals < 3:
+        parser.error('At least three values are '
+                     'required for {}'.format(arg))
+
+    if nvals > img.ndim:
+        parser.error('Input only has {} dimensions - too many values '
+                     'specified for {}'.format(img.ndim, arg))
+
+    if nvals < img.ndim:
+        vals = list(vals) + list(img.shape[nvals:])
+
+    return vals
+
+
 ARGS = {
     'input'     : ('input',),
     'output'    : ('output',),
@@ -36,8 +70,8 @@ OPTS = {
     'input'     : dict(type=parse_data.Image),
     'output'    : dict(type=parse_data.ImageOut),
     'reference' : dict(type=parse_data.Image, metavar='IMAGE'),
-    'shape'     : dict(type=int,   nargs=3, metavar=('X', 'Y', 'Z')),
-    'dim'       : dict(type=float, nargs=3, metavar=('X', 'Y', 'Z')),
+    'shape'     : dict(type=intlist,   metavar=('X,Y,Z,...')),
+    'dim'       : dict(type=floatlist, metavar=('X,Y,Z,...')),
     'interp'    : dict(choices=('nearest', 'linear', 'cubic'),
                        default='linear'),
     'origin'    : dict(choices=('centre', 'corner'), default='centre'),
@@ -110,6 +144,14 @@ def parseArgs(argv):
     args        = parser.parse_args(argv)
     args.interp = INTERPS[   args.interp]
     args.dtype  = DTYPES.get(args.dtype, args.input.dtype)
+    args.shape  = sanitiseList(parser, args.shape, args.input, 'shape')
+    args.dim    = sanitiseList(parser, args.dim,   args.input, 'dim')
+
+    if (args.reference is not None) and \
+       (args.input.ndim     > 3)    and \
+       (args.reference.ndim > 3):
+        print('Reference and image are both >3D - only '
+              'resampling along the spatial dimensions.')
 
     return args
 
@@ -154,6 +196,19 @@ def main(argv=None):
         xform = None
 
     resampled = fslimage.Image(resampled, xform=xform, header=hdr)
+
+    # Adjust the pixdims of the
+    # higher dimensions if they
+    # have been resampled
+    if len(resampled.shape) > 3:
+
+        oldPixdim = args.input.pixdim[3:]
+        oldShape  = args.input.shape[ 3:]
+        newShape  = resampled .shape[ 3:]
+
+        for i, (p, o, n) in enumerate(zip(oldPixdim, oldShape, newShape), 4):
+            resampled.header['pixdim'][i] = p * o / n
+
     resampled.save(args.output)
 
     return 0
diff --git a/fsl/utils/image/resample.py b/fsl/utils/image/resample.py
index cc49a621dca6a5759421e8353578fbb7038eca8e..91e79e6c0416cc01187bc4c636f59ec56aa23668 100644
--- a/fsl/utils/image/resample.py
+++ b/fsl/utils/image/resample.py
@@ -45,15 +45,43 @@ def resampleToReference(image, reference, **kwargs):
     This is a wrapper around :func:`resample` - refer to its documenttion
     for details on the other arguments and the return values.
 
+    When resampling to a reference image, resampling will only be applied
+    along the spatial (first three) dimensions.
+
     :arg image:     :class:`.Image` to resample
     :arg reference: :class:`.Nifti` defining the space to resample ``image``
                     into
     """
 
+    oldShape = list(image.shape)
+    newShape = list(reference.shape[:3])
+
+    # If the input image is >3D, pad the
+    # new shape so that we only resample
+    # along the first 3 dimensions.
+    if len(newShape) < len(oldShape):
+        newShape = newShape + oldShape[len(newShape):]
+
+    # Align the two images together
+    # via their vox-to-world affines.
+    matrix = transform.concat(image.worldToVoxMat, reference.voxToWorldMat)
+
+    # If the input image is >3D, we
+    # have to adjust the resampling
+    # matrix to take into account the
+    # additional dimensions (see scipy.
+    # ndimage.affine_transform)
+    if len(newShape) > 3:
+        rotmat  = matrix[:3, :3]
+        offsets = matrix[:3,  3]
+        matrix  = np.eye(len(newShape) + 1)
+        matrix[:3, :3] = rotmat
+        matrix[:3, -1] = offsets
+
     kwargs['mode']     = kwargs.get('mode', 'constant')
-    kwargs['newShape'] = reference.shape
-    kwargs['matrix']   = transform.concat(image.worldToVoxMat,
-                                          reference.voxToWorldMat)
+    kwargs['newShape'] = newShape
+    kwargs['matrix']   = matrix
+
     return resample(image, **kwargs)
 
 
@@ -182,7 +210,12 @@ def resample(image,
     # might not return a 4x4 matrix, so we
     # make sure it is valid.
     if matrix.shape != (4, 4):
-        matrix = np.vstack((matrix[:3, :4], [0, 0, 0, 1]))
+        rotmat         = matrix[:3, :3]
+        offsets        = matrix[:3, -1]
+        matrix         = np.eye(4)
+        matrix[:3, :3] = rotmat
+        matrix[:3, -1] = offsets
+
     matrix = transform.concat(image.voxToWorldMat, matrix)
 
     return data, matrix
diff --git a/tests/test_image_resample.py b/tests/test_image_resample.py
index 8802ded1eeee128e0d4cc39f833bb27e213527da..d9f474bf45330e5691535a545faa7ddcac9159a4 100644
--- a/tests/test_image_resample.py
+++ b/tests/test_image_resample.py
@@ -6,12 +6,21 @@ import numpy     as np
 
 import pytest
 
+import scipy.ndimage       as ndimage
+
 import fsl.data.image           as     fslimage
 import fsl.utils.transform      as     transform
 import fsl.utils.image.resample as     resample
 
 from . import make_random_image
 
+def random_affine():
+    return transform.compose(
+        0.25   + 4.75      * np.random.random(3),
+        -50    + 100       * np.random.random(3),
+        -np.pi + 2 * np.pi * np.random.random(3))
+
+
 
 def test_resample(seed):
 
@@ -183,13 +192,7 @@ def test_resampleToPixdims():
 
 
 
-def test_resampleToReference():
-
-    def random_v2w():
-        return transform.compose(
-            0.25   + 4.75      * np.random.random(3),
-            -50    + 100       * np.random.random(3),
-            -np.pi + 2 * np.pi * np.random.random(3))
+def test_resampleToReference1():
 
     # Basic test - output has same
     # dimensions/space as reference
@@ -197,8 +200,8 @@ def test_resampleToReference():
 
         ishape = np.random.randint(5, 50, 3)
         rshape = np.random.randint(5, 50, 3)
-        iv2w   = random_v2w()
-        rv2w   = random_v2w()
+        iv2w   = random_affine()
+        rv2w   = random_affine()
         img    = fslimage.Image(make_random_image(dims=ishape, xform=iv2w))
         ref    = fslimage.Image(make_random_image(dims=rshape, xform=rv2w))
         res    = resample.resampleToReference(img, ref)
@@ -206,10 +209,12 @@ def test_resampleToReference():
 
         assert res.sameSpace(ref)
 
+
+def test_resampleToReference2():
+
     # More specific test - output
     # data gets transformed correctly
     # into reference space
-
     img          = np.zeros((5, 5, 5), dtype=np.float)
     img[1, 1, 1] = 1
     img          = fslimage.Image(img)
@@ -223,3 +228,38 @@ def test_resampleToReference():
     exp[2, 2, 2] = 1
 
     assert np.all(np.isclose(res[0], exp))
+
+
+def test_resampleToReference3():
+
+    # Test resampling image to ref
+    # with mismatched dimensions
+    imgdata = np.random.randint(0, 65536, (5, 5, 5))
+    img     = fslimage.Image(imgdata, xform=transform.scaleOffsetXform(
+        (2, 2, 2), (0.5, 0.5, 0.5)))
+
+    # reference/expected data when
+    # resampled to ref (using nn interp).
+    # Same as image, upsampled by a
+    # factor of 2
+    refdata = np.repeat(np.repeat(np.repeat(imgdata, 2, 0), 2, 1), 2, 2)
+    refdata = np.array([refdata] * 8).transpose((1, 2, 3, 0))
+    ref     = fslimage.Image(refdata)
+
+    # We should be able to use a 4D reference
+    resampled, xform = resample.resampleToReference(img, ref, order=0, mode='nearest')
+    assert np.all(resampled == ref.data[..., 0])
+
+    # If resampling a 4D image with a 3D reference,
+    # the fourth dimension should be passed through
+    resampled, xform = resample.resampleToReference(ref, img, order=0, mode='nearest')
+    exp = np.array([imgdata] * 8).transpose((1, 2, 3, 0))
+    assert np.all(resampled == exp)
+
+    # When resampling 4D to 4D, only the
+    # first 3 dimensions should be resampled
+    imgdata = np.array([imgdata] * 15).transpose((1, 2, 3, 0))
+    img     = fslimage.Image(imgdata, xform=img.voxToWorldMat)
+    exp     = np.array([refdata[..., 0]] * 15).transpose((1, 2, 3, 0))
+    resampled, xform = resample.resampleToReference(img, ref, order=0, mode='nearest')
+    assert np.all(resampled == exp)
diff --git a/tests/test_scripts/test_resample_image.py b/tests/test_scripts/test_resample_image.py
index 7fc0cf9765fa4dbba615f88e20e5068b6786e8fd..ed7e944b57e2e4196bf45c10efe341ada0243c84 100644
--- a/tests/test_scripts/test_resample_image.py
+++ b/tests/test_scripts/test_resample_image.py
@@ -15,10 +15,11 @@ from fsl.data.image    import Image
 from .. import make_random_image
 
 
+
 def test_resample_image_shape():
     with tempdir():
         img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10)))
-        resample_image.main('image resampled -s 20 20 20'.split())
+        resample_image.main('image resampled -s 20,20,20'.split())
         res = Image('resampled')
 
         expv2w = transform.concat(
@@ -32,18 +33,36 @@ def test_resample_image_shape():
             np.array(transform.axisBounds(res.shape, res.voxToWorldMat)) - 0.25,
                      transform.axisBounds(img.shape, img.voxToWorldMat)))
 
-        resample_image.main('image resampled -s 20 20 20 -o corner'.split())
+        resample_image.main('image resampled -s 20,20,20 -o corner'.split())
         res = Image('resampled')
         assert np.all(np.isclose(
             transform.axisBounds(res.shape, res.voxToWorldMat),
             transform.axisBounds(img.shape, img.voxToWorldMat)))
 
 
+def test_resample_image_shape_4D():
+    with tempdir():
+
+        # Can specify three dims
+        img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10, 10)))
+        resample_image.main('image resampled -s 20,20,20'.split())
+        res = Image('resampled')
+
+        assert np.all(np.isclose(res.shape, (20, 20, 20, 10)))
+        assert np.all(np.isclose(res.pixdim, (0.5, 0.5, 0.5, 1)))
+
+        # Or resample along the higher dims
+        resample_image.main('image resampled -s 20,20,20,20'.split())
+        res = Image('resampled')
+        assert np.all(np.isclose(res.shape, (20, 20, 20, 20)))
+        assert np.all(np.isclose(res.pixdim, (0.5, 0.5, 0.5, 0.5)))
+
+
 def test_resample_image_dim():
     with tempdir():
         img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10)))
 
-        resample_image.main('image resampled -d 0.5 0.5 0.5'.split())
+        resample_image.main('image resampled -d 0.5,0.5,0.5'.split())
 
         res = Image('resampled')
         expv2w = transform.concat(
@@ -70,12 +89,55 @@ def test_resample_image_ref():
         assert np.all(np.isclose(res.pixdim, (0.5, 0.5, 0.5)))
         assert np.all(np.isclose(res.voxToWorldMat, expv2w))
 
+        # 3D / 4D
+        img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10)))
+        ref = Image(make_random_image('ref.nii.gz',   dims=(20, 20, 20, 20),
+                                      pixdims=(0.5, 0.5, 0.5, 1)))
+
+        resample_image.main('image resampled -r ref'.split())
+        res    = Image('resampled')
+        assert np.all(np.isclose(res.shape, (20, 20, 20)))
+        assert np.all(np.isclose(res.pixdim, (0.5, 0.5, 0.5)))
+
+        # 4D / 3D
+        img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10, 10)))
+        ref = Image(make_random_image('ref.nii.gz',   dims=(20, 20, 20),
+                                      pixdims=(0.5, 0.5, 0.5)))
+
+        resample_image.main('image resampled -r ref'.split())
+        res    = Image('resampled')
+        assert np.all(np.isclose(res.shape, (20, 20, 20, 10)))
+        assert np.all(np.isclose(res.pixdim, (0.5, 0.5, 0.5, 1)))
+
+        # 4D / 4D - no resampling along fourth dim
+        img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10, 10)))
+        ref = Image(make_random_image('ref.nii.gz',   dims=(20, 20, 20, 20),
+                                      pixdims=(0.5, 0.5, 0.5, 1)))
+
+        resample_image.main('image resampled -r ref'.split())
+        res    = Image('resampled')
+        assert np.all(np.isclose(res.shape, (20, 20, 20, 10)))
+        assert np.all(np.isclose(res.pixdim, (0.5, 0.5, 0.5, 1)))
+
 
 def test_resample_image_bad_options():
     with tempdir():
         img = Image(make_random_image('image.nii.gz', dims=(10, 10, 10)))
 
+        # No args - should print help and exit(0)
+        with pytest.raises(SystemExit) as e:
+            resample_image.main([])
+        assert e.value.code == 0
+
+        with pytest.raises(SystemExit) as e:
+            resample_image.main('image resampled -d 0.5,0.5,0.5 '
+                                '-s 20,20,20'.split())
+        assert e.value.code != 0
+
+        with pytest.raises(SystemExit) as e:
+            resample_image.main('image resampled -s 20,20'.split())
+        assert e.value.code != 0
+
         with pytest.raises(SystemExit) as e:
-            resample_image.main('image resampled -d 0.5 0.5 0.5 '
-                                '-s 20 20 20'.split())
+            resample_image.main('image resampled -s 20,20,20,20'.split())
         assert e.value.code != 0