Skip to content
Snippets Groups Projects
Commit 430bba8b authored by Paul McCarthy's avatar Paul McCarthy :mountain_bicyclist:
Browse files

Merge branch 'enh_atlases' into 'master'

Enh atlases

See merge request fsl/fslpy!126
parents e2d72e42 3af1f611
No related branches found
No related tags found
No related merge requests found
Pipeline #3745 passed
...@@ -366,6 +366,15 @@ class AtlasLabel(object): ...@@ -366,6 +366,15 @@ class AtlasLabel(object):
""" """
return self.index < other.index return self.index < other.index
def __repr__(self):
"""
Represents AtlasLabel as string
"""
return '{}({}, index={}, value={})'.format(
self.__class__.__name__, self.name,
self.index, self.value,
)
class AtlasDescription(object): class AtlasDescription(object):
"""An ``AtlasDescription`` instance parses and stores the information """An ``AtlasDescription`` instance parses and stores the information
...@@ -574,11 +583,11 @@ class AtlasDescription(object): ...@@ -574,11 +583,11 @@ class AtlasDescription(object):
self.labels = list(sorted(self.labels)) self.labels = list(sorted(self.labels))
def find(self, index=None, value=None): def find(self, index=None, value=None, name=None):
"""Find an :class:`.AtlasLabel` either by ``index``, or by ``value``. """Find an :class:`.AtlasLabel` either by ``index``, or by ``value``.
Exactly one of ``index`` or ``value`` may be specified - a Exactly one of ``index``, ``value``, or ``name`` may be specified - a
``ValueError`` is raised otherwise. If an invalid ``index`` or ``ValueError`` is raised otherwise. If an invalid ``index``, ``name``, or
``value`` is specified, an ``IndexError`` or ``KeyError`` will be ``value`` is specified, an ``IndexError`` or ``KeyError`` will be
raised. raised.
...@@ -586,12 +595,25 @@ class AtlasDescription(object): ...@@ -586,12 +595,25 @@ class AtlasDescription(object):
labels, and a 3D ``LabelAtlas`` may have more values labels, and a 3D ``LabelAtlas`` may have more values
than labels. than labels.
""" """
if (index is None and value is None) or \ if ((index is not None) + (value is not None) + (name is not None)) != 1:
(index is not None and value is not None): raise ValueError('Only one of index, value, or name may be specified')
raise ValueError('Only one of index or value may be specified') if index is not None: return self.labels[ index]
elif value is not None: return self.__labelsByValue[int(value)]
else:
matches = [l for l in self.labels if l.name == name]
if len(matches) == 0:
# look for partial matches only if there are no full matches
matches = [l for l in self.labels if l.name[:len(name)] == name]
if len(matches) == 0:
raise IndexError('No match for {} found in labels {}'.format(
name, tuple(l.name for l in self.labels)
))
elif len(matches) > 1:
raise IndexError('Multiple matches for {} found in labels {}'.format(
name, tuple(l.name for l in self.labels)
))
return matches[0]
if index is not None: return self.labels[ index]
else: return self.__labelsByValue[int(value)]
def __eq__(self, other): def __eq__(self, other):
...@@ -612,6 +634,12 @@ class AtlasDescription(object): ...@@ -612,6 +634,12 @@ class AtlasDescription(object):
""" """
return self.name.lower() < other.name.lower() return self.name.lower() < other.name.lower()
def __repr__(self, ):
"""
String representation of AtlasDescription
"""
return '{}({})'.format(self.__class__.__name__, self.atlasID)
class Atlas(fslimage.Image): class Atlas(fslimage.Image):
"""This is the base class for the :class:`LabelAtlas` and """This is the base class for the :class:`LabelAtlas` and
...@@ -849,6 +877,28 @@ class LabelAtlas(Atlas): ...@@ -849,6 +877,28 @@ class LabelAtlas(Atlas):
return values, props return values, props
def get(self, label=None, index=None, value=None, name=None):
"""
Returns the binary image for given label
Only one of the arguments should be used to define the label
:arg label: AtlasLabel contained within this atlas
:arg index: index of the label
:arg value: value of the label
:arg name: string of the label
:return: image.Image with the mask
"""
if ((label is not None) + (index is not None) +
(value is not None) + (name is not None)) != 1:
raise ValueError('Only one of label, index, value, or name may be specified')
if label is None:
label = self.find(index=index, name=name, value=value)
elif label not in self.desc.labels:
raise ValueError("Unknown label provided")
arr = (self.data == label.value).astype(int)
return fslimage.Image(arr, name=label.name, header=self.header)
class ProbabilisticAtlas(Atlas): class ProbabilisticAtlas(Atlas):
"""A 4D atlas which contains one volume for each region. """A 4D atlas which contains one volume for each region.
...@@ -867,6 +917,27 @@ class ProbabilisticAtlas(Atlas): ...@@ -867,6 +917,27 @@ class ProbabilisticAtlas(Atlas):
""" """
Atlas.__init__(self, atlasDesc, resolution, False, **kwargs) Atlas.__init__(self, atlasDesc, resolution, False, **kwargs)
def get(self, label=None, index=None, value=None, name=None):
"""
Returns the probabilistic image for given label
Only one of the arguments should be used to define the label
:arg label: AtlasLabel contained within this atlas
:arg index: index of the label
:arg value: value of the label
:arg name: string of the label
:return: image.Image with the probabilistic mask
"""
if ((label is not None) + (index is not None) +
(value is not None) + (name is not None)) != 1:
raise ValueError('Only one of label, index, value, or name may be specified')
if label is None:
label = self.find(index=index, value=value, name=name)
elif label not in self.desc.labels:
raise ValueError("Unknown label provided")
arr = self[..., label.index]
return fslimage.Image(arr, name=label.name, header=self.header)
def proportions(self, location, *args, **kwargs): def proportions(self, location, *args, **kwargs):
"""Looks up and returns the proportions of of all regions at the given """Looks up and returns the proportions of of all regions at the given
......
basename=fdt basename=fdt_paths
probtrackx.log (log_cmd) probtrackx.log (log_cmd)
{basename}.log (log_settings) {basename}.log (log_settings)
......
...@@ -99,6 +99,8 @@ def test_AtlasDescription(): ...@@ -99,6 +99,8 @@ def test_AtlasDescription():
tal = registry.getAtlasDescription('talairach') tal = registry.getAtlasDescription('talairach')
cort = registry.getAtlasDescription('harvardoxford-cortical') cort = registry.getAtlasDescription('harvardoxford-cortical')
assert str(tal) == 'AtlasDescription(talairach)'
assert str(cort) == 'AtlasDescription(harvardoxford-cortical)'
assert tal.atlasID == 'talairach' assert tal.atlasID == 'talairach'
assert tal.name == 'Talairach Daemon Labels' assert tal.name == 'Talairach Daemon Labels'
...@@ -140,8 +142,6 @@ def test_AtlasDescription(): ...@@ -140,8 +142,6 @@ def test_AtlasDescription():
registry.getAtlasDescription('non-existent-atlas') registry.getAtlasDescription('non-existent-atlas')
def test_add_remove_atlas(): def test_add_remove_atlas():
with tests.testdir() as testdir: with tests.testdir() as testdir:
...@@ -235,6 +235,23 @@ def test_load_atlas(): ...@@ -235,6 +235,23 @@ def test_load_atlas():
assert isinstance(lblatlas, atlases.LabelAtlas) assert isinstance(lblatlas, atlases.LabelAtlas)
def test_get():
reg = atlases.registry
reg.rescanAtlases()
probatlas = reg.loadAtlas('harvardoxford-cortical')
lblatlas = reg.loadAtlas('talairach')
for atlas in (probatlas, lblatlas):
for idx, label in enumerate(atlas.desc.labels[:10]):
target = probatlas[..., idx] if atlas is probatlas else lblatlas.data == label.value
assert (target == atlas.get(label).data).all()
assert label.name == atlas.get(label).name
assert (target == atlas.get(index=label.index).data).all()
assert (target == atlas.get(value=label.value).data).all()
assert (target == atlas.get(name=label.name).data).all()
def test_find(): def test_find():
reg = atlases.registry reg = atlases.registry
...@@ -252,16 +269,31 @@ def test_find(): ...@@ -252,16 +269,31 @@ def test_find():
assert atlas .find(value=label.value) == label assert atlas .find(value=label.value) == label
assert atlas .find(index=label.index) == label assert atlas .find(index=label.index) == label
assert atlas .find(name=label.name) == label
assert atlas.desc.find(value=label.value) == label assert atlas.desc.find(value=label.value) == label
assert atlas.desc.find(index=label.index) == label assert atlas.desc.find(index=label.index) == label
assert atlas.desc.find(name=label.name) == label
if atlas is not lblatlas:
# lblatlas has a lot of very similar label names
assert atlas .find(name=label.name[:-2]) == label
assert atlas.desc.find(name=label.name[:-2]) == label
with pytest.raises(ValueError): with pytest.raises(ValueError):
atlas.find() atlas.find()
with pytest.raises(ValueError): with pytest.raises(ValueError):
atlas.find(index=1, value=1) atlas.find(index=1, value=1)
with pytest.raises(ValueError):
atlas.find(index=1, name=1)
with pytest.raises(ValueError):
atlas.find(value=1, name=1)
with pytest.raises(IndexError): with pytest.raises(IndexError):
atlas.find(index=len(labels)) atlas.find(index=len(labels))
with pytest.raises(IndexError):
atlas.find(name='InvalidROI')
with pytest.raises(IndexError):
atlas.find(name='')
maxval = max([l.value for l in labels]) maxval = max([l.value for l in labels])
with pytest.raises(KeyError): with pytest.raises(KeyError):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment