From d878d612d9cf9caf5934d408829f0f256fee240a Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Tue, 21 Jul 2020 11:25:16 +0100
Subject: [PATCH] TEST: Test loadVestFile

---
 tests/test_vest.py | 56 ++++++++++++++++++++++++++++++++++++++++------
 1 file changed, 49 insertions(+), 7 deletions(-)

diff --git a/tests/test_vest.py b/tests/test_vest.py
index 422031f86..a49f5fc3a 100644
--- a/tests/test_vest.py
+++ b/tests/test_vest.py
@@ -6,17 +6,20 @@
 #
 
 
-import os.path as op
-import            io
-import            shutil
-import            tempfile
-import            warnings
+import os.path  as op
+import textwrap as tw
+import             io
+import             shutil
+import             tempfile
+import             warnings
 
-import numpy   as np
-import            pytest
+import numpy    as np
+import             pytest
 
 import fsl.data.vest as vest
 
+from tests import tempdir
+
 
 testfile1 = """%!VEST-LUT
 %%BeginInstance
@@ -248,3 +251,42 @@ def test_generateVest():
 
         assert expshape == gotshape
         assert np.all(np.isclose(data, gotdata))
+
+
+def test_loadVestFile():
+    def genvest(data, path, shapeOverride=None):
+        if shapeOverride is None:
+            nrows, ncols = data.shape
+        else:
+            nrows, ncols = shapeOverride
+
+        with open(path, 'wt') as f:
+            f.write(f'/NumWaves {ncols}\n')
+            f.write(f'/NumPoints {nrows}\n')
+            f.write( '/Matrix\n')
+
+            if np.issubdtype(data.dtype, np.integer): fmt = '%d'
+            else:                                     fmt = '%0.12f'
+
+            np.savetxt(f, data, fmt=fmt)
+
+    with tempdir():
+        data = np.random.randint(1, 100, (10, 20))
+        genvest(data, 'data.vest')
+        assert np.all(data == vest.loadVestFile('data.vest'))
+
+        data = np.random.random((20, 10))
+        genvest(data, 'data.vest')
+        assert np.all(np.isclose(data, vest.loadVestFile('data.vest')))
+
+        # should pass
+        vest.loadVestFile('data.vest', ignoreHeader=False)
+
+        # invalid VEST header
+        genvest(data, 'data.vest', (10, 20))
+
+        # default behaviour - ignore header
+        assert np.all(np.isclose(data, vest.loadVestFile('data.vest')))
+
+        with pytest.raises(ValueError):
+            vest.loadVestFile('data.vest', ignoreHeader=False)
-- 
GitLab