From 68edbe723b30a8d9d6b007f7097bb5318fdbebeb Mon Sep 17 00:00:00 2001 From: Paul McCarthy <pauldmccarthy@gmail.com> Date: Sat, 3 Mar 2018 15:41:33 +0000 Subject: [PATCH] basic unit tests for fileOrArray/Image --- tests/test_wrapperutils.py | 113 +++++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) diff --git a/tests/test_wrapperutils.py b/tests/test_wrapperutils.py index d2b323212..77029e73a 100644 --- a/tests/test_wrapperutils.py +++ b/tests/test_wrapperutils.py @@ -5,10 +5,15 @@ # Author: Paul McCarthy <pauldmccarthy@gmail.com> # +import os import shlex import pytest +import numpy as np +import nibabel as nib + +import fsl.utils.tempdir as tempdir import fsl.wrappers.wrapperutils as wutils @@ -157,3 +162,111 @@ def test_namedPositionals(): # TODO # - test _FileOrImage LOAD tuple order + + + +def test_fileOrArray(): + + @wutils.fileOrArray('arr1', 'other', 'output') + def func(arr1, **kwargs): + arr1 = np.loadtxt(arr1) + other = np.loadtxt(kwargs['other']) + np.savetxt(kwargs['output'], (arr1 * other)) + + with tempdir.tempdir(): + + arr1 = np.array([[1, 2], [ 3, 4]]) + other = np.array([[5, 6], [ 7, 8]]) + expected = np.array([[5, 12], [21, 32]]) + np.savetxt('arr1.txt', arr1) + np.savetxt('other.txt', other) + + # file file file + func('arr1.txt', other='other.txt', output='output.txt') + assert np.all(np.loadtxt('output.txt') == expected) + os.remove('output.txt') + + # file file array + result = func('arr1.txt', other='other.txt', output=wutils.LOAD)['output'] + assert np.all(result == expected) + + # file array file + func('arr1.txt', other=other, output='output.txt') + assert np.all(np.loadtxt('output.txt') == expected) + os.remove('output.txt') + + # file array array + result = func('arr1.txt', other=other, output=wutils.LOAD)['output'] + assert np.all(result == expected) + + # array file file + func(arr1, other='other.txt', output='output.txt') + assert np.all(np.loadtxt('output.txt') == expected) + os.remove('output.txt') + + # array file array + result = func(arr1, other='other.txt', output=wutils.LOAD)['output'] + assert np.all(result == expected) + + # array array file + func(arr1, other=other, output='output.txt') + assert np.all(np.loadtxt('output.txt') == expected) + os.remove('output.txt') + + # array array array + result = func(arr1, other=other, output=wutils.LOAD)['output'] + assert np.all(result == expected) + + +def test_fileOrImage(): + + @wutils.fileOrImage('img1', 'img2', 'output') + def func(img1, **kwargs): + img1 = nib.load(img1).get_data() + img2 = nib.load(kwargs['img2']).get_data() + output = nib.nifti1.Nifti1Image(img1 * img2, np.eye(4)) + nib.save(output, kwargs['output']) + + with tempdir.tempdir(): + + img1 = nib.nifti1.Nifti1Image(np.array([[1, 2], [ 3, 4]]), np.eye(4)) + img2 = nib.nifti1.Nifti1Image(np.array([[5, 6], [ 7, 8]]), np.eye(4)) + expected = np.array([[5, 12], [21, 32]]) + nib.save(img1, 'img1.nii') + nib.save(img2, 'img2.nii') + + # file file file + func('img1.nii', img2='img2.nii', output='output.nii') + assert np.all(nib.load('output.nii').get_data() == expected) + os.remove('output.nii') + + # file file array + result = func('img1.nii', img2='img2.nii', output=wutils.LOAD)['output'] + assert np.all(result.get_data() == expected) + + # file array file + func('img1.nii', img2=img2, output='output.nii') + assert np.all(nib.load('output.nii').get_data() == expected) + os.remove('output.nii') + + # file array array + result = func('img1.nii', img2=img2, output=wutils.LOAD)['output'] + assert np.all(result.get_data() == expected) + + # array file file + func(img1, img2='img2.nii', output='output.nii') + assert np.all(nib.load('output.nii').get_data() == expected) + os.remove('output.nii') + + # array file array + result = func(img1, img2='img2.nii', output=wutils.LOAD)['output'] + assert np.all(result.get_data() == expected) + + # array array file + func(img1, img2=img2, output='output.nii') + assert np.all(nib.load('output.nii').get_data() == expected) + os.remove('output.nii') + + # array array array + result = func(img1, img2=img2, output=wutils.LOAD)['output'] + assert np.all(result.get_data() == expected) -- GitLab