diff --git a/fsl/utils/transform.py b/fsl/utils/transform.py
index 4bb17bc570bff89d0614b8e398fb657c411abcb8..d934d21ff5c675fac4f7c96a866c7a1954f9b8c6 100644
--- a/fsl/utils/transform.py
+++ b/fsl/utils/transform.py
@@ -92,7 +92,8 @@ def compose(scales, offsets, rotations, origin=None):
 
     :arg offsets:   Sequence of three offset values.
 
-    :arg rotations: Sequence of three rotation values, in radians.
+    :arg rotations: Sequence of three rotation values, in radians, or
+                    a rotation matrix of shape ``(3, 3)``.
 
     :arg origin:    Origin of rotation - must be scaled by the ``scales``.
                     If not provided, the rotation origin is ``(0, 0, 0)``.
@@ -100,6 +101,12 @@ def compose(scales, offsets, rotations, origin=None):
 
     preRotate  = np.eye(4)
     postRotate = np.eye(4)
+
+    rotations = np.array(rotations)
+
+    if rotations.shape == (3,):
+        rotations = axisAnglesToRotMat(*rotations)
+
     if origin is not None:
         preRotate[ 0, 3] = -origin[0]
         preRotate[ 1, 3] = -origin[1]
@@ -118,7 +125,8 @@ def compose(scales, offsets, rotations, origin=None):
     offset[ 0,  3] = offsets[0]
     offset[ 1,  3] = offsets[1]
     offset[ 2,  3] = offsets[2]
-    rotate[:3, :3] = axisAnglesToRotMat(*rotations)
+
+    rotate[:3, :3] = rotations
 
     return concat(offset, postRotate, rotate, preRotate, scale)