From 2d0249c13f9b9e255947a89e0e20097a16ca3adb Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Fri, 29 Mar 2019 11:43:10 +0000
Subject: [PATCH] RF,BF: New asarray option to query method which defaults to
 false - query will default to returning a flat list of matches.
 subtree-related fixes to scan function

---
 fsl/utils/filetree/query.py | 25 +++++++++++++++++--------
 1 file changed, 17 insertions(+), 8 deletions(-)

diff --git a/fsl/utils/filetree/query.py b/fsl/utils/filetree/query.py
index 18b86b92e..b1f1f46ed 100644
--- a/fsl/utils/filetree/query.py
+++ b/fsl/utils/filetree/query.py
@@ -189,21 +189,27 @@ class FileTreeQuery(object):
         return list(self.__shortnamevars.keys())
 
 
-    def query(self, short_name, **variables):
+    def query(self, short_name, asarray=False, **variables):
         """Search for files of the given ``short_name``, which match
         the specified ``variables``. All hits are returned for variables
         that are unspecified.
 
-        :arg short_name: Short name of files to search for.
+        :arg short_name:  Short name of files to search for.
+
+        :arg asarray: If ``True``, the relevant :class:`Match` objects are
+                      returned in a in a ND ``numpy.array`` where each
+                      dimension corresponds to a variable for the
+                      ``short_name`` in question (as returned by
+                      :meth:`axes`). Otherwise (the default), they are
+                      returned in a list.
 
         All other arguments are assumed to be ``variable=value`` pairs,
         used to restrict which matches are returned. All values are returned
         for variables that are not specified, or variables which are given a
         value of ``'*'``.
 
-        :returns: A ``numpy.array`` of ``Match`` objects, with axes
-                  corresponding to the labels returned by the :meth:`axes`
-                  method.
+        :returns: A list  of ``Match`` objects, (or a ``numpy.array`` if
+                  ``asarray=True``).
         """
 
         varnames    = list(variables.keys())
@@ -225,7 +231,10 @@ class FileTreeQuery(object):
             if val == '*': slc.append(slice(None))
             else:          slc.extend([np.newaxis, varidxs[var][val]])
 
-        return matcharray[tuple(slc)]
+        result = matcharray[tuple(slc)]
+
+        if asarray: return result
+        else:       return [m for m in result.flat if isinstance(m, Match)]
 
 
 class Match(object):
@@ -306,8 +315,8 @@ def scan(tree : FileTree) -> List[Match]:
 
             matches.append(Match(filename, template, variables))
 
-    for tree_name, sub_tree in tree.sub_trees:
-        matches.extend(Match.scan(sub_tree))
+    for tree_name, sub_tree in tree.sub_trees.items():
+        matches.extend(scan(sub_tree))
 
     return matches
 
-- 
GitLab