Commit 815f5223 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

added ability to calculate statistics for test set

parent 01d776d1
......@@ -411,6 +411,7 @@ def evaluate_data(mapping_evaluation_parameters):
prediction_output_path = mapping_evaluation_parameters['prediction_output_path']
prediction_output_database_name = mapping_evaluation_parameters['prediction_output_database_name']
prediction_output_statistics_name = mapping_evaluation_parameters['prediction_output_statistics_name']
dmri_mean_mask_path = mapping_evaluation_parameters['dmri_mean_mask_path']
rsfmri_mean_mask_path = mapping_evaluation_parameters['rsfmri_mean_mask_path']
device = mapping_evaluation_parameters['device']
......@@ -428,6 +429,7 @@ def evaluate_data(mapping_evaluation_parameters):
shrinkage_flag = mapping_evaluation_parameters['shrinkage_flag']
hard_shrinkage_flag = mapping_evaluation_parameters['hard_shrinkage_flag']
crop_flag = mapping_evaluation_parameters['crop_flag']
output_database_flag = mapping_evaluation_parameters['output_database_flag']
cross_domain_x2x_flag = mapping_evaluation_parameters['cross_domain_x2x_flag']
cross_domain_y2y_flag = mapping_evaluation_parameters['cross_domain_y2y_flag']
......@@ -438,6 +440,7 @@ def evaluate_data(mapping_evaluation_parameters):
data_list,
prediction_output_path,
prediction_output_database_name,
prediction_output_statistics_name,
brain_mask_path,
dmri_mean_mask_path,
rsfmri_mean_mask_path,
......@@ -455,6 +458,7 @@ def evaluate_data(mapping_evaluation_parameters):
crop_flag,
device,
exit_on_error,
output_database_flag,
cross_domain_x2x_flag,
cross_domain_y2y_flag
)
......
......@@ -2,12 +2,14 @@
trained_model_path = "saved_models/VA2-1.pth.tar"
prediction_output_path = "VA2-1_predictions"
prediction_output_database_name = "output_test_data.h5"
prediction_output_statistics_name = "output_statistics.pkl"
data_directory = "/well/win-biobank/projects/imaging/data/data3/subjectsAll/"
mapping_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
mapping_targets_file = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
data_list_reduced = "datasets/test_reduced.txt"
data_list_all = "datasets/test_all.txt"
evaluate_all_data = False
output_database_flag = False
brain_mask_path = "utils/MNI152_T1_2mm_brain_mask.nii.gz"
rsfmri_mean_mask_path = "utils/mean_dr_stage2.nii.gz"
dmri_mean_mask_path = "utils/mean_tractsNormSummed_downsampled.nii.gz"
......
......@@ -26,6 +26,8 @@ import pandas as pd
from fsl.data.image import Image
from fsl.utils.image.roi import roi
import itertools
from scipy.spatial.distance import cosine
from scipy.stats import pearsonr, spearmanr
log = logging.getLogger(__name__)
......@@ -36,6 +38,7 @@ def evaluate_data(trained_model_path,
data_list,
prediction_output_path,
prediction_output_database_name,
prediction_output_statistics_name,
brain_mask_path,
dmri_mean_mask_path,
rsfmri_mean_mask_path,
......@@ -53,6 +56,7 @@ def evaluate_data(trained_model_path,
crop_flag,
device=0,
exit_on_error=False,
output_database_flag=False,
cross_domain_x2x_flag=False,
cross_domain_y2y_flag=False,
mode='evaluate'):
......@@ -69,6 +73,7 @@ def evaluate_data(trained_model_path,
data_list (str): Path to a .txt file containing the input files for consideration
prediction_output_path (str): Output prediction path
prediction_output_database_name (str): Name of the output database
prediction_output_statistics_name (str): Name of the output statistics database
brain_mask_path (str): Path to the MNI brain mask file
dmri_mean_mask_path (str): Path to the dualreg subject mean mask
rsfmri_mean_mask_path (str): Path to the summed tract mean mask
......@@ -87,6 +92,7 @@ def evaluate_data(trained_model_path,
device (str/int): Device type used for training (int - GPU id, str- CPU)
mode (str): Current run mode or phase
exit_on_error (bool): Flag that triggers the raising of an exception
output_database_flag (bool): Flag indicating if the output maps should be saved to hdf5 database
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
......@@ -129,12 +135,20 @@ def evaluate_data(trained_model_path,
log.info("rsfMRI Generation Started")
if cross_domain_y2y_flag == True:
# If doing y2y autoencoder, then we load the targets as inputs. In all other cases (x2x & x2y) we load the inputs as inputs.
file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file=mapping_targets_file)
file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file=mapping_targets_file, mapping_targets_file=mapping_targets_file)
elif cross_domain_x2x_flag == True:
file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file, mapping_targets_file=mapping_data_file)
else:
file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file)
file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file, mapping_targets_file)
output_database_path = os.path.join(prediction_output_path, prediction_output_database_name)
output_database_handle = h5py.File(output_database_path, 'w')
if output_database_flag == True:
output_database_path = os.path.join(prediction_output_path, prediction_output_database_name)
if os.path.exists(output_database_path):
os.remove(output_database_path)
output_database_handle = h5py.File(output_database_path, 'w')
output_statistics = {}
output_statistics_path = os.path.join(prediction_output_path, prediction_output_statistics_name)
with torch.no_grad():
......@@ -168,12 +182,27 @@ def evaluate_data(trained_model_path,
cross_domain_x2x_flag,
cross_domain_y2y_flag)
group = output_database_handle.create_group(subject)
group.create_dataset('predicted_complete_volume', data=predicted_complete_volume)
group.create_dataset('predicted_volume', data=predicted_volume)
group.create_dataset('header', data=header)
group.create_dataset('xform', data=xform)
target_volume = _generate_target_volume(file_path,
subject,
dmri_mean_mask_path,
rsfmri_mean_mask_path,
regression_factors,
mean_regression_flag,
mean_regression_all_flag,
mean_subtraction_flag,
crop_flag,
cross_domain_x2x_flag
)
mse, mae, cel, pearson_r, pearson_p, spearman_r, spearman_p, regression_w, regression_b = _statistics_calculator(predicted_volume, target_volume)
output_statistics[subject] = [mse, mae, cel, pearson_r, pearson_p, spearman_r, spearman_p, regression_w, regression_b]
if output_database_flag == True:
group = output_database_handle.create_group(subject)
group.create_dataset('predicted_complete_volume', data=predicted_complete_volume)
group.create_dataset('predicted_volume', data=predicted_volume)
group.create_dataset('header', data=header)
group.create_dataset('xform', data=xform)
log.info("Processed: " + volumes_to_be_used[volume_index] + " " + str(
volume_index + 1) + " out of " + str(len(volumes_to_be_used)))
......@@ -192,10 +221,14 @@ def evaluate_data(trained_model_path,
if exit_on_error:
raise(exception_expression)
output_statistics_df = pd.DataFrame.from_dict(output_statistics, orient='index', columns=['mse', 'mae', 'cel', 'pearson_r', 'pearson_p', 'spearman_r', 'spearman_p', 'regression_w', 'regression_b'])
output_statistics_df.to_pickle(output_statistics_path)
log.info("Output Data Generation Complete")
output_database_handle.close()
def evaluate_mapping(trained_model_path,
data_directory,
mapping_data_file,
......@@ -408,7 +441,7 @@ def _generate_volume_map(file_path,
cross_domain_x2x_flag,
cross_domain_y2y_flag
):
"""rsfMRI Volume Generator
"""Output Volume Generator
This function uses the trained model to generate a new volume
......@@ -616,7 +649,7 @@ def _regress_input(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path,
elif crop_flag == True:
group_mean = roi(Image(dmri_mean_mask_path),((9,81),(10,100),(0,77))).data
regressed_volume = np.subtract(volume, np.multiply(weight, group_mean))
regressed_volume = np.subtract(volume, np.multiply(weight, group_mean))
return regressed_volume
......@@ -716,6 +749,161 @@ def _soft_shrinkage(volume, lambd):
return volume
def _generate_target_volume(file_path,
subject,
dmri_mean_mask_path,
rsfmri_mean_mask_path,
regression_factors,
mean_regression_flag,
mean_regression_all_flag,
mean_subtraction_flag,
crop_flag,
cross_domain_x2x_flag
):
"""Target Volume Generator
This function loads and preprocesses a target volume for comparing with the network predicted volumes
Args:
file_path (str): Path to the desired file
subject (str): Subject ID of the subject volume to be regressed
dmri_mean_mask_path (str): Path to the group mean volume
rsfmri_mean_mask_path (str): Path to the dualreg subject mean mask
regression_factors (str): Path to the linear regression weights file
mean_regression_flag (bool): Flag indicating if the volumes should be de-meaned by regression using the mean_mask_path
mean_regression_all_flag (bool): Flag indicating if both the input and target volumes should be regressed. If False, only targets are regressed.
mean_subtraction_flag (bool): Flag indicating if the targets should be de-meaned by subtraction using the mean_mask_path
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
Returns:
volume (np.array): Array containing the information regarding the target volume
"""
volume = data_utils.load_and_preprocess_targets(file_path, crop_flag, cross_domain_x2x_flag)
if mean_regression_flag == True:
volume = _regress_target(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, regression_factors, crop_flag, cross_domain_x2x_flag, mean_regression_all_flag)
elif mean_subtraction_flag == True:
volume = _subtract_target(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, crop_flag, cross_domain_x2x_flag)
return volume
def _regress_target(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, regression_factors, crop_flag, cross_domain_x2x_flag, mean_regression_all_flag):
""" Target Regression
This function regresse the group mean from the target volume using the saved regression weights.
TODO: This function repressents only a temporary solution. For deployment, a NN needs to be trained which predicts the relevant scaling factors.
Args:
volume (np.array): Unregressed volume
subject (str): Subject ID of the subject volume to be regressed
dmri_mean_mask_path (str): Path to the group mean volume
rsfmri_mean_mask_path (str): Path to the target group mean volume
regression_factors (str): Path to the linear regression weights file
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the targets
mean_regression_all_flag (bool): Flag indicating if both the input and target volumes should be regressed. If False, only targets are regressed.
Returns:
regressed_volume (np.array): Linear regressed volume
"""
if cross_domain_x2x_flag == True:
if mean_regression_all_flag == True:
weight = pd.read_pickle(regression_factors).loc[subject]['w_dMRI']
if crop_flag == False:
group_mean = Image(dmri_mean_mask_path).data
elif crop_flag == True:
group_mean = roi(Image(dmri_mean_mask_path),((9,81),(10,100),(0,77))).data
regressed_volume = np.subtract(volume, np.multiply(weight, group_mean))
else:
regressed_volume = volume
else:
weight = pd.read_pickle(regression_factors).loc[subject]['w_rsfMRI']
if crop_flag == False:
group_mean = Image(rsfmri_mean_mask_path).data[:, :, :, 0]
elif crop_flag == True:
group_mean = roi(Image(rsfmri_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]
regressed_volume = np.subtract(volume, np.multiply(weight, group_mean))
return regressed_volume
def _subtract_target(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, crop_flag, cross_domain_x2x_flag):
""" Target Subtraction
This function subtracts the group mean from the target volume using the saved regression weights.
TODO: This function repressents only a temporary solution. For deployment, a NN needs to be trained which predicts the relevant scaling factors.
Args:
volume (np.array): Unregressed volume
subject (str): Subject ID of the subject volume to be regressed
dmri_mean_mask_path (str): Path to the group mean volume
rsfmri_mean_mask_path (str): Path to the target group mean volume
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns:
regressed_volume (np.array): Linear regressed volume
"""
if cross_domain_x2x_flag == True:
subtracted_volume = volume
else:
if crop_flag == False:
group_mean = Image(rsfmri_mean_mask_path).data[:, :, :, 0]
elif crop_flag == True:
group_mean = roi(Image(rsfmri_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]
subtracted_volume = np.subtract(volume, group_mean)
return subtracted_volume
def _statistics_calculator(volume, target):
""" Training statistics calculator
This function calculates the MSE, MAE, CEL, Pearson R and P and linear regression W and B for a predicted volume and it's target ground truth.
Args:
volume (np.array): Predicted volume
target (np.array): Ground truth volume
Returns:
mse (np.float64): The mean squared error between the prediction and the ground truth; The closer to 0, the better
mae (np.float64): The mean absolut error between the prediction and the ground truth; The closer to 0, the better
cel (np.float64): The cosine distance between the prediction and the ground truth; The closer to 0, the better
pearson_r (np.float64): Pearson’s correlation coefficient; The closer to 1, the better
pearson_p (np.float64): Two-tailed p-value for Pearson’s correlation coefficient; the closer to 0, the better
spearman_r (np.float64): Spearman correlation coefficient; The closer to 1, the better
spearman_p (np.float64): Two-tailed p-value for Spearman's correlation coefficient; the closer to 0, the better
regression_w (np.float64): Slope of the linear regression line; The closer to 1 the better
regression_b (np.float64): Intersect of the linear regression line; The closer to 0, the better
"""
x = np.reshape(volume, -1)
y = np.reshape(target, -1)
mse = np.square(np.subtract(x,y)).mean()
mae = np.abs(np.subtract(x,y)).mean()
cel = np.mean(cosine(x, y))
pearson_r, pearson_p = pearsonr(x,y)
spearman_r, spearman_p = spearmanr(x,y)
x_matrix = np.vstack((np.ones(len(x)), x)).T
regression_b, regression_w = np.linalg.inv(x_matrix.T.dot(x_matrix)).dot(x_matrix.T).dot(y)
return mse, mae, cel, pearson_r, pearson_p, spearman_r, spearman_p, regression_w, regression_b
def _pearson_correlation(volume, target):
"""Calculate Pearson Correlation Coefficient
......@@ -733,3 +921,4 @@ def _pearson_correlation(volume, target):
np.sum(np.power(np.subtract(volume, volume.mean()), 2)), np.sum(np.power(np.subtract(target, target.mean()), 2))))
return r
......@@ -96,7 +96,7 @@ def get_datasets(data_parameters, cross_domain_x2x_flag, cross_domain_y2y_flag):
)
def load_file_paths(data_directory, data_list, mapping_data_file, targets_directory=None, target_file=None):
def load_file_paths(data_directory, data_list, mapping_data_file, mapping_targets_file=None):
"""File Loader
This function returns a list of combined file paths for the input and output data.
......@@ -105,7 +105,7 @@ def load_file_paths(data_directory, data_list, mapping_data_file, targets_direct
data_directory (str): Path to input data directory
data_list (str): Path to a .txt file containing the input files for consideration
mapping_data_file (str): Path to the input files
targets_directory (str): Path to labelled data (Y-equivalent); None if during evaluation.
mapping_targets_file (str): Path to the target files
Returns:
file_paths (list): List containing the input data and target labelled output data
......@@ -117,12 +117,12 @@ def load_file_paths(data_directory, data_list, mapping_data_file, targets_direct
volumes_to_be_used = load_subjects_from_path(data_directory, data_list)
if targets_directory == None or target_file == None:
if mapping_targets_file == None:
file_paths = [[os.path.join(data_directory, volume, mapping_data_file)]
for volume in volumes_to_be_used]
else:
file_paths = [[os.path.join(data_directory, volume, mapping_data_file), os.path.join(
targets_directory, volume)] for volume in volumes_to_be_used]
file_paths = [[os.path.join(data_directory, volume, mapping_data_file),
os.path.join(data_directory, volume, mapping_targets_file)] for volume in volumes_to_be_used]
return file_paths, volumes_to_be_used
......@@ -158,10 +158,10 @@ def load_and_preprocess_evaluation(file_path, crop_flag, cross_domain_y2y_flag):
Args:
file_path (str): Path to the desired file
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_y2y_flag
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns:
volume (np.array): Array of training image data of data type dtype.
volume (np.array): Array of training image data
header (class): 'nibabel.nifti1.Nifti1Header' class object, containing image metadata
xform (np.array): Array of shape (4, 4), containing the adjusted voxel-to-world transformation for the spatial dimensions of the resampled data
......@@ -194,27 +194,35 @@ def load_and_preprocess_evaluation(file_path, crop_flag, cross_domain_y2y_flag):
return volume, header, xform
def load_and_preprocess_targets(target_path, mean_mask_path):
def load_and_preprocess_targets(file_path, crop_flag, cross_domain_x2x_flag):
"""Load & Preprocessing targets before evaluation
This function loads a nifty file and returns its volume, a de-meaned volume and header information
This function loads a nifty file and returns its volume information
Args:
file_path (str): Path to the desired target file
mean_mask_path (str): Path to the dualreg subject mean mask
file_path (str): Path to the desired file
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
Returns:
target (np.array): Array of training image data of data type dtype.
target_demeaned (np.array): Array of training data from which the group mean has been subtracted
volume (np.array): Array of target image intensities.
Raises:
ValueError: "Orientation value is invalid. It must be either >>coronal<<, >>axial<< or >>sagital<< "
"""
"""
original_image = Image(file_path[1])
target = Image(target_path[0]).data[:, :, :, 0]
target_demeaned = np.subtract(
target, Image(mean_mask_path).data[:, :, :, 0])
if cross_domain_x2x_flag == True:
if crop_flag == False:
volume, _ = resampleToPixdims(original_image, (2, 2, 2))
elif crop_flag == True:
resampled, xform = resampleToPixdims(original_image, (2, 2, 2))
volume = roi(Image(resampled, header=original_image.header, xform=xform),((9,81),(10,100),(0,77))).data
else:
if crop_flag == False:
volume = original_image.data[:, :, :, 0]
elif crop_flag == True:
volume = roi(original_image,((9,81),(10,100),(0,77))).data[:, :, :, 0]
return target, target_demeaned
return volume
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment