From d45b3ce968c6fcfe85880f42cb1b4e9bb0457c1d Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Fri, 27 Sep 2024 17:09:21 +0100
Subject: [PATCH] TEST: Test loadLabelFile with classification probabilities

---
 fsl/tests/test_fixlabels.py | 101 +++++++++++++++++++++++++++++++++---
 1 file changed, 95 insertions(+), 6 deletions(-)

diff --git a/fsl/tests/test_fixlabels.py b/fsl/tests/test_fixlabels.py
index 308883e9..4155f2d3 100644
--- a/fsl/tests/test_fixlabels.py
+++ b/fsl/tests/test_fixlabels.py
@@ -5,8 +5,9 @@
 # Author: Paul McCarthy <pauldmccarthy@gmail.com>
 #
 
-import os.path as op
-import            textwrap
+import             math
+import os.path  as op
+import textwrap as tw
 
 import pytest
 
@@ -178,10 +179,29 @@ path/to/analysis.ica
  ['Signal']],
 [1, 2]))
 
+# Classification probabilities
+goodfiles.append(("""
+path/to/analysis.ica
+1, Unclassified noise, True, 0.2
+2, Unclassified noise, True, 0.1
+3, Signal, False, 0.8
+[1, 2]
+""",
+'path/to/analysis.ica',
+[['Unclassified noise'],
+ ['Unclassified noise'],
+ ['Signal']],
+[1, 2],
+[0.2, 0.1, 0.8]))
+
 
 def test_loadLabelFile_good():
 
-    for filecontents, expMelDir, expLabels, expIdxs in goodfiles:
+    for test in goodfiles:
+        filecontents, expMelDir, expLabels, expIdxs  = test[:4]
+
+        if len(test) > 4: probs = test[4]
+        else:             probs = None
 
         with tests.testdir() as testdir:
 
@@ -206,6 +226,11 @@ def test_loadLabelFile_good():
             for exp, res in zip(expLabels, resLabels):
                 assert exp == res
 
+            if probs is not None:
+                resMelDir, resLabels, resProbs = fixlabels.loadLabelFile(
+                    fname, returnProbabilities=True)
+                assert resProbs == probs
+
 
 
 
@@ -316,7 +341,8 @@ def test_loadLabelFile_bad():
 def test_loadLabelFile_customLabels():
 
     included = [2, 3, 4, 5]
-    contents = '[{}]\n'.format([i + 1 for i in included])
+    contents = ','.join([str(i + 1) for i in included])
+    contents = f'[{contents}]\n'
 
     defIncLabel = 'Unclassified noise'
     defExcLabel = 'Signal'
@@ -350,6 +376,69 @@ def test_loadLabelFile_customLabels():
                 assert ilbls[0] == excLabel
 
 
+def test_loadLabelFile_probabilities():
+
+    def lists_equal(a, b):
+        if len(a) != len(b):
+            return False
+        for av, bv in zip(a, b):
+            if av == bv:
+                continue
+            if math.isnan(av) and math.isnan(bv):
+                continue
+            if math.isnan(av) and (not math.isnan(bv)):
+                return False
+            if (not math.isnan(av)) and math.isnan(bv):
+                return False
+
+        return True
+
+    nan = math.nan
+
+    testcases = [
+        ("""
+         analysis.ica
+         1, Signal, False
+         2, Unclassified noise, True
+         3, Signal, False
+         [2]
+         """, [nan, nan, nan]),
+        ("""
+         analysis.ica
+         1, Signal, False, 0.1
+         2, Unclassified noise, True, 0.2
+         3, Signal, False, 0.3
+         [2]
+         """, [0.1, 0.2, 0.3]),
+        ("""
+         analysis.ica
+         1, Signal, False, 0.1
+         2, Unclassified noise, True
+         3, Signal, False, 0.3
+         [2]
+         """, [0.1, nan, 0.3]),
+        ("""
+         [1, 2, 3]
+         """, [nan, nan, nan]),
+    ]
+
+    for contents, expprobs in testcases:
+        with tests.testdir() as testdir:
+            fname = op.join(testdir, 'labels.txt')
+
+            with open(fname, 'wt') as f:
+                f.write(tw.dedent(contents).strip())
+
+            _, _, gotprobs = fixlabels.loadLabelFile(
+                fname, returnProbabilities=True)
+
+            assert lists_equal(gotprobs, expprobs)
+
+
+
+
+
+
 def test_saveLabelFile():
 
 
@@ -359,7 +448,7 @@ def test_saveLabelFile():
               ['Label1'],
               ['Unknown']]
 
-    expected = textwrap.dedent("""
+    expected = tw.dedent("""
     1, Label1, Label2, Label3, True
     2, Signal, False
     3, Noise, True
@@ -391,7 +480,7 @@ def test_saveLabelFile():
 
         # Custom signal labels
         sigLabels = ['Label1']
-        exp = textwrap.dedent("""
+        exp = tw.dedent("""
         .
         1, Label1, Label2, Label3, False
         2, Signal, True
-- 
GitLab