From 0a057f31d992208b9f5fe6ac81a3ae6c9a849e06 Mon Sep 17 00:00:00 2001
From: Paul McCarthy <pauldmccarthy@gmail.com>
Date: Fri, 28 Jul 2023 11:23:07 +0100
Subject: [PATCH] TEST: Adjust fslstartup test

---
 unit_tests/utils/feedsRun.fslStartup | 39 ++++++++++++++++------------
 unit_tests/utils/test_fslStartup.cc  | 22 +++++++++++-----
 2 files changed, 38 insertions(+), 23 deletions(-)

diff --git a/unit_tests/utils/feedsRun.fslStartup b/unit_tests/utils/feedsRun.fslStartup
index d307692..b7f8e42 100755
--- a/unit_tests/utils/feedsRun.fslStartup
+++ b/unit_tests/utils/feedsRun.fslStartup
@@ -6,12 +6,25 @@ import tempfile
 
 import subprocess as sp
 
-def run(cmd, **kwargs):
+def run(cmd, ompthreads=None, blasthreads=None, fslskipglobal=None, keepenv=False):
+
+    env = os.environ.copy()
+
+    if not keepenv:
+        blacklist = ['OMP', 'GOTO', 'BLAS', 'FSL']
+
+        for varname in list(env.keys()):
+            if any(b in varname for b in blacklist):
+                env.pop(varname)
+
+        if ompthreads    is not None: env['OMP_NUM_THREADS']  = str(ompthreads)
+        if blasthreads   is not None: env['BLAS_NUM_THREADS'] = str(blasthreads)
+        if fslskipglobal is not None: env['FSL_SKIP_GLOBAL']  = str(fslskipglobal)
 
     result = sp.run(shlex.split(cmd), check=True, text=True,
-                  stdout=sp.PIPE, stderr=sp.STDOUT, **kwargs)
+                    stdout=sp.PIPE, stderr=sp.STDOUT, env=env)
 
-    print(f'Called {cmd}')
+    print(f'Called {cmd} {ompthreads} {blasthreads} {fslskipglobal}')
     print(f'  exit code: {result.returncode}')
     print(f'  stdout:    {result.stdout.strip()}')
 
@@ -20,23 +33,17 @@ def run(cmd, **kwargs):
 
 def main():
 
-    blacklist = ['OMP', 'GOTO', 'BLAS', 'FSL']
-    env       = os.environ.copy()
-    for varname in list(env.keys()):
-        if any(b in varname for b in blacklist):
-            env.pop(varname)
-
-    env['OMP_NUM_THREADS']  = '8'
-    env['BLAS_NUM_THREADS'] = '8'
-
-    run('make')
+    run('make', keepenv=True)
 
     # Default behaviour should be: OMP multi-threaded, BLAS single threaded.
-    assert run('./test_fslStartup', env=env) == '8 1 8'
+    assert run('./test_fslStartup', 8, 8) == '8 1 8'
+    assert run('./test_fslStartup', 4, 4) == '4 1 4'
+    assert run('./test_fslStartup', 1, 1) == '1 1 1'
 
     # With FSL_SKIP_GLOBAL, BLAS should be multi-threaded
-    env['FSL_SKIP_GLOBAL'] = '1'
-    assert run('./test_fslStartup', env=env) == '8 8 8'
+    assert run('./test_fslStartup', 8, 8, 1) == '8 8 8'
+    assert run('./test_fslStartup', 4, 4, 1) == '4 4 4'
+    assert run('./test_fslStartup', 1, 1, 1) == '1 1 1'
 
 
 if __name__ == '__main__':
diff --git a/unit_tests/utils/test_fslStartup.cc b/unit_tests/utils/test_fslStartup.cc
index 325509d..7c7f913 100644
--- a/unit_tests/utils/test_fslStartup.cc
+++ b/unit_tests/utils/test_fslStartup.cc
@@ -22,20 +22,28 @@ int main(int argc, char *argv[]) {
 
     int omp_threads;
     int blas_threads;
-    int sum = 0;
+    int sum[16];
 
-    // omp num threads should not be
-    // affected by the FSL startup logic.
-    // Sum should be equal to omp num threads.
+    for (int i = 0; i < 16; i++) {
+      sum[i] = 0;
+    }
+
+    // omp num threads should not be affected
+    // by the FSL startup logic.  Sum should
+    // be equal to omp num threads.
     #pragma omp parallel
     {
-        sum        += 1;
-        omp_threads = omp_get_num_threads();
+        sum[omp_get_thread_num()] = 1;
+        omp_threads               = omp_get_num_threads();
     }
 
     // blas num threads should be controlled
     // by FSL startup logic.
     blas_threads = openblas_get_num_threads();
 
-    cout << omp_threads << " " << blas_threads << " " << sum << endl;
+    for (int i = 1; i < 16; i++) {
+      sum[0] += sum[i];
+    }
+
+    cout << omp_threads << " " << blas_threads << " " << sum[0] << endl;
 }
-- 
GitLab