From ac4abe63906dbd1486dfbbf12e76bad67af14380 Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Mon, 20 Jul 2020 17:29:20 +0100
Subject: [PATCH] TEST: Unit test for generateVest function

---
 tests/test_vest.py | 34 ++++++++++++++++++++++++++++++++++
 1 file changed, 34 insertions(+)

diff --git a/tests/test_vest.py b/tests/test_vest.py
index 407c502ef..422031f86 100644
--- a/tests/test_vest.py
+++ b/tests/test_vest.py
@@ -7,6 +7,7 @@
 
 
 import os.path as op
+import            io
 import            shutil
 import            tempfile
 import            warnings
@@ -214,3 +215,36 @@ def test_loadVestLutFile():
 
     finally:
         shutil.rmtree(testdir)
+
+
+def test_generateVest():
+    def readvest(vstr):
+        lines = vstr.split('\n')
+        nrows = [l for l in lines if 'NumPoints' in l][0]
+        ncols = [l for l in lines if 'NumWaves'  in l][0]
+        nrows = int(nrows.split()[1])
+        ncols = int(ncols.split()[1])
+        data  = '\n'.join(lines[3:])
+        data  = np.loadtxt(io.StringIO(data)).reshape((nrows, ncols))
+
+        return ((nrows, ncols), data)
+
+    # shape, expectedshape
+    tests = [
+        ((10,   ), ( 1, 10)),
+        ((10,  1), (10,  1)),
+        (( 1, 10), ( 1, 10)),
+        (( 3,  5), ( 3,  5)),
+        (( 5,  3), ( 5,  3))
+    ]
+
+    for shape, expshape in tests:
+        data = np.random.random(shape)
+        vstr = vest.generateVest(data)
+
+        gotshape, gotdata = readvest(vstr)
+
+        data = data.reshape(expshape)
+
+        assert expshape == gotshape
+        assert np.all(np.isclose(data, gotdata))
-- 
GitLab