Commit 72fbf62a authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added option to control if the test statistiscts assume a null or mean output

parent db3c0e36
......@@ -190,6 +190,11 @@ def evaluate_data(trained_model_path,
os.remove(output_database_path)
output_database_handle = h5py.File(output_database_path, 'w')
if mapping_evaluation_parameters['control'] == 'mean':
prediction_output_statistics_name = "output_statistics_mean_target.csv"
elif mapping_evaluation_parameters['control'] == 'null':
prediction_output_statistics_name = "output_statistics_null_target.csv"
output_statistics = {}
output_statistics_path = os.path.join(prediction_output_path, prediction_output_statistics_name)
......@@ -232,18 +237,26 @@ def evaluate_data(trained_model_path,
if crop_flag == True:
predicted_volume = roi(Image(predicted_volume, header=header), ((-5,86),(-6,103),(0,91))).data
target_volume = _generate_target_volume(file_path,
subject,
dmri_mean_mask_path,
rsfmri_mean_mask_path,
regression_factors,
mean_regression_flag,
mean_regression_all_flag,
mean_subtraction_flag,
crop_flag,
cross_domain_x2x_flag,
vbm_flag
)
if mapping_evaluation_parameters['control'] == 'mean':
target_volume = _load_mean(dmri_mean_mask_path, rsfmri_mean_mask_path,
mean_regression_flag, mean_regression_all_flag,
cross_domain_x2x_flag, vbm_flag
)
elif mapping_evaluation_parameters['control'] == 'null':
target_volume = np.zeros((91, 109, 91)).astype('float32')
else:
target_volume = _generate_target_volume(file_path,
subject,
dmri_mean_mask_path,
rsfmri_mean_mask_path,
regression_factors,
mean_regression_flag,
mean_regression_all_flag,
mean_subtraction_flag,
crop_flag,
cross_domain_x2x_flag,
vbm_flag
)
mse, mae, cel, pearson_r, pearson_p, spearman_r, spearman_p, regression_w, regression_b, covariance = _statistics_calculator(predicted_volume, target_volume)
output_statistics[subject] = [mse, mae, cel, pearson_r, pearson_p, spearman_r, spearman_p, regression_w, regression_b, covariance]
......@@ -1048,3 +1061,34 @@ def _pearson_correlation(volume, target):
return r
def _load_mean(dmri_mean_mask_path,
rsfmri_mean_mask_path,
mean_regression_flag,
mean_regression_all_flag,
cross_domain_x2x_flag,
vbm_flag):
if mean_regression_flag == True:
if cross_domain_x2x_flag == True:
if mean_regression_all_flag == True:
group_mean = Image("utils/mean_tractsNormSummed_downsampled_regressed.nii.gz").data
else:
if vbm_flag == True:
group_mean = Image("utils/mean_T1_GM_to_template_GM_mod_regressed.nii.gz").data
else:
group_mean = Image("utils/mean_dr_stage2_regressed.nii.gz").data
else:
if cross_domain_x2x_flag == True:
group_mean = Image(dmri_mean_mask_path).data
else:
if vbm_flag == True:
group_mean = Image("utils/mean_T1_GM_to_template_GM_mod.nii.gz").data
factor = 2.5/(2 * np.sqrt(2 * np.log(2)))
group_mean = gaussian_filter(group_mean, sigma=factor)
else:
group_mean = Image(rsfmri_mean_mask_path).data[:, :, :, 0]
group_mean = np.float32(group_mean)
return group_mean
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment