Skip to content
Snippets Groups Projects
Commit 68edbe72 authored by Paul McCarthy's avatar Paul McCarthy :mountain_bicyclist:
Browse files

basic unit tests for fileOrArray/Image

parent e9d71e92
No related branches found
No related tags found
No related merge requests found
......@@ -5,10 +5,15 @@
# Author: Paul McCarthy <pauldmccarthy@gmail.com>
#
import os
import shlex
import pytest
import numpy as np
import nibabel as nib
import fsl.utils.tempdir as tempdir
import fsl.wrappers.wrapperutils as wutils
......@@ -157,3 +162,111 @@ def test_namedPositionals():
# TODO
# - test _FileOrImage LOAD tuple order
def test_fileOrArray():
@wutils.fileOrArray('arr1', 'other', 'output')
def func(arr1, **kwargs):
arr1 = np.loadtxt(arr1)
other = np.loadtxt(kwargs['other'])
np.savetxt(kwargs['output'], (arr1 * other))
with tempdir.tempdir():
arr1 = np.array([[1, 2], [ 3, 4]])
other = np.array([[5, 6], [ 7, 8]])
expected = np.array([[5, 12], [21, 32]])
np.savetxt('arr1.txt', arr1)
np.savetxt('other.txt', other)
# file file file
func('arr1.txt', other='other.txt', output='output.txt')
assert np.all(np.loadtxt('output.txt') == expected)
os.remove('output.txt')
# file file array
result = func('arr1.txt', other='other.txt', output=wutils.LOAD)['output']
assert np.all(result == expected)
# file array file
func('arr1.txt', other=other, output='output.txt')
assert np.all(np.loadtxt('output.txt') == expected)
os.remove('output.txt')
# file array array
result = func('arr1.txt', other=other, output=wutils.LOAD)['output']
assert np.all(result == expected)
# array file file
func(arr1, other='other.txt', output='output.txt')
assert np.all(np.loadtxt('output.txt') == expected)
os.remove('output.txt')
# array file array
result = func(arr1, other='other.txt', output=wutils.LOAD)['output']
assert np.all(result == expected)
# array array file
func(arr1, other=other, output='output.txt')
assert np.all(np.loadtxt('output.txt') == expected)
os.remove('output.txt')
# array array array
result = func(arr1, other=other, output=wutils.LOAD)['output']
assert np.all(result == expected)
def test_fileOrImage():
@wutils.fileOrImage('img1', 'img2', 'output')
def func(img1, **kwargs):
img1 = nib.load(img1).get_data()
img2 = nib.load(kwargs['img2']).get_data()
output = nib.nifti1.Nifti1Image(img1 * img2, np.eye(4))
nib.save(output, kwargs['output'])
with tempdir.tempdir():
img1 = nib.nifti1.Nifti1Image(np.array([[1, 2], [ 3, 4]]), np.eye(4))
img2 = nib.nifti1.Nifti1Image(np.array([[5, 6], [ 7, 8]]), np.eye(4))
expected = np.array([[5, 12], [21, 32]])
nib.save(img1, 'img1.nii')
nib.save(img2, 'img2.nii')
# file file file
func('img1.nii', img2='img2.nii', output='output.nii')
assert np.all(nib.load('output.nii').get_data() == expected)
os.remove('output.nii')
# file file array
result = func('img1.nii', img2='img2.nii', output=wutils.LOAD)['output']
assert np.all(result.get_data() == expected)
# file array file
func('img1.nii', img2=img2, output='output.nii')
assert np.all(nib.load('output.nii').get_data() == expected)
os.remove('output.nii')
# file array array
result = func('img1.nii', img2=img2, output=wutils.LOAD)['output']
assert np.all(result.get_data() == expected)
# array file file
func(img1, img2='img2.nii', output='output.nii')
assert np.all(nib.load('output.nii').get_data() == expected)
os.remove('output.nii')
# array file array
result = func(img1, img2='img2.nii', output=wutils.LOAD)['output']
assert np.all(result.get_data() == expected)
# array array file
func(img1, img2=img2, output='output.nii')
assert np.all(nib.load('output.nii').get_data() == expected)
os.remove('output.nii')
# array array array
result = func(img1, img2=img2, output=wutils.LOAD)['output']
assert np.all(result.get_data() == expected)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment