From 6ebae621cd7f69ff73844e4019b7ad830981207c Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Fri, 29 Mar 2019 11:44:29 +0000
Subject: [PATCH] TEST: more query tests

---
 tests/test_filetree/test_query.py | 194 +++++++++++++++++++++++-------
 1 file changed, 153 insertions(+), 41 deletions(-)

diff --git a/tests/test_filetree/test_query.py b/tests/test_filetree/test_query.py
index 094686d9e..66e50c1f8 100644
--- a/tests/test_filetree/test_query.py
+++ b/tests/test_filetree/test_query.py
@@ -84,32 +84,36 @@ def _expected_matches(short_name, **kwargs):
     return matches
 
 
-def _run_and_check_query(query, short_name, **vars):
+def _run_and_check_query(query, short_name, asarray=False, **vars):
 
-    gotmatches = query.query(      short_name, **vars)
+    gotmatches = query.query(      short_name, asarray=asarray, **vars)
     expmatches = _expected_matches(short_name, **{k : [v]
                                                   for k, v
                                                   in vars.items()})
-    snvars     = query.variables(short_name)
 
-    assert len(snvars) == len(gotmatches.shape)
+    if not asarray:
+        assert len(gotmatches) == len(expmatches)
+        for got, exp in zip(sorted(gotmatches), sorted(expmatches)):
+            assert got == exp
+    else:
+        snvars = query.variables(short_name)
 
-    for i, var in enumerate(sorted(snvars.keys())):
-        if var not in vars or vars[var] == '*':
-            assert len(snvars[var]) == gotmatches.shape[i]
-        else:
-            assert gotmatches.shape[i] == 1
+        assert len(snvars) == len(gotmatches.shape)
 
-    for expmatch in expmatches:
-        slc = []
-        for var in query.axes(short_name):
+        for i, var in enumerate(sorted(snvars.keys())):
             if var not in vars or vars[var] == '*':
-                vidx = snvars[var].index(expmatch.variables[var])
-                slc.append(vidx)
+                assert len(snvars[var]) == gotmatches.shape[i]
             else:
-                slc.append(0)
+                assert gotmatches.shape[i] == 1
 
-        assert expmatch == gotmatches[tuple(slc)]
+        for expmatch in expmatches:
+            slc = []
+            for var in query.axes(short_name):
+                if var not in vars or vars[var] == '*':
+                    vidx = snvars[var].index(expmatch.variables[var])
+                    slc.append(vidx)
+                else:
+                    slc.append(0)
 
 
 def test_query_properties():
@@ -180,7 +184,6 @@ def test_query_optional_var_folder():
         assert query.variables()['session'] == [None, '1', '2']
 
         m = query.query('T1w', participant='01')
-        m = [ma for ma in m.flatten() if isinstance(ma, ftquery.Match)]
         assert len(m) == 1
         assert m[0].filename == op.join('subj-01', 'T1w.nii.gz')
 
@@ -212,22 +215,19 @@ def test_query_optional_var_filename():
         assert qvars['subject']  == ['01', '02', '03', '04']
         assert qvars['modality'] == [None, 't1', 't2']
 
-        m = query.query('image', modality=None)
-        m = [ma.filename for ma in m.flatten()
-             if isinstance(ma, ftquery.Match)]
-        assert m == [op.join('sub-01', 'img.nii.gz'),
-                     op.join('sub-04', 'img.nii.gz')]
+        got = query.query('image', modality=None)
+        assert [m.filename for m in sorted(got)] == [
+            op.join('sub-01', 'img.nii.gz'),
+            op.join('sub-04', 'img.nii.gz')]
 
-        m = query.query('image', modality='t1')
-        m = [ma.filename for ma in m.flatten()
-             if isinstance(ma, ftquery.Match)]
-        assert m == [op.join('sub-02', 'img-t1.nii.gz'),
-                     op.join('sub-03', 'img-t1.nii.gz')]
+        got = query.query('image', modality='t1')
+        assert [m.filename for m in sorted(got)] == [
+            op.join('sub-02', 'img-t1.nii.gz'),
+            op.join('sub-03', 'img-t1.nii.gz')]
 
-        m = query.query('image', modality='t2')
-        m = [ma.filename for ma in m.flatten()
-             if isinstance(ma, ftquery.Match)]
-        assert m == [op.join('sub-02', 'img-t2.nii.gz')]
+        got = query.query('image', modality='t2')
+        assert len(got) == 1
+        assert got[0].filename == op.join('sub-02', 'img-t2.nii.gz')
 
 
 def test_query_missing_files():
@@ -242,21 +242,18 @@ def test_query_missing_files():
         tree  = filetree.FileTree.read('_test_tree.tree', '.')
         query = filetree.FileTreeQuery(tree)
 
-        m = query.query('T1w', session='1')
-        m = [ma.filename for ma in m.flatten() if isinstance(ma, ftquery.Match)]
-        assert sorted(m) == [
+        got = query.query('T1w', session='1')
+        assert [m.filename for m in sorted(got)] == [
             op.join('subj-02', 'ses-1', 'T1w.nii.gz'),
             op.join('subj-03', 'ses-1', 'T1w.nii.gz')]
 
-        m = query.query('T2w', session='2')
-        m = [ma.filename for ma in m.flatten() if isinstance(ma, ftquery.Match)]
-        assert sorted(m) == [
+        got = query.query('T2w', session='2')
+        assert [m.filename for m in sorted(got)] == [
             op.join('subj-01', 'ses-2', 'T2w.nii.gz'),
             op.join('subj-03', 'ses-2', 'T2w.nii.gz')]
 
-        m = query.query('surface', session='1', hemi='L')
-        m = [ma.filename for ma in m.flatten() if isinstance(ma, ftquery.Match)]
-        assert sorted(m) == [
+        got = query.query('surface', session='1', hemi='L')
+        assert [m.filename for m in sorted(got)] == [
             op.join('subj-01', 'ses-1', 'L.midthickness.gii'),
             op.join('subj-01', 'ses-1', 'L.pial.gii'),
             op.join('subj-01', 'ses-1', 'L.white.gii'),
@@ -265,7 +262,122 @@ def test_query_missing_files():
             op.join('subj-02', 'ses-1', 'L.white.gii')]
 
 
-    pass
+def test_query_asarray():
+    with _test_data():
+        tree  = filetree.FileTree.read('_test_tree.tree', '.')
+        query = filetree.FileTreeQuery(tree)
+
+        _run_and_check_query(query, 'T1w', asarray=True)
+        _run_and_check_query(query, 'T1w', asarray=True, participant='01')
+        _run_and_check_query(query, 'T1w', asarray=True, session='2')
+        _run_and_check_query(query, 'T1w', asarray=True, participant='02', session='1')
+        _run_and_check_query(query, 'T2w', asarray=True)
+        _run_and_check_query(query, 'T2w', asarray=True, participant='01')
+        _run_and_check_query(query, 'T2w', asarray=True, session='2')
+        _run_and_check_query(query, 'T2w', asarray=True, participant='02', session='1')
+        _run_and_check_query(query, 'surface', asarray=True)
+        _run_and_check_query(query, 'surface', asarray=True, hemi='L')
+        _run_and_check_query(query, 'surface', asarray=True, surf='midthickness')
+        _run_and_check_query(query, 'surface', asarray=True, hemi='R', surf='pial')
+        _run_and_check_query(query, 'surface', asarray=True, participant='03', surf='pial')
+        _run_and_check_query(query, 'surface', asarray=True, participant='03', sssion='2')
+
+
+def test_query_subtree():
+    tree1 = tw.dedent("""
+    subj-{participant}
+        T1w.nii.gz (T1w)
+        surf
+            ->surface (surfdir)
+    """)
+    tree2 = tw.dedent("""
+    {hemi}.{surf}.gii (surface)
+    """)
+
+    files = [
+        op.join('subj-01', 'T1w.nii.gz'),
+        op.join('subj-01', 'surf', 'L.pial.gii'),
+        op.join('subj-01', 'surf', 'R.pial.gii'),
+        op.join('subj-01', 'surf', 'L.white.gii'),
+        op.join('subj-01', 'surf', 'R.white.gii'),
+        op.join('subj-02', 'T1w.nii.gz'),
+        op.join('subj-02', 'surf', 'L.pial.gii'),
+        op.join('subj-02', 'surf', 'R.pial.gii'),
+        op.join('subj-02', 'surf', 'L.white.gii'),
+        op.join('subj-02', 'surf', 'R.white.gii'),
+        op.join('subj-03', 'T1w.nii.gz'),
+        op.join('subj-03', 'surf', 'L.pial.gii'),
+        op.join('subj-03', 'surf', 'R.pial.gii'),
+        op.join('subj-03', 'surf', 'L.white.gii'),
+        op.join('subj-03', 'surf', 'R.white.gii')]
+
+    with testdir(files):
+        with open('tree1.tree',   'wt') as f: f.write(tree1)
+        with open('surface.tree', 'wt') as f: f.write(tree2)
+
+        tree = filetree.FileTree.read('tree1.tree', '.')
+        query = filetree.FileTreeQuery(tree)
+
+        assert sorted(query.short_names) == ['T1w', 'surface']
+
+        qvars = query.variables()
+        assert sorted(qvars.keys()) == ['hemi', 'participant', 'surf']
+        assert qvars['hemi']        == ['L', 'R']
+        assert qvars['participant'] == ['01', '02', '03']
+        assert qvars['surf']        == ['pial', 'white']
+
+        qvars = query.variables('T1w')
+        assert sorted(qvars.keys()) == ['participant']
+        assert qvars['participant'] == ['01', '02', '03']
+
+        qvars = query.variables('surface')
+        assert sorted(qvars.keys()) == ['hemi', 'participant', 'surf']
+        assert qvars['hemi']        == ['L', 'R']
+        assert qvars['participant'] == ['01', '02', '03']
+        assert qvars['surf']        == ['pial', 'white']
+
+        got = query.query('T1w')
+        assert [m.filename for m in sorted(got)] == [
+            op.join('subj-01', 'T1w.nii.gz'),
+            op.join('subj-02', 'T1w.nii.gz'),
+            op.join('subj-03', 'T1w.nii.gz')]
+
+        got = query.query('T1w', participant='01')
+        assert [m.filename for m in sorted(got)] == [
+            op.join('subj-01', 'T1w.nii.gz')]
+
+        got = query.query('surface')
+        assert [m.filename for m in sorted(got)] == [
+            op.join('subj-01', 'surf', 'L.pial.gii'),
+            op.join('subj-01', 'surf', 'L.white.gii'),
+            op.join('subj-01', 'surf', 'R.pial.gii'),
+            op.join('subj-01', 'surf', 'R.white.gii'),
+            op.join('subj-02', 'surf', 'L.pial.gii'),
+            op.join('subj-02', 'surf', 'L.white.gii'),
+            op.join('subj-02', 'surf', 'R.pial.gii'),
+            op.join('subj-02', 'surf', 'R.white.gii'),
+            op.join('subj-03', 'surf', 'L.pial.gii'),
+            op.join('subj-03', 'surf', 'L.white.gii'),
+            op.join('subj-03', 'surf', 'R.pial.gii'),
+            op.join('subj-03', 'surf', 'R.white.gii')]
+
+        got = query.query('surface', hemi='L')
+        assert [m.filename for m in sorted(got)] == [
+            op.join('subj-01', 'surf', 'L.pial.gii'),
+            op.join('subj-01', 'surf', 'L.white.gii'),
+            op.join('subj-02', 'surf', 'L.pial.gii'),
+            op.join('subj-02', 'surf', 'L.white.gii'),
+            op.join('subj-03', 'surf', 'L.pial.gii'),
+            op.join('subj-03', 'surf', 'L.white.gii')]
+
+        got = query.query('surface', surf='white')
+        assert [m.filename for m in sorted(got)] == [
+            op.join('subj-01', 'surf', 'L.white.gii'),
+            op.join('subj-01', 'surf', 'R.white.gii'),
+            op.join('subj-02', 'surf', 'L.white.gii'),
+            op.join('subj-02', 'surf', 'R.white.gii'),
+            op.join('subj-03', 'surf', 'L.white.gii'),
+            op.join('subj-03', 'surf', 'R.white.gii')]
 
 
 def test_scan():
-- 
GitLab