Commit b22c8428 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added cross-validation option to the training function and solver

parent 0b755689
......@@ -23,7 +23,8 @@ Usage:
mode=train # For training the model
mode=evaluate-score # For evaluating the model score
mode=evaluate-mapping # For evaluating the model mapping
mode=clear-experiment # For clearning the experiments and logs directories of the last experiment
# For clearning the experiments and logs directories of the last experiment
mode=clear-experiment
mode=clear-all # For clearing all the files from the experiments and logs directories/
"""
......@@ -74,10 +75,10 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
"""Training Function
This function trains a given model using the provided training data.
Currently, the data loaded is set to have multiple sub-processes.
Currently, the data loaded is set to have multiple sub-processes.
A high enough number of workers assures that CPU computations are efficiently managed, i.e. that the bottleneck is indeed the neural network's forward and backward operations on the GPU (and not data generation)
Loader memory is also pinned, to speed up data transfer from CPU to GPU by using the page-locked memory.
Train data is also re-shuffled at each training epoch.
Train data is also re-shuffled at each training epoch.
Args:
data_parameters (dict): Dictionary containing relevant information for the datafiles.
......@@ -101,7 +102,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
'final_model_output_file': 'path/to/model'
}
network_parameters (dict): Contains information relevant parameters
network_parameters (dict): Contains information relevant parameters
misc_parameters (dict): Dictionary of aditional hyperparameters
misc_parameters = {
......@@ -113,60 +114,101 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
}
"""
train_data, validation_data = load_data(data_parameters)
def _train_runner(data_parameters, training_parameters, network_parameters, misc_parameters):
"""Wrapper for the training operation
This function wraps the training operation for the network
Args:
data_parameters (dict): Dictionary containing relevant information for the datafiles.
training_parameters(dict): Dictionary containing relevant hyperparameters for training the network.
network_parameters (dict): Contains information relevant parameters
misc_parameters (dict): Dictionary of aditional hyperparameters
"""
train_data, validation_data = load_data(data_parameters)
train_loader = data.DataLoader(
dataset=train_data,
batch_size=training_parameters['training_batch_size'],
shuffle=True,
num_workers=4,
pin_memory=True
)
validation_loader = data.DataLoader(
dataset=validation_data,
batch_size=training_parameters['validation_batch_size'],
shuffle=False,
num_workers=4,
pin_memory=True
)
if training_parameters['use_pre_trained']:
BrainMapperModel = torch.load(
training_parameters['pre_trained_path'])
else:
BrainMapperModel = BrainMapperUNet3D(network_parameters)
train_loader = data.DataLoader(
dataset=train_data,
batch_size=training_parameters['training_batch_size'],
shuffle=True,
num_workers=4,
pin_memory=True
)
solver = Solver(model=BrainMapperModel,
device=misc_parameters['device'],
number_of_classes=network_parameters['number_of_classes'],
experiment_name=training_parameters['experiment_name'],
optimizer_arguments={'lr': training_parameters['learning_rate'],
'betas': training_parameters['optimizer_beta'],
'eps': training_parameters['optimizer_epsilon'],
'weight_decay': training_parameters['optimizer_weigth_decay']
},
model_name=misc_parameters['model_name'],
number_epochs=training_parameters['number_of_epochs'],
loss_log_period=training_parameters['loss_log_period'],
learning_rate_scheduler_step_size=training_parameters[
'learning_rate_scheduler_step_size'],
learning_rate_scheduler_gamma=training_parameters['learning_rate_scheduler_gamma'],
use_last_checkpoint=training_parameters['use_last_checkpoint'],
experiment_directory=misc_parameters['experiments_directory'],
logs_directory=misc_parameters['logs_directory'],
checkpoint_directory=misc_parameters['checkpoint_directory']
)
validation_loader = data.DataLoader(
dataset=validation_data,
batch_size=training_parameters['validation_batch_size'],
shuffle=False,
num_workers=4,
pin_memory=True
)
validation_loss = solver.train(train_loader, validation_loader)
if training_parameters['use_pre_trained']:
BrainMapperModel = torch.load(training_parameters['pre_trained_path'])
else:
BrainMapperModel = BrainMapperUNet3D(network_parameters)
model_output_path = os.path.join(
misc_parameters['save_model_directory'], training_parameters['final_model_output_file'])
create_folder(misc_parameters['save_model_directory'])
BrainMapperModel.save(model_output_path)
solver = Solver(model=BrainMapperModel,
device=misc_parameters['device'],
number_of_classes=network_parameters['number_of_classes'],
experiment_name=training_parameters['experiment_name'],
optimizer_arguments={'lr': training_parameters['learning_rate'],
'betas': training_parameters['optimizer_beta'],
'eps': training_parameters['optimizer_epsilon'],
'weight_decay': training_parameters['optimizer_weigth_decay']
},
model_name=misc_parameters['model_name'],
number_epochs=training_parameters['number_of_epochs'],
loss_log_period=training_parameters['loss_log_period'],
learning_rate_scheduler_step_size=training_parameters[
'learning_rate_scheduler_step_size'],
learning_rate_scheduler_gamma=training_parameters['learning_rate_scheduler_gamma'],
use_last_checkpoint=training_parameters['use_last_checkpoint'],
experiment_directory=misc_parameters['experiments_directory'],
logs_directory=misc_parameters['logs_directory']
)
print("Final Model Saved in: {}".format(model_output_path))
solver.train(train_loader, validation_loader)
del train_data, validation_data, train_loader, validation_loader, BrainMapperModel, solver
torch.cuda.empty_cache()
return validation_loss
if data_parameters['k_fold'] is None:
_ = _train_runner(data_parameters, training_parameters,
network_parameters, misc_parameters)
else:
for k in range of data_parameters['k_fold']:
model_output_path = os.path.join(
misc_parameters['save_model_directory'], training_parameters['final_model_output_file'])
k_fold_losses = []
create_folder(misc_parameters['save_model_directory'])
data_parameters['train_list'] = os.path.join(
data_folder_name, 'train' + str(k+1)+'.txt')
data_parameters['validation_list'] = os.path.join(
data_folder_name, 'validation' + str(k+1)+'.txt')
training_parameters['final_model_output_file'])=final_model_output_file.replace(".pth.tar", str(k+1)+".pth.tar")
BrainMapperModel.save(model_output_path)
validation_loss=_train_runner(
data_parameters, training_parameters, network_parameters, misc_parameters)
print("Final Model Saved in: {}".format(model_output_path))
k_fold_losses.append(validation_loss)
mean_k_fold_loss=k_fold_losses.mean()
def evaluate_score(training_parameters, network_parameters, misc_parameters, evaluation_parameters):
"""Mapping Score Evaluator
......@@ -180,7 +222,7 @@ def evaluate_score(training_parameters, network_parameters, misc_parameters, eva
'experiment_name': 'experiment_name'
}
network_parameters (dict): Contains information relevant parameters
network_parameters (dict): Contains information relevant parameters
network_parameters= {
'number_of_classes': 1
}
......@@ -205,17 +247,17 @@ def evaluate_score(training_parameters, network_parameters, misc_parameters, eva
# TODO - NEED TO UPDATE THE DATA FUNCTIONS!
logWriter = LogWriter(number_of_classes=network_parameters['number_of_classes'],
logs_directory=misc_parameters['logs_directory'],
experiment_name=training_parameters['experiment_name']
logWriter=LogWriter(number_of_classes = network_parameters['number_of_classes'],
logs_directory = misc_parameters['logs_directory'],
experiment_name = training_parameters['experiment_name']
)
prediction_output_path = os.path.join(misc_parameters['experiments_directory'],
prediction_output_path=os.path.join(misc_parameters['experiments_directory'],
training_parameters['experiment_name'],
evaluation_parameters['saved_predictions_directory']
)
_ = evaluations.evaluate_dice_score(trained_model_path=evaluation_parameters['trained_model_path'],
_=evaluations.evaluate_dice_score(trained_model_path = evaluation_parameters['trained_model_path'],
number_of_classes=network_parameters['number_of_classes'],
data_directory=evaluation_parameters['data_directory'],
targets_directory=evaluation_parameters[
......
......@@ -7,9 +7,9 @@ k_fold = None
data_split_flag = False
test_percentage = 5
subject_number = None
train_list = "train.txt"
validation_list = "validation.txt"
test_list = "test.txt"
train_list = "datasets/train.txt"
validation_list = "datasets/validation.txt"
test_list = "datasets/test.txt"
train_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
train_output_targets = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
validation_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
......
......@@ -23,7 +23,6 @@ from utils.data_logging_utils import LogWriter
from utils.early_stopping import EarlyStopping
from torch.optim import lr_scheduler
checkpoint_directory = 'checkpoints'
checkpoint_extension = 'path.tar'
......@@ -71,7 +70,8 @@ class Solver():
learning_rate_scheduler_gamma=0.5,
use_last_checkpoint=True,
experiment_directory='experiments',
logs_directory='logs'
logs_directory='logs',
checkpoint_directory = 'checkpoints'
):
self.model = model
......@@ -99,10 +99,12 @@ class Solver():
experiment_directory, experiment_name)
self.experiment_directory_path = experiment_directory_path
self.checkpoint_directory = checkpoint_directory
create_folder(experiment_directory)
create_folder(experiment_directory_path)
create_folder(os.path.join(
experiment_directory_path, checkpoint_directory))
experiment_directory_path, self.checkpoint_directory))
self.start_epoch = 1
self.start_iteration = 1
......@@ -220,6 +222,7 @@ class Solver():
early_stop, save_checkpoint = self.EarlyStopping(np.mean(losses))
self.early_stop = early_stop
if save_checkpoint == True:
validation_loss = np.mean(losses)
self.save_checkpoint(state={'epoch': epoch + 1,
'start_iteration': iteration + 1,
'arch': self.model_name,
......@@ -227,7 +230,7 @@ class Solver():
'optimizer': optimizer.state_dict(),
'scheduler': learning_rate_scheduler.state_dict()
},
filename=os.path.join(self.experiment_directory_path, checkpoint_directory,
filename=os.path.join(self.experiment_directory_path, self.checkpoint_directory,
'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
)
......@@ -251,6 +254,8 @@ class Solver():
print('Training Duration: {}'.format(end_time - start_time))
print('****************************************************************')
return validation_loss
def save_checkpoint(self, state, filename):
"""General Checkpoint Save
......@@ -273,11 +278,11 @@ class Solver():
if epoch is not None:
checkpoint_file_path = os.path.join(
self.experiment_directory_path, checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
self._checkpoint_reader(checkpoint_file_path)
else:
universal_path = os.path.join(
self.experiment_directory_path, checkpoint_directory, '*.' + checkpoint_extension)
self.experiment_directory_path, self.checkpoint_directory, '*.' + checkpoint_extension)
files_in_universal_path = glob.glob(universal_path)
# We will sort through all the files in path to see which one is most recent
......@@ -289,7 +294,7 @@ class Solver():
else:
self.LogWriter.log("No Checkpoint found at {}".format(
os.path.join(self.experiment_directory_path, checkpoint_directory)))
os.path.join(self.experiment_directory_path, self.checkpoint_directory)))
def _checkpoint_reader(self, checkpoint_file_path):
"""Checkpoint Reader
......
......@@ -81,7 +81,7 @@ def data_file_reader(data_file_path):
return subDirectoryList
def data_test_train_validation_split(data_folder_name, test_percentage, subject_number, data_directory= None, data_file= None, K_fold= None):
def data_test_train_validation_split(data_folder_name, test_percentage, subject_number, data_directory=None, data_file=None, K_fold=None):
"""Produces lists of train, test and validation data
This function looks at the list of all available directories and returns three lists of dsub-directories.
......@@ -110,30 +110,37 @@ def data_test_train_validation_split(data_folder_name, test_percentage, subject_
subDirectoryList = np.array(subDirectoryList)
create_folder(data_folder_name)
train_data, test = train_test_split(subDirectoryList, test_size= test_percentage/100, random_state= 42, shuffle= True)
train_data, test = train_test_split(
subDirectoryList, test_size=test_percentage/100, random_state=42, shuffle=True)
np.savetxt(os.path.join(data_folder_name, 'test.txt'), test, fmt='%s')
print("Test={}".format(test))
if K_fold is None:
train, validation = train_test_split(train_data, test_size= int(len(test)), random_state= 42, shuffle= True)
np.savetxt(os.path.join(data_folder_name, 'train.txt'), train, fmt='%s')
np.savetxt(os.path.join(data_folder_name, 'validation.txt'), validation, fmt='%s')
train, validation = train_test_split(
train_data, test_size=int(len(test)), random_state=42, shuffle=True)
np.savetxt(os.path.join(data_folder_name,
'train.txt'), train, fmt='%s')
np.savetxt(os.path.join(data_folder_name, 'validation.txt'),
validation, fmt='%s')
print("Train={}, Validation={}".format(train, validation))
else:
k_fold = KFold(n_splits= K_fold)
k_fold = KFold(n_splits=K_fold)
k = 0
for train_index, validation_index in k_fold.split(train_data):
train, validation = train_data[train_index], train_data[validation_index]
np.savetxt(os.path.join(data_folder_name, 'train'+str(k+1)+'.txt'), train, fmt='%s')
np.savetxt(os.path.join(data_folder_name, 'validation'+str(k+1)+'.txt'), validation, fmt='%s')
np.savetxt(os.path.join(data_folder_name, 'train' +
str(k+1)+'.txt'), train, fmt='%s')
np.savetxt(os.path.join(data_folder_name, 'validation' +
str(k+1)+'.txt'), validation, fmt='%s')
print("K={}, Train={}, Validation={}".format(k, train, validation))
k += 1
def update_shuffling_flag(file_name):
""" Update shuffling flag
......@@ -637,12 +644,10 @@ if __name__ == '__main__':
K_fold = None
subject_number = None
data_directory = "../well/win-biobank/projects/imaging/data/data3/subjectsAll/"
data_test_train_validation_split(data_folder_name, test_percentage, subject_number, data_directory = data_directory, K_fold= K_fold)
data_test_train_validation_split(
data_folder_name, test_percentage, subject_number, data_directory=data_directory, K_fold=K_fold)
# data_test_train_validation_split_Kfold_cross_validation(data_folder_name, K_fold, subject_number, data_directory = data_directory)
# data = np.arange(23)
# K = 10
# test_size = int(len(data)/K)
......@@ -655,7 +660,6 @@ if __name__ == '__main__':
# remainder[(k-1) * test_size: k *test_size], test_slice = test_slice, remainder[(k-1) * test_size: k * test_size].copy()
# print("k= {}, test_slice={}, remainder={}".format(k, test_slice, remainder))
# print('SKLEARN TIME!')
# from sklearn.model_selection import KFold, train_test_split
......@@ -670,4 +674,4 @@ if __name__ == '__main__':
# for train_index, test_index in kf.split(train_data):
# train, test = train_data[train_index], train_data[test_index]
# print("k= {}, val_slice={}, train={}".format(k, test, train))
# k+=1
\ No newline at end of file
# k+=1
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment