Skip to content
Snippets Groups Projects
Forked from FSL / fslpy
1224 commits behind the upstream repository.
test_atlases_query.py 10.53 KiB
#!/usr/bin/env python
#
# test_atlases_query.py -
#
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#

import                    os
import itertools       as it
import numpy           as np
import                    pytest


import fsl.data.atlases         as fslatlases
import fsl.data.image           as fslimage
import fsl.transform            as transform
import fsl.utils.image.resample as resample
import fsl.utils.cache          as cache

from . import (testdir, make_random_mask)


pytestmark = pytest.mark.fsltest


def setup_module():
    if os.environ.get('FSLDIR', None) is None:
        raise Exception('FSLDIR is not set - atlas tests cannot be run')
    fslatlases.rescanAtlases()


# why this is not built into
# in itertools i don't even
def _repeat(iterator, n):
    for elem in iterator:
        for i in range(n):
            yield elem


_atlases = cache.Cache()
def _get_atlas(atlasID, res, summary=False):
    atlas = _atlases.get((atlasID, res, summary), default=None)
    if atlas is None:
        if summary or atlasID in ('talairach', 'striatum-structural',
                                  'jhu-labels', 'smatt'):
            kwargs = {}
        else:
            kwargs = {'loadData'  : False,
                      'calcRange' : False}

        atlas = fslatlases.loadAtlas(atlasID,
                                     loadSummary=summary,
                                     resolution=res,
                                     **kwargs)
        _atlases.put((atlasID, res, summary), atlas)

    return atlas

def _random_atlas(atype, res, summary=False):

    if atype == 'prob':
        atype = 'probabilistic'

    atlases = fslatlases.listAtlases()
    atlases = [a for a in atlases if a.atlasType == atype]
    desc    = atlases[np.random.randint(0, len(atlases))]
    return _get_atlas(desc.atlasID, res, summary)


# Generate a mask which tells us which
# voxels in the atlas are all zeros
_zero_masks = cache.Cache(maxsize=5)
def _get_zero_mask(aimg):

    atlasID = aimg.desc.atlasID
    res     = aimg.pixdim[0]
    summary = isinstance(aimg, fslatlases.LabelAtlas) \
              and aimg.desc.atlasType == 'probabilistic'

    zmask = _zero_masks.get((atlasID, summary, res), None)

    if zmask is None:
        if isinstance(aimg, fslatlases.LabelAtlas):
            zmask = aimg[:] == 0
        elif isinstance(aimg, fslatlases.ProbabilisticAtlas):

            # Keep memory usage down
            zmask = np.ones(aimg.shape[:3], dtype=np.bool)
            for vol in range(aimg.shape[-1]):
                zmask = np.logical_and(zmask, aimg[..., vol] == 0)

        _zero_masks[atlasID, summary, res] = zmask

    return zmask


def test_label_coord_query(  seed): _test_query('coord', 'label')
def test_label_voxel_query(  seed): _test_query('voxel', 'label')
@pytest.mark.longtest
def test_label_mask_query(   seed): _test_query('mask',  'label')
def test_summary_coord_query(seed): _test_query('coord', 'prob', summary=True)
def test_summary_voxel_query(seed): _test_query('voxel', 'prob', summary=True)
@pytest.mark.longtest
def test_summary_mask_query( seed): _test_query('mask',  'prob', summary=True)
def test_prob_coord_query(   seed): _test_query('coord', 'prob')
def test_prob_voxel_query(   seed): _test_query('voxel', 'prob')
@pytest.mark.longtest
def test_prob_mask_query(    seed): _test_query('mask',  'prob')


# qtype: (voxel|coord|mask)
# atype: (label|prob)
def _test_query(qtype, atype, summary=False):

    qins  = ['in', 'zero', 'out']
    reses = [1, 2]

    if qtype == 'mask': maskreses = [1, 2]
    else:               maskreses = [1]

    tests = _repeat(it.product(qins, reses, maskreses), 5)

    for qin, res, maskres in tests:

        atlas = _random_atlas(atype, res=res, summary=summary)

        with testdir():

            if qtype in ('voxel', 'coord'):
                genfunc  = _gen_coord_voxel_query
                evalfunc = _eval_coord_voxel_query
            else:
                genfunc  = _gen_mask_query
                evalfunc = _eval_mask_query

            print('Test: {} {}mm type={} in={}'.format(
                atlas.desc.atlasID, res, qtype, qin))

            query = genfunc(atlas, qtype, qin, maskres=maskres)
            evalfunc(atlas, query, qtype, qin)


# Generate a random voxel/world space
# coordinate to query the given atlas.
def _gen_coord_voxel_query(atlas, qtype, qin, **kwargs):

    voxel = qtype == 'voxel'

    if voxel: dtype = int
    else:     dtype = float

    if qin == 'out':

        if voxel:
            dlo = (0, 0, 0)
            dhi = atlas.shape
        else:
            dlo, dhi = transform.axisBounds(atlas.shape, atlas.voxToWorldMat)

        dlen = [hi - lo for lo, hi in zip(dlo, dhi)]

        coords = []
        for d in range(3):

            # over
            if np.random.random() > 0.5:
                coords.append(dlo[d] + dlen[d] + dlen[d] * np.random.random())
            # or under
            else:
                coords.append(dlo[d] - dlen[d] * np.random.random())

        coords = np.array(coords, dtype=dtype)

    else:

        # Make a mask which tells us which
        # voxels in the atlas are all zeros
        zmask = _get_zero_mask(atlas)

        # get indices to voxels which are
        # either all zero, or which are
        # not all all zero, depending on
        # the value of q_in
        if qin == 'in': zidxs = np.where(zmask == 0)
        else:           zidxs = np.where(zmask)

        # Randomly choose a voxel
        cidx   = np.random.randint(0, len(zidxs[0]))
        coords = [zidxs[0][cidx], zidxs[1][cidx], zidxs[2][cidx]]
        coords = np.array(coords, dtype=dtype)

        if not voxel:
            coords = transform.transform(coords, atlas.voxToWorldMat)

    return tuple([dtype(c) for c in coords])


