affine.py 18.9 KB
Newer Older
1
2
#!/usr/bin/env python
#
3
# affine.py - Utility functions for working with affine transformations.
4
5
6
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
7
"""This module contains utility functions for working with affine
8
transformations. The following functions are available:
9
10
11
12
13
14
15
16

.. autosummary::
   :nosignatures:

   transform
   scaleOffsetXform
   invert
   concat
Paul McCarthy's avatar
Paul McCarthy committed
17
   compose
18
   decompose
19
   rotMatToAffine
20
21
   rotMatToAxisAngles
   axisAnglesToRotMat
22
   axisBounds
23
   rmsdev
24
   rescale
25
26
27
28
29
30
31
32

And a few more functions are provided for working with vectors:

.. autosummary::
   :nosignatures:

   veclength
   normalise
33
   transformNormal
34
35
"""

36

37
import collections.abc as abc
Paul McCarthy's avatar
Paul McCarthy committed
38
39
import numpy           as np
import numpy.linalg    as linalg
40
41
42


def invert(x):
43
    """Inverts the given matrix using ``numpy.linalg.inv``. """
44
45
46
    return linalg.inv(x)


47
48
49
50
51
52
53
54
55
def concat(*xforms):
    """Combines the given matrices (returns the dot product)."""

    result = xforms[0]

    for i in range(1, len(xforms)):
        result = np.dot(result, xforms[i])

    return result
56
57


58
def veclength(vec):
59
60
61
62
63
64
    """Returns the length of the given vector(s).

    Multiple vectors may be passed in, with a shape of ``(n, 3)``.
    """
    vec = np.array(vec, copy=False).reshape(-1, 3)
    return np.sqrt(np.einsum('ij,ij->i', vec, vec))
65
66
67


def normalise(vec):
68
69
70
71
72
73
74
75
76
77
78
    """Normalises the given vector(s) to unit length.

    Multiple vectors may be passed in, with a shape of ``(n, 3)``.
    """
    vec = np.array(vec, copy=False).reshape(-1, 3)
    n   = (vec.T / veclength(vec)).T

    if n.size == 3:
        n = n[0]

    return n
79
80


81
82
83
def scaleOffsetXform(scales, offsets):
    """Creates and returns an affine transformation matrix which encodes
    the specified scale(s) and offset(s).
84

85

86
87
88
89
90
91
92
93
94
    :arg scales:  A tuple of up to three values specifying the scale factors
                  for each dimension. If less than length 3, is padded with
                  ``1.0``.

    :arg offsets: A tuple of up to three values specifying the offsets for
                  each dimension. If less than length 3, is padded with
                  ``0.0``.

    :returns:     A ``numpy.float32`` array of size :math:`4 \\times 4`.
95
96
    """

Paul McCarthy's avatar
Paul McCarthy committed
97
    oktypes = (abc.Sequence, np.ndarray)
98
99
100
101
102

    if not isinstance(scales,  oktypes): scales  = [scales]
    if not isinstance(offsets, oktypes): offsets = [offsets]
    if not isinstance(scales,  list):    scales  = list(scales)
    if not isinstance(offsets, list):    offsets = list(offsets)
103
104
105
106

    lens = len(scales)
    leno = len(offsets)

107
108
    if lens < 3: scales  = scales  + [1.0] * (3 - lens)
    if leno < 3: offsets = offsets + [0.0] * (3 - leno)
109

110
    xform = np.eye(4, dtype=np.float64)
111
112
113
114
115

    xform[0, 0] = scales[0]
    xform[1, 1] = scales[1]
    xform[2, 2] = scales[2]

116
117
118
    xform[0, 3] = offsets[0]
    xform[1, 3] = offsets[1]
    xform[2, 3] = offsets[2]
119
120
121
122

    return xform


