diff --git a/unit_tests/fdt/dtifit/feedsRun b/unit_tests/fdt/dtifit/feedsRun index 41565733b68f280a7131b8f71651efcc422ff812..d2e3a03c73db70fba313ad5b035e0dba9375996f 100755 --- a/unit_tests/fdt/dtifit/feedsRun +++ b/unit_tests/fdt/dtifit/feedsRun @@ -38,7 +38,7 @@ def gen_data(): # eigen-values can not be the same or eigen-vectors will be ill-defined eigen_values = np.array([ [1.2, 1., 0.4], - [1., 0.5, 0.], + [1., 0.9, 0.], [0.8, .6, .3], ])[None, :, None, :, None] * 1e-3 @@ -70,7 +70,7 @@ def gen_data(): if kurt == 2: beig = bvals[:, None] * eigen_values[..., None, :, 0] beig[..., 0] -= 0.1 * (eigen_values[..., None, 0, 0] * bvals) ** 2 / 6 - beig[..., 1:] -= 0.1 * (np.mean(eigen_values[..., None, 1:, 0], -1) * bvals)[..., None] ** 2 / 6 + beig[..., 1:] -= 0.05 * (np.mean(eigen_values[..., None, 1:, 0], -1) * bvals)[..., None] ** 2 / 6 data = S0[..., None] * np.exp(np.sum( -beig * np.sum(bvecs[:, None, :] * eigen_vectors[..., None, :, :], -1) ** 2, axis=-1 @@ -155,34 +155,36 @@ for directory, multi_shell, kurt in gen_data(): print(base_output) if ( (kurt == 0 and (multi_shell or not (fkurt or fkurtdir))) or - kurt == 1 and (fkurt or fkurtdir) or - kurt == 2 and fkurtdir + (kurt == 1 and fkurt and not fkurtdir) or + (kurt == 2 and fkurtdir) ): - print('This fit should be valid') def compare(name): ref = nib.load(f'{directory}/ref_{name}.nii.gz').get_fdata() fit = nib.load(f'{base_output}_{name}.nii.gz').get_fdata() assert ref.shape == fit.shape, f'incorrect NIFTI image shape for {name}' - testing.assert_allclose(ref, fit, atol=1e-6, + print(name, ref[0, 0, 0], fit[0, 0, 0]) + testing.assert_allclose(ref, fit, atol=1e-3 if fkurtdir else 1e-6, rtol=0.1 if fkurtdir else 1e-3, err_msg=f'mismatch in {name}') for idx in (1, 2, 3): - compare(f'L{idx}') - ref = nib.load(f'{directory}/ref_V{idx}.nii.gz').get_fdata() fit = nib.load(f'{base_output}_V{idx}.nii.gz').get_fdata() assert ref.shape == fit.shape inner = (ref * fit).sum(-1) - testing.assert_allclose(abs(inner), 1.) + testing.assert_allclose(abs(inner), 1., rtol=1e-4) + + compare(f'L{idx}') compare('S0') compare('FA') - compare('MO') + if not fkurtdir: + compare('MO') compare('MD') - compare('tensor') + if not fkurtdir: + compare('tensor') sse = nib.load(f'{base_output}_sse.nii.gz').get_fdata() - testing.assert_allclose(sse, 0., atol=1e-8) + testing.assert_allclose(sse, 0., atol=1e-3 if fkurtdir else 1e-8) if fkurt: kurt_fit = nib.load(f'{base_output}_kurt.nii.gz').get_fdata() @@ -201,17 +203,18 @@ for directory, multi_shell, kurt in gen_data(): assert not np.allclose(kurt, 0.1, rtol=0.01) assert not np.allclose(kurt, 0.1, rtol=0.01) elif kurt == 2: - testing.assert_allclose(kurt_para, 0.1, rtol=1e-5) - testing.assert_allclose(kurt_perp, 0.05, rtol=1e-5) - # for kurt == 1; kurt_para + testing.assert_allclose(kurt_para, 0.1, rtol=0.1) + testing.assert_allclose(kurt_perp, 0.05, rtol=0.1) else: assert not os.path.isfile(f'{base_output}_kurt_para.nii.gz') assert not os.path.isfile(f'{base_output}_kurt_perp.nii.gz') else: + if fkurtdir and (not multi_shell or kurt == 1): + continue # unpredictable whether this works or not print('This fit should be invalid') ref = nib.load(f'{directory}/ref_L1.nii.gz').get_fdata() fit = nib.load(f'{base_output}_L1.nii.gz').get_fdata() - assert not np.allclose(ref, fit, rtol=0.01) + assert not np.allclose(ref, fit, rtol=1e-3, atol=1e-6)