diff --git a/fsl/data/atlases.py b/fsl/data/atlases.py index d43a09d954a636414f67faed284318e1048e20cc..b28eb952c98ccbda29f2097ca8eda9ae660cba4c 100644 --- a/fsl/data/atlases.py +++ b/fsl/data/atlases.py @@ -586,8 +586,8 @@ class AtlasDescription(object): def find(self, index=None, value=None, name=None): """Find an :class:`.AtlasLabel` either by ``index``, or by ``value``. - Exactly one of ``index`` or ``value`` may be specified - a - ``ValueError`` is raised otherwise. If an invalid ``index`` or + Exactly one of ``index``, ``value``, or ``name`` may be specified - a + ``ValueError`` is raised otherwise. If an invalid ``index``, ``name``, or ``value`` is specified, an ``IndexError`` or ``KeyError`` will be raised. @@ -877,6 +877,28 @@ class LabelAtlas(Atlas): return values, props + def get(self, label=None, index=None, value=None, name=None): + """ + Returns the binary image for given label + + Only one of the arguments should be used to define the label + + :arg label: AtlasLabel contained within this atlas + :arg index: index of the label + :arg value: value of the label + :arg name: string of the label + :return: image.Image with the mask + """ + if ((label is not None) + (index is not None) + + (value is not None) + (name is not None)) != 1: + raise ValueError('Only one of label, index, value, or name may be specified') + if label is None: + label = self.find(index=index, name=name, value=value) + elif label not in self.desc.labels: + raise ValueError("Unknown label provided") + arr = (self.data == label.value).astype(int) + return fslimage.Image(arr, name=label.name, header=self.header) + class ProbabilisticAtlas(Atlas): """A 4D atlas which contains one volume for each region. @@ -895,6 +917,27 @@ class ProbabilisticAtlas(Atlas): """ Atlas.__init__(self, atlasDesc, resolution, False, **kwargs) + def get(self, label=None, index=None, value=None, name=None): + """ + Returns the probabilistic image for given label + + Only one of the arguments should be used to define the label + + :arg label: AtlasLabel contained within this atlas + :arg index: index of the label + :arg value: value of the label + :arg name: string of the label + :return: image.Image with the probabilistic mask + """ + if ((label is not None) + (index is not None) + + (value is not None) + (name is not None)) != 1: + raise ValueError('Only one of label, index, value, or name may be specified') + if label is None: + label = self.find(index=index, value=value, name=name) + elif label not in self.desc.labels: + raise ValueError("Unknown label provided") + arr = self[..., label.index] + return fslimage.Image(arr, name=label.name, header=self.header) def proportions(self, location, *args, **kwargs): """Looks up and returns the proportions of of all regions at the given diff --git a/tests/test_atlases.py b/tests/test_atlases.py index 8f4136b5dfb6dbaaa46a1e6061275c3c549a6f02..b2dfec113a00349dea0601f781f565561fa18b2c 100644 --- a/tests/test_atlases.py +++ b/tests/test_atlases.py @@ -142,8 +142,6 @@ def test_AtlasDescription(): registry.getAtlasDescription('non-existent-atlas') - - def test_add_remove_atlas(): with tests.testdir() as testdir: @@ -237,6 +235,23 @@ def test_load_atlas(): assert isinstance(lblatlas, atlases.LabelAtlas) +def test_get(): + + reg = atlases.registry + reg.rescanAtlases() + + probatlas = reg.loadAtlas('harvardoxford-cortical') + lblatlas = reg.loadAtlas('talairach') + for atlas in (probatlas, lblatlas): + for idx, label in enumerate(atlas.desc.labels[:10]): + target = probatlas[..., idx] if atlas is probatlas else lblatlas.data == label.value + assert (target == atlas.get(label).data).all() + assert label.name == atlas.get(label).name + assert (target == atlas.get(index=label.index).data).all() + assert (target == atlas.get(value=label.value).data).all() + assert (target == atlas.get(name=label.name).data).all() + + def test_find(): reg = atlases.registry