From fe82a4ace9358bcea44001c514dd9ed5c8393d81 Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Thu, 7 Dec 2023 16:39:27 +0000
Subject: [PATCH] TEST: Basic unit tests for evalImage/evalVectorImage

---
 pyfeeds/tests/__init__.py      |  7 ++++++
 pyfeeds/tests/test_evaluate.py | 41 ++++++++++++++++++++++++++++++++++
 2 files changed, 48 insertions(+)

diff --git a/pyfeeds/tests/__init__.py b/pyfeeds/tests/__init__.py
index db6dbb5..7feda33 100644
--- a/pyfeeds/tests/__init__.py
+++ b/pyfeeds/tests/__init__.py
@@ -113,6 +113,13 @@ def maketest(filename, returnCode=0, inputs=None, outputs=None, stdout=None):
 
 
 def makepyfeeds(**kwargs):
+
+    # simplest way of creating a dummy pyfeeds object
+    if len(kwargs) == 0:
+        kwargs['command']      = 'compare'
+        kwargs['inputDir']     = os.getcwd()
+        kwargs['benchmarkDir'] = os.getcwd()
+
     args = argparse.Namespace(**kwargs)
     cfg  = argparse.Namespace()
     return main.Pyfeeds(args, cfg)
diff --git a/pyfeeds/tests/test_evaluate.py b/pyfeeds/tests/test_evaluate.py
index 96ff097..c016e78 100644
--- a/pyfeeds/tests/test_evaluate.py
+++ b/pyfeeds/tests/test_evaluate.py
@@ -8,6 +8,9 @@
 import os
 import os.path as op
 
+import numpy as np
+import nibabel as nib
+
 from . import tempdir, makepaths, maketest, makepyfeeds, CaptureStdout
 
 from pyfeeds import testing, evaluate
@@ -57,3 +60,41 @@ def test_evaluateTestAgainstBenchmark():
             assert len(lines) == 3
             for l in lines[1:]:
                 assert 'PASS' in l
+
+
+def test_evalVectorImage():
+
+    vecarr1 = -1 + 2 * np.random.random((10, 10, 10, 3))
+    vecarr2 = -1 + 2 * np.random.random((10, 10, 10, 3))
+
+    with tempdir():
+
+        pyf    = makepyfeeds()
+        fname1 = 'image1.nii.gz'
+        fname2 = 'image2.nii.gz'
+
+        nib.Nifti1Image(vecarr1, np.eye(4)).to_filename(fname1)
+        nib.Nifti1Image(vecarr2, np.eye(4)).to_filename(fname2)
+
+        assert evaluate.evalVectorImage(pyf, fname1, fname1) == 0
+        assert evaluate.evalVectorImage(pyf, fname2, fname2) == 0
+        assert evaluate.evalVectorImage(pyf, fname1, fname2) != 0
+
+
+def test_evalImage():
+
+    arr1 = -1 + 2 * np.random.random((10, 10, 10, 10))
+    arr2 = -1 + 2 * np.random.random((10, 10, 10, 10))
+
+    with tempdir():
+
+        pyf    = makepyfeeds()
+        fname1 = 'image1.nii.gz'
+        fname2 = 'image2.nii.gz'
+
+        nib.Nifti1Image(arr1, np.eye(4)).to_filename(fname1)
+        nib.Nifti1Image(arr2, np.eye(4)).to_filename(fname2)
+
+        assert evaluate.evalImage(pyf, fname1, fname1) == 0
+        assert evaluate.evalImage(pyf, fname2, fname2) == 0
+        assert evaluate.evalImage(pyf, fname1, fname2) != 0
-- 
GitLab