Commit 8d308d99 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

added volume cropping to reduce storage and network req + inc speed

parent 2d2cd6eb
......@@ -178,7 +178,8 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
logs_directory=misc_parameters['logs_directory'],
checkpoint_directory=misc_parameters['checkpoint_directory'],
save_model_directory=misc_parameters['save_model_directory'],
final_model_output_file=training_parameters['final_model_output_file']
final_model_output_file=training_parameters['final_model_output_file'],
crop_flag = data_parameters['crop_flag']
)
validation_loss = solver.train(train_loader, validation_loader)
......@@ -192,66 +193,6 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
_ = _train_runner(data_parameters, training_parameters, network_parameters, misc_parameters)
def evaluate_score(training_parameters, network_parameters, misc_parameters, evaluation_parameters):
"""Mapping Score Evaluator
This function evaluates a given trained model by calculating the it's dice score prediction.
Args:
training_parameters(dict): Dictionary containing relevant hyperparameters for training the network.
training_parameters = {
'experiment_name': 'experiment_name'
}
network_parameters (dict): Contains information relevant parameters
network_parameters= {
'number_of_classes': 1
}
misc_parameters (dict): Dictionary of aditional hyperparameters
misc_parameters = {
'logs_directory': 'log-directory'
'device': 1
'experiments_directory': 'experiments-directory'
}
evaluation_parameters (dict): Dictionary of parameters useful during evaluation.
evaluation_parameters = {
'trained_model_path': 'path/to/model'
'data_directory': 'path/to/data'
'targets_directory': 'path/to/targets'
'data_list': 'path/to/datalist.txt/
'orientation': 'coronal'
'saved_predictions_directory': 'directory-of-saved-predictions'
}
"""
# TODO - NEED TO UPDATE THE DATA FUNCTIONS!
prediction_output_path = os.path.join(misc_parameters['experiments_directory'],
training_parameters['experiment_name'],
evaluation_parameters['saved_predictions_directory']
)
evaluations.evaluate_correlation(trained_model_path=evaluation_parameters['trained_model_path'],
data_directory=evaluation_parameters['data_directory'],
mapping_data_file=mapping_evaluation_parameters['mapping_data_file'],
target_data_file=evaluation_parameters['targets_directory'],
data_list=evaluation_parameters['data_list'],
prediction_output_path=prediction_output_path,
brain_mask_path=mapping_evaluation_parameters['brain_mask_path'],
rsfmri_mean_mask_path=mapping_evaluation_parameters[
'rsfmri_mean_mask_path'],
dmri_mean_mask_path=mapping_evaluation_parameters[
'dmri_mean_mask_path'],
mean_regression=mapping_evaluation_parameters['mean_regression'],
scaling_factors=mapping_evaluation_parameters['scaling_factors'],
regression_factors=mapping_evaluation_parameters['regression_factors'],
device=misc_parameters['device'],
)
def evaluate_mapping(mapping_evaluation_parameters):
"""Mapping Evaluator
......@@ -280,10 +221,17 @@ def evaluate_mapping(mapping_evaluation_parameters):
device = mapping_evaluation_parameters['device']
exit_on_error = mapping_evaluation_parameters['exit_on_error']
brain_mask_path = mapping_evaluation_parameters['brain_mask_path']
mean_regression = mapping_evaluation_parameters['mean_regression']
mean_subtraction = mapping_evaluation_parameters['mean_subtraction']
scaling_factors = mapping_evaluation_parameters['scaling_factors']
regression_factors = mapping_evaluation_parameters['regression_factors']
mean_regression_flag = mapping_evaluation_parameters['mean_regression_flag']
mean_regression_all_flag = mapping_evaluation_parameters['mean_regression_all_flag']
mean_subtraction_flag = mapping_evaluation_parameters['mean_subtraction_flag']
scale_volumes_flag = mapping_evaluation_parameters['scale_volumes_flag']
normalize_flag = mapping_evaluation_parameters['normalize_flag']
negative_flag = mapping_evaluation_parameters['negative_flag']
outlier_flag = mapping_evaluation_parameters['outlier_flag']
shrinkage_flag = mapping_evaluation_parameters['shrinkage_flag']
hard_shrinkage_flag = mapping_evaluation_parameters['hard_shrinkage_flag']
crop_flag = mapping_evaluation_parameters['crop_flag']
evaluations.evaluate_mapping(trained_model_path,
data_directory,
......@@ -293,13 +241,19 @@ def evaluate_mapping(mapping_evaluation_parameters):
brain_mask_path,
dmri_mean_mask_path,
rsfmri_mean_mask_path,
mean_regression,
mean_subtraction,
scaling_factors,
regression_factors,
device=device,
exit_on_error=exit_on_error)
mean_regression_flag,
mean_regression_all_flag,
mean_subtraction_flag,
scale_volumes_flag,
normalize_flag,
negative_flag,
outlier_flag,
shrinkage_flag,
hard_shrinkage_flag,
crop_flag,
device,
exit_on_error)
def delete_files(folder):
""" Clear Folder Contents
......@@ -344,9 +298,6 @@ if __name__ == '__main__':
# NOTE: THE EVAL FUNCTIONS HAVE NOT YET BEEN DEBUGGED (16/04/20)
elif arguments.mode == 'evaluate-score':
evaluate_score(training_parameters,
network_parameters, misc_parameters, evaluation_parameters)
elif arguments.mode == 'evaluate-mapping':
logging.basicConfig(filename='evaluate-mapping-error.log')
if arguments.settings_path is not None:
......
......@@ -4,6 +4,7 @@ input_data_train = "input_data_train.h5"
target_data_train = "target_data_train.h5"
input_data_validation = "input_data_validation.h5"
target_data_validation = "target_data_validation.h5"
crop_flag = False
[TRAINING]
experiment_name = "VA2-1"
......
[MAPPING]
trained_model_path = "saved_models/VA2.pth.tar"
prediction_output_path = "VA2_predictions"
trained_model_path = "saved_models/VA2-1.pth.tar"
prediction_output_path = "VA2-1_predictions"
data_directory = "/well/win-biobank/projects/imaging/data/data3/subjectsAll/"
mapping_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
data_list = "datasets/test.txt"
......@@ -8,9 +8,16 @@ brain_mask_path = "utils/MNI152_T1_2mm_brain_mask.nii.gz"
rsfmri_mean_mask_path = "utils/mean_dr_stage2.nii.gz"
dmri_mean_mask_path = "utils/mean_tractsNormSummed_downsampled.nii.gz"
mean_mask_path = "utils/mean_dr_stage2.nii.gz"
scaling_factors = "datasets/scaling_factors.pkl"
regression_factors = "datasets/regression_weights.pkl"
mean_regression = False
mean_subtraction = True
mean_regression_flag = False
mean_regression_all_flag = False
mean_subtraction_flag = True
scale_volumes_flag = True
normalize_flag = True
negative_flag = True
outlier_flag = True
shrinkage_flag = False
hard_shrinkage_flag = False
crop_flag = False
device = 0
exit_on_error = True
exit_on_error = True
\ No newline at end of file
......@@ -17,6 +17,7 @@ import torch
import glob
from fsl.data.image import Image
from fsl.utils.image.roi import roi
from datetime import datetime
from utils.losses import MSELoss
from utils.common_utils import create_folder
......@@ -49,6 +50,7 @@ class Solver():
use_last_checkpoint (bool): Flag for loading the previous checkpoint
experiment_directory (str): Experiment output directory name
logs_directory (str): Directory for outputing training logs
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
Returns:
trained model - working on this!
......@@ -76,7 +78,8 @@ class Solver():
logs_directory='logs',
checkpoint_directory='checkpoints',
save_model_directory='saved_models',
final_model_output_file='finetuned_alldata.pth.tar'
final_model_output_file='finetuned_alldata.pth.tar',
crop_flag = False
):
self.model = model
......@@ -125,8 +128,10 @@ class Solver():
self.EarlyStopping = EarlyStopping(patience=10, min_delta=0)
self.early_stop = False
self.MNI152_T1_2mm_brain_mask = torch.from_numpy(
Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data)
if crop_flag == False:
self.MNI152_T1_2mm_brain_mask = torch.from_numpy(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data)
elif crop_flag == True:
self.MNI152_T1_2mm_brain_mask = torch.from_numpy(roi(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz'),((9,81),(10,100),(0,77))).data)
self.save_model_directory = save_model_directory
self.final_model_output_file = final_model_output_file
......
This diff is collapsed.
......@@ -19,6 +19,7 @@ import torch.utils.data as data
import h5py
from fsl.data.image import Image
from fsl.utils.image.resample import resampleToPixdims
from fsl.utils.image.roi import roi
class DataMapper(data.Dataset):
"""Data Mapper Class
......@@ -137,13 +138,14 @@ def load_subjects_from_path(data_directory, data_list):
return volumes_to_be_used
def load_and_preprocess_evaluation(file_path):
def load_and_preprocess_evaluation(file_path, crop_flag):
"""Load & Preprocessing before evaluation
This function loads a nifty file and returns its volume and header information
Args:
file_path (str): Path to the desired file
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
Returns:
volume (np.array): Array of training image data of data type dtype.
......@@ -155,8 +157,16 @@ def load_and_preprocess_evaluation(file_path):
"""
original_image = Image(file_path[0])
volume, xform = resampleToPixdims(original_image, (2, 2, 2))
header = Image(volume, header=original_image.header, xform=xform).header
if crop_flag == False:
volume, xform = resampleToPixdims(original_image, (2, 2, 2))
header = Image(volume, header=original_image.header, xform=xform).header
elif crop_flag == True:
resampled, xform = resampleToPixdims(original_image, (2, 2, 2))
resampled = Image(resampled, header=original_image.header, xform=xform)
cropped = roi(resampled,((9,81),(10,100),(0,77)))
volume = cropped.data
header = cropped.header
return volume, header, xform
......
......@@ -83,7 +83,8 @@ def convert_hdf5(data_parameters, file_information):
negative_flag = data_parameters['negative_flag'],
outlier_flag = data_parameters['outlier_flag'],
shrinkage_flag = data_parameters['shrinkage_flag'],
hard_shrinkage_flag = data_parameters['hard_shrinkage_flag']
hard_shrinkage_flag = data_parameters['hard_shrinkage_flag'],
crop_flag = data_parameters['crop_flag']
)
write_hdf5(train_dMRI, train_rsfMRI, file_information, mode='train')
......@@ -107,7 +108,8 @@ def convert_hdf5(data_parameters, file_information):
negative_flag = data_parameters['negative_flag'],
outlier_flag = data_parameters['outlier_flag'],
shrinkage_flag = data_parameters['shrinkage_flag'],
hard_shrinkage_flag = data_parameters['hard_shrinkage_flag']
hard_shrinkage_flag = data_parameters['hard_shrinkage_flag'],
crop_flag = data_parameters['crop_flag']
)
write_hdf5(validation_dMRI, validation_rsfMRI, file_information, mode='validation')
......
......@@ -23,6 +23,7 @@ negative_flag = True
outlier_flag = True
shrinkage_flag = False
hard_shrinkage_flag = False
crop_flag = False
input_data_train = "input_data_train.h5"
target_data_train = "target_data_train.h5"
input_data_validation = "input_data_validation.h5"
......
......@@ -19,6 +19,7 @@ import configparser
import pandas as pd
from fsl.data.image import Image
from fsl.utils.image.resample import resampleToPixdims
from fsl.utils.image.roi import roi
from sklearn.model_selection import train_test_split
from common_utils import create_folder
......@@ -273,7 +274,7 @@ def weight_calculator(data_directory, subject, train_inputs, train_targets, rsfM
def load_datasets(subjects, data_directory, input_file, output_target, mean_regression_flag, mean_regression_all_flag, regression_weights_path,
dMRI_mean_mask_path, rsfMRI_mean_mask_path, mean_subtraction_flag, scale_volumes_flag, normalize_flag, negative_flag,
outlier_flag, shrinkage_flag, hard_shrinkage_flag):
outlier_flag, shrinkage_flag, hard_shrinkage_flag, crop_flag):
""" Dataset loader and pre-processor
This function acts as a wrapper for loading and pre-processing of the datasets.
......@@ -295,6 +296,7 @@ def load_datasets(subjects, data_directory, input_file, output_target, mean_regr
outlier_flag (bool): Flag indicating if outliers should be set to the min/max values.
shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied.
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
Returns:
input_volumes (list): List of all the input volumes.
......@@ -311,7 +313,7 @@ def load_datasets(subjects, data_directory, input_file, output_target, mean_regr
input_volume, target_volume = load_and_preprocess(subject, data_directory, input_file, output_target, mean_regression_flag, mean_regression_all_flag, regression_weights_path,
dMRI_mean_mask_path, rsfMRI_mean_mask_path, mean_subtraction_flag, scale_volumes_flag, normalize_flag, negative_flag,
outlier_flag, shrinkage_flag, hard_shrinkage_flag)
outlier_flag, shrinkage_flag, hard_shrinkage_flag, crop_flag)
input_volumes.append(input_volume)
target_volumes.append(target_volume)
......@@ -323,7 +325,7 @@ def load_datasets(subjects, data_directory, input_file, output_target, mean_regr
def load_and_preprocess(subject, data_directory, input_file, output_target, mean_regression_flag, mean_regression_all_flag, regression_weights_path,
dMRI_mean_mask_path, rsfMRI_mean_mask_path, mean_subtraction_flag, scale_volumes_flag, normalize_flag, negative_flag,
outlier_flag, shrinkage_flag, hard_shrinkage_flag):
outlier_flag, shrinkage_flag, hard_shrinkage_flag, crop_flag):
""" Subject loader and pre-processor
This function acts as a wrapper for loading and pre-processing individual subjects.
......@@ -345,19 +347,20 @@ def load_and_preprocess(subject, data_directory, input_file, output_target, mean
outlier_flag (bool): Flag indicating if outliers should be set to the min/max values.
shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied.
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
Returns:
input_volume (np.array): Numpy array representing the preprocessed input volume.
target_volume (np.array) Numpy array representing the preprocessed target volume.
"""
input_volume, target_volume = load_data(subject, data_directory, input_file, output_target)
input_volume, target_volume = load_data(subject, data_directory, input_file, output_target, crop_flag)
input_volume, target_volume = preprocess(input_volume, target_volume, subject, mean_regression_flag, mean_regression_all_flag, regression_weights_path,
dMRI_mean_mask_path, rsfMRI_mean_mask_path, mean_subtraction_flag, scale_volumes_flag, normalize_flag, negative_flag,
outlier_flag, shrinkage_flag, hard_shrinkage_flag)
outlier_flag, shrinkage_flag, hard_shrinkage_flag, crop_flag)
return input_volume, target_volume
def load_data(subject, data_directory, input_file, output_target):
def load_data(subject, data_directory, input_file, output_target, crop_flag=False):
""" Load subject data
This function generates relevant paths for the input and target files for each subject, and then loads them as numpy arrays.
......@@ -367,6 +370,7 @@ def load_data(subject, data_directory, input_file, output_target):
data_directory (str): Directory where the various subjects are stored.
input_file (str): Intenal path for each subject to the relevant normalized summed dMRI tracts
output_target (str): Internal path for each subject to the relevant rsfMRI data
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
Returns:
input_volume
......@@ -376,15 +380,21 @@ def load_data(subject, data_directory, input_file, output_target):
input_path = os.path.join(os.path.expanduser("~"), data_directory, subject, input_file)
target_path = os.path.join(os.path.expanduser("~"), data_directory, subject, output_target)
input_volume, _ = resampleToPixdims(Image(input_path), (2,2,2))
target_volume = Image(target_path).data[:, :, :, 0]
if crop_flag == False:
input_volume, _ = resampleToPixdims(Image(input_path), (2,2,2))
target_volume = Image(target_path).data[:, :, :, 0]
elif crop_flag == True:
input_image = Image(input_path)
resampled_volume, xform = resampleToPixdims(input_image, (2,2,2))
input_volume = roi(Image(resampled_volume, header=input_image.header, xform=xform),((9,81),(10,100),(0,77))).data
target_volume = roi(Image(target_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]
return input_volume, target_volume
def preprocess(input_volume, target_volume, subject, mean_regression_flag, mean_regression_all_flag, regression_weights_path,
dMRI_mean_mask_path, rsfMRI_mean_mask_path, mean_subtraction_flag, scale_volumes_flag, normalize_flag, negative_flag,
outlier_flag, shrinkage_flag, hard_shrinkage_flag):
outlier_flag, shrinkage_flag, hard_shrinkage_flag, crop_flag):
"""Conducts pre-processing based on arguments
Function which wraps the various pre-processing subfunctions for every volume.
......@@ -404,7 +414,8 @@ def preprocess(input_volume, target_volume, subject, mean_regression_flag, mean_
negative_flag (bool): Flag indicating if all the negative values should be 0-ed.
outlier_flag (bool): Flag indicating if outliers should be set to the min/max values.
shrinkage_flag (bool): Flag indicating if shrinkage should be applied.
hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied.
hard_shrinkage_flag (bool): Flag indicating if hard shrinkage should be applied. If False, soft shrinkage is applied.
crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training
Returns:
input_volume
......@@ -414,18 +425,18 @@ def preprocess(input_volume, target_volume, subject, mean_regression_flag, mean_
if mean_regression_flag == True:
if mean_regression_all_flag == True:
# Regress both inputs and targets
input_volume = linear_regress_mean(input_volume, subject, regression_weights_path, target_flag=False, dMRI_mean_mask_path=dMRI_mean_mask_path)
target_volume = linear_regress_mean(target_volume, subject, regression_weights_path, target_flag=True, rsfMRI_mean_mask_path=rsfMRI_mean_mask_path)
input_volume = linear_regress_mean(input_volume, subject, regression_weights_path, crop_flag, target_flag=False, dMRI_mean_mask_path=dMRI_mean_mask_path)
target_volume = linear_regress_mean(target_volume, subject, regression_weights_path, crop_flag, target_flag=True, rsfMRI_mean_mask_path=rsfMRI_mean_mask_path)
# Set scaling parameters to Andrei Scaling
scaling_parameters = [-0.0539, 0.0969, -12.094, 14.6319]
else:
# Regress only targets, leave inputs as they are
target_volume = linear_regress_mean(target_volume, subject, regression_weights_path, target_flag=True, rsfMRI_mean_mask_path=rsfMRI_mean_mask_path)
target_volume = linear_regress_mean(target_volume, subject, regression_weights_path, crop_flag, target_flag=True, rsfMRI_mean_mask_path=rsfMRI_mean_mask_path)
# Set scaling parameters to Mixed Scaling
scaling_parameters = [0.0, 0.2, -12.094, 14.6319]
elif mean_subtraction_flag == True:
# Subtract the mean from targets, leave inputs as they are
target_volume = subtract_mean(target_volume, rsfMRI_mean_mask_path)
target_volume = subtract_mean(target_volume, crop_flag, rsfMRI_mean_mask_path)
# Set Scaling parameters to Steve Scaling
scaling_parameters = [0.0, 0.2, 0.0, 10.0]
else:
......@@ -438,7 +449,7 @@ def preprocess(input_volume, target_volume, subject, mean_regression_flag, mean_
return input_volume, target_volume
def linear_regress_mean(volume, subject, regression_weights_path, target_flag, dMRI_mean_mask_path=None, rsfMRI_mean_mask_path=None):
def linear_regress_mean(volume, subject, regression_weights_path, crop_flag, target_flag, dMRI_mean_mask_path=None, rsfMRI_mean_mask_path=None):
""" Linear regressed mean subtraction
Helper function which substracts or regressed the dual mean subject mask
......@@ -456,10 +467,16 @@ def linear_regress_mean(volume, subject, regression_weights_path, target_flag, d
"""
if target_flag == False:
group_mean = Image(dMRI_mean_mask_path).data
if crop_flag == False:
group_mean = Image(dMRI_mean_mask_path).data
elif crop_flag == True:
group_mean = roi(Image(dMRI_mean_mask_path),((9,81),(10,100),(0,77))).data
dataframe_key = 'w_dMRI'
elif target_flag == True:
group_mean = Image(rsfMRI_mean_mask_path).data[:, :, :, 0]
if crop_flag == False:
group_mean = Image(rsfMRI_mean_mask_path).data[:, :, :, 0]
elif crop_flag == True:
group_mean = roi(Image(rsfMRI_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]
dataframe_key = 'w_rsfMRI'
weight = pd.read_pickle(regression_weights_path).loc[subject][dataframe_key]
......@@ -469,7 +486,7 @@ def linear_regress_mean(volume, subject, regression_weights_path, target_flag, d
return volume
def subtract_mean(volume, rsfMRI_mean_mask_path):
def subtract_mean(volume, crop_flag, rsfMRI_mean_mask_path):
"""Mean Mask Substraction
Helper function which substracts the dualreg mean subject mask
......@@ -481,8 +498,10 @@ def subtract_mean(volume, rsfMRI_mean_mask_path):
Returns:
subtracted_volume (np.array): Numpy array representation of the subtracted volume data
"""
dualreg_subject_mean = Image(rsfMRI_mean_mask_path).data[:, :, :, 0]
if crop_flag == False:
dualreg_subject_mean = Image(rsfMRI_mean_mask_path).data[:, :, :, 0]
elif crop_flag == True:
dualreg_subject_mean = roi(Image(rsfMRI_mean_mask_path),((9,81),(10,100),(0,77))).data[:, :, :, 0]
volume = np.subtract(volume, dualreg_subject_mean)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment