test_atlases_query.py 10.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
#!/usr/bin/env python
#
# test_atlases_query.py -
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#

import                    os
import itertools       as it
import numpy           as np
import                    pytest


14
15
import fsl.data.atlases         as fslatlases
import fsl.data.image           as fslimage
Paul McCarthy's avatar
Paul McCarthy committed
16
import fsl.transform.affine     as affine
Paul McCarthy's avatar
Paul McCarthy committed
17
import fsl.utils.image.resample as resample
18
import fsl.utils.cache          as cache
19
20
21
22

from . import (testdir, make_random_mask)


23
24
25
pytestmark = pytest.mark.fsltest


26
27
28
29
30
31
32
33
34
35
36
37
38
39
def setup_module():
    if os.environ.get('FSLDIR', None) is None:
        raise Exception('FSLDIR is not set - atlas tests cannot be run')
    fslatlases.rescanAtlases()


# why this is not built into
# in itertools i don't even
def _repeat(iterator, n):
    for elem in iterator:
        for i in range(n):
            yield elem


40
_atlases = cache.Cache()
41
def _get_atlas(atlasID, res, summary=False):
Paul McCarthy's avatar
Paul McCarthy committed
42
    atlas = _atlases.get((atlasID, res, summary), default=None)
43
    if atlas is None:
44
        if summary or atlasID in ('talairach', 'striatum-structural',
45
                                  'jhu-labels', 'smatt'):
46
47
48
            kwargs = {}
        else:
            kwargs = {'loadData'  : False,
49
                      'calcRange' : False}
50

51
52
        atlas = fslatlases.loadAtlas(atlasID,
                                     loadSummary=summary,
Paul McCarthy's avatar
Paul McCarthy committed
53
                                     resolution=res,
54
                                     **kwargs)
Paul McCarthy's avatar
Paul McCarthy committed
55
56
        _atlases.put((atlasID, res, summary), atlas)

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    return atlas

def _random_atlas(atype, res, summary=False):

    if atype == 'prob':
        atype = 'probabilistic'

    atlases = fslatlases.listAtlases()
    atlases = [a for a in atlases if a.atlasType == atype]
    desc    = atlases[np.random.randint(0, len(atlases))]
    return _get_atlas(desc.atlasID, res, summary)


# Generate a mask which tells us which
# voxels in the atlas are all zeros
72
_zero_masks = cache.Cache(maxsize=5)
73
74
75
76
77
78
79
80
81
82
83
84
85
def _get_zero_mask(aimg):

    atlasID = aimg.desc.atlasID
    res     = aimg.pixdim[0]
    summary = isinstance(aimg, fslatlases.LabelAtlas) \
              and aimg.desc.atlasType == 'probabilistic'

    zmask = _zero_masks.get((atlasID, summary, res), None)

    if zmask is None:
        if isinstance(aimg, fslatlases.LabelAtlas):
            zmask = aimg[:] == 0
        elif isinstance(aimg, fslatlases.ProbabilisticAtlas):
Paul McCarthy's avatar
Paul McCarthy committed
86
87

            # Keep memory usage down
88
            zmask = np.ones(aimg.shape[:3], dtype=bool)
Paul McCarthy's avatar
Paul McCarthy committed
89
90
91
            for vol in range(aimg.shape[-1]):
                zmask = np.logical_and(zmask, aimg[..., vol] == 0)

92
93
94
95
96
97
98
        _zero_masks[atlasID, summary, res] = zmask

    return zmask


def test_label_coord_query(  seed): _test_query('coord', 'label')
def test_label_voxel_query(  seed): _test_query('voxel', 'label')
99
@pytest.mark.longtest
100
101
102
def test_label_mask_query(   seed): _test_query('mask',  'label')
def test_summary_coord_query(seed): _test_query('coord', 'prob', summary=True)
def test_summary_voxel_query(seed): _test_query('voxel', 'prob', summary=True)
103
@pytest.mark.longtest
104
105
106
def test_summary_mask_query( seed): _test_query('mask',  'prob', summary=True)
def test_prob_coord_query(   seed): _test_query('coord', 'prob')
def test_prob_voxel_query(   seed): _test_query('voxel', 'prob')
Paul McCarthy's avatar
Paul McCarthy committed
107
@pytest.mark.longtest
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def test_prob_mask_query(    seed): _test_query('mask',  'prob')


# qtype: (voxel|coord|mask)
# atype: (label|prob)
def _test_query(qtype, atype, summary=False):

    qins  = ['in', 'zero', 'out']
    reses = [1, 2]

    if qtype == 'mask': maskreses = [1, 2]
    else:               maskreses = [1]

    tests = _repeat(it.product(qins, reses, maskreses), 5)

    for qin, res, maskres in tests:

Paul McCarthy's avatar
Paul McCarthy committed
125
        atlas = _random_atlas(atype, res=res, summary=summary)
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

        with testdir():

            if qtype in ('voxel', 'coord'):
                genfunc  = _gen_coord_voxel_query
                evalfunc = _eval_coord_voxel_query
            else:
                genfunc  = _gen_mask_query
                evalfunc = _eval_mask_query

            print('Test: {} {}mm type={} in={}'.format(
                atlas.desc.atlasID, res, qtype, qin))

            query = genfunc(atlas, qtype, qin, maskres=maskres)
            evalfunc(atlas, query, qtype, qin)


# Generate a random voxel/world space
# coordinate to query the given atlas.
def _gen_coord_voxel_query(atlas, qtype, qin, **kwargs):

    voxel = qtype == 'voxel'

    if voxel: dtype = int
    else:     dtype = float

    if qin == 'out':

        if voxel:
            dlo = (0, 0, 0)
            dhi = atlas.shape
        else:
Paul McCarthy's avatar
Paul McCarthy committed
158
            dlo, dhi = affine.axisBounds(atlas.shape, atlas.voxToWorldMat)
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192

        dlen = [hi - lo for lo, hi in zip(dlo, dhi)]

        coords = []
        for d in range(3):

            # over
            if np.random.random() > 0.5:
                coords.append(dlo[d] + dlen[d] + dlen[d] * np.random.random())
            # or under
            else:
                coords.append(dlo[d] - dlen[d] * np.random.random())

        coords = np.array(coords, dtype=dtype)

    else:

        # Make a mask which tells us which
        # voxels in the atlas are all zeros
        zmask = _get_zero_mask(atlas)

        # get indices to voxels which are
        # either all zero, or which are
        # not all all zero, depending on
        # the value of q_in
        if qin == 'in': zidxs = np.where(zmask == 0)
        else:           zidxs = np.where(zmask)

        # Randomly choose a voxel
        cidx   = np.random.randint(0, len(zidxs[0]))
        coords = [zidxs[0][cidx], zidxs[1][cidx], zidxs[2][cidx]]
        coords = np.array(coords, dtype=dtype)

        if not voxel:
Paul McCarthy's avatar
Paul McCarthy committed
193
            coords = affine.transform(coords, atlas.voxToWorldMat)
194
195
196
197
198
199
200
201
202

    return tuple([dtype(c) for c in coords])


def _eval_coord_voxel_query(atlas, query, qtype, qin):

    voxel = qtype == 'voxel'

    if voxel: vx, vy, vz = query
Paul McCarthy's avatar
Paul McCarthy committed
203
    else:     vx, vy, vz = affine.transform(query, atlas.worldToVoxMat)
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220

    vx, vy, vz = [int(round(v)) for v in [vx, vy, vz]]

    def evalLabel():
        if qin in ('in', 'zero'): expval = atlas[vx, vy, vz]
        else:                     expval = None

        assert atlas.label(     query, voxel=voxel) == expval
        assert atlas.coordLabel(query, voxel=voxel) == expval

    def evalProb():
        if qin in ('in', 'zero'):
            expval = atlas[vx, vy, vz, :]
            expval = [expval[l.index] for l in atlas.desc.labels]
        elif qin == 'out':
            expval = []

221
222
        assert atlas.values(     query, voxel=voxel) == expval
        assert atlas.coordValues(query, voxel=voxel) == expval
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259

    if   isinstance(atlas, fslatlases.LabelAtlas):         evalLabel()
    elif isinstance(atlas, fslatlases.ProbabilisticAtlas): evalProb()


def _gen_mask_query(atlas, qtype, qin, maskres):

    maskfile = 'mask.nii.gz'
    res      = atlas.pixdim[0]

    if qin == 'out':
        make_random_mask(maskfile, (20, 20, 20), np.eye(4))
        return maskfile

    zmask = _get_zero_mask(atlas)

    if qin == 'in':
        zmask = zmask == 0

    mask = make_random_mask(
        maskfile, atlas.shape[:3], atlas.voxToWorldMat, zmask)

    # Make sure that when the mask gets
    # resampled into the atlas resolution,
    # it is still either in or out of the
    # atlas space
    if maskres != res:
        a       = _get_atlas(atlas.desc.atlasID, maskres, True)
        a_zmask = _get_zero_mask(a)

        if qin == 'in':
            a_zmask = a_zmask == 0

        # use linear interp and threshold
        # aggresively to make sure there
        # is no overlap between the different
        # resolutions
Paul McCarthy's avatar
Paul McCarthy committed
260
261
        mask, xform = resample.resample(
            mask, a.shape[:3], dtype=np.float32, order=1)
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

        mask[mask   < 1.0] = 0
        mask[a_zmask == 0] = 0

        mask = np.array(mask, dtype=np.uint8)
        mask = fslimage.Image(mask, xform=xform)

        mask.save(maskfile)

    return maskfile


def _eval_mask_query(atlas, query, qtype, qin):

    mask    = fslimage.Image(query)
    prob    = atlas.desc.atlasType == 'probabilistic'
    maskres = mask .pixdim[0]
    res     = atlas.pixdim[0]

    if maskres == res:
282
        rmask = mask[:]
283
    else:
Paul McCarthy's avatar
Paul McCarthy committed
284
285
        rmask = resample.resample(
            mask, atlas.shape[:3], dtype=np.float32, order=0)[0]
Paul McCarthy's avatar
Paul McCarthy committed
286

287
    rmask = np.array(rmask, dtype=np.bool)
288
289
290
291
292
293
294
295
296
297

    def evalLabel():

        if qin == 'out':
            with pytest.raises(fslatlases.MaskError): atlas.maskLabel(mask)
            with pytest.raises(fslatlases.MaskError): atlas.label(    mask)
            return

        if qin == 'in':

298
299
300
            voxels    = np.array(np.where(rmask)).T
            maxval    = int(atlas[:].max())
            valcounts = np.zeros((maxval + 1, ))
301
302
303
304
            nvoxels   = voxels.shape[0]

            for x, y, z in voxels:
                x, y, z = [int(v) for v in [x, y, z]]
305
                valcounts[int(atlas[x, y, z])] += 1.0
306
307
308

            # make sure the values are sorted
            # according to their atlas ordering
309
            expvals   = np.where(valcounts > 0)[0]
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
            explabels = []

            # There may be more values in an image
            # than are listed in the atlas spec :(
            for v in expvals:
                try:             explabels.append(atlas.find(value=int(v)))
                except KeyError: pass
            explabels = list(sorted(explabels))
            expvals   = [l.value for l in explabels]
            expprops  = [100 * valcounts[v] / nvoxels for v in expvals]

        else:
            if prob:
                expvals  = []
                expprops = []
            else:
                allvals = [l.value for l in atlas.desc.labels]
                if 0 in allvals:
                    expvals  = [0]
                    expprops = [100]
                else:
                    expvals  = []
                    expprops = []

        vals,  props  = atlas.    label(mask)
        vals2, props2 = atlas.maskLabel(mask)

        assert np.all(np.isclose(vals,  vals2))
        assert np.all(np.isclose(props, props2))
        assert np.all(np.isclose(vals,  expvals))
        assert np.all(np.isclose(props, expprops))

    def evalProb():

        if qin == 'out':
            with pytest.raises(fslatlases.MaskError):
346
                atlas.maskValues(mask)
347
            with pytest.raises(fslatlases.MaskError):
348
                atlas.values(    mask)
349
350
            return

351
352
        props  = atlas.    values(mask)
        props2 = atlas.maskValues(mask)
353
354
355
356
357

        assert np.all(np.isclose(props, props2))

    if   isinstance(atlas, fslatlases.LabelAtlas):         evalLabel()
    elif isinstance(atlas, fslatlases.ProbabilisticAtlas): evalProb()