From 0a057f31d992208b9f5fe6ac81a3ae6c9a849e06 Mon Sep 17 00:00:00 2001 From: Paul McCarthy <pauldmccarthy@gmail.com> Date: Fri, 28 Jul 2023 11:23:07 +0100 Subject: [PATCH] TEST: Adjust fslstartup test --- unit_tests/utils/feedsRun.fslStartup | 39 ++++++++++++++++------------ unit_tests/utils/test_fslStartup.cc | 22 +++++++++++----- 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/unit_tests/utils/feedsRun.fslStartup b/unit_tests/utils/feedsRun.fslStartup index d307692..b7f8e42 100755 --- a/unit_tests/utils/feedsRun.fslStartup +++ b/unit_tests/utils/feedsRun.fslStartup @@ -6,12 +6,25 @@ import tempfile import subprocess as sp -def run(cmd, **kwargs): +def run(cmd, ompthreads=None, blasthreads=None, fslskipglobal=None, keepenv=False): + + env = os.environ.copy() + + if not keepenv: + blacklist = ['OMP', 'GOTO', 'BLAS', 'FSL'] + + for varname in list(env.keys()): + if any(b in varname for b in blacklist): + env.pop(varname) + + if ompthreads is not None: env['OMP_NUM_THREADS'] = str(ompthreads) + if blasthreads is not None: env['BLAS_NUM_THREADS'] = str(blasthreads) + if fslskipglobal is not None: env['FSL_SKIP_GLOBAL'] = str(fslskipglobal) result = sp.run(shlex.split(cmd), check=True, text=True, - stdout=sp.PIPE, stderr=sp.STDOUT, **kwargs) + stdout=sp.PIPE, stderr=sp.STDOUT, env=env) - print(f'Called {cmd}') + print(f'Called {cmd} {ompthreads} {blasthreads} {fslskipglobal}') print(f' exit code: {result.returncode}') print(f' stdout: {result.stdout.strip()}') @@ -20,23 +33,17 @@ def run(cmd, **kwargs): def main(): - blacklist = ['OMP', 'GOTO', 'BLAS', 'FSL'] - env = os.environ.copy() - for varname in list(env.keys()): - if any(b in varname for b in blacklist): - env.pop(varname) - - env['OMP_NUM_THREADS'] = '8' - env['BLAS_NUM_THREADS'] = '8' - - run('make') + run('make', keepenv=True) # Default behaviour should be: OMP multi-threaded, BLAS single threaded. - assert run('./test_fslStartup', env=env) == '8 1 8' + assert run('./test_fslStartup', 8, 8) == '8 1 8' + assert run('./test_fslStartup', 4, 4) == '4 1 4' + assert run('./test_fslStartup', 1, 1) == '1 1 1' # With FSL_SKIP_GLOBAL, BLAS should be multi-threaded - env['FSL_SKIP_GLOBAL'] = '1' - assert run('./test_fslStartup', env=env) == '8 8 8' + assert run('./test_fslStartup', 8, 8, 1) == '8 8 8' + assert run('./test_fslStartup', 4, 4, 1) == '4 4 4' + assert run('./test_fslStartup', 1, 1, 1) == '1 1 1' if __name__ == '__main__': diff --git a/unit_tests/utils/test_fslStartup.cc b/unit_tests/utils/test_fslStartup.cc index 325509d..7c7f913 100644 --- a/unit_tests/utils/test_fslStartup.cc +++ b/unit_tests/utils/test_fslStartup.cc @@ -22,20 +22,28 @@ int main(int argc, char *argv[]) { int omp_threads; int blas_threads; - int sum = 0; + int sum[16]; - // omp num threads should not be - // affected by the FSL startup logic. - // Sum should be equal to omp num threads. + for (int i = 0; i < 16; i++) { + sum[i] = 0; + } + + // omp num threads should not be affected + // by the FSL startup logic. Sum should + // be equal to omp num threads. #pragma omp parallel { - sum += 1; - omp_threads = omp_get_num_threads(); + sum[omp_get_thread_num()] = 1; + omp_threads = omp_get_num_threads(); } // blas num threads should be controlled // by FSL startup logic. blas_threads = openblas_get_num_threads(); - cout << omp_threads << " " << blas_threads << " " << sum << endl; + for (int i = 1; i < 16; i++) { + sum[0] += sum[i]; + } + + cout << omp_threads << " " << blas_threads << " " << sum[0] << endl; } -- GitLab