Skip to content
Snippets Groups Projects
Commit e1ecfe82 authored by Matthew Webster's avatar Matthew Webster
Browse files

ENH: allow dtifit to crash for invalid tests

parent 84f094c1
No related branches found
No related tags found
No related merge requests found
......@@ -6,7 +6,7 @@ No noise is added, so we expect (near-)perfect fits
"""
import sys
import os
from subprocess import run
from subprocess import run, CalledProcessError
import numpy as np
from numpy import testing
import nibabel as nib
......@@ -141,16 +141,13 @@ def fit_data(directory):
:yield: tuple with
- base name of the output files
- base name of the output files (CalledProcessError if dtifit crashed)
- boolean indicating whether the --kurt flag was set
- boolean indicating whether the --kudtdir flag was set
"""
for kurtdir in (False, True):
for main_kurt in (False, True):
for wls in (False, True): # in this noise-free data the --wls flag should not matter
if wls and kurtdir:
# weighted least square not implemented for kurtdir
continue
cmd = [
'dtifit',
'-k', f'{directory}/ref_data.nii.gz',
......@@ -160,6 +157,7 @@ def fit_data(directory):
'--sse',
'--save_tensor',
]
base_output = f'{directory}/dti'
if wls:
cmd += ['--wls']
......@@ -171,19 +169,29 @@ def fit_data(directory):
cmd += ['--kurtdir']
base_output += '_kurtdir'
cmd.extend(['-o', base_output])
run(cmd, check=True)
yield base_output, main_kurt, kurtdir
try:
run(cmd, check=True)
except CalledProcessError as e:
yield e, main_kurt, kurtdir, wls
else:
yield base_output, main_kurt, kurtdir, wls
# Loops through multiple ways to generate and fit the data and check the output
for directory, multi_shell, kurt in gen_data():
for base_output, fkurt, fkurtdir in fit_data(directory):
for base_output, fkurt, fkurtdir, fwls in fit_data(directory):
print('testing', base_output)
if (
(kurt == 0 and (multi_shell or not (fkurt or fkurtdir))) or
(kurt == 1 and fkurt and not fkurtdir) or
(kurt == 2 and fkurtdir)
(
(kurt == 0 and (multi_shell or not (fkurt or fkurtdir))) or
(kurt == 1 and fkurt and not fkurtdir) or
(kurt == 2 and fkurtdir)
) and (not (fwls and fkurtdir))
):
if isinstance(base_output, CalledProcessError): # dtifit crashed
print("dtifit crashed on what should be a valid run. Error message printed above.")
raise base_output
def compare(name):
ref = nib.load(f'{directory}/ref_{name}.nii.gz').get_fdata()
fit = nib.load(f'{base_output}_{name}.nii.gz').get_fdata()
......@@ -243,6 +251,9 @@ for directory, multi_shell, kurt in gen_data():
assert not os.path.isfile(f'{base_output}_kurt3.nii.gz')
assert not os.path.isfile(f'{base_output}_MK.nii.gz')
else:
if isinstance(base_output, CalledProcessError): # dtifit crashed
print("dtifit crashed on an invalid run, which is fine")
continue
if fkurtdir and (not multi_shell or kurt == 1):
continue # unpredictable whether this works or not
print('This fit should be invalid')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment