Commit e9c81ad6 authored by Paul McCarthy's avatar Paul McCarthy 🚵
Browse files

Merge branch 'enh/statlas' into 'master'

Statistic atlases

See merge request fsl/fslpy!160
parents 41ad8b85 bc1b4a0d
Pipeline #4317 passed with stages
in 18 minutes and 23 seconds
......@@ -11,6 +11,7 @@ Added
* New :meth:`.Image.iscomplex` attribute.
* Support for a new ``Statistic`` atlas type.
Changed
......@@ -21,6 +22,20 @@ Changed
as a least-recently-used cache.
* The :mod:`.filetree` module has been refactored to make it easier for the
:mod:`.query` module to work with file tree hierarchies.
* The :meth:`.LabelAtlas.get` method has a new ``binary`` flag, allowing
either a binary mask, or a mask with the original label value, to be
returned.
Deprecated
^^^^^^^^^^
* :meth:`.ProbabilisticAtlas.proportions`,
:meth:`.ProbabilisticAtlas.maskProportions`, and
:meth:`.ProbabilisticAtlas.labelProportions` have been deprecated in favour
of :meth:`.StatisticAtlas.values`, :meth:`.StatisticAtlas.maskValues`, and
:meth:`.StatisticAtlas.labelValues`
2.5.0 (Tuesday 6th August 2019)
......
......@@ -36,27 +36,29 @@ load an atlas image, which will be one of the following atlas-specific
:nosignatures:
LabelAtlas
StatisticAtlas
ProbabilisticAtlas
"""
from __future__ import division
import xml.etree.ElementTree as et
import os.path as op
import glob
import bisect
import logging
import xml.etree.ElementTree as et
import os.path as op
import glob
import bisect
import logging
import numpy as np
import numpy as np
import fsl.data.image as fslimage
import fsl.data.constants as constants
from fsl.utils.platform import platform as platform
import fsl.utils.image.resample as resample
import fsl.transform.affine as affine
import fsl.utils.notifier as notifier
import fsl.utils.settings as fslsettings
import fsl.data.image as fslimage
import fsl.data.constants as constants
from fsl.utils.platform import platform
import fsl.utils.image.resample as resample
import fsl.transform.affine as affine
import fsl.utils.notifier as notifier
import fsl.utils.settings as fslsettings
import fsl.utils.deprecated as deprecated
log = logging.getLogger(__name__)
......@@ -322,9 +324,9 @@ class AtlasLabel(object):
========= ================================================================
``name`` Region name
``index`` The index of this label into the list of all labels in the
``AtlasDescription`` that owns it. For probabilistic atlases,
this is also the index into the 4D atlas image of the volume
that corresponds to this region.
``AtlasDescription`` that owns it. For statistic/probabilistic
atlases, this is also the index into the 4D atlas image of the
volume that corresponds to this region.
``value`` For label atlases and summary images, the value of voxels that
are in this region.
``x`` X coordinate of the region in world space
......@@ -386,8 +388,13 @@ class AtlasDescription(object):
<atlas>
<header>
<name></name> # Atlas name
<type></type> # 'Probabilistic' or 'Label'
<name></name> # Atlas name
<type></type> # 'Statistic', 'Probabilistic' or 'Label'
<statistic></statistic> # Optional. Type of statistic
<units></units> # Optional. Units of measurement
<precision></precision> # Optional. Decimal precision to report
<upper></upper> # Optional. Upper threshold
<lower></lower> # Optional. Lower threshold
<images>
<imagefile>
</imagefile> # If type is Probabilistic, path
......@@ -412,11 +419,12 @@ class AtlasDescription(object):
</header>
<data>
# index - For probabilistic atlases, index of corresponding volume in
# 4D image file. For label images, the value of voxels which
# are in the corresponding region. For probabilistic atlases,
# it is assumed that the value for each region in the summary
# image(s) are equal to ``index + 1``.
# index - For statistic/probabilistic atlases, index of corresponding
# volume in 4D image file. For label images, the value of
# voxels which are in the corresponding region. For
# statistic/probabilistic atlases, it is assumed that the
# value for each region in the summary image(s) are equal to
# ``index + 1``.
#
#
# x |
......@@ -452,7 +460,18 @@ class AtlasDescription(object):
``specPath`` Path to the atlas XML specification file.
``atlasType`` Atlas type - either *probabilistic* or *label*.
``atlasType`` Atlas type - either *statistic*, *probabilistic* or
*label*.
``statistic`` Type of statistic, for statistic atlases.
``units`` Unit of measurement, for statistic atlases.
``precision`` Reporting precision, for statistic atlases.
``upper`` Upper threshold, for statistic atlases.
``lower`` Lower threshold, for statistic atlases.
``images`` A list of images available for this atlas - usually
:math:`1mm^3` and :math:`2mm^3` images are present.
......@@ -500,6 +519,29 @@ class AtlasDescription(object):
if self.atlasType == 'probabalistic':
self.atlasType = 'probabilistic'
if self.atlasType == 'statistic':
fields = ['statistic', 'units', 'lower', 'upper', 'precision']
values = {}
for field in fields:
elem = header.find(field)
if elem is not None and elem.text is not None:
values[field] = elem.text.strip()
self.statistic = values.get('statistic', '')
self.units = values.get('units', '')
self.lower = float(values.get('lower', 0))
self.upper = float(values.get('upper', 100))
self.precision = int( values.get('precision', 2))
elif self.atlasType == 'probabilistic':
self.statistic = ''
self.units = '%'
self.lower = 5
self.upper = 100
self.precision = 0
images = header.findall('images')
self.images = []
self.summaryImages = []
......@@ -661,7 +703,7 @@ class Atlas(fslimage.Image):
:arg resolution: Desired isotropic resolution in millimetres.
:arg isLabel: Pass in ``True`` for label atlases, ``False`` for
probabilistic atlases.
statistic/probabilistic atlases.
All other arguments are passed to :meth:`.Image.__init__`.
"""
......@@ -708,7 +750,7 @@ class Atlas(fslimage.Image):
"""Makes sure that the given mask has the same resolution as this
atlas, so it can be used for querying. Used by the
:meth:`.LabelAtlas.maskLabels` and
:meth:`.ProbabilisticAtlas.maskProportions` methods.
:meth:`.StatisticAtlas.maskValues` methods.
:arg mask: A :class:`.Image`
......@@ -738,13 +780,11 @@ class Atlas(fslimage.Image):
return mask
class MaskError(Exception):
"""Exception raised by the :meth:`LabelAtlas.maskLabel` and
:meth:`ProbabilisticAtlas.maskProportions` when a mask is provided which
:meth:`StatisticAtlas.maskValues` when a mask is provided which
does not match the atlas space.
"""
pass
class LabelAtlas(Atlas):
......@@ -877,17 +917,20 @@ class LabelAtlas(Atlas):
return values, props
def get(self, label=None, index=None, value=None, name=None):
"""
Returns the binary image for given label
def get(self, label=None, index=None, value=None, name=None, binary=True):
"""Returns the binary image for the given label.
Only one of the arguments should be used to define the label
:arg label: AtlasLabel contained within this atlas
:arg index: index of the label
:arg value: value of the label
:arg name: string of the label
:return: image.Image with the mask
:arg label: :class:`AtlasLabel` contained within this atlas
:arg index: index of the label
:arg value: value of the label
:arg name: string of the label
:arg binary: If ``True`` (the default), the image will contain 1s in
the label region. Otherwise the image will contain the
label value.
:return: :class:`.Image` with the mask
"""
if ((label is not None) + (index is not None) +
(value is not None) + (name is not None)) != 1:
......@@ -896,19 +939,27 @@ class LabelAtlas(Atlas):
label = self.find(index=index, name=name, value=value)
elif label not in self.desc.labels:
raise ValueError("Unknown label provided")
arr = (self.data == label.value).astype(int)
arr = (self.data == label.value).astype(np.int32)
if not binary:
arr[arr > 0] = label.value
return fslimage.Image(arr, name=label.name, header=self.header)
class ProbabilisticAtlas(Atlas):
"""A 4D atlas which contains one volume for each region.
class StatisticAtlas(Atlas):
"""A ``StatisticAtlas`` is a 4D image which contains one volume for
each region in the atlas; each volume contains some statistic value
for the corresponding region.
The ``ProbabilisticAtlas`` provides the :meth`proportions` method,
which makes looking up region probabilities easy.
The :class:`ProbabilisticAtlas` is a specialisation of the
``StatisticAtlas``
"""
def __init__(self, atlasDesc, resolution=None, **kwargs):
"""Create a ``ProbabilisticAtlas`` instance.
"""Create a ``StatisticAtlas`` instance.
:arg atlasDesc: The :class:`AtlasDescription` instance describing
the atlas.
......@@ -917,17 +968,18 @@ class ProbabilisticAtlas(Atlas):
"""
Atlas.__init__(self, atlasDesc, resolution, False, **kwargs)
def get(self, label=None, index=None, value=None, name=None):
"""
Returns the probabilistic image for given label
"""Returns the statistic image at the given label.
Only one of the arguments should be used to define the label
:arg label: AtlasLabel contained within this atlas
:arg label: :class:`AtlasLabel` contained within this atlas
:arg index: index of the label
:arg value: value of the label
:arg name: string of the label
:return: image.Image with the probabilistic mask
:arg name: string of the label
:return: :class:`.Image` with the statistic values for the
specified label.
"""
if ((label is not None) + (index is not None) +
(value is not None) + (name is not None)) != 1:
......@@ -939,36 +991,37 @@ class ProbabilisticAtlas(Atlas):
arr = self[..., label.index]
return fslimage.Image(arr, name=label.name, header=self.header)
def proportions(self, location, *args, **kwargs):
"""Looks up and returns the proportions of of all regions at the given
def values(self, location, *args, **kwargs):
"""Looks up and returns the values of of all regions at the given
location.
:arg location: Can be one of the following:
- A sequence of three values, interpreted as atlas
coordinates. In this case, :meth:`coordProportions`
coordinates. In this case, :meth:`coordValues`
is called.
- An :class:`.Image` which is interpreted as a
weighted mask. In this case, :meth:`maskProportions`
weighted mask. In this case, :meth:`maskValues`
is called.
All other arguments are passed through to the :meth:`coordProportions`
or :meth:`maskProportions` methods.
All other arguments are passed through to the :meth:`coordValues`
or :meth:`maskValues` methods.
:returns: The return value of either :meth:`coordProportions` or
:meth:`maskProportions`.
:returns: The return value of either :meth:`coordValues` or
:meth:`maskValues`.
"""
if isinstance(location, fslimage.Image):
return self.maskProportions(location, *args, **kwargs)
return self.maskValues(location, *args, **kwargs)
else:
return self.coordProportions(location, *args, **kwargs)
return self.coordValues(location, *args, **kwargs)
def coordProportions(self, loc, voxel=False):
"""Looks up the region probabilities for the given location.
def coordValues(self, loc, voxel=False):
"""Looks up the region values for the given location.
:arg loc: A sequence of three values, interpreted as atlas
world or voxel coordinates.
......@@ -976,10 +1029,8 @@ class ProbabilisticAtlas(Atlas):
:arg voxel: Defaults to ``False``. If ``True``, the ``loc``
argument is interpreted as voxel coordinates.
:returns: a list of values, one per region, which represent
the probability of each region for the specified
location. Returns an empty list if the given
location is out of bounds.
:returns: a list of values, one per region. Returns an empty
list if the given location is out of bounds.
"""
if not voxel:
......@@ -994,30 +1045,27 @@ class ProbabilisticAtlas(Atlas):
loc[2] >= self.shape[2]:
return []
props = self[loc[0], loc[1], loc[2], :]
vals = self[loc[0], loc[1], loc[2], :]
# We only return labels for this atlas -
# the underlying image may have more
# volumes than this atlas has labels.
return [props[l.index] for l in self.desc.labels]
return [vals[l.index] for l in self.desc.labels]
def maskProportions(self, mask):
"""Looks up the probabilities of all regions in the given ``mask``.
def maskValues(self, mask):
"""Looks up the average values of all regions in the given ``mask``.
:arg mask: A 3D :class:`.Image`` which is interpreted as a weighted
mask. If the ``mask`` shape does not match that of this
``ProbabilisticAtlas``, it is resampled using
:meth:`.Image.resample`, with nearest-neighbour
interpolation.
``StatisticAtlas``, it is resampled using
:meth:`Atlas.prepareMask`.
:returns: A sequence containing the proportion, within the mask,
of all regions in the atlas. The proportions are returned as
values between 0 and 100.
:returns: A sequence containing the average value, within the mask,
of all regions in the atlas.
"""
props = []
avgvals = []
mask = self.prepareMask(mask)
boolmask = mask > 0
weights = mask[boolmask]
......@@ -1030,11 +1078,35 @@ class ProbabilisticAtlas(Atlas):
vals = self[..., label.index]
vals = vals[boolmask] * weights
prop = vals.sum() / weightsum
val = vals.sum() / weightsum
avgvals.append(val)
props.append(prop)
return avgvals
return props
@deprecated.deprecated('2.6.0', '3.0.0', 'Use values instead')
def proportions(self, *args, **kwargs):
"""Deprecated - use :meth:`values` instead. """
return self.values(*args, **kwargs)
@deprecated.deprecated('2.6.0', '3.0.0', 'Use coordValues instead')
def coordProportions(self, *args, **kwargs):
"""Deprecated - use :meth:`coordValues` instead. """
return self.coordValues(*args, **kwargs)
@deprecated.deprecated('2.6.0', '3.0.0', 'Use maskValues instead')
def maskProportions(self, *args, **kwargs):
"""Deprecated - use :meth:`maskValues` instead. """
return self.maskValues(*args, **kwargs)
class ProbabilisticAtlas(StatisticAtlas):
"""A 4D atlas which contains one volume for each region. Each volume
contains probabiliy values for one region, between 0 and 100.
"""
registry = AtlasRegistry()
......
......@@ -381,7 +381,7 @@ def maskQuery(atlas, masks, *args, **kwargs):
labels = []
props = []
zprops = atlas.maskProportions(mask)
zprops = atlas.maskValues(mask)
for i in range(len(zprops)):
if zprops[i] > 0:
......@@ -405,7 +405,7 @@ def coordQuery(atlas, coords, voxel, *args, **kwargs):
if isinstance(atlas, fslatlases.ProbabilisticAtlas):
props = atlas.proportions(coord, voxel=voxel)
props = atlas.values(coord, voxel=voxel)
labels = []
nzprops = []
......
......@@ -40,7 +40,8 @@ dummy_atlas_desc = """<?xml version="1.0" encoding="ISO-8859-1"?>
<header>
<name>{name}</name>
<shortname>{shortname}</shortname>
<type>Label</type>
<type>{atlastype}</type>
{extraheader}
<images>
<imagefile>/{shortname}/{filename}</imagefile>
<summaryimagefile>/{shortname}/My{filename}</summaryimagefile>
......@@ -52,7 +53,8 @@ dummy_atlas_desc = """<?xml version="1.0" encoding="ISO-8859-1"?>
</data>
</atlas>
"""
def _make_dummy_atlas(savedir, name, shortName, filename):
def _make_dummy_atlas(
savedir, name, shortName, filename, atlastype='Label', extraheader=''):
mladir = op.join(savedir, shortName)
mlaxmlfile = op.join(savedir, '{}.xml'.format(shortName))
mlaimgfile = op.join(savedir, shortName, '{}.nii.gz'.format(filename))
......@@ -70,7 +72,9 @@ def _make_dummy_atlas(savedir, name, shortName, filename):
desc = dummy_atlas_desc.format(
name=name,
shortname=shortName,
filename=filename)
filename=filename,
atlastype=atlastype,
extraheader=extraheader)
f.write(desc)
return mlaxmlfile
......@@ -142,6 +146,28 @@ def test_AtlasDescription():
registry.getAtlasDescription('non-existent-atlas')
def test_StatisticHeader():
with tests.testdir() as testdir:
hdr = '<statistic>T</statistic>' \
'<units></units>' \
'<precision>3</precision>' \
'<upper>75</upper>'
xmlfile = _make_dummy_atlas(testdir,
'statlas',
'STA',
'StAtlas',
atlastype='Statistic',
extraheader=hdr)
desc = atlases.AtlasDescription(xmlfile, 'StAtlas')
assert desc.atlasType == 'statistic'
assert desc.statistic == 'T'
assert desc.units == ''
assert desc.precision == 3
assert desc.lower == 0
assert desc.upper == 75
def test_add_remove_atlas():
with tests.testdir() as testdir:
......@@ -250,6 +276,9 @@ def test_get():
assert (target == atlas.get(index=label.index).data).all()
assert (target == atlas.get(value=label.value).data).all()
assert (target == atlas.get(name=label.name).data).all()
if atlas is lblatlas:
target = target * label.value
assert (target == atlas.get(value=label.value, binary=False).data).all()
def test_find():
......
......@@ -218,8 +218,8 @@ def _eval_coord_voxel_query(atlas, query, qtype, qin):
elif qin == 'out':
expval = []
assert atlas.proportions( query, voxel=voxel) == expval
assert atlas.coordProportions(query, voxel=voxel) == expval
assert atlas.values( query, voxel=voxel) == expval
assert atlas.coordValues(query, voxel=voxel) == expval
if isinstance(atlas, fslatlases.LabelAtlas): evalLabel()
elif isinstance(atlas, fslatlases.ProbabilisticAtlas): evalProb()
......@@ -343,13 +343,13 @@ def _eval_mask_query(atlas, query, qtype, qin):
if qin == 'out':
with pytest.raises(fslatlases.MaskError):
atlas.maskProportions(mask)
atlas.maskValues(mask)
with pytest.raises(fslatlases.MaskError):
atlas.proportions( mask)
atlas.values( mask)
return
props = atlas. proportions(mask)
props2 = atlas.maskProportions(mask)
props = atlas. values(mask)
props2 = atlas.maskValues(mask)
assert np.all(np.isclose(props, props2))
......
......@@ -51,7 +51,7 @@ def test_coords(seed):
"""Test the ohi -a "atlas" -c "coords" mode. """
def expectedProbOutput(atlas, coords):
probs = atlas.proportions(coords)
probs = atlas.values(coords)
expected = '<b>{}</b><br>'.format(atlas.desc.name)
nzprobs = []
......@@ -161,7 +161,7 @@ def test_mask(seed):
def expectedProbOutput(mask, atlas):
props = atlas.maskProportions(mask)
props = atlas.maskValues(mask)
labels = [l.index for l in atlas.desc.labels]
exp = []
......
......@@ -294,7 +294,7 @@ def _eval_coord_voxel_query(
if o_type == 'normal': evalLabelNormalOutput(explabel)
else: evalLabelShortOutput(explabel)
elif isinstance(a_img, fslatlases.ProbabilisticAtlas):
expprops = a_img.proportions(query, voxel=voxel)
expprops = a_img.values(query, voxel=voxel)
if o_type == 'normal': evalProbNormalOutput(expprops)
else: evalProbShortOutput(expprops)
......@@ -400,7 +400,7 @@ def _eval_mask_query(
explabels, expprops = [], []
elif isinstance(aimg, fslatlases.ProbabilisticAtlas):
try:
expprops = aimg.maskProportions(maskimg)
expprops = aimg.maskValues(maskimg)
explabels = aimg.desc.labels
except fslatlases.MaskError:
explabels = []
......
......@@ -132,7 +132,7 @@ def test_convertDeformationType():
gotconvrel2 = nonlinear.convertDeformationType(relfield, 'absolute')
gotconvabs2 = nonlinear.convertDeformationType(absfield, 'relative')
tol = dict(atol=1e-5, rtol=1e-5)
tol = dict(atol=1e-3, rtol=1e-3)
assert np.all(np.isclose(gotconvrel1, absfield.data, **tol))
assert np.all(np.isclose(gotconvabs1, relfield.data, **tol))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment