Commit 8cbb592b authored by Andrei Roibu's avatar Andrei Roibu
Browse files

added cross-domain x2x y2y x2y options to mapping eval

parent 1c12a511
...@@ -3,6 +3,7 @@ trained_model_path = "saved_models/VA2-1.pth.tar" ...@@ -3,6 +3,7 @@ trained_model_path = "saved_models/VA2-1.pth.tar"
prediction_output_path = "VA2-1_predictions" prediction_output_path = "VA2-1_predictions"
data_directory = "/well/win-biobank/projects/imaging/data/data3/subjectsAll/" data_directory = "/well/win-biobank/projects/imaging/data/data3/subjectsAll/"
mapping_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz" mapping_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
mapping_targets_file = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
data_list = "datasets/test.txt" data_list = "datasets/test.txt"
brain_mask_path = "utils/MNI152_T1_2mm_brain_mask.nii.gz" brain_mask_path = "utils/MNI152_T1_2mm_brain_mask.nii.gz"
rsfmri_mean_mask_path = "utils/mean_dr_stage2.nii.gz" rsfmri_mean_mask_path = "utils/mean_dr_stage2.nii.gz"
...@@ -21,4 +22,6 @@ shrinkage_flag = False ...@@ -21,4 +22,6 @@ shrinkage_flag = False
hard_shrinkage_flag = False hard_shrinkage_flag = False
crop_flag = True crop_flag = True
device = 0 device = 0
exit_on_error = True exit_on_error = True
\ No newline at end of file cross_domain_x2x_flag = False
cross_domain_y2y_flag = False
\ No newline at end of file
...@@ -31,6 +31,7 @@ log = logging.getLogger(__name__) ...@@ -31,6 +31,7 @@ log = logging.getLogger(__name__)
def evaluate_mapping(trained_model_path, def evaluate_mapping(trained_model_path,
data_directory, data_directory,
mapping_data_file, mapping_data_file,
mapping_targets_file,
data_list, data_list,
prediction_output_path, prediction_output_path,
brain_mask_path, brain_mask_path,
...@@ -50,6 +51,8 @@ def evaluate_mapping(trained_model_path, ...@@ -50,6 +51,8 @@ def evaluate_mapping(trained_model_path,
crop_flag, crop_flag,
device=0, device=0,
exit_on_error=False, exit_on_error=False,
cross_domain_x2x_flag=False,
cross_domain_y2y_flag=False,
mode='evaluate'): mode='evaluate'):
"""Model Evaluator """Model Evaluator
...@@ -59,6 +62,7 @@ def evaluate_mapping(trained_model_path, ...@@ -59,6 +62,7 @@ def evaluate_mapping(trained_model_path,
trained_model_path (str): Path to the location of the trained model trained_model_path (str): Path to the location of the trained model
data_directory (str): Path to input data directory data_directory (str): Path to input data directory
mapping_data_file (str): Path to the input file mapping_data_file (str): Path to the input file
mapping_targets_file (str): Path to the target file
data_list (str): Path to a .txt file containing the input files for consideration data_list (str): Path to a .txt file containing the input files for consideration
prediction_output_path (str): Output prediction path prediction_output_path (str): Output prediction path
brain_mask_path (str): Path to the MNI brain mask file brain_mask_path (str): Path to the MNI brain mask file
...@@ -79,6 +83,8 @@ def evaluate_mapping(trained_model_path, ...@@ -79,6 +83,8 @@ def evaluate_mapping(trained_model_path,
device (str/int): Device type used for training (int - GPU id, str- CPU) device (str/int): Device type used for training (int - GPU id, str- CPU)
mode (str): Current run mode or phase mode (str): Current run mode or phase
exit_on_error (bool): Flag that triggers the raising of an exception exit_on_error (bool): Flag that triggers the raising of an exception
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Raises: Raises:
FileNotFoundError: Error in reading the provided file! FileNotFoundError: Error in reading the provided file!
...@@ -117,8 +123,11 @@ def evaluate_mapping(trained_model_path, ...@@ -117,8 +123,11 @@ def evaluate_mapping(trained_model_path,
# Initiate the evaluation # Initiate the evaluation
log.info("rsfMRI Generation Started") log.info("rsfMRI Generation Started")
file_paths, volumes_to_be_used = data_utils.load_file_paths( if cross_domain_y2y_flag == True:
data_directory, data_list, mapping_data_file) # If doing y2y autoencoder, then we load the targets as inputs. In all other cases (x2x & x2y) we load the inputs as inputs.
file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file=mapping_targets_file)
else:
file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file)
with torch.no_grad(): with torch.no_grad():
...@@ -148,10 +157,15 @@ def evaluate_mapping(trained_model_path, ...@@ -148,10 +157,15 @@ def evaluate_mapping(trained_model_path,
outlier_flag, outlier_flag,
shrinkage_flag, shrinkage_flag,
hard_shrinkage_flag, hard_shrinkage_flag,
crop_flag) crop_flag,
cross_domain_x2x_flag,
cross_domain_y2y_flag)
if crop_flag == False: if crop_flag == False:
output_nifti_image = Image(predicted_volume, header=header, xform=xform) if cross_domain_y2y_flag == True:
output_nifti_image = Image(predicted_volume, header=header)
else:
output_nifti_image = Image(predicted_volume, header=header, xform=xform)
elif crop_flag == True: elif crop_flag == True:
output_nifti_image = Image(predicted_volume, header=header) output_nifti_image = Image(predicted_volume, header=header)
output_nifti_image = roi(output_nifti_image, ((-9,82),(-10,99),(0,91))) output_nifti_image = roi(output_nifti_image, ((-9,82),(-10,99),(0,91)))
...@@ -166,7 +180,10 @@ def evaluate_mapping(trained_model_path, ...@@ -166,7 +180,10 @@ def evaluate_mapping(trained_model_path,
if mean_regression_flag == True: if mean_regression_flag == True:
if crop_flag == False: if crop_flag == False:
output_complete_nifti_image = Image(predicted_complete_volume, header=header, xform=xform) if cross_domain_y2y_flag == True:
output_nifti_image = Image(predicted_complete_volume, header=header)
else:
output_complete_nifti_image = Image(predicted_complete_volume, header=header, xform=xform)
elif crop_flag == True: elif crop_flag == True:
output_complete_nifti_image = Image(predicted_complete_volume, header=header) output_complete_nifti_image = Image(predicted_complete_volume, header=header)
output_complete_nifti_image = roi(output_complete_nifti_image, ((-9,82),(-10,99),(0,91))) output_complete_nifti_image = roi(output_complete_nifti_image, ((-9,82),(-10,99),(0,91)))
...@@ -220,6 +237,8 @@ def _generate_volume_map(file_path, ...@@ -220,6 +237,8 @@ def _generate_volume_map(file_path,
shrinkage_flag, shrinkage_flag,
hard_shrinkage_flag, hard_shrinkage_flag,
crop_flag, crop_flag,
cross_domain_x2x_flag,
cross_domain_y2y_flag
): ):
"""rsfMRI Volume Generator """rsfMRI Volume Generator
...@@ -246,17 +265,19 @@ def _generate_volume_map(file_path, ...@@ -246,17 +265,19 @@ def _generate_volume_map(file_path,
shrinkage_flag (bool): Flag indicating if shrinkage should be applied. shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied. hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied.
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns Returns
predicted_volume (np.array): Array containing the information regarding the generated volume predicted_volume (np.array): Array containing the information regarding the generated volume
header (class): 'nibabel.nifti1.Nifti1Header' class object, containing volume metadata header (class): 'nibabel.nifti1.Nifti1Header' class object, containing volume metadata
""" """
volume, header, xform = data_utils.load_and_preprocess_evaluation(file_path, crop_flag) volume, header, xform = data_utils.load_and_preprocess_evaluation(file_path, crop_flag, cross_domain_y2y_flag)
if mean_regression_flag == True: if mean_regression_flag == True:
if mean_regression_all_flag == True: if mean_regression_all_flag == True:
volume = _regress_input(volume, subject, dmri_mean_mask_path, regression_factors, crop_flag) volume = _regress_input(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, regression_factors, crop_flag, cross_domain_y2y_flag)
scaling_parameters = [-0.0626, 0.1146, -14.18, 16.9475] scaling_parameters = [-0.0626, 0.1146, -14.18, 16.9475]
else: else:
scaling_parameters = [0.0, 0.2, -14.18, 16.9475] scaling_parameters = [0.0, 0.2, -14.18, 16.9475]
...@@ -266,7 +287,7 @@ def _generate_volume_map(file_path, ...@@ -266,7 +287,7 @@ def _generate_volume_map(file_path,
print('volume range:', np.min(volume), np.max(volume)) print('volume range:', np.min(volume), np.max(volume))
if scale_volumes_flag == True: if scale_volumes_flag == True:
volume = _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, outlier_flag, shrinkage_flag, hard_shrinkage_flag) volume = _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, outlier_flag, shrinkage_flag, hard_shrinkage_flag, cross_domain_y2y_flag)
if len(volume.shape) == 5: if len(volume.shape) == 5:
volume = volume volume = volume
...@@ -284,7 +305,7 @@ def _generate_volume_map(file_path, ...@@ -284,7 +305,7 @@ def _generate_volume_map(file_path,
print('output range:', np.min(output), np.max(output)) print('output range:', np.min(output), np.max(output))
output = _rescale_output(output, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, shrinkage_flag, hard_shrinkage_flag) output = _rescale_output(output, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, shrinkage_flag, hard_shrinkage_flag, cross_domain_x2x_flag)
print('output rescaled:', np.min(output), np.max(output)) print('output rescaled:', np.min(output), np.max(output))
...@@ -295,13 +316,22 @@ def _generate_volume_map(file_path, ...@@ -295,13 +316,22 @@ def _generate_volume_map(file_path,
if mean_regression_flag == True or mean_subtraction_flag == True: if mean_regression_flag == True or mean_subtraction_flag == True:
if crop_flag == False: if cross_domain_x2x_flag == True:
mean_mask = Image(rsfmri_mean_mask_path).data[:, :, :, 0] if crop_flag == False:
elif crop_flag == True: mean_mask = Image(dmri_mean_mask_path).data
mean_mask = roi(Image(rsfmri_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0] elif crop_flag == True:
mean_mask = roi(Image(dmri_mean_mask_path),((9,81),(10,100),(0,77))).data
else:
if crop_flag == False:
mean_mask = Image(rsfmri_mean_mask_path).data[:, :, :, 0]
elif crop_flag == True:
mean_mask = roi(Image(rsfmri_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]
if mean_regression_flag == True: if mean_regression_flag == True:
weight = pd.read_pickle(regression_factors).loc[subject]['w_rsfMRI'] if cross_domain_x2x_flag == True:
weight = pd.read_pickle(regression_factors).loc[subject]['w_dMRI']
else:
weight = pd.read_pickle(regression_factors).loc[subject]['w_rsfMRI']
predicted_complete_volume = np.add(output, np.multiply(weight, mean_mask)) predicted_complete_volume = np.add(output, np.multiply(weight, mean_mask))
if mean_subtraction_flag == True: if mean_subtraction_flag == True:
...@@ -327,7 +357,7 @@ def _generate_volume_map(file_path, ...@@ -327,7 +357,7 @@ def _generate_volume_map(file_path,
return predicted_complete_volume, predicted_volume, header, xform return predicted_complete_volume, predicted_volume, header, xform
def _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, outlier_flag, shrinkage_flag, hard_shrinkage_flag): def _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, outlier_flag, shrinkage_flag, hard_shrinkage_flag, cross_domain_y2y_flag):
"""Input Scaling """Input Scaling
This function reads the scaling factors from the saved file and then scales the data. This function reads the scaling factors from the saved file and then scales the data.
...@@ -341,15 +371,22 @@ def _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_f ...@@ -341,15 +371,22 @@ def _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_f
outlier_flag (bool): Flag indicating if outliers should be set to the min/max values. outlier_flag (bool): Flag indicating if outliers should be set to the min/max values.
shrinkage_flag (bool): Flag indicating if shrinkage should be applied. shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied. hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied.
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns: Returns:
scaled_volume (np.array): Scaled volume scaled_volume (np.array): Scaled volume
""" """
min_value, max_value, _, _ = scaling_parameters if cross_domain_y2y_flag == True:
_, _, min_value, max_value = scaling_parameters
else:
min_value, max_value, _, _ = scaling_parameters
if shrinkage_flag == True: if shrinkage_flag == True:
lambd = 0.003 # Hard coded, equivalent to tht 1p and 99p values across the whole population in UKBB if cross_domain_y2y_flag == True:
lambd = 3.0
else:
lambd = 0.003 # Hard coded, equivalent to tht 1p and 99p values across the whole population in UKBB
if hard_shrinkage_flag == True: if hard_shrinkage_flag == True:
volume = _hard_shrinkage(volume, lambd) volume = _hard_shrinkage(volume, lambd)
...@@ -377,7 +414,7 @@ def _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_f ...@@ -377,7 +414,7 @@ def _scale_input(volume, scaling_parameters, normalize_flag, minus_one_scaling_f
return scaled_volume return scaled_volume
def _regress_input(volume, subject, dmri_mean_mask_path, regression_factors, crop_flag): def _regress_input(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, regression_factors, crop_flag, cross_domain_y2y_flag):
""" Inputn Regression """ Inputn Regression
This function regresse the group mean from the input volume using the saved regression weights. This function regresse the group mean from the input volume using the saved regression weights.
...@@ -388,26 +425,35 @@ def _regress_input(volume, subject, dmri_mean_mask_path, regression_factors, cro ...@@ -388,26 +425,35 @@ def _regress_input(volume, subject, dmri_mean_mask_path, regression_factors, cro
volume (np.array): Unregressed volume volume (np.array): Unregressed volume
subject (str): Subject ID of the subject volume to be regressed subject (str): Subject ID of the subject volume to be regressed
dmri_mean_mask_path (str): Path to the group mean volume dmri_mean_mask_path (str): Path to the group mean volume
rsfmri_mean_mask_path (str): Path to the target group mean volume
regression_factors (str): Path to the linear regression weights file regression_factors (str): Path to the linear regression weights file
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns: Returns:
regressed_volume (np.array): Linear regressed volume regressed_volume (np.array): Linear regressed volume
""" """
weight = pd.read_pickle(regression_factors).loc[subject]['w_dMRI'] if cross_domain_y2y_flag == True:
if crop_flag == False: weight = pd.read_pickle(regression_factors).loc[subject]['w_rsfMRI']
group_mean = Image(dmri_mean_mask_path).data if crop_flag == False:
elif crop_flag == True: group_mean = Image(rsfmri_mean_mask_path).data[:, :, :, 0]
group_mean = roi(Image(dmri_mean_mask_path),((9,81),(10,100),(0,77))).data elif crop_flag == True:
group_mean = roi(Image(rsfmri_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]
else:
weight = pd.read_pickle(regression_factors).loc[subject]['w_dMRI']
if crop_flag == False:
group_mean = Image(dmri_mean_mask_path).data
elif crop_flag == True:
group_mean = roi(Image(dmri_mean_mask_path),((9,81),(10,100),(0,77))).data
regressed_volume = np.subtract(volume, np.multiply(weight, group_mean)) regressed_volume = np.subtract(volume, np.multiply(weight, group_mean))
return regressed_volume return regressed_volume
def _rescale_output(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, shrinkage_flag, hard_shrinkage_flag): def _rescale_output(volume, scaling_parameters, normalize_flag, minus_one_scaling_flag, negative_flag, shrinkage_flag, hard_shrinkage_flag, cross_domain_x2x_flag):
"""Output Rescaling """Output Rescaling
This function reads the scaling factors from the saved file and then scales the data. This function reads the scaling factors from the saved file and then scales the data.
...@@ -420,15 +466,23 @@ def _rescale_output(volume, scaling_parameters, normalize_flag, minus_one_scalin ...@@ -420,15 +466,23 @@ def _rescale_output(volume, scaling_parameters, normalize_flag, minus_one_scalin
negative_flag (bool): Flag indicating if all the negative values should be 0-ed. negative_flag (bool): Flag indicating if all the negative values should be 0-ed.
shrinkage_flag (bool): Flag indicating if shrinkage should be applied. shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied. hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied.
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
Returns: Returns:
rescaled_volume (np.array): Rescaled volume rescaled_volume (np.array): Rescaled volume
""" """
_, _, min_value, max_value = scaling_parameters if cross_domain_x2x_flag == True:
min_value, max_value, _, _ = scaling_parameters
else:
_, _, min_value, max_value = scaling_parameters
if shrinkage_flag == True: if shrinkage_flag == True:
lambd = 3.0 if cross_domain_x2x_flag == True:
lambd = 0.003
else:
lambd = 3.0
if hard_shrinkage_flag == True: if hard_shrinkage_flag == True:
pass pass
elif hard_shrinkage_flag == False: elif hard_shrinkage_flag == False:
......
...@@ -150,7 +150,7 @@ def load_subjects_from_path(data_directory, data_list): ...@@ -150,7 +150,7 @@ def load_subjects_from_path(data_directory, data_list):
return volumes_to_be_used return volumes_to_be_used
def load_and_preprocess_evaluation(file_path, crop_flag): def load_and_preprocess_evaluation(file_path, crop_flag, cross_domain_y2y_flag):
"""Load & Preprocessing before evaluation """Load & Preprocessing before evaluation
This function loads a nifty file and returns its volume and header information This function loads a nifty file and returns its volume and header information
...@@ -158,6 +158,7 @@ def load_and_preprocess_evaluation(file_path, crop_flag): ...@@ -158,6 +158,7 @@ def load_and_preprocess_evaluation(file_path, crop_flag):
Args: Args:
file_path (str): Path to the desired file file_path (str): Path to the desired file
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_y2y_flag
Returns: Returns:
volume (np.array): Array of training image data of data type dtype. volume (np.array): Array of training image data of data type dtype.
...@@ -170,15 +171,25 @@ def load_and_preprocess_evaluation(file_path, crop_flag): ...@@ -170,15 +171,25 @@ def load_and_preprocess_evaluation(file_path, crop_flag):
original_image = Image(file_path[0]) original_image = Image(file_path[0])
if crop_flag == False: if cross_domain_y2y_flag == True:
volume, xform = resampleToPixdims(original_image, (2, 2, 2)) if crop_flag == False:
header = Image(volume, header=original_image.header, xform=xform).header volume = original_image.data[:, :, :, 0]
elif crop_flag == True: header = Image(volume, header=original_image.header).header
resampled, xform = resampleToPixdims(original_image, (2, 2, 2)) elif crop_flag == True:
resampled = Image(resampled, header=original_image.header, xform=xform) cropped = roi(original_image,((9,81),(10,100),(0,77)))
cropped = roi(resampled,((9,81),(10,100),(0,77))) volume = cropped.data[:, :, :, 0]
volume = cropped.data header = cropped.header
header = cropped.header xform = None
else:
if crop_flag == False:
volume, xform = resampleToPixdims(original_image, (2, 2, 2))
header = Image(volume, header=original_image.header, xform=xform).header
elif crop_flag == True:
resampled, xform = resampleToPixdims(original_image, (2, 2, 2))
resampled = Image(resampled, header=original_image.header, xform=xform)
cropped = roi(resampled,((9,81),(10,100),(0,77)))
volume = cropped.data
header = cropped.header
return volume, header, xform return volume, header, xform
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment