diff --git a/tests/test_fslsub.py b/tests/test_fslsub.py index 93e96a2366cb5e98623129eed602cd7838166d7c..40d7668285c813b332c8da91574cc34ffa671681 100644 --- a/tests/test_fslsub.py +++ b/tests/test_fslsub.py @@ -147,28 +147,31 @@ def test_add_to_parser(): ('-F', ), ('-s', 'pename,thread') ] - for flag in test_flags: - for include in (None, [flag[0]]): - parser = argparse.ArgumentParser("test parser") - fslsub.SubmitParams.add_to_parser(parser, include=include) - args = parser.parse_args(flag) - submitter = fslsub.SubmitParams.from_args(args) - assert submitter.as_flags() == flag - - parser = argparse.ArgumentParser("test parser") - parser.add_argument('some_input') - fslsub.SubmitParams.add_to_parser(parser, include=None) - all_flags = tuple(part for flag in test_flags for part in flag) - args = parser.parse_args(('input', ) + all_flags) - assert args.some_input == 'input' - submitter = fslsub.SubmitParams.from_args(args) - assert len(all_flags) == len(submitter.as_flags()) - for flag in test_flags: - res_flags = submitter.as_flags() - assert flag[0] in res_flags - start_index = res_flags.index(flag[0]) - for idx, part in enumerate(flag): - assert res_flags[idx + start_index] == part + with fslsub_mockFSLDIR(): + for flag in test_flags: + for include in (None, [flag[0]]): + parser = argparse.ArgumentParser("test parser") + fslsub.SubmitParams.add_to_parser(parser, include=include) + args = parser.parse_args(flag) + submitter = fslsub.SubmitParams.from_args(args) + assert submitter.as_flags() == flag + + with fslsub_mockFSLDIR(): + parser = argparse.ArgumentParser("test parser") + parser.add_argument('some_input') + fslsub.SubmitParams.add_to_parser(parser, include=None) + all_flags = tuple(part for flag in test_flags for part in flag) + args = parser.parse_args(('input', ) + all_flags) + assert args.some_input == 'input' + submitter = fslsub.SubmitParams.from_args(args) + assert len(all_flags) == len(submitter.as_flags()) + + for flag in test_flags: + res_flags = submitter.as_flags() + assert flag[0] in res_flags + start_index = res_flags.index(flag[0]) + for idx, part in enumerate(flag): + assert res_flags[idx + start_index] == part def myfunc():