123
def compose(scales, offsets, rotations, origin=None, shears=None):
124
125
126
127
    """Compose a transformation matrix out of the given scales, offsets
    and axis rotations.

    :arg scales:    Sequence of three scale values.
128

129
    :arg offsets:   Sequence of three offset values.
130

131
132
    :arg rotations: Sequence of three rotation values, in radians, or
                    a rotation matrix of shape ``(3, 3)``.
133

134
135
    :arg origin:    Origin of rotation - must be scaled by the ``scales``.
                    If not provided, the rotation origin is ``(0, 0, 0)``.
136
137

    :arg shears:    Sequence of three shear values
138
139
140
141
    """

    preRotate  = np.eye(4)
    postRotate = np.eye(4)
142
143
144
145
146
147

    rotations = np.array(rotations)

    if rotations.shape == (3,):
        rotations = axisAnglesToRotMat(*rotations)

148
149
150
151
152
153
    if origin is not None:
        preRotate[ 0, 3] = -origin[0]
        preRotate[ 1, 3] = -origin[1]
        preRotate[ 2, 3] = -origin[2]
        postRotate[0, 3] =  origin[0]
        postRotate[1, 3] =  origin[1]
154
        postRotate[2, 3] =  origin[2]
155
156
157
158

    scale  = np.eye(4, dtype=np.float64)
    offset = np.eye(4, dtype=np.float64)
    rotate = np.eye(4, dtype=np.float64)
159
    shear  = np.eye(4, dtype=np.float64)
160

161
162
163
164
165
166
    scale[  0,  0] = scales[ 0]
    scale[  1,  1] = scales[ 1]
    scale[  2,  2] = scales[ 2]
    offset[ 0,  3] = offsets[0]
    offset[ 1,  3] = offsets[1]
    offset[ 2,  3] = offsets[2]
167
168

    rotate[:3, :3] = rotations
169

170
171
172
173
174
175
    if shears is not None:
        shear[0, 1] = shears[0]
        shear[0, 2] = shears[1]
        shear[1, 2] = shears[2]

    return concat(offset, postRotate, rotate, preRotate, scale, shear)
176
177


178
def decompose(xform, angles=True, shears=False):
179
    """Decomposes the given transformation matrix into separate offsets,
180
    scales, and rotations, according to the algorithm described in:
181

182
183
184
    Spencer W. Thomas, Decomposing a matrix into simple transformations, pp
    320-323 in *Graphics Gems II*, James Arvo (editor), Academic Press, 1991,
    ISBN: 0120644819.
185

186
    It is assumed that the given transform has no perspective components.
187

188
    :arg xform:  A ``(3, 3)`` or ``(4, 4)`` affine transformation matrix.
189
190
191
192

    :arg angles: If ``True`` (the default), the rotations are returned
                 as axis-angles, in radians. Otherwise, the rotation matrix
                 is returned.
193

194
195
    :arg shears: Defaults to ``False``. If ``True``, shears are returned.

196
    :returns: The following:
Paul McCarthy's avatar
Paul McCarthy committed
197
198

               - A sequence of three scales
199
200
               - A sequence of three translations (all ``0`` if ``xform``
                 was a ``(3, 3)`` matrix)
Paul McCarthy's avatar
Paul McCarthy committed
201
202
               - A sequence of three rotations, in radians. Or, if
                 ``angles is False``, a rotation matrix.
203
               - If ``shears is True``, a sequence of three shears.
204
    """
205

206
207
    # The inline comments in the code below are taken verbatim from
    # the referenced article, [except for notes in square brackets].
208

209
210
211
    # The next step is to extract the translations. This is trivial;
    # we find t_x = M_{4,1}, t_y = M_{4,2}, and t_z = M_{4,3}. At this
    # point we are left with a 3*3 matrix M' = M_{1..3,1..3}.
212
    xform = np.array(xform).T
213
214
215
216
217
218

    if xform.shape == (4, 4):
        translations = xform[ 3, :3]
        xform        = xform[:3, :3]
    else:
        translations = np.array([0, 0, 0])
219
220
221
222
223
224

    M1 = xform[0]
    M2 = xform[1]
    M3 = xform[2]

    # The process of finding the scaling factors and shear parameters
225
    # is interleaved. First, find s_x = |M'_1|.
226
    sx = np.sqrt(np.dot(M1, M1))
227
    M1 = M1 / sx
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243

    # Then, compute an initial value for the xy shear factor,
    # s_xy = M'_1 * M'_2. (this is too large by the y scaling factor).
    sxy = np.dot(M1, M2)

    # The second row of the matrix is made orthogonal to the first by
    # setting M'_2 = M'_2 - s_xy * M'_1.
    M2 = M2 - sxy * M1

    # Then the y scaling factor, s_y, is the length of the modified
    # second row.
    sy = np.sqrt(np.dot(M2, M2))

    # The second row is normalized, and s_xy is divided by s_y to
    # get its final value.
    M2  = M2  / sy
244
    sxy = sxy / sx
245
246
247
248
249
250
251
252
253
254
255
256
257
258

    # The xz and yz shear factors are computed as in the preceding,
    sxz = np.dot(M1, M3)
    syz = np.dot(M2, M3)

    # the third row is made orthogonal to the first two rows,
    M3 = M3 - sxz * M1 - syz * M2

    # the z scaling factor is computed,
    sz = np.sqrt(np.dot(M3, M3))

    # the third row is normalized, and the xz and yz shear factors are
    # rescaled.
    M3  = M3  / sz
259
260
    sxz = sxz / sx
    syz = syz / sy
261

262
    # The resulting matrix now is a pure rotation matrix, except that it
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    # might still include a scale factor of -1. If the determinant of the
    # matrix is -1, negate the matrix and all three scaling factors. Call
    # the resulting matrix R.
    #
    # [We do things different here - if the rotation matrix has negative
    #  determinant, the flip is encoded in the x scaling factor.]
    R = np.array([M1, M2, M3])
    if linalg.det(R) < 0:
        R[0] = -R[0]
        sx   = -sx

    # Finally, we need to decompose the rotation matrix into a sequence
    # of rotations about the x, y, and z axes. [This is done in the
    # rotMatToAxisAngles function]
277
278
279
    if angles: rotations = rotMatToAxisAngles(R.T)
    else:      rotations = R.T

280
281
282
283
284
285
286
    retval = [np.array([sx, sy, sz]), translations, rotations]

    if shears:
        retval.append(np.array((sxy, sxz, syz)))

    return tuple(retval)

287

288
289
290
291
292
293

def rotMatToAffine(rotmat, origin=None):
    """Convenience function which encodes the given ``(3, 3)`` rotation
    matrix into a ``(4, 4)`` affine.
    """
    return compose([1, 1, 1], [0, 0, 0], rotmat, origin)
294
295
296
297
298
299


def rotMatToAxisAngles(rotmat):
    """Given a ``(3, 3)`` rotation matrix, decomposes the rotations into
    an angle in radians about each axis.
    """
300
301
302
303
304
305
306
307
308
309

    yrot = np.sqrt(rotmat[0, 0] ** 2 + rotmat[1, 0] ** 2)

    if np.isclose(yrot, 0):
        xrot = np.arctan2(-rotmat[1, 2], rotmat[1, 1])
        yrot = np.arctan2(-rotmat[2, 0], yrot)
        zrot = 0
    else:
        xrot = np.arctan2( rotmat[2, 1], rotmat[2, 2])
        yrot = np.arctan2(-rotmat[2, 0], yrot)
310
        zrot = np.arctan2( rotmat[1, 0], rotmat[0, 0])
311
312
313
314
315
316
317

    return [xrot, yrot, zrot]


def axisAnglesToRotMat(xrot, yrot, zrot):
    """Constructs a ``(3, 3)`` rotation matrix from the given angles, which
    must be specified in radians.
318
    """
319
320
321
322

    xmat = np.eye(3)
    ymat = np.eye(3)
    zmat = np.eye(3)
323

324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    xmat[1, 1] =  np.cos(xrot)
    xmat[1, 2] = -np.sin(xrot)
    xmat[2, 1] =  np.sin(xrot)
    xmat[2, 2] =  np.cos(xrot)

    ymat[0, 0] =  np.cos(yrot)
    ymat[0, 2] =  np.sin(yrot)
    ymat[2, 0] = -np.sin(yrot)
    ymat[2, 2] =  np.cos(yrot)

    zmat[0, 0] =  np.cos(zrot)
    zmat[0, 1] = -np.sin(zrot)
    zmat[1, 0] =  np.sin(zrot)
    zmat[1, 1] =  np.cos(zrot)

    return concat(zmat, ymat, xmat)


342
343
344
345
346
347
348
349
def axisBounds(shape,
               xform,
               axes=None,
               origin='centre',
               boundary='high',
               offset=1e-4):
    """Returns the ``(lo, hi)`` bounds of the specified axis/axes in the
    world coordinate system defined by ``xform``.
350

351
352
    If the ``origin`` parameter is set to  ``centre`` (the default),
    this function assumes that voxel indices correspond to the voxel
353
    centre. For example, the voxel at ``(4, 5, 6)`` covers the space:
354

355
      ``[3.5 - 4.5, 4.5 - 5.5, 5.5 - 6.5]``
356

357
358
    So the bounds of the specified shape extends from the corner at

359
      ``(-0.5, -0.5, -0.5)``
360
361
362

    to the corner at

363
      ``(shape[0] - 0.5, shape[1] - 0.5, shape[1] - 0.5)``
364
365
366
367

    If the ``origin`` parameter is set to ``corner``, this function
    assumes that voxel indices correspond to the voxel corner. In this
    case, a voxel  at ``(4, 5, 6)`` covers the space:
368

369
      ``[4 - 5, 5 - 6, 6 - 7]``
370

371
372
    So the bounds of the specified shape extends from the corner at

373
      ``(0, 0, 0)``
374
375
376

    to the corner at

377
      ``(shape[0], shape[1], shape[1])``.
378

379
380
381
382
383
384
385
386

    If the ``boundary`` parameter is set to ``high``, the high voxel bounds
    are reduced by a small amount (specified by the ``offset`` parameter)
    before they are transformed to the world coordinate system.  If
    ``boundary`` is set to ``low``, the low bounds are increased by a small
    amount.  The ``boundary`` parameter can also be set to ``'both'``, or
    ``None``. This option is provided so that you can ensure that the
    resulting bounds will always be contained within the image space.
387

388
389
390
391
392
393
    :arg shape:    The ``(x, y, z)`` shape of the data.

    :arg xform:    Transformation matrix which transforms voxel coordinates
                   to the world coordinate system.

    :arg axes:     The world coordinate system axis bounds to calculate.
394

395
    :arg origin:   Either ``'centre'`` (the default) or ``'corner'``.
396

397
    :arg boundary: Either ``'high'`` (the default), ``'low'``, ''`both'``,
398
                   or ``None``.
399

400
401
    :arg offset:   Amount by which the boundary voxel coordinates should be
                   offset. Defaults to ``1e-4``.
402

403
404
    :returns:      A tuple containing the ``(low, high)`` bounds for each
                   requested world coordinate system axis.
405
    """
406

407
408
409
410
411
412
    origin = origin.lower()

    # lousy US spelling
    if origin == 'center':
        origin = 'centre'

413
414
    if origin not in ('centre', 'corner'):
        raise ValueError('Invalid origin value: {}'.format(origin))
415
    if boundary not in ('low', 'high', 'both', None):
416
        raise ValueError('Invalid boundary value: {}'.format(boundary))
417

418
419
420
421
    scalar = False

    if axes is None:
        axes = [0, 1, 2]
422

Paul McCarthy's avatar
Paul McCarthy committed
423
    elif not isinstance(axes, abc.Iterable):
424
425
        scalar = True
        axes   = [axes]
426

427
    x, y, z = shape[:3]
428

429
430
    points = np.zeros((8, 3), dtype=np.float32)

431
432
433
434
435
436
437
438
439
440
441
442
    if origin == 'centre':
        x0 = -0.5
        y0 = -0.5
        z0 = -0.5
        x -=  0.5
        y -=  0.5
        z -=  0.5
    else:
        x0 = 0
        y0 = 0
        z0 = 0

443
444
445
446
    if boundary in ('low', 'both'):
        x0 += offset
        y0 += offset
        z0 += offset
447

448
449
450
451
452
    if boundary in ('high', 'both'):
        x  -= offset
        y  -= offset
        z  -= offset

453
454
455
456
457
458
459
460
    points[0, :] = [x0, y0, z0]
    points[1, :] = [x0, y0,  z]
    points[2, :] = [x0,  y, z0]
    points[3, :] = [x0,  y,  z]
    points[4, :] = [x,  y0, z0]
    points[5, :] = [x,  y0,  z]
    points[6, :] = [x,   y, z0]
    points[7, :] = [x,   y,  z]
461

462
    tx = transform(points, xform)
463

464
465
    lo = tx[:, axes].min(axis=0)
    hi = tx[:, axes].max(axis=0)
466

467
468
    if scalar: return (lo[0], hi[0])
    else:      return (lo,    hi)
469

470

471
def transform(p, xform, axes=None, vector=False):
472
    """Transforms the given set of points ``p`` according to the given affine
473
474
    transformation ``xform``.

475

476
    :arg p:      A sequence or array of points of shape :math:`N \\times  3`.
477

478
479
    :arg xform:  A ``(4, 4)`` affine transformation matrix with which to
                 transform the points in ``p``.
480

481
482
483
484
485
    :arg axes:   If you are only interested in one or two axes, and the source
                 axes are orthogonal to the target axes (see the note below),
                 you may pass in a 1D, ``N*1``, or ``N*2`` array as ``p``, and
                 use this argument to specify which axis/axes that the data in
                 ``p`` correspond to.
486

487
488
    :arg vector: Defaults to ``False``. If ``True``, the points are treated
                 as vectors - the translation component of the transformation
489
490
                 is not applied. If you set this flag, you pass in a ``(3, 3)``
                 transformation matrix.
491
492
493

    :returns:    The points in ``p``, transformed by ``xform``, as a ``numpy``
                 array with the same data type as the input.
494
495
496
497
498
499
500
501
502


    .. note:: The ``axes`` argument should only be used if the source
              coordinate system (the points in ``p``) axes are orthogonal
              to the target coordinate system (defined by the ``xform``).

              In other words, you can only use the ``axes`` argument if
              the ``xform`` matrix consists solely of translations and
              scalings.
503
504
    """

505
    p  = _fillPoints(p, axes)
506
507
508
509
    t  = np.dot(xform[:3, :3], p.T).T

    if not vector:
        t = t + xform[:3, 3]
510

511
512
    if axes is not None:
        t = t[:, axes]
513

514
515
    if t.size == 1: return t[0]
    else:           return t
516
517


518
519
520
521
522
523
524
525
def transformNormal(p, xform, axes=None):
    """Transforms the given point(s), under the assumption that they
    are normal vectors. In this case, the points are transformed by
    ``invert(xform[:3, :3]).T``.
    """
    return transform(p, invert(xform[:3, :3]).T, axes, vector=True)


526
527
def _fillPoints(p, axes):
    """Used by the :func:`transform` function. Turns the given array p into
528
529
    a ``N*3`` array of ``x,y,z`` coordinates. The array p may be a 1D array,
    or an ``N*2`` or ``N*3`` array.
530
531
    """

Paul McCarthy's avatar
Paul McCarthy committed
532
    if not isinstance(p, abc.Iterable): p = [p]
533

534
535
536
537
    p = np.array(p)

    if axes is None: return p

Paul McCarthy's avatar
Paul McCarthy committed
538
    if not isinstance(axes, abc.Iterable): axes = [axes]
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556

    if p.ndim == 1:
        p = p.reshape((len(p), 1))

    if p.ndim != 2:
        raise ValueError('Points array must be either one or two '
                         'dimensions')

    if len(axes) != p.shape[1]:
        raise ValueError('Points array shape does not match specified '
                         'number of axes')

    newp = np.zeros((len(p), 3), dtype=p.dtype)

    for i, ax in enumerate(axes):
        newp[:, ax] = p[:, i]

    return newp
557
558


559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
def rmsdev(T1, T2, R=None, xc=None):
    """Calculates the RMS deviation of the given affine transforms ``T1`` and
    ``T2``. This can be used as a measure of the 'distance' between two
    affines.

    The ``T1`` and ``T2`` arguments may be either full ``(4, 4)`` affines, or
    ``(3, 3)`` rotation matrices.

    See FMRIB technical report TR99MJ1, available at:

    https://www.fmrib.ox.ac.uk/datasets/techrep/

    :arg T1:  First affine
    :arg T2:  Second affine
    :arg R:   Sphere radius
    :arg xc:  Sphere centre
    :returns: The RMS deviation between ``T1`` and ``T2``.
    """

    if R is None:
        R = 1

    if xc is None:
        xc = np.zeros(3)

    # rotations only
    if T1.shape == (3, 3):
        M = np.dot(T2, invert(T1)) - np.eye(3)
        A = M[:3, :3]
        t = np.zeros(3)

    # full affine
    else:
        M = np.dot(T2, invert(T1)) - np.eye(4)
        A = M[:3, :3]
        t = M[:3,  3]

    Axc = np.dot(A, xc)

    erms = np.dot((t + Axc).T, t + Axc)
    erms = 0.2 * R ** 2 * np.dot(A.T, A).trace() + erms
    erms = np.sqrt(erms)

    return erms
603
604


Paul McCarthy's avatar
Paul McCarthy committed
605
def rescale(oldShape, newShape, origin=None):
606
607
    """Calculates an affine matrix to use for resampling.

Paul McCarthy's avatar
Paul McCarthy committed
608
    This function generates an affine transformation matrix that can be used
609
610
611
612
613
614
615
616
617
618
619
620
621
    to resample an N-D array from ``oldShape`` to ``newShape`` using, for
    example, ``scipy.ndimage.affine_transform``.

    The matrix will contain scaling factors derived from the ``oldShape /
    newShape`` ratio, and an offset determined by the ``origin``.

    The default value for ``origin`` (``'centre'``) causes the corner voxel of
    the output to have the same centre as the corner voxel of the input. If
    the origin is ``'corner'``, we apply an offset which effectively causes
    the voxel grid corners of the input and output to be aligned.

    :arg oldShape: Shape of input data
    :arg newShape: Shape to resample data to
Paul McCarthy's avatar
Paul McCarthy committed
622
623
    :arg origin:   Voxel grid alignment - either ``'centre'`` (the default) or
                   ``'corner'``
624
625
626
    :returns:      An affine resampling matrix
    """

Paul McCarthy's avatar
Paul McCarthy committed
627
628
629
    if origin is None:
        origin = 'centre'

630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
    oldShape = np.array(oldShape, dtype=np.float)
    newShape = np.array(newShape, dtype=np.float)
    ndim     = len(oldShape)

    if len(oldShape) != len(newShape):
        raise ValueError('Shape mismatch')

    # shapes are the same - no rescaling needed
    if np.all(np.isclose(oldShape, newShape)):
        return np.eye(ndim + 1)

    # Otherwise we calculate a scaling
    # matrix from the old/new shape
    # ratio, and specify an offset
    # according to the origin
    ratio = oldShape / newShape
    scale = np.diag(ratio)

    # Calculate an offset from the origin
    if   origin == 'centre': offset = [0] * ndim
    elif origin == 'corner': offset = (ratio - 1) / 2

    # combine the scales and translations
    # to form thte final affine
    xform               = np.eye(ndim + 1)
    xform[:ndim, :ndim] = scale
    xform[:ndim, -1]    = offset

    return xform