Commit b22c8428 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added cross-validation option to the training function and solver

parent 0b755689
...@@ -23,7 +23,8 @@ Usage: ...@@ -23,7 +23,8 @@ Usage:
mode=train # For training the model mode=train # For training the model
mode=evaluate-score # For evaluating the model score mode=evaluate-score # For evaluating the model score
mode=evaluate-mapping # For evaluating the model mapping mode=evaluate-mapping # For evaluating the model mapping
mode=clear-experiment # For clearning the experiments and logs directories of the last experiment # For clearning the experiments and logs directories of the last experiment
mode=clear-experiment
mode=clear-all # For clearing all the files from the experiments and logs directories/ mode=clear-all # For clearing all the files from the experiments and logs directories/
""" """
...@@ -74,10 +75,10 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -74,10 +75,10 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
"""Training Function """Training Function
This function trains a given model using the provided training data. This function trains a given model using the provided training data.
Currently, the data loaded is set to have multiple sub-processes. Currently, the data loaded is set to have multiple sub-processes.
A high enough number of workers assures that CPU computations are efficiently managed, i.e. that the bottleneck is indeed the neural network's forward and backward operations on the GPU (and not data generation) A high enough number of workers assures that CPU computations are efficiently managed, i.e. that the bottleneck is indeed the neural network's forward and backward operations on the GPU (and not data generation)
Loader memory is also pinned, to speed up data transfer from CPU to GPU by using the page-locked memory. Loader memory is also pinned, to speed up data transfer from CPU to GPU by using the page-locked memory.
Train data is also re-shuffled at each training epoch. Train data is also re-shuffled at each training epoch.
Args: Args:
data_parameters (dict): Dictionary containing relevant information for the datafiles. data_parameters (dict): Dictionary containing relevant information for the datafiles.
...@@ -101,7 +102,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -101,7 +102,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
'final_model_output_file': 'path/to/model' 'final_model_output_file': 'path/to/model'
} }
network_parameters (dict): Contains information relevant parameters network_parameters (dict): Contains information relevant parameters
misc_parameters (dict): Dictionary of aditional hyperparameters misc_parameters (dict): Dictionary of aditional hyperparameters
misc_parameters = { misc_parameters = {
...@@ -113,60 +114,101 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -113,60 +114,101 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
} }
""" """
train_data, validation_data = load_data(data_parameters) def _train_runner(data_parameters, training_parameters, network_parameters, misc_parameters):
"""Wrapper for the training operation
This function wraps the training operation for the network
Args:
data_parameters (dict): Dictionary containing relevant information for the datafiles.
training_parameters(dict): Dictionary containing relevant hyperparameters for training the network.
network_parameters (dict): Contains information relevant parameters
misc_parameters (dict): Dictionary of aditional hyperparameters
"""
train_data, validation_data = load_data(data_parameters)
train_loader = data.DataLoader(
dataset=train_data,
batch_size=training_parameters['training_batch_size'],
shuffle=True,
num_workers=4,
pin_memory=True
)
validation_loader = data.DataLoader(
dataset=validation_data,
batch_size=training_parameters['validation_batch_size'],
shuffle=False,
num_workers=4,
pin_memory=True
)
if training_parameters['use_pre_trained']:
BrainMapperModel = torch.load(
training_parameters['pre_trained_path'])
else:
BrainMapperModel = BrainMapperUNet3D(network_parameters)
train_loader = data.DataLoader( solver = Solver(model=BrainMapperModel,
dataset=train_data, device=misc_parameters['device'],
batch_size=training_parameters['training_batch_size'], number_of_classes=network_parameters['number_of_classes'],
shuffle=True, experiment_name=training_parameters['experiment_name'],
num_workers=4, optimizer_arguments={'lr': training_parameters['learning_rate'],
pin_memory=True 'betas': training_parameters['optimizer_beta'],
) 'eps': training_parameters['optimizer_epsilon'],
'weight_decay': training_parameters['optimizer_weigth_decay']
},
model_name=misc_parameters['model_name'],
number_epochs=training_parameters['number_of_epochs'],
loss_log_period=training_parameters['loss_log_period'],
learning_rate_scheduler_step_size=training_parameters[
'learning_rate_scheduler_step_size'],
learning_rate_scheduler_gamma=training_parameters['learning_rate_scheduler_gamma'],
use_last_checkpoint=training_parameters['use_last_checkpoint'],
experiment_directory=misc_parameters['experiments_directory'],
logs_directory=misc_parameters['logs_directory'],
checkpoint_directory=misc_parameters['checkpoint_directory']
)
validation_loader = data.DataLoader( validation_loss = solver.train(train_loader, validation_loader)
dataset=validation_data,
batch_size=training_parameters['validation_batch_size'],
shuffle=False,
num_workers=4,
pin_memory=True
)
if training_parameters['use_pre_trained']: model_output_path = os.path.join(
BrainMapperModel = torch.load(training_parameters['pre_trained_path']) misc_parameters['save_model_directory'], training_parameters['final_model_output_file'])
else:
BrainMapperModel = BrainMapperUNet3D(network_parameters) create_folder(misc_parameters['save_model_directory'])
BrainMapperModel.save(model_output_path)
solver = Solver(model=BrainMapperModel, print("Final Model Saved in: {}".format(model_output_path))
device=misc_parameters['device'],
number_of_classes=network_parameters['number_of_classes'],
experiment_name=training_parameters['experiment_name'],
optimizer_arguments={'lr': training_parameters['learning_rate'],
'betas': training_parameters['optimizer_beta'],
'eps': training_parameters['optimizer_epsilon'],
'weight_decay': training_parameters['optimizer_weigth_decay']
},
model_name=misc_parameters['model_name'],
number_epochs=training_parameters['number_of_epochs'],
loss_log_period=training_parameters['loss_log_period'],
learning_rate_scheduler_step_size=training_parameters[
'learning_rate_scheduler_step_size'],
learning_rate_scheduler_gamma=training_parameters['learning_rate_scheduler_gamma'],
use_last_checkpoint=training_parameters['use_last_checkpoint'],
experiment_directory=misc_parameters['experiments_directory'],
logs_directory=misc_parameters['logs_directory']
)
solver.train(train_loader, validation_loader) del train_data, validation_data, train_loader, validation_loader, BrainMapperModel, solver
torch.cuda.empty_cache()
return validation_loss
if data_parameters['k_fold'] is None:
_ = _train_runner(data_parameters, training_parameters,
network_parameters, misc_parameters)
else:
for k in range of data_parameters['k_fold']:
model_output_path = os.path.join( k_fold_losses = []
misc_parameters['save_model_directory'], training_parameters['final_model_output_file'])
create_folder(misc_parameters['save_model_directory']) data_parameters['train_list'] = os.path.join(
data_folder_name, 'train' + str(k+1)+'.txt')
data_parameters['validation_list'] = os.path.join(
data_folder_name, 'validation' + str(k+1)+'.txt')
training_parameters['final_model_output_file'])=final_model_output_file.replace(".pth.tar", str(k+1)+".pth.tar")
BrainMapperModel.save(model_output_path) validation_loss=_train_runner(
data_parameters, training_parameters, network_parameters, misc_parameters)
print("Final Model Saved in: {}".format(model_output_path)) k_fold_losses.append(validation_loss)
mean_k_fold_loss=k_fold_losses.mean()
def evaluate_score(training_parameters, network_parameters, misc_parameters, evaluation_parameters): def evaluate_score(training_parameters, network_parameters, misc_parameters, evaluation_parameters):
"""Mapping Score Evaluator """Mapping Score Evaluator
...@@ -180,7 +222,7 @@ def evaluate_score(training_parameters, network_parameters, misc_parameters, eva ...@@ -180,7 +222,7 @@ def evaluate_score(training_parameters, network_parameters, misc_parameters, eva
'experiment_name': 'experiment_name' 'experiment_name': 'experiment_name'
} }
network_parameters (dict): Contains information relevant parameters network_parameters (dict): Contains information relevant parameters
network_parameters= { network_parameters= {
'number_of_classes': 1 'number_of_classes': 1
} }
...@@ -205,17 +247,17 @@ def evaluate_score(training_parameters, network_parameters, misc_parameters, eva ...@@ -205,17 +247,17 @@ def evaluate_score(training_parameters, network_parameters, misc_parameters, eva
# TODO - NEED TO UPDATE THE DATA FUNCTIONS! # TODO - NEED TO UPDATE THE DATA FUNCTIONS!
logWriter = LogWriter(number_of_classes=network_parameters['number_of_classes'], logWriter=LogWriter(number_of_classes = network_parameters['number_of_classes'],
logs_directory=misc_parameters['logs_directory'], logs_directory = misc_parameters['logs_directory'],
experiment_name=training_parameters['experiment_name'] experiment_name = training_parameters['experiment_name']
) )
prediction_output_path = os.path.join(misc_parameters['experiments_directory'], prediction_output_path=os.path.join(misc_parameters['experiments_directory'],
training_parameters['experiment_name'], training_parameters['experiment_name'],
evaluation_parameters['saved_predictions_directory'] evaluation_parameters['saved_predictions_directory']
) )
_ = evaluations.evaluate_dice_score(trained_model_path=evaluation_parameters['trained_model_path'], _=evaluations.evaluate_dice_score(trained_model_path = evaluation_parameters['trained_model_path'],
number_of_classes=network_parameters['number_of_classes'], number_of_classes=network_parameters['number_of_classes'],
data_directory=evaluation_parameters['data_directory'], data_directory=evaluation_parameters['data_directory'],
targets_directory=evaluation_parameters[ targets_directory=evaluation_parameters[
......
...@@ -7,9 +7,9 @@ k_fold = None ...@@ -7,9 +7,9 @@ k_fold = None
data_split_flag = False data_split_flag = False
test_percentage = 5 test_percentage = 5
subject_number = None subject_number = None
train_list = "train.txt" train_list = "datasets/train.txt"
validation_list = "validation.txt" validation_list = "datasets/validation.txt"
test_list = "test.txt" test_list = "datasets/test.txt"
train_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz" train_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
train_output_targets = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz" train_output_targets = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
validation_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz" validation_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
......
...@@ -23,7 +23,6 @@ from utils.data_logging_utils import LogWriter ...@@ -23,7 +23,6 @@ from utils.data_logging_utils import LogWriter
from utils.early_stopping import EarlyStopping from utils.early_stopping import EarlyStopping
from torch.optim import lr_scheduler from torch.optim import lr_scheduler
checkpoint_directory = 'checkpoints'
checkpoint_extension = 'path.tar' checkpoint_extension = 'path.tar'
...@@ -71,7 +70,8 @@ class Solver(): ...@@ -71,7 +70,8 @@ class Solver():
learning_rate_scheduler_gamma=0.5, learning_rate_scheduler_gamma=0.5,
use_last_checkpoint=True, use_last_checkpoint=True,
experiment_directory='experiments', experiment_directory='experiments',
logs_directory='logs' logs_directory='logs',
checkpoint_directory = 'checkpoints'
): ):
self.model = model self.model = model
...@@ -99,10 +99,12 @@ class Solver(): ...@@ -99,10 +99,12 @@ class Solver():
experiment_directory, experiment_name) experiment_directory, experiment_name)
self.experiment_directory_path = experiment_directory_path self.experiment_directory_path = experiment_directory_path
self.checkpoint_directory = checkpoint_directory
create_folder(experiment_directory) create_folder(experiment_directory)
create_folder(experiment_directory_path) create_folder(experiment_directory_path)
create_folder(os.path.join( create_folder(os.path.join(
experiment_directory_path, checkpoint_directory)) experiment_directory_path, self.checkpoint_directory))
self.start_epoch = 1 self.start_epoch = 1
self.start_iteration = 1 self.start_iteration = 1
...@@ -220,6 +222,7 @@ class Solver(): ...@@ -220,6 +222,7 @@ class Solver():
early_stop, save_checkpoint = self.EarlyStopping(np.mean(losses)) early_stop, save_checkpoint = self.EarlyStopping(np.mean(losses))
self.early_stop = early_stop self.early_stop = early_stop
if save_checkpoint == True: if save_checkpoint == True:
validation_loss = np.mean(losses)
self.save_checkpoint(state={'epoch': epoch + 1, self.save_checkpoint(state={'epoch': epoch + 1,
'start_iteration': iteration + 1, 'start_iteration': iteration + 1,
'arch': self.model_name, 'arch': self.model_name,
...@@ -227,7 +230,7 @@ class Solver(): ...@@ -227,7 +230,7 @@ class Solver():
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'scheduler': learning_rate_scheduler.state_dict() 'scheduler': learning_rate_scheduler.state_dict()
}, },
filename=os.path.join(self.experiment_directory_path, checkpoint_directory, filename=os.path.join(self.experiment_directory_path, self.checkpoint_directory,
'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension) 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
) )
...@@ -251,6 +254,8 @@ class Solver(): ...@@ -251,6 +254,8 @@ class Solver():
print('Training Duration: {}'.format(end_time - start_time)) print('Training Duration: {}'.format(end_time - start_time))
print('****************************************************************') print('****************************************************************')
return validation_loss
def save_checkpoint(self, state, filename): def save_checkpoint(self, state, filename):
"""General Checkpoint Save """General Checkpoint Save
...@@ -273,11 +278,11 @@ class Solver(): ...@@ -273,11 +278,11 @@ class Solver():
if epoch is not None: if epoch is not None:
checkpoint_file_path = os.path.join( checkpoint_file_path = os.path.join(
self.experiment_directory_path, checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension) self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
self._checkpoint_reader(checkpoint_file_path) self._checkpoint_reader(checkpoint_file_path)
else: else:
universal_path = os.path.join( universal_path = os.path.join(
self.experiment_directory_path, checkpoint_directory, '*.' + checkpoint_extension) self.experiment_directory_path, self.checkpoint_directory, '*.' + checkpoint_extension)
files_in_universal_path = glob.glob(universal_path) files_in_universal_path = glob.glob(universal_path)
# We will sort through all the files in path to see which one is most recent # We will sort through all the files in path to see which one is most recent
...@@ -289,7 +294,7 @@ class Solver(): ...@@ -289,7 +294,7 @@ class Solver():
else: else:
self.LogWriter.log("No Checkpoint found at {}".format( self.LogWriter.log("No Checkpoint found at {}".format(
os.path.join(self.experiment_directory_path, checkpoint_directory))) os.path.join(self.experiment_directory_path, self.checkpoint_directory)))
def _checkpoint_reader(self, checkpoint_file_path): def _checkpoint_reader(self, checkpoint_file_path):
"""Checkpoint Reader """Checkpoint Reader
......
...@@ -81,7 +81,7 @@ def data_file_reader(data_file_path): ...@@ -81,7 +81,7 @@ def data_file_reader(data_file_path):
return subDirectoryList return subDirectoryList
def data_test_train_validation_split(data_folder_name, test_percentage, subject_number, data_directory= None, data_file= None, K_fold= None): def data_test_train_validation_split(data_folder_name, test_percentage, subject_number, data_directory=None, data_file=None, K_fold=None):
"""Produces lists of train, test and validation data """Produces lists of train, test and validation data
This function looks at the list of all available directories and returns three lists of dsub-directories. This function looks at the list of all available directories and returns three lists of dsub-directories.
...@@ -110,30 +110,37 @@ def data_test_train_validation_split(data_folder_name, test_percentage, subject_ ...@@ -110,30 +110,37 @@ def data_test_train_validation_split(data_folder_name, test_percentage, subject_
subDirectoryList = np.array(subDirectoryList) subDirectoryList = np.array(subDirectoryList)
create_folder(data_folder_name) create_folder(data_folder_name)
train_data, test = train_test_split(subDirectoryList, test_size= test_percentage/100, random_state= 42, shuffle= True) train_data, test = train_test_split(
subDirectoryList, test_size=test_percentage/100, random_state=42, shuffle=True)
np.savetxt(os.path.join(data_folder_name, 'test.txt'), test, fmt='%s') np.savetxt(os.path.join(data_folder_name, 'test.txt'), test, fmt='%s')
print("Test={}".format(test)) print("Test={}".format(test))
if K_fold is None: if K_fold is None:
train, validation = train_test_split(train_data, test_size= int(len(test)), random_state= 42, shuffle= True) train, validation = train_test_split(
train_data, test_size=int(len(test)), random_state=42, shuffle=True)
np.savetxt(os.path.join(data_folder_name, 'train.txt'), train, fmt='%s')
np.savetxt(os.path.join(data_folder_name, 'validation.txt'), validation, fmt='%s') np.savetxt(os.path.join(data_folder_name,
'train.txt'), train, fmt='%s')
np.savetxt(os.path.join(data_folder_name, 'validation.txt'),
validation, fmt='%s')
print("Train={}, Validation={}".format(train, validation)) print("Train={}, Validation={}".format(train, validation))
else: else:
k_fold = KFold(n_splits= K_fold) k_fold = KFold(n_splits=K_fold)
k = 0 k = 0
for train_index, validation_index in k_fold.split(train_data): for train_index, validation_index in k_fold.split(train_data):
train, validation = train_data[train_index], train_data[validation_index] train, validation = train_data[train_index], train_data[validation_index]
np.savetxt(os.path.join(data_folder_name, 'train'+str(k+1)+'.txt'), train, fmt='%s') np.savetxt(os.path.join(data_folder_name, 'train' +
np.savetxt(os.path.join(data_folder_name, 'validation'+str(k+1)+'.txt'), validation, fmt='%s') str(k+1)+'.txt'), train, fmt='%s')
np.savetxt(os.path.join(data_folder_name, 'validation' +
str(k+1)+'.txt'), validation, fmt='%s')
print("K={}, Train={}, Validation={}".format(k, train, validation)) print("K={}, Train={}, Validation={}".format(k, train, validation))
k += 1 k += 1
def update_shuffling_flag(file_name): def update_shuffling_flag(file_name):
""" Update shuffling flag """ Update shuffling flag
...@@ -637,12 +644,10 @@ if __name__ == '__main__': ...@@ -637,12 +644,10 @@ if __name__ == '__main__':
K_fold = None K_fold = None
subject_number = None subject_number = None
data_directory = "../well/win-biobank/projects/imaging/data/data3/subjectsAll/" data_directory = "../well/win-biobank/projects/imaging/data/data3/subjectsAll/"
data_test_train_validation_split(
data_folder_name, test_percentage, subject_number, data_directory=data_directory, K_fold=K_fold)
data_test_train_validation_split(data_folder_name, test_percentage, subject_number, data_directory = data_directory, K_fold= K_fold)
# data_test_train_validation_split_Kfold_cross_validation(data_folder_name, K_fold, subject_number, data_directory = data_directory) # data_test_train_validation_split_Kfold_cross_validation(data_folder_name, K_fold, subject_number, data_directory = data_directory)
# data = np.arange(23) # data = np.arange(23)
# K = 10 # K = 10
# test_size = int(len(data)/K) # test_size = int(len(data)/K)
...@@ -655,7 +660,6 @@ if __name__ == '__main__': ...@@ -655,7 +660,6 @@ if __name__ == '__main__':
# remainder[(k-1) * test_size: k *test_size], test_slice = test_slice, remainder[(k-1) * test_size: k * test_size].copy() # remainder[(k-1) * test_size: k *test_size], test_slice = test_slice, remainder[(k-1) * test_size: k * test_size].copy()
# print("k= {}, test_slice={}, remainder={}".format(k, test_slice, remainder)) # print("k= {}, test_slice={}, remainder={}".format(k, test_slice, remainder))
# print('SKLEARN TIME!') # print('SKLEARN TIME!')
# from sklearn.model_selection import KFold, train_test_split # from sklearn.model_selection import KFold, train_test_split
...@@ -670,4 +674,4 @@ if __name__ == '__main__': ...@@ -670,4 +674,4 @@ if __name__ == '__main__':
# for train_index, test_index in kf.split(train_data): # for train_index, test_index in kf.split(train_data):
# train, test = train_data[train_index], train_data[test_index] # train, test = train_data[train_index], train_data[test_index]
# print("k= {}, val_slice={}, train={}".format(k, test, train)) # print("k= {}, val_slice={}, train={}".format(k, test, train))
# k+=1 # k+=1
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment