From d45b3ce968c6fcfe85880f42cb1b4e9bb0457c1d Mon Sep 17 00:00:00 2001 From: Paul McCarthy <pauldmccarthy@gmail.com> Date: Fri, 27 Sep 2024 17:09:21 +0100 Subject: [PATCH] TEST: Test loadLabelFile with classification probabilities --- fsl/tests/test_fixlabels.py | 101 +++++++++++++++++++++++++++++++++--- 1 file changed, 95 insertions(+), 6 deletions(-) diff --git a/fsl/tests/test_fixlabels.py b/fsl/tests/test_fixlabels.py index 308883e9..4155f2d3 100644 --- a/fsl/tests/test_fixlabels.py +++ b/fsl/tests/test_fixlabels.py @@ -5,8 +5,9 @@ # Author: Paul McCarthy <pauldmccarthy@gmail.com> # -import os.path as op -import textwrap +import math +import os.path as op +import textwrap as tw import pytest @@ -178,10 +179,29 @@ path/to/analysis.ica ['Signal']], [1, 2])) +# Classification probabilities +goodfiles.append((""" +path/to/analysis.ica +1, Unclassified noise, True, 0.2 +2, Unclassified noise, True, 0.1 +3, Signal, False, 0.8 +[1, 2] +""", +'path/to/analysis.ica', +[['Unclassified noise'], + ['Unclassified noise'], + ['Signal']], +[1, 2], +[0.2, 0.1, 0.8])) + def test_loadLabelFile_good(): - for filecontents, expMelDir, expLabels, expIdxs in goodfiles: + for test in goodfiles: + filecontents, expMelDir, expLabels, expIdxs = test[:4] + + if len(test) > 4: probs = test[4] + else: probs = None with tests.testdir() as testdir: @@ -206,6 +226,11 @@ def test_loadLabelFile_good(): for exp, res in zip(expLabels, resLabels): assert exp == res + if probs is not None: + resMelDir, resLabels, resProbs = fixlabels.loadLabelFile( + fname, returnProbabilities=True) + assert resProbs == probs + @@ -316,7 +341,8 @@ def test_loadLabelFile_bad(): def test_loadLabelFile_customLabels(): included = [2, 3, 4, 5] - contents = '[{}]\n'.format([i + 1 for i in included]) + contents = ','.join([str(i + 1) for i in included]) + contents = f'[{contents}]\n' defIncLabel = 'Unclassified noise' defExcLabel = 'Signal' @@ -350,6 +376,69 @@ def test_loadLabelFile_customLabels(): assert ilbls[0] == excLabel +def test_loadLabelFile_probabilities(): + + def lists_equal(a, b): + if len(a) != len(b): + return False + for av, bv in zip(a, b): + if av == bv: + continue + if math.isnan(av) and math.isnan(bv): + continue + if math.isnan(av) and (not math.isnan(bv)): + return False + if (not math.isnan(av)) and math.isnan(bv): + return False + + return True + + nan = math.nan + + testcases = [ + (""" + analysis.ica + 1, Signal, False + 2, Unclassified noise, True + 3, Signal, False + [2] + """, [nan, nan, nan]), + (""" + analysis.ica + 1, Signal, False, 0.1 + 2, Unclassified noise, True, 0.2 + 3, Signal, False, 0.3 + [2] + """, [0.1, 0.2, 0.3]), + (""" + analysis.ica + 1, Signal, False, 0.1 + 2, Unclassified noise, True + 3, Signal, False, 0.3 + [2] + """, [0.1, nan, 0.3]), + (""" + [1, 2, 3] + """, [nan, nan, nan]), + ] + + for contents, expprobs in testcases: + with tests.testdir() as testdir: + fname = op.join(testdir, 'labels.txt') + + with open(fname, 'wt') as f: + f.write(tw.dedent(contents).strip()) + + _, _, gotprobs = fixlabels.loadLabelFile( + fname, returnProbabilities=True) + + assert lists_equal(gotprobs, expprobs) + + + + + + def test_saveLabelFile(): @@ -359,7 +448,7 @@ def test_saveLabelFile(): ['Label1'], ['Unknown']] - expected = textwrap.dedent(""" + expected = tw.dedent(""" 1, Label1, Label2, Label3, True 2, Signal, False 3, Noise, True @@ -391,7 +480,7 @@ def test_saveLabelFile(): # Custom signal labels sigLabels = ['Label1'] - exp = textwrap.dedent(""" + exp = tw.dedent(""" . 1, Label1, Label2, Label3, False 2, Signal, True -- GitLab