Commit 815f5223 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

added ability to calculate statistics for test set

parent 01d776d1
...@@ -411,6 +411,7 @@ def evaluate_data(mapping_evaluation_parameters): ...@@ -411,6 +411,7 @@ def evaluate_data(mapping_evaluation_parameters):
prediction_output_path = mapping_evaluation_parameters['prediction_output_path'] prediction_output_path = mapping_evaluation_parameters['prediction_output_path']
prediction_output_database_name = mapping_evaluation_parameters['prediction_output_database_name'] prediction_output_database_name = mapping_evaluation_parameters['prediction_output_database_name']
prediction_output_statistics_name = mapping_evaluation_parameters['prediction_output_statistics_name']
dmri_mean_mask_path = mapping_evaluation_parameters['dmri_mean_mask_path'] dmri_mean_mask_path = mapping_evaluation_parameters['dmri_mean_mask_path']
rsfmri_mean_mask_path = mapping_evaluation_parameters['rsfmri_mean_mask_path'] rsfmri_mean_mask_path = mapping_evaluation_parameters['rsfmri_mean_mask_path']
device = mapping_evaluation_parameters['device'] device = mapping_evaluation_parameters['device']
...@@ -428,6 +429,7 @@ def evaluate_data(mapping_evaluation_parameters): ...@@ -428,6 +429,7 @@ def evaluate_data(mapping_evaluation_parameters):
shrinkage_flag = mapping_evaluation_parameters['shrinkage_flag'] shrinkage_flag = mapping_evaluation_parameters['shrinkage_flag']
hard_shrinkage_flag = mapping_evaluation_parameters['hard_shrinkage_flag'] hard_shrinkage_flag = mapping_evaluation_parameters['hard_shrinkage_flag']
crop_flag = mapping_evaluation_parameters['crop_flag'] crop_flag = mapping_evaluation_parameters['crop_flag']
output_database_flag = mapping_evaluation_parameters['output_database_flag']
cross_domain_x2x_flag = mapping_evaluation_parameters['cross_domain_x2x_flag'] cross_domain_x2x_flag = mapping_evaluation_parameters['cross_domain_x2x_flag']
cross_domain_y2y_flag = mapping_evaluation_parameters['cross_domain_y2y_flag'] cross_domain_y2y_flag = mapping_evaluation_parameters['cross_domain_y2y_flag']
...@@ -438,6 +440,7 @@ def evaluate_data(mapping_evaluation_parameters): ...@@ -438,6 +440,7 @@ def evaluate_data(mapping_evaluation_parameters):
data_list, data_list,
prediction_output_path, prediction_output_path,
prediction_output_database_name, prediction_output_database_name,
prediction_output_statistics_name,
brain_mask_path, brain_mask_path,
dmri_mean_mask_path, dmri_mean_mask_path,
rsfmri_mean_mask_path, rsfmri_mean_mask_path,
...@@ -455,6 +458,7 @@ def evaluate_data(mapping_evaluation_parameters): ...@@ -455,6 +458,7 @@ def evaluate_data(mapping_evaluation_parameters):
crop_flag, crop_flag,
device, device,
exit_on_error, exit_on_error,
output_database_flag,
cross_domain_x2x_flag, cross_domain_x2x_flag,
cross_domain_y2y_flag cross_domain_y2y_flag
) )
......
...@@ -2,12 +2,14 @@ ...@@ -2,12 +2,14 @@
trained_model_path = "saved_models/VA2-1.pth.tar" trained_model_path = "saved_models/VA2-1.pth.tar"
prediction_output_path = "VA2-1_predictions" prediction_output_path = "VA2-1_predictions"
prediction_output_database_name = "output_test_data.h5" prediction_output_database_name = "output_test_data.h5"
prediction_output_statistics_name = "output_statistics.pkl"
data_directory = "/well/win-biobank/projects/imaging/data/data3/subjectsAll/" data_directory = "/well/win-biobank/projects/imaging/data/data3/subjectsAll/"
mapping_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz" mapping_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
mapping_targets_file = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz" mapping_targets_file = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
data_list_reduced = "datasets/test_reduced.txt" data_list_reduced = "datasets/test_reduced.txt"
data_list_all = "datasets/test_all.txt" data_list_all = "datasets/test_all.txt"
evaluate_all_data = False evaluate_all_data = False
output_database_flag = False
brain_mask_path = "utils/MNI152_T1_2mm_brain_mask.nii.gz" brain_mask_path = "utils/MNI152_T1_2mm_brain_mask.nii.gz"
rsfmri_mean_mask_path = "utils/mean_dr_stage2.nii.gz" rsfmri_mean_mask_path = "utils/mean_dr_stage2.nii.gz"
dmri_mean_mask_path = "utils/mean_tractsNormSummed_downsampled.nii.gz" dmri_mean_mask_path = "utils/mean_tractsNormSummed_downsampled.nii.gz"
......
...@@ -26,6 +26,8 @@ import pandas as pd ...@@ -26,6 +26,8 @@ import pandas as pd
from fsl.data.image import Image from fsl.data.image import Image
from fsl.utils.image.roi import roi from fsl.utils.image.roi import roi
import itertools import itertools
from scipy.spatial.distance import cosine
from scipy.stats import pearsonr, spearmanr
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -36,6 +38,7 @@ def evaluate_data(trained_model_path, ...@@ -36,6 +38,7 @@ def evaluate_data(trained_model_path,
data_list, data_list,
prediction_output_path, prediction_output_path,
prediction_output_database_name, prediction_output_database_name,
prediction_output_statistics_name,
brain_mask_path, brain_mask_path,
dmri_mean_mask_path, dmri_mean_mask_path,
rsfmri_mean_mask_path, rsfmri_mean_mask_path,
...@@ -53,6 +56,7 @@ def evaluate_data(trained_model_path, ...@@ -53,6 +56,7 @@ def evaluate_data(trained_model_path,
crop_flag, crop_flag,
device=0, device=0,
exit_on_error=False, exit_on_error=False,
output_database_flag=False,
cross_domain_x2x_flag=False, cross_domain_x2x_flag=False,
cross_domain_y2y_flag=False, cross_domain_y2y_flag=False,
mode='evaluate'): mode='evaluate'):
...@@ -69,6 +73,7 @@ def evaluate_data(trained_model_path, ...@@ -69,6 +73,7 @@ def evaluate_data(trained_model_path,
data_list (str): Path to a .txt file containing the input files for consideration data_list (str): Path to a .txt file containing the input files for consideration
prediction_output_path (str): Output prediction path prediction_output_path (str): Output prediction path
prediction_output_database_name (str): Name of the output database prediction_output_database_name (str): Name of the output database
prediction_output_statistics_name (str): Name of the output statistics database
brain_mask_path (str): Path to the MNI brain mask file brain_mask_path (str): Path to the MNI brain mask file
dmri_mean_mask_path (str): Path to the dualreg subject mean mask dmri_mean_mask_path (str): Path to the dualreg subject mean mask
rsfmri_mean_mask_path (str): Path to the summed tract mean mask rsfmri_mean_mask_path (str): Path to the summed tract mean mask
...@@ -87,6 +92,7 @@ def evaluate_data(trained_model_path, ...@@ -87,6 +92,7 @@ def evaluate_data(trained_model_path,
device (str/int): Device type used for training (int - GPU id, str- CPU) device (str/int): Device type used for training (int - GPU id, str- CPU)
mode (str): Current run mode or phase mode (str): Current run mode or phase
exit_on_error (bool): Flag that triggers the raising of an exception exit_on_error (bool): Flag that triggers the raising of an exception
output_database_flag (bool): Flag indicating if the output maps should be saved to hdf5 database
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
...@@ -129,12 +135,20 @@ def evaluate_data(trained_model_path, ...@@ -129,12 +135,20 @@ def evaluate_data(trained_model_path,
log.info("rsfMRI Generation Started") log.info("rsfMRI Generation Started")
if cross_domain_y2y_flag == True: if cross_domain_y2y_flag == True:
# If doing y2y autoencoder, then we load the targets as inputs. In all other cases (x2x & x2y) we load the inputs as inputs. # If doing y2y autoencoder, then we load the targets as inputs. In all other cases (x2x & x2y) we load the inputs as inputs.
file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file=mapping_targets_file) file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file=mapping_targets_file, mapping_targets_file=mapping_targets_file)
elif cross_domain_x2x_flag == True:
file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file, mapping_targets_file=mapping_data_file)
else: else:
file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file) file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file, mapping_targets_file)
output_database_path = os.path.join(prediction_output_path, prediction_output_database_name) if output_database_flag == True:
output_database_handle = h5py.File(output_database_path, 'w') output_database_path = os.path.join(prediction_output_path, prediction_output_database_name)
if os.path.exists(output_database_path):
os.remove(output_database_path)
output_database_handle = h5py.File(output_database_path, 'w')
output_statistics = {}
output_statistics_path = os.path.join(prediction_output_path, prediction_output_statistics_name)
with torch.no_grad(): with torch.no_grad():
...@@ -168,12 +182,27 @@ def evaluate_data(trained_model_path, ...@@ -168,12 +182,27 @@ def evaluate_data(trained_model_path,
cross_domain_x2x_flag, cross_domain_x2x_flag,
cross_domain_y2y_flag) cross_domain_y2y_flag)
target_volume = _generate_target_volume(file_path,
group = output_database_handle.create_group(subject) subject,
group.create_dataset('predicted_complete_volume', data=predicted_complete_volume) dmri_mean_mask_path,
group.create_dataset('predicted_volume', data=predicted_volume) rsfmri_mean_mask_path,
group.create_dataset('header', data=header) regression_factors,
group.create_dataset('xform', data=xform) mean_regression_flag,
mean_regression_all_flag,
mean_subtraction_flag,
crop_flag,
cross_domain_x2x_flag
)
mse, mae, cel, pearson_r, pearson_p, spearman_r, spearman_p, regression_w, regression_b = _statistics_calculator(predicted_volume, target_volume)
output_statistics[subject] = [mse, mae, cel, pearson_r, pearson_p, spearman_r, spearman_p, regression_w, regression_b]
if output_database_flag == True:
group = output_database_handle.create_group(subject)
group.create_dataset('predicted_complete_volume', data=predicted_complete_volume)
group.create_dataset('predicted_volume', data=predicted_volume)
group.create_dataset('header', data=header)
group.create_dataset('xform', data=xform)
log.info("Processed: " + volumes_to_be_used[volume_index] + " " + str( log.info("Processed: " + volumes_to_be_used[volume_index] + " " + str(
volume_index + 1) + " out of " + str(len(volumes_to_be_used))) volume_index + 1) + " out of " + str(len(volumes_to_be_used)))
...@@ -192,10 +221,14 @@ def evaluate_data(trained_model_path, ...@@ -192,10 +221,14 @@ def evaluate_data(trained_model_path,
if exit_on_error: if exit_on_error:
raise(exception_expression) raise(exception_expression)
output_statistics_df = pd.DataFrame.from_dict(output_statistics, orient='index', columns=['mse', 'mae', 'cel', 'pearson_r', 'pearson_p', 'spearman_r', 'spearman_p', 'regression_w', 'regression_b'])
output_statistics_df.to_pickle(output_statistics_path)
log.info("Output Data Generation Complete") log.info("Output Data Generation Complete")
output_database_handle.close() output_database_handle.close()
def evaluate_mapping(trained_model_path, def evaluate_mapping(trained_model_path,
data_directory, data_directory,
mapping_data_file, mapping_data_file,
...@@ -408,7 +441,7 @@ def _generate_volume_map(file_path, ...@@ -408,7 +441,7 @@ def _generate_volume_map(file_path,
cross_domain_x2x_flag, cross_domain_x2x_flag,
cross_domain_y2y_flag cross_domain_y2y_flag
): ):
"""rsfMRI Volume Generator """Output Volume Generator
This function uses the trained model to generate a new volume This function uses the trained model to generate a new volume
...@@ -616,7 +649,7 @@ def _regress_input(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, ...@@ -616,7 +649,7 @@ def _regress_input(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path,
elif crop_flag == True: elif crop_flag == True:
group_mean = roi(Image(dmri_mean_mask_path),((9,81),(10,100),(0,77))).data group_mean = roi(Image(dmri_mean_mask_path),((9,81),(10,100),(0,77))).data
regressed_volume = np.subtract(volume, np.multiply(weight, group_mean)) regressed_volume = np.subtract(volume, np.multiply(weight, group_mean))
return regressed_volume return regressed_volume
...@@ -716,6 +749,161 @@ def _soft_shrinkage(volume, lambd): ...@@ -716,6 +749,161 @@ def _soft_shrinkage(volume, lambd):
return volume return volume
def _generate_target_volume(file_path,
subject,
dmri_mean_mask_path,
rsfmri_mean_mask_path,
regression_factors,
mean_regression_flag,
mean_regression_all_flag,
mean_subtraction_flag,
crop_flag,
cross_domain_x2x_flag
):
"""Target Volume Generator
This function loads and preprocesses a target volume for comparing with the network predicted volumes
Args:
file_path (str): Path to the desired file
subject (str): Subject ID of the subject volume to be regressed
dmri_mean_mask_path (str): Path to the group mean volume
rsfmri_mean_mask_path (str): Path to the dualreg subject mean mask
regression_factors (str): Path to the linear regression weights file
mean_regression_flag (bool): Flag indicating if the volumes should be de-meaned by regression using the mean_mask_path
mean_regression_all_flag (bool): Flag indicating if both the input and target volumes should be regressed. If False, only targets are regressed.
mean_subtraction_flag (bool): Flag indicating if the targets should be de-meaned by subtraction using the mean_mask_path
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
Returns:
volume (np.array): Array containing the information regarding the target volume
"""
volume = data_utils.load_and_preprocess_targets(file_path, crop_flag, cross_domain_x2x_flag)
if mean_regression_flag == True:
volume = _regress_target(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, regression_factors, crop_flag, cross_domain_x2x_flag, mean_regression_all_flag)
elif mean_subtraction_flag == True:
volume = _subtract_target(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, crop_flag, cross_domain_x2x_flag)
return volume
def _regress_target(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, regression_factors, crop_flag, cross_domain_x2x_flag, mean_regression_all_flag):
""" Target Regression
This function regresse the group mean from the target volume using the saved regression weights.
TODO: This function repressents only a temporary solution. For deployment, a NN needs to be trained which predicts the relevant scaling factors.
Args:
volume (np.array): Unregressed volume
subject (str): Subject ID of the subject volume to be regressed
dmri_mean_mask_path (str): Path to the group mean volume
rsfmri_mean_mask_path (str): Path to the target group mean volume
regression_factors (str): Path to the linear regression weights file
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the targets
mean_regression_all_flag (bool): Flag indicating if both the input and target volumes should be regressed. If False, only targets are regressed.
Returns:
regressed_volume (np.array): Linear regressed volume
"""
if cross_domain_x2x_flag == True:
if mean_regression_all_flag == True:
weight = pd.read_pickle(regression_factors).loc[subject]['w_dMRI']
if crop_flag == False:
group_mean = Image(dmri_mean_mask_path).data
elif crop_flag == True:
group_mean = roi(Image(dmri_mean_mask_path),((9,81),(10,100),(0,77))).data
regressed_volume = np.subtract(volume, np.multiply(weight, group_mean))
else:
regressed_volume = volume
else:
weight = pd.read_pickle(regression_factors).loc[subject]['w_rsfMRI']
if crop_flag == False:
group_mean = Image(rsfmri_mean_mask_path).data[:, :, :, 0]
elif crop_flag == True:
group_mean = roi(Image(rsfmri_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]
regressed_volume = np.subtract(volume, np.multiply(weight, group_mean))
return regressed_volume
def _subtract_target(volume, subject, dmri_mean_mask_path, rsfmri_mean_mask_path, crop_flag, cross_domain_x2x_flag):
""" Target Subtraction
This function subtracts the group mean from the target volume using the saved regression weights.
TODO: This function repressents only a temporary solution. For deployment, a NN needs to be trained which predicts the relevant scaling factors.
Args:
volume (np.array): Unregressed volume
subject (str): Subject ID of the subject volume to be regressed
dmri_mean_mask_path (str): Path to the group mean volume
rsfmri_mean_mask_path (str): Path to the target group mean volume
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns:
regressed_volume (np.array): Linear regressed volume
"""
if cross_domain_x2x_flag == True:
subtracted_volume = volume
else:
if crop_flag == False:
group_mean = Image(rsfmri_mean_mask_path).data[:, :, :, 0]
elif crop_flag == True:
group_mean = roi(Image(rsfmri_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]
subtracted_volume = np.subtract(volume, group_mean)
return subtracted_volume
def _statistics_calculator(volume, target):
""" Training statistics calculator
This function calculates the MSE, MAE, CEL, Pearson R and P and linear regression W and B for a predicted volume and it's target ground truth.
Args:
volume (np.array): Predicted volume
target (np.array): Ground truth volume
Returns:
mse (np.float64): The mean squared error between the prediction and the ground truth; The closer to 0, the better
mae (np.float64): The mean absolut error between the prediction and the ground truth; The closer to 0, the better
cel (np.float64): The cosine distance between the prediction and the ground truth; The closer to 0, the better
pearson_r (np.float64): Pearson’s correlation coefficient; The closer to 1, the better
pearson_p (np.float64): Two-tailed p-value for Pearson’s correlation coefficient; the closer to 0, the better
spearman_r (np.float64): Spearman correlation coefficient; The closer to 1, the better
spearman_p (np.float64): Two-tailed p-value for Spearman's correlation coefficient; the closer to 0, the better
regression_w (np.float64): Slope of the linear regression line; The closer to 1 the better
regression_b (np.float64): Intersect of the linear regression line; The closer to 0, the better
"""
x = np.reshape(volume, -1)
y = np.reshape(target, -1)
mse = np.square(np.subtract(x,y)).mean()
mae = np.abs(np.subtract(x,y)).mean()
cel = np.mean(cosine(x, y))
pearson_r, pearson_p = pearsonr(x,y)
spearman_r, spearman_p = spearmanr(x,y)
x_matrix = np.vstack((np.ones(len(x)), x)).T
regression_b, regression_w = np.linalg.inv(x_matrix.T.dot(x_matrix)).dot(x_matrix.T).dot(y)
return mse, mae, cel, pearson_r, pearson_p, spearman_r, spearman_p, regression_w, regression_b
def _pearson_correlation(volume, target): def _pearson_correlation(volume, target):
"""Calculate Pearson Correlation Coefficient """Calculate Pearson Correlation Coefficient
...@@ -733,3 +921,4 @@ def _pearson_correlation(volume, target): ...@@ -733,3 +921,4 @@ def _pearson_correlation(volume, target):
np.sum(np.power(np.subtract(volume, volume.mean()), 2)), np.sum(np.power(np.subtract(target, target.mean()), 2)))) np.sum(np.power(np.subtract(volume, volume.mean()), 2)), np.sum(np.power(np.subtract(target, target.mean()), 2))))
return r return r
...@@ -96,7 +96,7 @@ def get_datasets(data_parameters, cross_domain_x2x_flag, cross_domain_y2y_flag): ...@@ -96,7 +96,7 @@ def get_datasets(data_parameters, cross_domain_x2x_flag, cross_domain_y2y_flag):
) )
def load_file_paths(data_directory, data_list, mapping_data_file, targets_directory=None, target_file=None): def load_file_paths(data_directory, data_list, mapping_data_file, mapping_targets_file=None):
"""File Loader """File Loader
This function returns a list of combined file paths for the input and output data. This function returns a list of combined file paths for the input and output data.
...@@ -105,7 +105,7 @@ def load_file_paths(data_directory, data_list, mapping_data_file, targets_direct ...@@ -105,7 +105,7 @@ def load_file_paths(data_directory, data_list, mapping_data_file, targets_direct
data_directory (str): Path to input data directory data_directory (str): Path to input data directory
data_list (str): Path to a .txt file containing the input files for consideration data_list (str): Path to a .txt file containing the input files for consideration
mapping_data_file (str): Path to the input files mapping_data_file (str): Path to the input files
targets_directory (str): Path to labelled data (Y-equivalent); None if during evaluation. mapping_targets_file (str): Path to the target files
Returns: Returns:
file_paths (list): List containing the input data and target labelled output data file_paths (list): List containing the input data and target labelled output data
...@@ -117,12 +117,12 @@ def load_file_paths(data_directory, data_list, mapping_data_file, targets_direct ...@@ -117,12 +117,12 @@ def load_file_paths(data_directory, data_list, mapping_data_file, targets_direct
volumes_to_be_used = load_subjects_from_path(data_directory, data_list) volumes_to_be_used = load_subjects_from_path(data_directory, data_list)
if targets_directory == None or target_file == None: if mapping_targets_file == None:
file_paths = [[os.path.join(data_directory, volume, mapping_data_file)] file_paths = [[os.path.join(data_directory, volume, mapping_data_file)]
for volume in volumes_to_be_used] for volume in volumes_to_be_used]
else: else:
file_paths = [[os.path.join(data_directory, volume, mapping_data_file), os.path.join( file_paths = [[os.path.join(data_directory, volume, mapping_data_file),
targets_directory, volume)] for volume in volumes_to_be_used] os.path.join(data_directory, volume, mapping_targets_file)] for volume in volumes_to_be_used]
return file_paths, volumes_to_be_used return file_paths, volumes_to_be_used
...@@ -158,10 +158,10 @@ def load_and_preprocess_evaluation(file_path, crop_flag, cross_domain_y2y_flag): ...@@ -158,10 +158,10 @@ def load_and_preprocess_evaluation(file_path, crop_flag, cross_domain_y2y_flag):
Args: Args:
file_path (str): Path to the desired file file_path (str): Path to the desired file
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_y2y_flag cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns: Returns:
volume (np.array): Array of training image data of data type dtype. volume (np.array): Array of training image data
header (class): 'nibabel.nifti1.Nifti1Header' class object, containing image metadata header (class): 'nibabel.nifti1.Nifti1Header' class object, containing image metadata
xform (np.array): Array of shape (4, 4), containing the adjusted voxel-to-world transformation for the spatial dimensions of the resampled data xform (np.array): Array of shape (4, 4), containing the adjusted voxel-to-world transformation for the spatial dimensions of the resampled data
...@@ -194,27 +194,35 @@ def load_and_preprocess_evaluation(file_path, crop_flag, cross_domain_y2y_flag): ...@@ -194,27 +194,35 @@ def load_and_preprocess_evaluation(file_path, crop_flag, cross_domain_y2y_flag):
return volume, header, xform return volume, header, xform
def load_and_preprocess_targets(target_path, mean_mask_path): def load_and_preprocess_targets(file_path, crop_flag, cross_domain_x2x_flag):
"""Load & Preprocessing targets before evaluation """Load & Preprocessing targets before evaluation
This function loads a nifty file and returns its volume, a de-meaned volume and header information This function loads a nifty file and returns its volume information
Args: Args:
file_path (str): Path to the desired target file file_path (str): Path to the desired file
mean_mask_path (str): Path to the dualreg subject mean mask crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
Returns: Returns:
target (np.array): Array of training image data of data type dtype. volume (np.array): Array of target image intensities.
target_demeaned (np.array): Array of training data from which the group mean has been subtracted
Raises: """
ValueError: "Orientation value is invalid. It must be either >>coronal<<, >>axial<< or >>sagital<< "
""" original_image = Image(file_path[1])
target = Image(target_path[0]).data[:, :, :, 0] if cross_domain_x2x_flag == True:
target_demeaned = np.subtract( if crop_flag == False:
target, Image(mean_mask_path).data[:, :, :, 0]) volume, _ = resampleToPixdims(original_image, (2, 2, 2))
elif crop_flag == True:
resampled, xform = resampleToPixdims(original_image, (2, 2, 2))
volume = roi(Image(resampled, header=original_image.header, xform=xform),((9,81),(10,100),(0,77))).data
else:
if crop_flag == False:
volume = original_image.data[:, :, :, 0]
elif crop_flag == True:
volume = roi(original_image,((9,81),(10,100),(0,77))).data[:, :, :, 0]
return target, target_demeaned return volume
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment