From 4574224592ee031693ad66c5fc715f0aa75ae830 Mon Sep 17 00:00:00 2001 From: Paul McCarthy <pauldmccarthy@gmail.com> Date: Mon, 9 Jul 2018 09:18:03 +0100 Subject: [PATCH] TEST: Test passing sequence as input --- tests/test_wrapperutils.py | 55 +++++++++++++++++++++++++++++++++++--- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/tests/test_wrapperutils.py b/tests/test_wrapperutils.py index 90cfc7ff9..9a5d0d264 100644 --- a/tests/test_wrapperutils.py +++ b/tests/test_wrapperutils.py @@ -13,6 +13,7 @@ import textwrap try: from unittest import mock except ImportError: import mock +import six import pytest import numpy as np @@ -309,7 +310,55 @@ def test_fileOrImage(): assert np.all(result.get_data()[:] == expected) -def test_fileOrImage_outprefix(): +def test_fileOrThing_sequence(): + + @wutils.fileOrArray('arrs', 'out') + def func(arrs, out): + + if isinstance(arrs, six.string_types): + arrs = [arrs] + + print('Loading from files', arrs) + + arrs = [np.loadtxt(a) for a in arrs] + + res = np.sum(arrs, axis=0) + + print('result', res) + + np.savetxt(out, res) + + inputs = [np.random.randint(1, 10, (3, 3)) for i in range(4)] + infiles = ['input{}.txt'.format(i) for i in range(len(inputs))] + exp = np.sum(inputs, axis=0) + + with tempdir.tempdir(): + + for ifile, idata in zip(infiles, inputs): + np.savetxt(ifile, idata) + + func(inputs, 'result.txt') + assert np.all(np.loadtxt('result.txt') == exp) + + assert np.all(func(inputs, wutils.LOAD)['out'] == exp) + + func(inputs[0], 'result.txt') + assert np.all(np.loadtxt('result.txt') == inputs[0]) + + assert np.all(func(inputs[0], wutils.LOAD)['out'] == inputs[0]) + + func(infiles, 'result.txt') + assert np.all(np.loadtxt('result.txt') == exp) + + assert np.all(func(infiles, wutils.LOAD)['out'] == exp) + + func(infiles[0], 'result.txt') + assert np.all(np.loadtxt('result.txt') == inputs[0]) + + assert np.all(func(infiles[0], wutils.LOAD)['out'] == inputs[0]) + + +def test_fileOrThing_outprefix(): @wutils.fileOrImage('img', outprefix='output_base') def basefunc(img, output_base): @@ -352,7 +401,7 @@ def test_fileOrImage_outprefix(): cleardir(td, 'myout*') -def test_fileOrImage_outprefix_differentTypes(): +def test_fileOrThing_outprefix_differentTypes(): @wutils.fileOrImage('img', outprefix='outpref') def func(img, outpref): @@ -392,7 +441,7 @@ def test_fileOrImage_outprefix_differentTypes(): cleardir(td, 'myout*') -def test_fileOrImage_outprefix_directory(): +def test_fileOrThing_outprefix_directory(): import logging logging.basicConfig() -- GitLab