From 1bf67e0e5aefbb1a586d09a42a01477baf52a5f2 Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Fri, 22 Sep 2017 12:06:36 +0100
Subject: [PATCH] New methods on Atlas image classes for mask-based queries.
 Not tested. Also adjusted resample method to force float output (although
 this might change).

---
 fsl/data/atlases.py | 199 ++++++++++++++++++++++++++++++++++++++------
 fsl/data/image.py   |   4 +
 2 files changed, 179 insertions(+), 24 deletions(-)

diff --git a/fsl/data/atlases.py b/fsl/data/atlases.py
index 6f45caffe..45369444a 100644
--- a/fsl/data/atlases.py
+++ b/fsl/data/atlases.py
@@ -623,23 +623,60 @@ class LabelAtlas(Atlas):
         Atlas.__init__(self, atlasDesc, resolution, True)
 
 
-    def label(self, worldLoc):
-        """Looks up and returns the label of the region at the given world
-        location, or ``None`` if the location is out of bounds.
+    def label(self, location, *args, **kwargs):
+        """Looks up and returns the label of the region at the given
+        location.
+
+        :arg location: Can be one of the following:
+
+                        - A sequence of three values, interpreted as
+                          atlas coordinates. In this case, :meth:`coordLabel`
+                          is called.
+
+                        - An :class:`.Image` which is interpreted as a
+                          weighted mask. In this case, :meth:`maskLabel` is
+                          called.
+
+        All other arguments are passed through to the :meth:`coordLabel` or
+        :meth:`maskLabel` methods.
+
+
+        :returns: The return value of either :meth:`coordLabel` or
+                  :meth:`maskLabel`.
+        """
+
+        if isinstance(location, fslimage.Image):
+            return self.maskLabel(location, *args, **kwargs)
+        else:
+            return self.coordLabel(location, *args, **kwargs)
+
+
+    def coordLabel(self, loc, voxel=False):
+        """Looks up and returns the label at the given location.
+
+        :arg loc:   A sequence of three values, interpreted as atlas
+                    coordinates. In this case, :meth:`coordLabel` is called.
+
+        :arg voxel: Defaults to ``False``. If ``True``, the ``location``
+                    is interpreted as voxel coordinates.
+
+        :returns:   The label at the given coordinates, or ``None`` if the
+                    coordinates are out of bounds.
         """
 
-        voxelLoc = transform.transform([worldLoc], self.worldToVoxMat)[0]
-        voxelLoc = [int(v) for v in voxelLoc.round()]
+        if not voxel:
+            loc = transform.transform([loc], self.worldToVoxMat)[0]
+            loc = [int(v) for v in loc.round()]
 
-        if voxelLoc[0] <  0             or \
-           voxelLoc[1] <  0             or \
-           voxelLoc[2] <  0             or \
-           voxelLoc[0] >= self.shape[0] or \
-           voxelLoc[1] >= self.shape[1] or \
-           voxelLoc[2] >= self.shape[2]:
+        if loc[0] <  0             or \
+           loc[1] <  0             or \
+           loc[2] <  0             or \
+           loc[0] >= self.shape[0] or \
+           loc[1] >= self.shape[1] or \
+           loc[2] >= self.shape[2]:
             return None
 
-        val = self[voxelLoc[0], voxelLoc[1], voxelLoc[2]]
+        val = self[loc[0], loc[1], loc[2]]
 
         if self.desc.atlasType == 'label':
             return val
@@ -648,6 +685,50 @@ class LabelAtlas(Atlas):
             return val - 1
 
 
+    def maskLabel(self, mask):
+        """Looks up and returns the proportions of all regions that are present
+        in the given ``mask``.
+
+        :arg mask: A 3D :class:`.Image`` which is interpreted as a weighted
+                   mask. If the ``mask`` shape does not match that of this
+                   ``LabelAtlas``, it is resampled using
+                   :meth:`.Image.resample`, with linear interpolation.
+
+        :returns:  A tuple containing:
+
+                     - A sequence of all labels which are present in the mask
+                     - A sequence containing the proportion, within the mask,
+                       of each present label.
+        """
+
+        # Make sure that the mask has the
+        # same number of voxels as the
+        # atlas image
+        mask     = mask.resample(self.shape[:3], order=1)
+        boolmask = mask > 0
+
+        # Extract the labels that are in
+        # the mask, and their corresponding
+        # mask weights
+        vals    = self[boolmask]
+        weights = vals * mask[boolmask]
+        labels  = np.unique(vals)
+        props   = []
+
+        for label in labels:
+
+            # Figure out the number of all voxels
+            # in the mask with this label, weighted
+            # by the mask
+            prop = ((vals == label) * weights).sum()
+
+            # Normalise it to be a proportion
+            # of all voxels in the mask
+            props.append(prop / float(len(vals)))
+
+        return labels, props
+
+
 class ProbabilisticAtlas(Atlas):
     """A 4D atlas which contains one volume for each region.
 
@@ -666,28 +747,98 @@ class ProbabilisticAtlas(Atlas):
         Atlas.__init__(self, atlasDesc, resolution, False)
 
 
-    def proportions(self, worldLoc):
+    def proportions(self, location, *args, **kwargs):
+        """Looks up and returns the proportions of of all regions at the given
+        location.
+
+        :arg location: Can be one of the following:
+
+                        - A sequence of three values, interpreted as atlas
+                          coordinates. In this case, :meth:`coordProportions`
+                          is called.
+
+                        - An :class:`.Image` which is interpreted as a
+                          weighted mask. In this case, :meth:`maskProportions`
+                          is called.
+
+        All other arguments are passed through to the :meth:`coordProportions`
+        or :meth:`maskProportions` methods.
+
+
+        :returns: The return value of either :meth:`coordProportions` or
+                  :meth:`maskProportions`.
+        """
+
+        if isinstance(location, fslimage.Image):
+            return self.maskProportions(location, *args, **kwargs)
+        else:
+            return self.coordProportions(location, *args, **kwargs)
+
+
+    def coordProportions(self, loc, voxel=False):
         """Looks up the region probabilities for the given location.
 
-        :arg worldLoc: Location in the world coordinate system.
+        :arg loc:   A sequence of three values, interpreted as atlas
+                    world or voxel coordinates.
+
+        :arg voxel: Defaults to ``False``. If ``True``, the ``loc``
+                    argument is interpreted as voxel coordinates.
 
         :returns: a list of values, one per region, which represent
                   the probability of each region for the specified
                   location. Returns an empty list if the given
                   location is out of bounds.
         """
-        voxelLoc = transform.transform([worldLoc], self.worldToVoxMat)[0]
-        voxelLoc = [int(v) for v in voxelLoc.round()]
-
-        if voxelLoc[0] <  0             or \
-           voxelLoc[1] <  0             or \
-           voxelLoc[2] <  0             or \
-           voxelLoc[0] >= self.shape[0] or \
-           voxelLoc[1] >= self.shape[1] or \
-           voxelLoc[2] >= self.shape[2]:
+
+        if not voxel:
+            loc = transform.transform([loc], self.worldToVoxMat)[0]
+            loc = [int(v) for v in loc.round()]
+
+        if loc[0] <  0             or \
+           loc[1] <  0             or \
+           loc[2] <  0             or \
+           loc[0] >= self.shape[0] or \
+           loc[1] >= self.shape[1] or \
+           loc[2] >= self.shape[2]:
             return []
 
-        return self[voxelLoc[0], voxelLoc[1], voxelLoc[2], :]
+        return self[loc[0], loc[1], loc[2], :]
+
+
+    def maskProportions(self, mask):
+        """Looks up the probabilities of all regions in the given ``mask``.
+
+        :arg mask: A 3D :class:`.Image`` which is interpreted as a weighted
+                   mask. If the ``mask`` shape does not match that of this
+                   ``ProbabilisticAtlas``, it is resampled using
+                   :meth:`.Image.resample`, with linear interpolation.
+
+        :returns:  A tuple containing:
+
+                     - A sequence of all labels which are present in the mask
+                     - A sequence containing the proportion, within the mask,
+                       of each present label.
+        """
+
+        labels = []
+        props  = []
+
+        # Make sure that the mask has the same
+        # number of voxels as the atlas image
+        mask     = mask.resample(self.shape[:3], order=1)
+        boolmask = mask > 0
+
+        for label in range(self.shape[3]):
+
+            vals = self[..., label]
+            vals = vals[boolmask] * mask[boolmask]
+            prop = vals.mean()
+
+            if prop != 0:
+                labels.append(label)
+                props .append(prop)
+
+        return labels, props
 
 
 registry            = AtlasRegistry()
diff --git a/fsl/data/image.py b/fsl/data/image.py
index 7a60d3ec3..d6f418037 100644
--- a/fsl/data/image.py
+++ b/fsl/data/image.py
@@ -1128,6 +1128,9 @@ class Image(Nifti):
 
         All other arguments are passed through to the ``scipy.ndimage.zoom``
         function.
+
+        :returns: A ``numpy`` array of shape ``shape``, containing an
+                  interpolated copy of the data in this ``Image``.
         """
 
         if sliceobj is None:
@@ -1137,6 +1140,7 @@ class Image(Nifti):
         data  = self[sliceobj]
 
         if tuple(data.shape) != tuple(shape):
+            data  = np.array(data, dtype=np.float, copy=False)
             zooms = [float(shape[i]) / data.shape[i] for i in range(ndims)]
             data  = ndimage.zoom(data, zooms, **kwargs)
 
-- 
GitLab