Commit f480a89a authored by Paul McCarthy's avatar Paul McCarthy 🚵
Browse files

unit tests for atlasq

parent 9c5ce793
#!/usr/bin/env python
#
# test_list_summary.py -
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
import os
import itertools as it
import fsl.data.atlases as fslatlases
import fsl.atlasq.atlasq as fslatlasq
from . import CaptureStdout
def setup_module():
if os.environ.get('FSLDIR', None) is None:
raise Exception('FSLDIR is not set - atlas tests cannot be run')
def test_list():
fslatlases.rescanAtlases()
adescs = fslatlases.listAtlases()
capture = CaptureStdout()
tests = ['list', 'list --extended']
extendeds = [False, True]
for test, extended in zip(tests, extendeds):
capture.reset()
with capture:
fslatlasq.main(test.split())
stdout = capture.stdout
for desc in adescs:
assert desc.atlasID in stdout
assert desc.name in stdout
assert (desc.specPath in stdout) == extended
for image in it.chain(desc.images, desc.summaryImages):
assert (image in stdout) == extended
assert (image in stdout) == extended
def test_summary():
fslatlases.rescanAtlases()
adescs = fslatlases.listAtlases()
capture = CaptureStdout()
for desc in adescs:
tests = [desc.atlasID, desc.name]
for test in tests:
capture.reset()
with capture:
fslatlasq.main(['summary', test])
stdout = capture.stdout
assert desc.atlasID in stdout
assert desc.name in stdout
assert desc.specPath in stdout
assert desc.atlasType in stdout
for image in it.chain(desc.images, desc.summaryImages):
assert image in stdout
for label in desc.labels:
assert label.name in stdout
#!/usr/bin/env python
#
# test_ohi.py - Test the fsl.atlasq ohi interface, which mimics the behaviour
# of the old atlasquery tool.
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
import os.path as op
import os
import shlex
import numpy as np
import fsl.atlasq.atlasq as fslatlasq
import fsl.data.atlases as fslatlases
from . import (tempdir,
make_random_mask,
CaptureStdout)
def setup_module():
if os.environ.get('FSLDIR', None) is None:
raise Exception('FSLDIR is not set - atlas tests cannot be run')
def test_dumpatlases():
"""Test the ohi --dumpatlases option. """
capture = CaptureStdout()
with capture:
fslatlasq.main('ohi --dumpatlases'.split())
atlases = fslatlases.listAtlases()
atlases = [a.name for a in atlases]
atlases = sorted(atlases)
assert capture.stdout.strip() == '\n'.join(atlases)
def test_coords(seed):
"""Test the ohi -a "atlas" -c "coords" mode. """
def expectedProbOutput(atlas, coords):
probs = atlas.proportions(coords)
expected = '<b>{}</b><br>'.format(atlas.desc.name)
nzprobs = []
for i, p in enumerate(probs):
if p > 0:
label = atlas.desc.labels[i].name
nzprobs.append((p, label))
if len(nzprobs) > 0:
nzprobs = reversed(sorted(nzprobs, key=lambda b: b[0]))
nzprobs = ['{:d}% {}'.format(int(round(p)), l) for
(p, l) in nzprobs]
expected += ', '.join(nzprobs)
else:
expected += 'No label found!'
return expected
def expectedLabelOutput(atlas, coords):
label = atlas.label(coords)
expected = '<b>{}</b><br>'.format(atlas.desc.name)
if label is None:
return expected + 'Unclassified'
else:
return expected + atlas.desc.labels[int(label)].name
capture = CaptureStdout()
# random coordinates in MNI152 space,
# with some coordinates out of bounds
ncoords = 50
xc = -100 + 190 * np.random.random(ncoords)
yc = -130 + 220 * np.random.random(ncoords)
zc = -80 + 120 * np.random.random(ncoords)
coords = np.vstack((xc, yc, zc)).T
fslatlases.rescanAtlases()
atlases = fslatlases.listAtlases()
for ad in atlases:
# atlasquery/ohi always uses 2mm resolution
atlas = fslatlases.loadAtlas(ad.atlasID, resolution=2)
print(ad.name)
for x, y, z in coords:
cmd = 'ohi -a "{}" -c "{},{},{}"'.format(ad.name, x, y, z)
capture.reset()
with capture:
fslatlasq.main(shlex.split(cmd))
if isinstance(atlas, fslatlases.ProbabilisticAtlas):
expected = expectedProbOutput(atlas, (x, y, z))
# LabelAtlas
else:
expected = expectedLabelOutput(atlas, (x, y, z))
assert capture.stdout.strip() == expected.strip()
def test_bad_atlas():
"""Test the ohi -a "atlas" ..., with a non-existent atlas. """
capture = CaptureStdout()
atlases = fslatlases.listAtlases()
atlases = sorted([a.name for a in atlases])
expected = ['Invalid atlas name. Try one of:'] + atlases
expected = '\n'.join(expected)
cmds = ['ohi -a "non-existent atlas" -c "0,0,0"',
'ohi -a "non-existent atlas" -m "nomask"']
for cmd in cmds:
capture.reset()
with capture:
fslatlasq.main(shlex.split(cmd))
assert capture.stdout.strip() == expected
def test_mask(seed):
"""Test the ohi -a "atlas" -m "mask" mode, with label and probabilistic
atlases.
"""
def expectedLabelOutput(mask, atlas):
labels, props = atlas.maskLabel(mask)
exp = []
for lbl, prop in zip(labels, props):
exp.append('{}:{:0.4f}'.format(desc.labels[int(lbl)].name,
prop))
return '\n'.join(exp)
def expectedProbOutput(mask, atlas):
props = atlas.maskProportions(mask)
labels = [l.index for l in atlas.desc.labels]
exp = []
for lbl, prop in zip(labels, props):
if prop > 0:
exp.append('{}:{:0.4f}'.format(desc.labels[int(lbl)].name,
prop))
return '\n'.join(exp)
fslatlases.rescanAtlases()
capture = CaptureStdout()
atlases = fslatlases.listAtlases()
with tempdir() as td:
maskfile = op.join(td, 'mask.nii')
for desc in atlases:
# atlasquery always uses 2mm
# resolution versions of atlases
atlas2mm = fslatlases.loadAtlas(desc.atlasID, resolution=2)
# Test with 1mm and 2mm masks
for res in [1, 2]:
atlasimg = fslatlases.loadAtlas(desc.atlasID, resolution=res)
maskimg = make_random_mask(maskfile,
atlasimg.shape[:3],
atlasimg.voxToWorldMat)
cmd = 'ohi -a "{}" -m {}'.format(desc.name, maskfile)
print(cmd)
capture.reset()
with capture:
fslatlasq.main(shlex.split(cmd))
if isinstance(atlasimg, fslatlases.LabelAtlas):
expected = expectedLabelOutput(maskimg, atlas2mm)
elif isinstance(atlasimg, fslatlases.ProbabilisticAtlas):
expected = expectedProbOutput(maskimg, atlas2mm)
assert capture.stdout.strip() == expected
#!/usr/bin/env python
#
# test_query.py -
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
import os
import os.path as op
import re
import itertools as it
import six
import numpy as np
import scipy.ndimage as ndi
import fsl.utils.transform as transform
import fsl.data.atlases as fslatlases
import fsl.data.image as fslimage
import fsl.atlasq.atlasq as fslatlasq
from . import (tempdir,
make_random_mask,
CaptureStdout)
def setup_module():
if os.environ.get('FSLDIR', None) is None:
raise Exception('FSLDIR is not set - atlas tests cannot be run')
atlases = ['harvardoxford-cortical', 'talairach']
# False: do not use the --label flag
# True: use the --label flag
#
# ('label', True) should be equivalent to ('label', False)
use_labels = [True, False]
# mask_1 == 1mm mask image
# mask_2 == 2mm mask image
query_types = ['coordinate',
'voxel',
'mask_1',
'mask_2']
# 'in': query is inside the atlas (results expected)
# 'zero': query is inside the atlas, but out of any region (no results
# expected)
# 'out': query is outside the atlas (no results expected)
query_is_in = ['in', 'zero', 'out']
resolutions = [1, 2]
output_types = ['normal', 'short']
tests = list(it.product(atlases,
use_labels,
query_types,
query_is_in,
resolutions,
output_types))
# side-test - multiple queries together,
# in both short and normal output formats
def test_query_voxel(seed):
_test_query([t for t in tests if t[2] == 'voxel'])
def test_query_coord(seed):
_test_query([t for t in tests if t[2] == 'coordinate'])
def test_query_mask(seed):
_test_query([t for t in tests if t[2].startswith('mask')])
def _test_query(tests):
fslatlases.rescanAtlases()
capture = CaptureStdout()
print()
for atlas, use_label, q_type, q_in, res, o_type in tests:
with tempdir() as td:
if q_type in ('voxel', 'coordinate'):
genfunc = _gen_coord_voxel_query
evalfunc = _eval_coord_voxel_query
else:
genfunc = _gen_mask_query
evalfunc = _eval_mask_query
print('Test: {} {}mm label={} type={} in={} type={}'.format(
atlas, res, use_label, q_type, q_in, o_type))
query = genfunc(atlas, use_label, q_type, q_in, res)
cmd = _build_command_line(
atlas, query, use_label, q_type, res, o_type)
print('fslatlasq {}'.format(' '.join(cmd)))
capture.reset()
with capture:
fslatlasq.main(cmd)
evalfunc(capture.stdout,
atlas,
query,
use_label,
q_type,
q_in,
res,
o_type)
_atlases = {}
def _get_atlas(aid, use_label, res):
atlas = _atlases.get((aid, use_label, res), None)
if atlas is None:
atlas = fslatlases.loadAtlas(aid,
loadSummary=use_label,
resolution=res)
_atlases[aid] = atlas
return atlas
_zero_masks = {}
def _get_zero_mask(a_img, atlas, use_label, res):
# Make a mask which tells us which
# voxels in the atlas are all zeros
zmask = _zero_masks.get((atlas, use_label, res), None)
if zmask is None:
if isinstance(a_img, fslatlases.LabelAtlas):
zmask = a_img[:] == 0
elif isinstance(a_img, fslatlases.ProbabilisticAtlas):
zmask = np.all(a_img[:] == 0, axis=-1)
_zero_masks[atlas, use_label, res] = zmask
return zmask
def _gen_coord_voxel_query(atlas, use_label, q_type, q_in, res):
a_img = _get_atlas(atlas, use_label, res)
voxel = q_type == 'voxel'
if voxel: dtype = int
else: dtype = float
if q_in == 'out':
if voxel:
dlo = (0, 0, 0)
dhi = a_img.shape
else:
dlo, dhi = transform.axisBounds(a_img.shape, a_img.voxToWorldMat)
dlen = [hi - lo for lo, hi in zip(dlo, dhi)]
coords = []
for d in range(3):
# over
if np.random.random() > 0.5:
coords.append(dlo[d] + dlen[d] + dlen[d] * np.random.random())
# or under
else:
coords.append(dlo[d] - dlen[d] * np.random.random())
coords = np.array(coords, dtype=dtype)
else:
# Make a mask which tells us which
# voxels in the atlas are all zeros
zmask = _get_zero_mask(a_img, atlas, use_label, res)
# get indices to voxels which are
# either all zero, or which are
# not all all zero, depending on
# the value of q_in
if q_in == 'in': zidxs = np.where(zmask == 0)
else: zidxs = np.where(zmask)
# Randomly choose a voxel
cidx = np.random.randint(0, len(zidxs[0]))
coords = [zidxs[0][cidx], zidxs[1][cidx], zidxs[2][cidx]]
coords = np.array(coords, dtype=dtype)
if not voxel:
coords = transform.transform(coords, a_img.voxToWorldMat)
return tuple([dtype(c) for c in coords])
def _eval_coord_voxel_query(
stdout, atlas, query, use_label, q_type, q_in, res, o_type):
a_img = _get_atlas(atlas, use_label, res)
voxel = q_type == 'voxel'
prob = a_img.desc.atlasType == 'probabilistic'
x, y, z = query
if voxel: squery = '{:0.0f} {:0.0f} {:0.0f}'.format(*query)
else: squery = '{:0.2f} {:0.2f} {:0.2f}'.format(*query)
if voxel: lsquery = 'voxel {}' .format(squery)
else: lsquery = 'coordinate {}'.format(squery)
def evalLabelNormalOutput(explabel):
assert lsquery in stdout
# all label atlases have an entry for 0
if q_in == 'in' or (q_in == 'zero' and not prob):
if prob: labelobj = a_img.desc.labels[explabel - 1]
else: labelobj = a_img.desc.labels[explabel]
assert labelobj.name in stdout
assert ' {} '.format(explabel) in stdout
if prob:
assert ' {} '.format(labelobj.index) in stdout
else:
assert 'No label' in stdout
def evalLabelShortOutput(explabel):
if q_in == 'in' or (q_in == 'zero' and not prob):
if prob: labelobj = a_img.desc.labels[explabel - 1]
else: labelobj = a_img.desc.labels[explabel]
exp = [q_type, squery, labelobj.name]
else:
exp = [q_type, squery, 'No label']
_stdout = re.sub('\s+', ' ', stdout).strip()
assert _stdout == ' '.join(exp)
def evalProbNormalOutput(expprops):
assert lsquery in stdout
if q_in == 'in':
lines = stdout.split('\n')
explabels = [a_img.desc.labels[i] for i in range(len(expprops))]
for explabel, expprop in zip(explabels, expprops):
if expprop == 0:
continue
hits = [l for l in lines if explabel.name in l]
assert len(hits) == 1
line = hits[0]
assert ' {} ' .format(explabel.index) in line
assert ' {} ' .format(explabel.index + 1) in line
assert '{:0.4f}'.format(expprop) in line
else:
assert 'No results' in stdout
def evalProbShortOutput(expprops):
if q_in == 'in':
exp = [q_type, squery]
labels = [a_img.desc.labels[i].name for i in range(len(expprops))]
for expprop, explabel in reversed(sorted(zip(expprops, labels))):
if expprop == 0:
break
exp.append('{} {:0.4f}'.format(explabel, expprop))
else:
exp = [q_type, squery]
_stdout = re.sub('\s+', ' ', stdout).strip()
assert _stdout == ' '.join(exp)
if isinstance(a_img, fslatlases.LabelAtlas):
explabel = a_img.label(query, voxel=voxel)
if o_type == 'normal': evalLabelNormalOutput(explabel)
else: evalLabelShortOutput(explabel)
elif isinstance(a_img, fslatlases.ProbabilisticAtlas):
expprops = a_img.proportions(query, voxel=voxel)
if o_type == 'normal': evalProbNormalOutput(expprops)
else: evalProbShortOutput(expprops)
def _gen_mask_query(atlas, use_label, q_type, q_in, res):
maskres = int(q_type[-1])
maskfile = 'mask.nii.gz'
a_img = _get_atlas(atlas, use_label, res)
if q_in == 'out':
make_random_mask(maskfile, (20, 20, 20), np.eye(4))
else:
zmask = _get_zero_mask(a_img, atlas, use_label, res)
if q_in == 'in':
zmask = zmask == 0
mask = make_random_mask(
maskfile, a_img.shape[:3], a_img.voxToWorldMat, zmask)
mask.save('/Users/paulmc/Projects/fsl-atlasq/mask_before_resample.nii.gz')
mask.save('/Users/paulmc/Projects/fsl-atlasq/mask.nii.gz')
fslimage.Image(np.array(zmask, dtype=np.uint8), xform=a_img.voxToWorldMat).save('/Users/paulmc/Projects/fsl-atlasq/zmask.nii.gz')
if maskres != res:
zmask = ndi.binary_erosion(zmask, iterations=3)
fslimage.Image(np.array(zmask, dtype=np.uint8), xform=a_img.voxToWorldMat).save('/Users/paulmc/Projects/fsl-atlasq/zmask.nii.gz')
mask[zmask == 0] = 0
a = _get_atlas(atlas, True, maskres)
# Make sure that when the mask gets
# resampled into the atlas resolution,
# it is still either in or out of the
# atlas space
mask, xform = mask.resample(a