#!/usr/bin/env python

import os
import shlex
import tempfile

import subprocess as sp

def run(cmd, ompthreads=None, blasthreads=None, fslskipglobal=None, keepenv=False):

    env = os.environ.copy()

    if not keepenv:
        blacklist = ['OMP', 'GOTO', 'BLAS', 'FSL']

        for varname in list(env.keys()):
            if any(b in varname for b in blacklist):
                env.pop(varname)

        if ompthreads    is not None: env['OMP_NUM_THREADS']  = str(ompthreads)
        if blasthreads   is not None: env['BLAS_NUM_THREADS'] = str(blasthreads)
        if fslskipglobal is not None: env['FSL_SKIP_GLOBAL']  = str(fslskipglobal)

    result = sp.run(shlex.split(cmd), check=True, text=True,
                    stdout=sp.PIPE, stderr=sp.STDOUT, env=env)

    print(f'Called {cmd} {ompthreads} {blasthreads} {fslskipglobal}')
    print(f'  exit code: {result.returncode}')
    print(f'  stdout:    {result.stdout.strip()}')

    return result.stdout.strip()


def main():

    run('make', keepenv=True)

    # Default behaviour should be: OMP multi-threaded, BLAS single threaded.
    assert run('./test_fslStartup', 8, 8) == '8 1 8'
    assert run('./test_fslStartup', 4, 4) == '4 1 4'
    assert run('./test_fslStartup', 1, 1) == '1 1 1'

    # With FSL_SKIP_GLOBAL, BLAS should be multi-threaded
    assert run('./test_fslStartup', 8, 8, 1) == '8 8 8'
    assert run('./test_fslStartup', 4, 4, 1) == '4 4 4'
    assert run('./test_fslStartup', 1, 1, 1) == '1 1 1'


if __name__ == '__main__':
    main()