Skip to content
Snippets Groups Projects
Commit 0a057f31 authored by Paul McCarthy's avatar Paul McCarthy :mountain_bicyclist: Committed by Matthew Webster
Browse files

TEST: Adjust fslstartup test

parent f3f98799
No related branches found
No related tags found
1 merge request!57TEST: Test fslStartup logic
......@@ -6,12 +6,25 @@ import tempfile
import subprocess as sp
def run(cmd, **kwargs):
def run(cmd, ompthreads=None, blasthreads=None, fslskipglobal=None, keepenv=False):
env = os.environ.copy()
if not keepenv:
blacklist = ['OMP', 'GOTO', 'BLAS', 'FSL']
for varname in list(env.keys()):
if any(b in varname for b in blacklist):
env.pop(varname)
if ompthreads is not None: env['OMP_NUM_THREADS'] = str(ompthreads)
if blasthreads is not None: env['BLAS_NUM_THREADS'] = str(blasthreads)
if fslskipglobal is not None: env['FSL_SKIP_GLOBAL'] = str(fslskipglobal)
result = sp.run(shlex.split(cmd), check=True, text=True,
stdout=sp.PIPE, stderr=sp.STDOUT, **kwargs)
stdout=sp.PIPE, stderr=sp.STDOUT, env=env)
print(f'Called {cmd}')
print(f'Called {cmd} {ompthreads} {blasthreads} {fslskipglobal}')
print(f' exit code: {result.returncode}')
print(f' stdout: {result.stdout.strip()}')
......@@ -20,23 +33,17 @@ def run(cmd, **kwargs):
def main():
blacklist = ['OMP', 'GOTO', 'BLAS', 'FSL']
env = os.environ.copy()
for varname in list(env.keys()):
if any(b in varname for b in blacklist):
env.pop(varname)
env['OMP_NUM_THREADS'] = '8'
env['BLAS_NUM_THREADS'] = '8'
run('make')
run('make', keepenv=True)
# Default behaviour should be: OMP multi-threaded, BLAS single threaded.
assert run('./test_fslStartup', env=env) == '8 1 8'
assert run('./test_fslStartup', 8, 8) == '8 1 8'
assert run('./test_fslStartup', 4, 4) == '4 1 4'
assert run('./test_fslStartup', 1, 1) == '1 1 1'
# With FSL_SKIP_GLOBAL, BLAS should be multi-threaded
env['FSL_SKIP_GLOBAL'] = '1'
assert run('./test_fslStartup', env=env) == '8 8 8'
assert run('./test_fslStartup', 8, 8, 1) == '8 8 8'
assert run('./test_fslStartup', 4, 4, 1) == '4 4 4'
assert run('./test_fslStartup', 1, 1, 1) == '1 1 1'
if __name__ == '__main__':
......
......@@ -22,20 +22,28 @@ int main(int argc, char *argv[]) {
int omp_threads;
int blas_threads;
int sum = 0;
int sum[16];
// omp num threads should not be
// affected by the FSL startup logic.
// Sum should be equal to omp num threads.
for (int i = 0; i < 16; i++) {
sum[i] = 0;
}
// omp num threads should not be affected
// by the FSL startup logic. Sum should
// be equal to omp num threads.
#pragma omp parallel
{
sum += 1;
omp_threads = omp_get_num_threads();
sum[omp_get_thread_num()] = 1;
omp_threads = omp_get_num_threads();
}
// blas num threads should be controlled
// by FSL startup logic.
blas_threads = openblas_get_num_threads();
cout << omp_threads << " " << blas_threads << " " << sum << endl;
for (int i = 1; i < 16; i++) {
sum[0] += sum[i];
}
cout << omp_threads << " " << blas_threads << " " << sum[0] << endl;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment