From f8f758ae6bc29d2d8b02094859182a9e701ba406 Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Mon, 2 Jul 2018 14:05:28 +0100
Subject: [PATCH] TEST: Update tests

---
 fsl/scripts/extract_noise.py |  2 +-
 tests/test_extract_noise.py  | 42 ++++++++++++++++++++++--------------
 2 files changed, 27 insertions(+), 17 deletions(-)

diff --git a/fsl/scripts/extract_noise.py b/fsl/scripts/extract_noise.py
index b9590e686..5e59493c4 100644
--- a/fsl/scripts/extract_noise.py
+++ b/fsl/scripts/extract_noise.py
@@ -211,7 +211,7 @@ def main(argv=None):
     if argv is None:
         argv = sys.argv[1:]
 
-    args  = parseArgs(argv)
+    args = parseArgs(argv)
 
     try:
         comps = genComponentIndexList(args.components)
diff --git a/tests/test_extract_noise.py b/tests/test_extract_noise.py
index 768b98a91..ea3bdef12 100644
--- a/tests/test_extract_noise.py
+++ b/tests/test_extract_noise.py
@@ -8,6 +8,8 @@
 
 import numpy as np
 
+import pytest
+
 import fsl.utils.tempdir         as tempdir
 import fsl.scripts.extract_noise as extn
 
@@ -17,8 +19,6 @@ def test_genComponentIndexList():
     with tempdir.tempdir():
 
         # sequence of 1-indexed integers/file paths
-        # both potentially containing larger than
-        # the actual number of components
         icomps  = [1, 5, 28, 12, 42, 54]
         fcomps1 = [1, 4, 6, 3, 7]
         fcomps2 = [12, 42, 31, 1, 4, 8]
@@ -35,12 +35,10 @@ def test_genComponentIndexList():
 
         assert extn.genComponentIndexList(comps, ncomps) == expcomps
 
-        ncomps   = 40
-        comps    = icomps + ['comps1.txt', 'comps2.txt'] + [0, -1]
-        expcomps = list(sorted(set(icomps + fcomps1 + fcomps2)))
-        expcomps = [c - 1 for c in expcomps if c <= ncomps]
-
-        assert extn.genComponentIndexList(comps, ncomps) == expcomps
+        with pytest.raises(ValueError):
+            extn.genComponentIndexList(comps + [-1], 60)
+        with pytest.raises(ValueError):
+            extn.genComponentIndexList(comps, 40)
 
 
 def test_loadConfoundFiles():
@@ -50,17 +48,17 @@ def test_loadConfoundFiles():
         confs = [
             np.random.randint(1, 100, (50, 10)),
             np.random.randint(1, 100, (50, 1)),
-            np.random.randint(1, 100, (50, 5)),
+            np.random.randint(1, 100, (50, 5))]
+
+        badconfs = [
             np.random.randint(1, 100, (40, 10)),
             np.random.randint(1, 100, (60, 10))]
 
-        expected            = np.empty((50, 36), dtype=np.float64)
-        expected[:,   :]      = np.nan
-        expected[:,   :10]    = confs[0]
-        expected[:,    10:11] = confs[1]
-        expected[:,    11:16] = confs[2]
-        expected[:40,  16:26] = confs[3]
-        expected[:,    26:36] = confs[4][:50, :]
+        expected            = np.empty((50, 16), dtype=np.float64)
+        expected[:, :]      = np.nan
+        expected[:, :10]    = confs[0]
+        expected[:,  10:11] = confs[1]
+        expected[:,  11:16] = confs[2]
 
         conffiles = []
         for i, c in enumerate(confs):
@@ -74,3 +72,15 @@ def test_loadConfoundFiles():
         assert np.all(~np.isnan(result) == amask)
         assert np.all(result[amask]     == expected[amask])
         assert np.all(result[amask]     == expected[amask])
+
+        badconfs = [
+            np.random.randint(1, 100, (40, 10)),
+            np.random.randint(1, 100, (60, 10))]
+        conffiles = []
+        for i, c in enumerate(badconfs):
+            fname = 'conf{}.txt'.format(i)
+            conffiles.append(fname)
+            np.savetxt(fname, c)
+
+        with pytest.raises(ValueError):
+            extn.loadConfoundFiles(conffiles, npts)
-- 
GitLab