From 3af1f6110d71c1d279ff6b50401e3c6784afe889 Mon Sep 17 00:00:00 2001
From: Michiel Cottaar <MichielCottaar@gmail.com>
Date: Tue, 14 May 2019 17:17:10 +0100
Subject: [PATCH] ENH: wrote get methods

these return 3D images with the binary or probabilistic mask
in a new image.Image with the name set to the label name
and the same NIFTI header as the atlas
---
 fsl/data/atlases.py   | 47 +++++++++++++++++++++++++++++++++++++++++--
 tests/test_atlases.py | 19 +++++++++++++++--
 2 files changed, 62 insertions(+), 4 deletions(-)

diff --git a/fsl/data/atlases.py b/fsl/data/atlases.py
index d43a09d95..b28eb952c 100644
--- a/fsl/data/atlases.py
+++ b/fsl/data/atlases.py
@@ -586,8 +586,8 @@ class AtlasDescription(object):
     def find(self, index=None, value=None, name=None):
         """Find an :class:`.AtlasLabel` either by ``index``, or by ``value``.
 
-        Exactly one of ``index`` or ``value`` may be specified - a
-        ``ValueError`` is raised otherwise. If an invalid ``index`` or
+        Exactly one of ``index``, ``value``, or ``name`` may be specified - a
+        ``ValueError`` is raised otherwise. If an invalid ``index``, ``name``, or
         ``value`` is specified, an ``IndexError`` or ``KeyError`` will be
         raised.
 
@@ -877,6 +877,28 @@ class LabelAtlas(Atlas):
 
         return values, props
 
+    def get(self, label=None, index=None, value=None, name=None):
+        """
+        Returns the binary image for given label
+
+        Only one of the arguments should be used to define the label
+
+        :arg label: AtlasLabel contained within this atlas
+        :arg index: index of the label
+        :arg value: value of the label
+        :arg name: string of the label
+        :return: image.Image with the mask
+        """
+        if ((label is not None) + (index is not None) +
+            (value is not None) + (name is not None)) != 1:
+            raise ValueError('Only one of label, index, value, or name may be specified')
+        if label is None:
+            label = self.find(index=index, name=name, value=value)
+        elif label not in self.desc.labels:
+            raise ValueError("Unknown label provided")
+        arr = (self.data == label.value).astype(int)
+        return fslimage.Image(arr, name=label.name, header=self.header)
+
 
 class ProbabilisticAtlas(Atlas):
     """A 4D atlas which contains one volume for each region.
@@ -895,6 +917,27 @@ class ProbabilisticAtlas(Atlas):
         """
         Atlas.__init__(self, atlasDesc, resolution, False, **kwargs)
 
+    def get(self, label=None, index=None, value=None, name=None):
+        """
+        Returns the probabilistic image for given label
+
+        Only one of the arguments should be used to define the label
+
+        :arg label: AtlasLabel contained within this atlas
+        :arg index: index of the label
+        :arg value: value of the label
+        :arg name: string of the label
+        :return: image.Image with the probabilistic mask
+        """
+        if ((label is not None) + (index is not None) +
+            (value is not None) + (name is not None)) != 1:
+            raise ValueError('Only one of label, index, value, or name may be specified')
+        if label is None:
+            label = self.find(index=index, value=value, name=name)
+        elif label not in self.desc.labels:
+            raise ValueError("Unknown label provided")
+        arr = self[..., label.index]
+        return fslimage.Image(arr, name=label.name, header=self.header)
 
     def proportions(self, location, *args, **kwargs):
         """Looks up and returns the proportions of of all regions at the given
diff --git a/tests/test_atlases.py b/tests/test_atlases.py
index 8f4136b5d..b2dfec113 100644
--- a/tests/test_atlases.py
+++ b/tests/test_atlases.py
@@ -142,8 +142,6 @@ def test_AtlasDescription():
         registry.getAtlasDescription('non-existent-atlas')
 
 
-
-
 def test_add_remove_atlas():
 
     with tests.testdir() as testdir:
@@ -237,6 +235,23 @@ def test_load_atlas():
     assert isinstance(lblatlas,     atlases.LabelAtlas)
 
 
+def test_get():
+
+    reg = atlases.registry
+    reg.rescanAtlases()
+
+    probatlas = reg.loadAtlas('harvardoxford-cortical')
+    lblatlas = reg.loadAtlas('talairach')
+    for atlas in (probatlas, lblatlas):
+        for idx, label in enumerate(atlas.desc.labels[:10]):
+            target = probatlas[..., idx] if atlas is probatlas else lblatlas.data == label.value
+            assert (target == atlas.get(label).data).all()
+            assert label.name == atlas.get(label).name
+            assert (target == atlas.get(index=label.index).data).all()
+            assert (target == atlas.get(value=label.value).data).all()
+            assert (target == atlas.get(name=label.name).data).all()
+
+
 def test_find():
 
     reg = atlases.registry
-- 
GitLab