From 4574224592ee031693ad66c5fc715f0aa75ae830 Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Mon, 9 Jul 2018 09:18:03 +0100
Subject: [PATCH] TEST: Test passing sequence as input

---
 tests/test_wrapperutils.py | 55 +++++++++++++++++++++++++++++++++++---
 1 file changed, 52 insertions(+), 3 deletions(-)

diff --git a/tests/test_wrapperutils.py b/tests/test_wrapperutils.py
index 90cfc7ff9..9a5d0d264 100644
--- a/tests/test_wrapperutils.py
+++ b/tests/test_wrapperutils.py
@@ -13,6 +13,7 @@ import            textwrap
 try: from unittest import mock
 except ImportError: import mock
 
+import six
 import pytest
 
 import numpy as np
@@ -309,7 +310,55 @@ def test_fileOrImage():
         assert np.all(result.get_data()[:] == expected)
 
 
-def test_fileOrImage_outprefix():
+def test_fileOrThing_sequence():
+
+    @wutils.fileOrArray('arrs', 'out')
+    def func(arrs, out):
+
+        if isinstance(arrs, six.string_types):
+            arrs = [arrs]
+
+        print('Loading from files', arrs)
+
+        arrs = [np.loadtxt(a) for a in arrs]
+
+        res  = np.sum(arrs, axis=0)
+
+        print('result', res)
+
+        np.savetxt(out, res)
+
+    inputs  = [np.random.randint(1, 10, (3, 3)) for i in range(4)]
+    infiles = ['input{}.txt'.format(i) for i in range(len(inputs))]
+    exp     = np.sum(inputs, axis=0)
+
+    with tempdir.tempdir():
+
+        for ifile, idata in zip(infiles, inputs):
+            np.savetxt(ifile, idata)
+
+        func(inputs, 'result.txt')
+        assert np.all(np.loadtxt('result.txt') == exp)
+
+        assert np.all(func(inputs, wutils.LOAD)['out'] == exp)
+
+        func(inputs[0], 'result.txt')
+        assert np.all(np.loadtxt('result.txt') == inputs[0])
+
+        assert np.all(func(inputs[0], wutils.LOAD)['out'] == inputs[0])
+
+        func(infiles, 'result.txt')
+        assert np.all(np.loadtxt('result.txt') == exp)
+
+        assert np.all(func(infiles, wutils.LOAD)['out'] == exp)
+
+        func(infiles[0], 'result.txt')
+        assert np.all(np.loadtxt('result.txt') == inputs[0])
+
+        assert np.all(func(infiles[0], wutils.LOAD)['out'] == inputs[0])
+
+
+def test_fileOrThing_outprefix():
 
     @wutils.fileOrImage('img', outprefix='output_base')
     def basefunc(img, output_base):
@@ -352,7 +401,7 @@ def test_fileOrImage_outprefix():
         cleardir(td, 'myout*')
 
 
-def test_fileOrImage_outprefix_differentTypes():
+def test_fileOrThing_outprefix_differentTypes():
 
     @wutils.fileOrImage('img', outprefix='outpref')
     def func(img, outpref):
@@ -392,7 +441,7 @@ def test_fileOrImage_outprefix_differentTypes():
         cleardir(td, 'myout*')
 
 
-def test_fileOrImage_outprefix_directory():
+def test_fileOrThing_outprefix_directory():
 
     import logging
     logging.basicConfig()
-- 
GitLab