Commit 01d776d1 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

wrote evaluate data function which evaluates all data and saves it as an h5 db

parent 8248443a
......@@ -326,7 +326,12 @@ def evaluate_mapping(mapping_evaluation_parameters):
data_directory = mapping_evaluation_parameters['data_directory']
mapping_data_file = mapping_evaluation_parameters['mapping_data_file']
mapping_targets_file = mapping_evaluation_parameters['mapping_targets_file']
data_list = mapping_evaluation_parameters['data_list']
if mapping_evaluation_parameters['evaluate_all_data'] == False:
data_list = mapping_evaluation_parameters['data_list_reduced']
elif mapping_evaluation_parameters['evaluate_all_data'] == True:
data_list = mapping_evaluation_parameters['data_list_all']
prediction_output_path = mapping_evaluation_parameters['prediction_output_path']
dmri_mean_mask_path = mapping_evaluation_parameters['dmri_mean_mask_path']
rsfmri_mean_mask_path = mapping_evaluation_parameters['rsfmri_mean_mask_path']
......@@ -376,6 +381,85 @@ def evaluate_mapping(mapping_evaluation_parameters):
)
def evaluate_data(mapping_evaluation_parameters):
"""Mapping Evaluator
This function passes through the network an input and generates the rsfMRI outputs.
Args:
mapping_evaluation_parameters (dict): Dictionary of parameters useful during mapping evaluation.
mapping_evaluation_parameters = {
'trained_model_path': 'path/to/model'
'data_directory': 'path/to/data'
'data_list': 'path/to/datalist.txt/
'prediction_output_path': 'directory-of-saved-predictions'
'batch_size': 2
'device': 0
'exit_on_error': True
}
"""
trained_model_path = mapping_evaluation_parameters['trained_model_path']
data_directory = mapping_evaluation_parameters['data_directory']
mapping_data_file = mapping_evaluation_parameters['mapping_data_file']
mapping_targets_file = mapping_evaluation_parameters['mapping_targets_file']
if mapping_evaluation_parameters['evaluate_all_data'] == False:
data_list = mapping_evaluation_parameters['data_list_reduced']
elif mapping_evaluation_parameters['evaluate_all_data'] == True:
data_list = mapping_evaluation_parameters['data_list_all']
prediction_output_path = mapping_evaluation_parameters['prediction_output_path']
prediction_output_database_name = mapping_evaluation_parameters['prediction_output_database_name']
dmri_mean_mask_path = mapping_evaluation_parameters['dmri_mean_mask_path']
rsfmri_mean_mask_path = mapping_evaluation_parameters['rsfmri_mean_mask_path']
device = mapping_evaluation_parameters['device']
exit_on_error = mapping_evaluation_parameters['exit_on_error']
brain_mask_path = mapping_evaluation_parameters['brain_mask_path']
regression_factors = mapping_evaluation_parameters['regression_factors']
mean_regression_flag = mapping_evaluation_parameters['mean_regression_flag']
mean_regression_all_flag = mapping_evaluation_parameters['mean_regression_all_flag']
mean_subtraction_flag = mapping_evaluation_parameters['mean_subtraction_flag']
scale_volumes_flag = mapping_evaluation_parameters['scale_volumes_flag']
normalize_flag = mapping_evaluation_parameters['normalize_flag']
minus_one_scaling_flag = mapping_evaluation_parameters['minus_one_scaling_flag']
negative_flag = mapping_evaluation_parameters['negative_flag']
outlier_flag = mapping_evaluation_parameters['outlier_flag']
shrinkage_flag = mapping_evaluation_parameters['shrinkage_flag']
hard_shrinkage_flag = mapping_evaluation_parameters['hard_shrinkage_flag']
crop_flag = mapping_evaluation_parameters['crop_flag']
cross_domain_x2x_flag = mapping_evaluation_parameters['cross_domain_x2x_flag']
cross_domain_y2y_flag = mapping_evaluation_parameters['cross_domain_y2y_flag']
evaluations.evaluate_data(trained_model_path,
data_directory,
mapping_data_file,
mapping_targets_file,
data_list,
prediction_output_path,
prediction_output_database_name,
brain_mask_path,
dmri_mean_mask_path,
rsfmri_mean_mask_path,
regression_factors,
mean_regression_flag,
mean_regression_all_flag,
mean_subtraction_flag,
scale_volumes_flag,
normalize_flag,
minus_one_scaling_flag,
negative_flag,
outlier_flag,
shrinkage_flag,
hard_shrinkage_flag,
crop_flag,
device,
exit_on_error,
cross_domain_x2x_flag,
cross_domain_y2y_flag
)
def delete_files(folder):
""" Clear Folder Contents
......
[MAPPING]
trained_model_path = "saved_models/VA2-1.pth.tar"
prediction_output_path = "VA2-1_predictions"
prediction_output_database_name = "output_test_data.h5"
data_directory = "/well/win-biobank/projects/imaging/data/data3/subjectsAll/"
mapping_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
mapping_targets_file = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
data_list = "datasets/test.txt"
data_list_reduced = "datasets/test_reduced.txt"
data_list_all = "datasets/test_all.txt"
evaluate_all_data = False
brain_mask_path = "utils/MNI152_T1_2mm_brain_mask.nii.gz"
rsfmri_mean_mask_path = "utils/mean_dr_stage2.nii.gz"
dmri_mean_mask_path = "utils/mean_tractsNormSummed_downsampled.nii.gz"
......
......@@ -19,6 +19,7 @@ import pickle
import numpy as np
import torch
import logging
import h5py
import utils.data_utils as data_utils
from utils.common_utils import create_folder
import pandas as pd
......@@ -28,6 +29,173 @@ import itertools
log = logging.getLogger(__name__)
def evaluate_data(trained_model_path,
data_directory,
mapping_data_file,
mapping_targets_file,
data_list,
prediction_output_path,
prediction_output_database_name,
brain_mask_path,
dmri_mean_mask_path,
rsfmri_mean_mask_path,
regression_factors,
mean_regression_flag,
mean_regression_all_flag,
mean_subtraction_flag,
scale_volumes_flag,
normalize_flag,
minus_one_scaling_flag,
negative_flag,
outlier_flag,
shrinkage_flag,
hard_shrinkage_flag,
crop_flag,
device=0,
exit_on_error=False,
cross_domain_x2x_flag=False,
cross_domain_y2y_flag=False,
mode='evaluate'):
"""Model Evaluator
This function generates the rsfMRI arrays for the given inputs
Args:
trained_model_path (str): Path to the location of the trained model
data_directory (str): Path to input data directory
mapping_data_file (str): Path to the input file
mapping_targets_file (str): Path to the target file
data_list (str): Path to a .txt file containing the input files for consideration
prediction_output_path (str): Output prediction path
prediction_output_database_name (str): Name of the output database
brain_mask_path (str): Path to the MNI brain mask file
dmri_mean_mask_path (str): Path to the dualreg subject mean mask
rsfmri_mean_mask_path (str): Path to the summed tract mean mask
regression_factors (str): Path to the linear regression weights file
mean_regression_flag (bool): Flag indicating if the volumes should be de-meaned by regression using the mean_mask_path
mean_regression_all_flag (bool): Flag indicating if both the input and target volumes should be regressed. If False, only targets are regressed.
mean_subtraction_flag (bool): Flag indicating if the targets should be de-meaned by subtraction using the mean_mask_path
scale_volumes_flag (bool): Flag indicating if the volumes should be scaled.
normalize_flag (bool): Flag signaling if the volume should be normalized ([0,1] if True) or scaled to [-1,1] if False.
minus_one_scaling_flag (bool): Flag signaling if the volume should be scaled to [-1,1] if True
negative_flag (bool): Flag indicating if all the negative values should be 0-ed.
outlier_flag (bool): Flag indicating if outliers should be set to the min/max values.
shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied.
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
device (str/int): Device type used for training (int - GPU id, str- CPU)
mode (str): Current run mode or phase
exit_on_error (bool): Flag that triggers the raising of an exception
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Raises:
FileNotFoundError: Error in reading the provided file!
Exception: Error code execution!
"""
log.info(
"Started Evaluation. Check tensorboard for plots (if a LogWriter is provided)")
with open(data_list) as data_list_file:
volumes_to_be_used = data_list_file.read().splitlines()
# Test if cuda is available and attempt to run on GPU
cuda_available = torch.cuda.is_available()
if type(device) == int:
if cuda_available:
model = torch.load(trained_model_path)
torch.cuda.empty_cache()
model.cuda(device)
else:
log.warning(
"CUDA not available. Switching to CPU. Investigate behaviour!")
device = 'cpu'
if (type(device) == str) or not cuda_available:
model = torch.load(trained_model_path,
map_location=torch.device(device))
model.eval()
# Create the prediction path folder if this is not available
create_folder(prediction_output_path)
# Initiate the evaluation
log.info("rsfMRI Generation Started")
if cross_domain_y2y_flag == True:
# If doing y2y autoencoder, then we load the targets as inputs. In all other cases (x2x & x2y) we load the inputs as inputs.
file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file=mapping_targets_file)
else:
file_paths, volumes_to_be_used = data_utils.load_file_paths(data_directory, data_list, mapping_data_file)
output_database_path = os.path.join(prediction_output_path, prediction_output_database_name)
output_database_handle = h5py.File(output_database_path, 'w')
with torch.no_grad():
for volume_index, file_path in enumerate(file_paths):
try:
print("Mapping Volume {}/{}".format(volume_index+1, len(file_paths)))
# Generate volume & header
subject = volumes_to_be_used[volume_index]
predicted_complete_volume, predicted_volume, header, xform = _generate_volume_map(file_path,
subject,
model,
device,
cuda_available,
brain_mask_path,
dmri_mean_mask_path,
rsfmri_mean_mask_path,
regression_factors,
mean_regression_flag,
mean_regression_all_flag,
mean_subtraction_flag,
scale_volumes_flag,
normalize_flag,
minus_one_scaling_flag,
negative_flag,
outlier_flag,
shrinkage_flag,
hard_shrinkage_flag,
crop_flag,
cross_domain_x2x_flag,
cross_domain_y2y_flag)
group = output_database_handle.create_group(subject)
group.create_dataset('predicted_complete_volume', data=predicted_complete_volume)
group.create_dataset('predicted_volume', data=predicted_volume)
group.create_dataset('header', data=header)
group.create_dataset('xform', data=xform)
log.info("Processed: " + volumes_to_be_used[volume_index] + " " + str(
volume_index + 1) + " out of " + str(len(volumes_to_be_used)))
print("Mapped Volumes saved in: ", prediction_output_path)
except FileNotFoundError as exception_expression:
log.error("Error in reading the provided file!")
log.exception(exception_expression)
if exit_on_error:
raise(exception_expression)
except Exception as exception_expression:
log.error("Error code execution!")
log.exception(exception_expression)
if exit_on_error:
raise(exception_expression)
log.info("Output Data Generation Complete")
output_database_handle.close()
def evaluate_mapping(trained_model_path,
data_directory,
mapping_data_file,
......@@ -523,6 +691,7 @@ def _hard_shrinkage(volume, lambd):
return volume
def _soft_shrinkage(volume, lambd):
""" Soft Shrinkage
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment