test_transform.py 9.89 KB
Newer Older
Paul McCarthy's avatar
Paul McCarthy committed
1
2
3
4
5
6
7
#!/usr/bin/env python
#
# test_transform.py -
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#

8
9
10

from __future__ import division

11
12
13
14
import              glob
import os.path   as op
import itertools as it
import numpy     as np
15

16
17
import six

18
import pytest
Paul McCarthy's avatar
Paul McCarthy committed
19
20

import fsl.utils.transform as transform
21
import fsl.data.image      as fslimage
Paul McCarthy's avatar
Paul McCarthy committed
22

23
24
25
26
27
28
29

datadir = op.join(op.dirname(__file__), 'testdata')


def readlines(filename):
    with open(filename, 'rt') as f:
        lines = f.readlines()
30
31
32
33
34
35
36
37
38
39
40
41
42
43
        lines = [l.strip()         for l in lines]
        lines = [l                 for l in lines if not l.startswith('#')]
        lines = [l                 for l in lines if l != '']

        # numpy.genfromtxt is busted in python 3.
        # Pass it [str, str, ...], and it complains:
        #
        #   TypeError: must be str or None, not bytes
        #
        # Pass it [bytes, bytes, ...], and it works
        # fine.
        if six.PY3:
            lines = [l.encode('ascii') for l in lines]

44
45
46
47
48
49
50
51
    return lines


def test_invert():

    testfile = op.join(datadir, 'test_transform_test_invert.txt')
    testdata = np.loadtxt(testfile)

52
    nmatrices = testdata.shape[0] // 4
53
54
55
56
57
58
59
60
61
62
63

    for i in range(nmatrices):

        x      = testdata[i * 4:i * 4 + 4, 0:4]
        invx   = testdata[i * 4:i * 4 + 4, 4:8]
        result = transform.invert(x)

        assert np.all(np.isclose(invx, result))


def test_concat():
64

65
66
67
68
    testfile = op.join(datadir, 'test_transform_test_concat.txt')
    lines    = readlines(testfile)


69
    ntests = len(lines) // 4
70
71
72
73
74
    tests  = []

    for i in range(ntests):
        ilines = lines[i * 4:i * 4 + 4]
        data    = np.genfromtxt(ilines)
75
        ninputs = data.shape[1] // 4 - 1
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

        inputs  = []

        for j in range(ninputs):
            inputs.append(data[:, j * 4:j * 4 + 4])

        output = data[:, -4:]

        tests.append((inputs, output))

    for inputs, expected in tests:

        result = transform.concat(*inputs)

        assert np.all(np.isclose(result, expected))


Paul McCarthy's avatar
Paul McCarthy committed
93
94
def test_scaleOffsetXform():

95
96
    testfile = op.join(datadir, 'test_transform_test_scaleoffsetxform.txt')
    lines    = readlines(testfile)
97
    ntests   = len(lines) // 5
Paul McCarthy's avatar
Paul McCarthy committed
98

99
    for i in range(ntests):
100

101
        lineoff         = i * 5
102
        scales, offsets = lines[lineoff].decode('ascii').split(',')
Paul McCarthy's avatar
Paul McCarthy committed
103

104
105
        scales  = [float(s) for s in scales .split()]
        offsets = [float(o) for o in offsets.split()]
Paul McCarthy's avatar
Paul McCarthy committed
106

107
108
109
        expected = lines[lineoff + 1: lineoff + 5]
        expected = [[float(v) for v in l.split()] for l in expected]
        expected = np.array(expected)
Paul McCarthy's avatar
Paul McCarthy committed
110

111
112
        result1 = transform.scaleOffsetXform(      scales,        offsets)
        result2 = transform.scaleOffsetXform(tuple(scales), tuple(offsets))
Paul McCarthy's avatar
Paul McCarthy committed
113

114
115
        assert np.all(np.isclose(result1, expected))
        assert np.all(np.isclose(result2, expected))
Paul McCarthy's avatar
Paul McCarthy committed
116
117


118
def test_compose_and_decompose():
Paul McCarthy's avatar
Paul McCarthy committed
119

120
121
    testfile = op.join(datadir, 'test_transform_test_compose.txt')
    lines    = readlines(testfile)
122
    ntests   = len(lines) // 4
Paul McCarthy's avatar
Paul McCarthy committed
123

124
125
126
127
    for i in range(ntests):

        xform                      = lines[i * 4: i * 4 + 4]
        xform                      = np.genfromtxt(xform)
128

129
130
131
        scales, offsets, rotations = transform.decompose(xform)
        result = transform.compose(scales, offsets, rotations)

132
        assert np.all(np.isclose(xform, result, atol=1e-5))
133
134
135
136
137
138
139

        # The decompose function does not support a
        # different rotation origin, but we test
        # explicitly passing the origin for
        # completeness
        scales, offsets, rotations = transform.decompose(xform)
        result = transform.compose(scales, offsets, rotations, [0, 0, 0])
Paul McCarthy's avatar
Paul McCarthy committed
140

141
142
        assert np.all(np.isclose(xform, result, atol=1e-5))

143
144
145
146
147
148
149
150
151
152
153
154
155
    # compose should also accept a rotation matrix
    rots = [np.pi / 5, np.pi / 4, np.pi / 3]
    rmat  = transform.axisAnglesToRotMat(*rots)
    xform = transform.compose([1, 1, 1], [0, 0, 0], rmat)
    sc, of, rot = transform.decompose(xform)
    sc = np.array(sc)
    of = np.array(of)
    rot = np.array(rot)

    assert np.all(sc == [1, 1, 1])
    assert np.all(of == [0, 0, 0])
    assert np.all(np.isclose(rot, rots))

156
157
158
159

def test_axisBounds():
    testfile = op.join(datadir, 'test_transform_test_axisBounds.txt')
    lines    = readlines(testfile)
160
    ntests   = len(lines) // 6
161
162
163

    def readTest(testnum):
        tlines   = lines[testnum * 6: testnum * 6 + 6]
164
        params   = [p.strip() for p in tlines[0].decode('ascii').split(',')]
165
166
167
168
169
170
171
172
        shape    = [int(s) for s in params[0].split()]
        origin   = params[1]
        boundary = None if params[2] == 'None' else params[2]
        xform    = np.genfromtxt(tlines[1:5])
        expected = np.genfromtxt([tlines[5]])
        expected = (expected[:3], expected[3:])

        return shape, origin, boundary, xform, expected
173

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
    allAxes  = list(it.chain(
        range(0, 1, 2),
        it.permutations((0, 1, 2), 1),
        it.permutations((0, 1, 2), 2),
        it.permutations((0, 1, 2), 3)))

    for i in range(ntests):

        shape, origin, boundary, xform, expected = readTest(i)

        for axes in allAxes:
            result = transform.axisBounds(shape,
                                          xform,
                                          axes=axes,
                                          origin=origin,
                                          boundary=boundary)

            exp = expected[0][(axes,)], expected[1][(axes,)]

            assert np.all(np.isclose(exp, result))


196
    # Do some parameter checks on
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
    # the first test in the file
    # which has origin == centre
    for i in range(ntests):
        shape, origin, boundary, xform, expected = readTest(i)
        if origin == 'centre':
            break

    # US-spelling
    assert np.all(np.isclose(
        expected,
        transform.axisBounds(
            shape, xform, origin='center', boundary=boundary)))

    # Bad origin/boundary values
    with pytest.raises(ValueError):
        transform.axisBounds(shape, xform, origin='Blag', boundary=boundary)
    with pytest.raises(ValueError):
214
        transform.axisBounds(shape, xform, origin=origin, boundary='Blufu')
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241


def test_transform():

    def is_orthogonal(xform):
        """Returns ``True`` if the given xform consists
        solely of translations and scales.
        """

        mask = np.array([[1, 0, 0, 1],
                         [0, 1, 0, 1],
                         [0, 0, 1, 1],
                         [0, 0, 0, 1]], dtype=np.bool)

        return np.all((xform != 0) == mask)

    coordfile   = op.join(datadir, 'test_transform_test_transform_coords.txt')
    testcoords  = np.loadtxt(coordfile)

    testpattern = op.join(datadir, 'test_transform_test_transform_??.txt')
    testfiles   = glob.glob(testpattern)

    allAxes  = list(it.chain(
        range(0, 1, 2),
        it.permutations((0, 1, 2), 1),
        it.permutations((0, 1, 2), 2),
        it.permutations((0, 1, 2), 3)))
242

243
    for i, testfile in enumerate(testfiles):
244

245
246
247
248
        lines    = readlines(testfile)
        xform    = np.genfromtxt(lines[:4])
        expected = np.genfromtxt(lines[ 4:])
        result   = transform.transform(testcoords, xform)
249

250
251
252
253
        assert np.all(np.isclose(expected, result))

        if not is_orthogonal(xform):
            continue
254

255
256
257
258
259
260
261
262
263
264
265
266
267
268
        for axes in allAxes:
            atestcoords = testcoords[:, axes]
            aexpected   = expected[  :, axes]
            aresult     = transform.transform(atestcoords, xform, axes=axes)

            assert np.all(np.isclose(aexpected, aresult))

    # Pass in some bad data, expect an error
    xform     = np.eye(4)
    badxform  = np.eye(3)
    badcoords = np.random.randint(1, 10, (10, 4))
    coords    = badcoords[:, :3]

    with pytest.raises(IndexError):
269
        transform.transform(coords, badxform)
270
271
272

    with pytest.raises(ValueError):
        transform.transform(badcoords, xform)
273

274
275
276
277
    with pytest.raises(ValueError):
        transform.transform(badcoords.reshape(5, 2, 4), xform)

    with pytest.raises(ValueError):
278
        transform.transform(badcoords.reshape(5, 2, 4), xform, axes=1)
279
280
281

    with pytest.raises(ValueError):
        transform.transform(badcoords[:, (1, 2, 3)], xform, axes=[1, 2])
282
283
284


def test_flirtMatrixToSform():
285

286
287
    testfile = op.join(datadir, 'test_transform_test_flirtMatrixToSform.txt')
    lines    = readlines(testfile)
288
    ntests   = len(lines) // 18
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309

    for i in range(ntests):
        tlines    = lines[i * 18: i * 18 + 18]
        srcShape  = [int(  w) for w in tlines[0].split()]
        srcXform  = np.genfromtxt(tlines[1:5])
        refShape  = [int(  w) for w in tlines[5].split()]
        refXform  = np.genfromtxt(tlines[6:10])
        flirtMat  = np.genfromtxt(tlines[10:14])
        expected  = np.genfromtxt(tlines[14:18])

        srcImg = fslimage.Image(np.zeros(srcShape), xform=srcXform)
        refImg = fslimage.Image(np.zeros(refShape), xform=refXform)

        result = transform.flirtMatrixToSform(flirtMat, srcImg, refImg)

        assert np.all(np.isclose(result, expected))


def test_sformToFlirtMatrix():
    testfile = op.join(datadir, 'test_transform_test_flirtMatrixToSform.txt')
    lines    = readlines(testfile)
310
    ntests   = len(lines) // 18
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331

    for i in range(ntests):
        tlines      = lines[i * 18: i * 18 + 18]
        srcShape    = [int(  w) for w in tlines[0].split()]
        srcXform    = np.genfromtxt(tlines[1:5])
        refShape    = [int(  w) for w in tlines[5].split()]
        refXform    = np.genfromtxt(tlines[6:10])
        expected    = np.genfromtxt(tlines[10:14])
        srcXformOvr = np.genfromtxt(tlines[14:18])

        srcImg1 = fslimage.Image(np.zeros(srcShape), xform=srcXform)
        srcImg2 = fslimage.Image(np.zeros(srcShape), xform=srcXform)
        refImg  = fslimage.Image(np.zeros(refShape), xform=refXform)

        srcImg2.voxToWorldMat = srcXformOvr

        result1 = transform.sformToFlirtMatrix(srcImg1, refImg, srcXformOvr)
        result2 = transform.sformToFlirtMatrix(srcImg2, refImg)

        assert np.all(np.isclose(result1, expected))
        assert np.all(np.isclose(result2, expected))