Commit 8cbb592b authored by Andrei Roibu's avatar Andrei Roibu
Browse files

added cross-domain x2x y2y x2y options to mapping eval

parent 1c12a511
......@@ -3,6 +3,7 @@ trained_model_path = "saved_models/VA2-1.pth.tar"
prediction_output_path = "VA2-1_predictions"
data_directory = "/well/win-biobank/projects/imaging/data/data3/subjectsAll/"
mapping_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
mapping_targets_file = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
data_list = "datasets/test.txt"
brain_mask_path = "utils/MNI152_T1_2mm_brain_mask.nii.gz"
rsfmri_mean_mask_path = "utils/mean_dr_stage2.nii.gz"
......@@ -21,4 +22,6 @@ shrinkage_flag = False
hard_shrinkage_flag = False
crop_flag = True
device = 0
exit_on_error = True
\ No newline at end of file
exit_on_error = True
cross_domain_x2x_flag = False
cross_domain_y2y_flag = False
\ No newline at end of file
......@@ -31,6 +31,7 @@ log = logging.getLogger(__name__)
def evaluate_mapping(trained_model_path,
data_directory,
mapping_data_file,
mapping_targets_file,
data_list,
prediction_output_path,
brain_mask_path,
......@@ -50,6 +51,8 @@ def evaluate_mapping(trained_model_path,
crop_flag,
device=0,
exit_on_error=False,
cross_domain_x2x_flag=False,
cross_domain_y2y_flag=False,
mode='evaluate'):
"""Model Evaluator
......@@ -59,6 +62,7 @@ def evaluate_mapping(trained_model_path,
trained_model_path (str): Path to the location of the trained model
data_directory (str): Path to input data directory
mapping_data_file (str): Path to the input file
mapping_targets_file (str): Path to the target file
data_list (str): Path to a .txt file containing the input files for consideration
prediction_output_path (str): Output prediction path
brain_mask_path (str): Path to the MNI brain mask file
......@@ -79,6 +83,8 @@ def evaluate_mapping(trained_model_path,
device (str/int): Device type used for training (int - GPU id, str- CPU)
mode (str): Current run mode or phase
exit_on_error (bool): Flag that triggers the raising of an exception
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Raises:
FileNotFoundError: Error in reading the provided file!
......@@ -117,8 +123,11 @@ def evaluate_mapping(trained_model_path,
# Initiate the evaluation
log.info("rsfMRI Generation Started")
file_paths, volumes_to_be_used = data_utils.load_file_paths(
data_directory, data_list, mapping_data_file)
if cross_domain_y2y_flag == True:
# If doing y2y autoencoder, then we load the targets as inputs. In all other cases (x2x & x2y) we load the inputs as inputs.
file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file=mapping_targets_file)
else:
file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file)
with torch.no_grad():
......@@ -148,10 +157,15 @@ def evaluate_mapping(trained_model_path,
outlier_flag,
shrinkage_flag,
hard_shrinkage_flag,
crop_flag)
crop_flag,
cross_domain_x2x_flag,
cross_domain_y2y_flag)
if crop_flag == False:
output_nifti_image = Image(predicted_volume, header=header, xform=xform)
if cross_domain_y2y_flag == True:
output_nifti_image = Image(predicted_volume, header=header)
else:
output_nifti_image = Image(predicted_volume, header=header, xform=xform)
elif crop_flag == True:
output_nifti_image = Image(predicted_volume, header=header)
output_nifti_image = roi(output_nifti_image, ((-9,82),(-10,99),(0,91)))
......@@ -166,7 +180,10 @@ def evaluate_mapping(trained_model_path,
if mean_regression_flag == True:
if crop_flag == False:
output_complete_nifti_image = Image(predicted_complete_volume, header=header, xform=xform)
if cross_domain_y2y_flag == True:
output_nifti_image = Image(predicted_complete_volume, header=header)
else:
output_complete_nifti_image = Image(predicted_complete_volume, header=header, xform=xform)
elif crop_flag == True:
output_complete_nifti_image = Image(predicted_complete_volume, header=header)
output_complete_nifti_image = roi(output_complete_nifti_image, ((-9,82),(-10,99),(0,91)))
......@@ -220,6 +237,8 @@ def _generate_volume_map(file_path,
shrinkage_flag,
hard_shrinkage_flag,
crop_flag,
cross_domain_x2x_flag,
cross_domain_y2y_flag
):
"""rsfMRI Volume Generator
......@@ -246,17 +265,19 @@ def _generate_volume_map(file_path,
shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied.
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns
predicted_volume (np.array): Array containing the information regarding the generated volume
header (class): 'nibabel.nifti1.Nifti1Header' class object, containing volume metadata
"""
volume, header, xform = data_utils.load_and_preprocess_evaluation(file_path, crop_flag)
volume, header, xform = data_utils.load_and_preprocess_evaluation(file_path, crop_flag, cross_domain_y2y_flag)
if mean_regression_flag == True:
if mean_regression_all_flag == True:
volume = _regress_input(volume, subject, dmri_mean_mask_path, regression_factors, crop_flag)
volume = _regress_input(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, regression_factors, crop_flag, cross_domain_y2y_flag)
scaling_parameters = [-0.0626, 0.1146, -14.18, 16.9475]
else:
scaling_parameters = [0.0, 0.2, -14.18, 16.9475]
......@@ -266,7 +287,7 @@ def _generate_volume_map(file_path,
print('volume range:', np.min(volume), np.max(volume))
if scale_volumes_flag == True:
volume = _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, outlier_flag, shrinkage_flag, hard_shrinkage_flag)
volume = _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, outlier_flag, shrinkage_flag, hard_shrinkage_flag, cross_domain_y2y_flag)
if len(volume.shape) == 5:
volume = volume
......@@ -284,7 +305,7 @@ def _generate_volume_map(file_path,
print('output range:', np.min(output), np.max(output))
output = _rescale_output(output, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, shrinkage_flag, hard_shrinkage_flag)
output = _rescale_output(output, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, shrinkage_flag, hard_shrinkage_flag, cross_domain_x2x_flag)
print('output rescaled:', np.min(output), np.max(output))
......@@ -295,13 +316,22 @@ def _generate_volume_map(file_path,
if mean_regression_flag == True or mean_subtraction_flag == True:
if crop_flag == False:
mean_mask = Image(rsfmri_mean_mask_path).data[:, :, :, 0]
elif crop_flag == True:
mean_mask = roi(Image(rsfmri_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]
if cross_domain_x2x_flag == True:
if crop_flag == False:
mean_mask = Image(dmri_mean_mask_path).data
elif crop_flag == True:
mean_mask = roi(Image(dmri_mean_mask_path),((9,81),(10,100),(0,77))).data
else:
if crop_flag == False:
mean_mask = Image(rsfmri_mean_mask_path).data[:, :, :, 0]
elif crop_flag == True:
mean_mask = roi(Image(rsfmri_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]
if mean_regression_flag == True:
weight = pd.read_pickle(regression_factors).loc[subject]['w_rsfMRI']
if cross_domain_x2x_flag == True:
weight = pd.read_pickle(regression_factors).loc[subject]['w_dMRI']
else:
weight = pd.read_pickle(regression_factors).loc[subject]['w_rsfMRI']
predicted_complete_volume = np.add(output, np.multiply(weight, mean_mask))
if mean_subtraction_flag == True:
......@@ -327,7 +357,7 @@ def _generate_volume_map(file_path,
return predicted_complete_volume, predicted_volume, header, xform
def _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, outlier_flag, shrinkage_flag, hard_shrinkage_flag):
def _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, outlier_flag, shrinkage_flag, hard_shrinkage_flag, cross_domain_y2y_flag):
"""Input Scaling
This function reads the scaling factors from the saved file and then scales the data.
......@@ -341,15 +371,22 @@ def _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_f
outlier_flag (bool): Flag indicating if outliers should be set to the min/max values.
shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied.
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns:
scaled_volume (np.array): Scaled volume
"""
min_value, max_value, _, _ = scaling_parameters
if cross_domain_y2y_flag == True:
_, _, min_value, max_value = scaling_parameters
else:
min_value, max_value, _, _ = scaling_parameters
if shrinkage_flag == True:
lambd = 0.003 # Hard coded, equivalent to tht 1p and 99p values across the whole population in UKBB
if cross_domain_y2y_flag == True:
lambd = 3.0
else:
lambd = 0.003 # Hard coded, equivalent to tht 1p and 99p values across the whole population in UKBB
if hard_shrinkage_flag == True:
volume = _hard_shrinkage(volume, lambd)
......@@ -377,7 +414,7 @@ def _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_f
return scaled_volume
def _regress_input(volume, subject, dmri_mean_mask_path, regression_factors, crop_flag):
def _regress_input(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, regression_factors, crop_flag, cross_domain_y2y_flag):
""" Inputn Regression
This function regresse the group mean from the input volume using the saved regression weights.
......@@ -388,26 +425,35 @@ def _regress_input(volume, subject, dmri_mean_mask_path, regression_factors, cro
volume (np.array): Unregressed volume
subject (str): Subject ID of the subject volume to be regressed
dmri_mean_mask_path (str): Path to the group mean volume
rsfmri_mean_mask_path (str): Path to the target group mean volume
regression_factors (str): Path to the linear regression weights file
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns:
regressed_volume (np.array): Linear regressed volume
"""
weight = pd.read_pickle(regression_factors).loc[subject]['w_dMRI']
if crop_flag == False:
group_mean = Image(dmri_mean_mask_path).data
elif crop_flag == True:
group_mean = roi(Image(dmri_mean_mask_path),((9,81),(10,100),(0,77))).data
if cross_domain_y2y_flag == True:
weight = pd.read_pickle(regression_factors).loc[subject]['w_rsfMRI']
if crop_flag == False:
group_mean = Image(rsfmri_mean_mask_path).data[:, :, :, 0]
elif crop_flag == True:
group_mean = roi(Image(rsfmri_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]
else:
weight = pd.read_pickle(regression_factors).loc[subject]['w_dMRI']
if crop_flag == False:
group_mean = Image(dmri_mean_mask_path).data
elif crop_flag == True:
group_mean = roi(Image(dmri_mean_mask_path),((9,81),(10,100),(0,77))).data
regressed_volume = np.subtract(volume, np.multiply(weight, group_mean))
regressed_volume = np.subtract(volume, np.multiply(weight, group_mean))
return regressed_volume
def _rescale_output(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, shrinkage_flag, hard_shrinkage_flag):
def _rescale_output(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, shrinkage_flag, hard_shrinkage_flag, cross_domain_x2x_flag):
"""Output Rescaling
This function reads the scaling factors from the saved file and then scales the data.
......@@ -420,15 +466,23 @@ def _rescale_output(volume, scaling_parameters, normalize_flag, minus_one_scalin
negative_flag (bool): Flag indicating if all the negative values should be 0-ed.
shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied.
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
Returns:
rescaled_volume (np.array): Rescaled volume
"""
_, _, min_value, max_value = scaling_parameters
if cross_domain_x2x_flag == True:
min_value, max_value, _, _ = scaling_parameters
else:
_, _, min_value, max_value = scaling_parameters
if shrinkage_flag == True:
lambd = 3.0
if cross_domain_x2x_flag == True:
lambd = 0.003
else:
lambd = 3.0
if hard_shrinkage_flag == True:
pass
elif hard_shrinkage_flag == False:
......
......@@ -150,7 +150,7 @@ def load_subjects_from_path(data_directory, data_list):
return volumes_to_be_used
def load_and_preprocess_evaluation(file_path, crop_flag):
def load_and_preprocess_evaluation(file_path, crop_flag, cross_domain_y2y_flag):
"""Load & Preprocessing before evaluation
This function loads a nifty file and returns its volume and header information
......@@ -158,6 +158,7 @@ def load_and_preprocess_evaluation(file_path, crop_flag):
Args:
file_path (str): Path to the desired file
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_y2y_flag
Returns:
volume (np.array): Array of training image data of data type dtype.
......@@ -170,15 +171,25 @@ def load_and_preprocess_evaluation(file_path, crop_flag):
original_image = Image(file_path[0])
if crop_flag == False:
volume, xform = resampleToPixdims(original_image, (2, 2, 2))
header = Image(volume, header=original_image.header, xform=xform).header
elif crop_flag == True:
resampled, xform = resampleToPixdims(original_image, (2, 2, 2))
resampled = Image(resampled, header=original_image.header, xform=xform)
cropped = roi(resampled,((9,81),(10,100),(0,77)))
volume = cropped.data
header = cropped.header
if cross_domain_y2y_flag == True:
if crop_flag == False:
volume = original_image.data[:, :, :, 0]
header = Image(volume, header=original_image.header).header
elif crop_flag == True:
cropped = roi(original_image,((9,81),(10,100),(0,77)))
volume = cropped.data[:, :, :, 0]
header = cropped.header
xform = None
else:
if crop_flag == False:
volume, xform = resampleToPixdims(original_image, (2, 2, 2))
header = Image(volume, header=original_image.header, xform=xform).header
elif crop_flag == True:
resampled, xform = resampleToPixdims(original_image, (2, 2, 2))
resampled = Image(resampled, header=original_image.header, xform=xform)
cropped = roi(resampled,((9,81),(10,100),(0,77)))
volume = cropped.data
header = cropped.header
return volume, header, xform
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment