From 6ebae621cd7f69ff73844e4019b7ad830981207c Mon Sep 17 00:00:00 2001 From: Paul McCarthy <pauldmccarthy@gmail.com> Date: Fri, 29 Mar 2019 11:44:29 +0000 Subject: [PATCH] TEST: more query tests --- tests/test_filetree/test_query.py | 194 +++++++++++++++++++++++------- 1 file changed, 153 insertions(+), 41 deletions(-) diff --git a/tests/test_filetree/test_query.py b/tests/test_filetree/test_query.py index 094686d9e..66e50c1f8 100644 --- a/tests/test_filetree/test_query.py +++ b/tests/test_filetree/test_query.py @@ -84,32 +84,36 @@ def _expected_matches(short_name, **kwargs): return matches -def _run_and_check_query(query, short_name, **vars): +def _run_and_check_query(query, short_name, asarray=False, **vars): - gotmatches = query.query( short_name, **vars) + gotmatches = query.query( short_name, asarray=asarray, **vars) expmatches = _expected_matches(short_name, **{k : [v] for k, v in vars.items()}) - snvars = query.variables(short_name) - assert len(snvars) == len(gotmatches.shape) + if not asarray: + assert len(gotmatches) == len(expmatches) + for got, exp in zip(sorted(gotmatches), sorted(expmatches)): + assert got == exp + else: + snvars = query.variables(short_name) - for i, var in enumerate(sorted(snvars.keys())): - if var not in vars or vars[var] == '*': - assert len(snvars[var]) == gotmatches.shape[i] - else: - assert gotmatches.shape[i] == 1 + assert len(snvars) == len(gotmatches.shape) - for expmatch in expmatches: - slc = [] - for var in query.axes(short_name): + for i, var in enumerate(sorted(snvars.keys())): if var not in vars or vars[var] == '*': - vidx = snvars[var].index(expmatch.variables[var]) - slc.append(vidx) + assert len(snvars[var]) == gotmatches.shape[i] else: - slc.append(0) + assert gotmatches.shape[i] == 1 - assert expmatch == gotmatches[tuple(slc)] + for expmatch in expmatches: + slc = [] + for var in query.axes(short_name): + if var not in vars or vars[var] == '*': + vidx = snvars[var].index(expmatch.variables[var]) + slc.append(vidx) + else: + slc.append(0) def test_query_properties(): @@ -180,7 +184,6 @@ def test_query_optional_var_folder(): assert query.variables()['session'] == [None, '1', '2'] m = query.query('T1w', participant='01') - m = [ma for ma in m.flatten() if isinstance(ma, ftquery.Match)] assert len(m) == 1 assert m[0].filename == op.join('subj-01', 'T1w.nii.gz') @@ -212,22 +215,19 @@ def test_query_optional_var_filename(): assert qvars['subject'] == ['01', '02', '03', '04'] assert qvars['modality'] == [None, 't1', 't2'] - m = query.query('image', modality=None) - m = [ma.filename for ma in m.flatten() - if isinstance(ma, ftquery.Match)] - assert m == [op.join('sub-01', 'img.nii.gz'), - op.join('sub-04', 'img.nii.gz')] + got = query.query('image', modality=None) + assert [m.filename for m in sorted(got)] == [ + op.join('sub-01', 'img.nii.gz'), + op.join('sub-04', 'img.nii.gz')] - m = query.query('image', modality='t1') - m = [ma.filename for ma in m.flatten() - if isinstance(ma, ftquery.Match)] - assert m == [op.join('sub-02', 'img-t1.nii.gz'), - op.join('sub-03', 'img-t1.nii.gz')] + got = query.query('image', modality='t1') + assert [m.filename for m in sorted(got)] == [ + op.join('sub-02', 'img-t1.nii.gz'), + op.join('sub-03', 'img-t1.nii.gz')] - m = query.query('image', modality='t2') - m = [ma.filename for ma in m.flatten() - if isinstance(ma, ftquery.Match)] - assert m == [op.join('sub-02', 'img-t2.nii.gz')] + got = query.query('image', modality='t2') + assert len(got) == 1 + assert got[0].filename == op.join('sub-02', 'img-t2.nii.gz') def test_query_missing_files(): @@ -242,21 +242,18 @@ def test_query_missing_files(): tree = filetree.FileTree.read('_test_tree.tree', '.') query = filetree.FileTreeQuery(tree) - m = query.query('T1w', session='1') - m = [ma.filename for ma in m.flatten() if isinstance(ma, ftquery.Match)] - assert sorted(m) == [ + got = query.query('T1w', session='1') + assert [m.filename for m in sorted(got)] == [ op.join('subj-02', 'ses-1', 'T1w.nii.gz'), op.join('subj-03', 'ses-1', 'T1w.nii.gz')] - m = query.query('T2w', session='2') - m = [ma.filename for ma in m.flatten() if isinstance(ma, ftquery.Match)] - assert sorted(m) == [ + got = query.query('T2w', session='2') + assert [m.filename for m in sorted(got)] == [ op.join('subj-01', 'ses-2', 'T2w.nii.gz'), op.join('subj-03', 'ses-2', 'T2w.nii.gz')] - m = query.query('surface', session='1', hemi='L') - m = [ma.filename for ma in m.flatten() if isinstance(ma, ftquery.Match)] - assert sorted(m) == [ + got = query.query('surface', session='1', hemi='L') + assert [m.filename for m in sorted(got)] == [ op.join('subj-01', 'ses-1', 'L.midthickness.gii'), op.join('subj-01', 'ses-1', 'L.pial.gii'), op.join('subj-01', 'ses-1', 'L.white.gii'), @@ -265,7 +262,122 @@ def test_query_missing_files(): op.join('subj-02', 'ses-1', 'L.white.gii')] - pass +def test_query_asarray(): + with _test_data(): + tree = filetree.FileTree.read('_test_tree.tree', '.') + query = filetree.FileTreeQuery(tree) + + _run_and_check_query(query, 'T1w', asarray=True) + _run_and_check_query(query, 'T1w', asarray=True, participant='01') + _run_and_check_query(query, 'T1w', asarray=True, session='2') + _run_and_check_query(query, 'T1w', asarray=True, participant='02', session='1') + _run_and_check_query(query, 'T2w', asarray=True) + _run_and_check_query(query, 'T2w', asarray=True, participant='01') + _run_and_check_query(query, 'T2w', asarray=True, session='2') + _run_and_check_query(query, 'T2w', asarray=True, participant='02', session='1') + _run_and_check_query(query, 'surface', asarray=True) + _run_and_check_query(query, 'surface', asarray=True, hemi='L') + _run_and_check_query(query, 'surface', asarray=True, surf='midthickness') + _run_and_check_query(query, 'surface', asarray=True, hemi='R', surf='pial') + _run_and_check_query(query, 'surface', asarray=True, participant='03', surf='pial') + _run_and_check_query(query, 'surface', asarray=True, participant='03', sssion='2') + + +def test_query_subtree(): + tree1 = tw.dedent(""" + subj-{participant} + T1w.nii.gz (T1w) + surf + ->surface (surfdir) + """) + tree2 = tw.dedent(""" + {hemi}.{surf}.gii (surface) + """) + + files = [ + op.join('subj-01', 'T1w.nii.gz'), + op.join('subj-01', 'surf', 'L.pial.gii'), + op.join('subj-01', 'surf', 'R.pial.gii'), + op.join('subj-01', 'surf', 'L.white.gii'), + op.join('subj-01', 'surf', 'R.white.gii'), + op.join('subj-02', 'T1w.nii.gz'), + op.join('subj-02', 'surf', 'L.pial.gii'), + op.join('subj-02', 'surf', 'R.pial.gii'), + op.join('subj-02', 'surf', 'L.white.gii'), + op.join('subj-02', 'surf', 'R.white.gii'), + op.join('subj-03', 'T1w.nii.gz'), + op.join('subj-03', 'surf', 'L.pial.gii'), + op.join('subj-03', 'surf', 'R.pial.gii'), + op.join('subj-03', 'surf', 'L.white.gii'), + op.join('subj-03', 'surf', 'R.white.gii')] + + with testdir(files): + with open('tree1.tree', 'wt') as f: f.write(tree1) + with open('surface.tree', 'wt') as f: f.write(tree2) + + tree = filetree.FileTree.read('tree1.tree', '.') + query = filetree.FileTreeQuery(tree) + + assert sorted(query.short_names) == ['T1w', 'surface'] + + qvars = query.variables() + assert sorted(qvars.keys()) == ['hemi', 'participant', 'surf'] + assert qvars['hemi'] == ['L', 'R'] + assert qvars['participant'] == ['01', '02', '03'] + assert qvars['surf'] == ['pial', 'white'] + + qvars = query.variables('T1w') + assert sorted(qvars.keys()) == ['participant'] + assert qvars['participant'] == ['01', '02', '03'] + + qvars = query.variables('surface') + assert sorted(qvars.keys()) == ['hemi', 'participant', 'surf'] + assert qvars['hemi'] == ['L', 'R'] + assert qvars['participant'] == ['01', '02', '03'] + assert qvars['surf'] == ['pial', 'white'] + + got = query.query('T1w') + assert [m.filename for m in sorted(got)] == [ + op.join('subj-01', 'T1w.nii.gz'), + op.join('subj-02', 'T1w.nii.gz'), + op.join('subj-03', 'T1w.nii.gz')] + + got = query.query('T1w', participant='01') + assert [m.filename for m in sorted(got)] == [ + op.join('subj-01', 'T1w.nii.gz')] + + got = query.query('surface') + assert [m.filename for m in sorted(got)] == [ + op.join('subj-01', 'surf', 'L.pial.gii'), + op.join('subj-01', 'surf', 'L.white.gii'), + op.join('subj-01', 'surf', 'R.pial.gii'), + op.join('subj-01', 'surf', 'R.white.gii'), + op.join('subj-02', 'surf', 'L.pial.gii'), + op.join('subj-02', 'surf', 'L.white.gii'), + op.join('subj-02', 'surf', 'R.pial.gii'), + op.join('subj-02', 'surf', 'R.white.gii'), + op.join('subj-03', 'surf', 'L.pial.gii'), + op.join('subj-03', 'surf', 'L.white.gii'), + op.join('subj-03', 'surf', 'R.pial.gii'), + op.join('subj-03', 'surf', 'R.white.gii')] + + got = query.query('surface', hemi='L') + assert [m.filename for m in sorted(got)] == [ + op.join('subj-01', 'surf', 'L.pial.gii'), + op.join('subj-01', 'surf', 'L.white.gii'), + op.join('subj-02', 'surf', 'L.pial.gii'), + op.join('subj-02', 'surf', 'L.white.gii'), + op.join('subj-03', 'surf', 'L.pial.gii'), + op.join('subj-03', 'surf', 'L.white.gii')] + + got = query.query('surface', surf='white') + assert [m.filename for m in sorted(got)] == [ + op.join('subj-01', 'surf', 'L.white.gii'), + op.join('subj-01', 'surf', 'R.white.gii'), + op.join('subj-02', 'surf', 'L.white.gii'), + op.join('subj-02', 'surf', 'R.white.gii'), + op.join('subj-03', 'surf', 'L.white.gii'), + op.join('subj-03', 'surf', 'R.white.gii')] def test_scan(): -- GitLab