diff --git a/tests/test_wrapperutils.py b/tests/test_wrapperutils.py index 90cfc7ff97b440393b89f8d0077f2b463f7cacea..9a5d0d26407d29113bdb09308473d373c1bfdc9d 100644 --- a/tests/test_wrapperutils.py +++ b/tests/test_wrapperutils.py @@ -13,6 +13,7 @@ import textwrap try: from unittest import mock except ImportError: import mock +import six import pytest import numpy as np @@ -309,7 +310,55 @@ def test_fileOrImage(): assert np.all(result.get_data()[:] == expected) -def test_fileOrImage_outprefix(): +def test_fileOrThing_sequence(): + + @wutils.fileOrArray('arrs', 'out') + def func(arrs, out): + + if isinstance(arrs, six.string_types): + arrs = [arrs] + + print('Loading from files', arrs) + + arrs = [np.loadtxt(a) for a in arrs] + + res = np.sum(arrs, axis=0) + + print('result', res) + + np.savetxt(out, res) + + inputs = [np.random.randint(1, 10, (3, 3)) for i in range(4)] + infiles = ['input{}.txt'.format(i) for i in range(len(inputs))] + exp = np.sum(inputs, axis=0) + + with tempdir.tempdir(): + + for ifile, idata in zip(infiles, inputs): + np.savetxt(ifile, idata) + + func(inputs, 'result.txt') + assert np.all(np.loadtxt('result.txt') == exp) + + assert np.all(func(inputs, wutils.LOAD)['out'] == exp) + + func(inputs[0], 'result.txt') + assert np.all(np.loadtxt('result.txt') == inputs[0]) + + assert np.all(func(inputs[0], wutils.LOAD)['out'] == inputs[0]) + + func(infiles, 'result.txt') + assert np.all(np.loadtxt('result.txt') == exp) + + assert np.all(func(infiles, wutils.LOAD)['out'] == exp) + + func(infiles[0], 'result.txt') + assert np.all(np.loadtxt('result.txt') == inputs[0]) + + assert np.all(func(infiles[0], wutils.LOAD)['out'] == inputs[0]) + + +def test_fileOrThing_outprefix(): @wutils.fileOrImage('img', outprefix='output_base') def basefunc(img, output_base): @@ -352,7 +401,7 @@ def test_fileOrImage_outprefix(): cleardir(td, 'myout*') -def test_fileOrImage_outprefix_differentTypes(): +def test_fileOrThing_outprefix_differentTypes(): @wutils.fileOrImage('img', outprefix='outpref') def func(img, outpref): @@ -392,7 +441,7 @@ def test_fileOrImage_outprefix_differentTypes(): cleardir(td, 'myout*') -def test_fileOrImage_outprefix_directory(): +def test_fileOrThing_outprefix_directory(): import logging logging.basicConfig()