def _eval_coord_voxel_query(atlas, query, qtype, qin):

    voxel = qtype == 'voxel'

    if voxel: vx, vy, vz = query
    else:     vx, vy, vz = transform.transform(query, atlas.worldToVoxMat)

    vx, vy, vz = [int(round(v)) for v in [vx, vy, vz]]

    def evalLabel():
        if qin in ('in', 'zero'): expval = atlas[vx, vy, vz]
        else:                     expval = None

        assert atlas.label(     query, voxel=voxel) == expval
        assert atlas.coordLabel(query, voxel=voxel) == expval

    def evalProb():
        if qin in ('in', 'zero'):
            expval = atlas[vx, vy, vz, :]
            expval = [expval[l.index] for l in atlas.desc.labels]
        elif qin == 'out':
            expval = []

        assert atlas.proportions(     query, voxel=voxel) == expval
        assert atlas.coordProportions(query, voxel=voxel) == expval

    if   isinstance(atlas, fslatlases.LabelAtlas):         evalLabel()
    elif isinstance(atlas, fslatlases.ProbabilisticAtlas): evalProb()


def _gen_mask_query(atlas, qtype, qin, maskres):

    maskfile = 'mask.nii.gz'
    res      = atlas.pixdim[0]

    if qin == 'out':
        make_random_mask(maskfile, (20, 20, 20), np.eye(4))
        return maskfile

    zmask = _get_zero_mask(atlas)

    if qin == 'in':
        zmask = zmask == 0

    mask = make_random_mask(
        maskfile, atlas.shape[:3], atlas.voxToWorldMat, zmask)

    # Make sure that when the mask gets
    # resampled into the atlas resolution,
    # it is still either in or out of the
    # atlas space
    if maskres != res:
        a       = _get_atlas(atlas.desc.atlasID, maskres, True)
        a_zmask = _get_zero_mask(a)

        if qin == 'in':
            a_zmask = a_zmask == 0

        # use linear interp and threshold
        # aggresively to make sure there
        # is no overlap between the different
        # resolutions
        mask, xform = resample.resample(
            mask, a.shape[:3], dtype=np.float32, order=1)

        mask[mask   < 1.0] = 0
        mask[a_zmask == 0] = 0

        mask = np.array(mask, dtype=np.uint8)
        mask = fslimage.Image(mask, xform=xform)

        mask.save(maskfile)

    return maskfile


def _eval_mask_query(atlas, query, qtype, qin):

    mask    = fslimage.Image(query)
    prob    = atlas.desc.atlasType == 'probabilistic'
    maskres = mask .pixdim[0]
    res     = atlas.pixdim[0]

    if maskres == res:
        rmask = mask[:]
    else:
        rmask = resample.resample(
            mask, atlas.shape[:3], dtype=np.float32, order=0)[0]

    rmask = np.array(rmask, dtype=np.bool)

    def evalLabel():

        if qin == 'out':
            with pytest.raises(fslatlases.MaskError): atlas.maskLabel(mask)
            with pytest.raises(fslatlases.MaskError): atlas.label(    mask)
            return

        if qin == 'in':

            voxels    = np.array(np.where(rmask)).T
            maxval    = int(atlas[:].max())
            valcounts = np.zeros((maxval + 1, ))
            nvoxels   = voxels.shape[0]

            for x, y, z in voxels:
                x, y, z = [int(v) for v in [x, y, z]]
                valcounts[int(atlas[x, y, z])] += 1.0

            # make sure the values are sorted
            # according to their atlas ordering
            expvals   = np.where(valcounts > 0)[0]
            explabels = []

            # There may be more values in an image
            # than are listed in the atlas spec :(
            for v in expvals:
                try:             explabels.append(atlas.find(value=int(v)))
                except KeyError: pass
            explabels = list(sorted(explabels))
            expvals   = [l.value for l in explabels]
            expprops  = [100 * valcounts[v] / nvoxels for v in expvals]

        else:
            if prob:
                expvals  = []
                expprops = []
            else:
                allvals = [l.value for l in atlas.desc.labels]
                if 0 in allvals:
                    expvals  = [0]
                    expprops = [100]
                else:
                    expvals  = []
                    expprops = []

        vals,  props  = atlas.    label(mask)
        vals2, props2 = atlas.maskLabel(mask)

        assert np.all(np.isclose(vals,  vals2))
        assert np.all(np.isclose(props, props2))
        assert np.all(np.isclose(vals,  expvals))
        assert np.all(np.isclose(props, expprops))

    def evalProb():

        if qin == 'out':
            with pytest.raises(fslatlases.MaskError):
                atlas.maskProportions(mask)
            with pytest.raises(fslatlases.MaskError):
                atlas.proportions(    mask)
            return

        props  = atlas.    proportions(mask)
        props2 = atlas.maskProportions(mask)

        assert np.all(np.isclose(props, props2))

    if   isinstance(atlas, fslatlases.LabelAtlas):         evalLabel()
    elif isinstance(atlas, fslatlases.ProbabilisticAtlas): evalProb()