Commit b22c8428 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added cross-validation option to the training function and solver

parent 0b755689
......@@ -23,7 +23,8 @@ Usage:
mode=train # For training the model
mode=evaluate-score # For evaluating the model score
mode=evaluate-mapping # For evaluating the model mapping
mode=clear-experiment # For clearning the experiments and logs directories of the last experiment
# For clearning the experiments and logs directories of the last experiment
mode=clear-experiment
mode=clear-all # For clearing all the files from the experiments and logs directories/
"""
......@@ -113,6 +114,18 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
}
"""
def _train_runner(data_parameters, training_parameters, network_parameters, misc_parameters):
"""Wrapper for the training operation
This function wraps the training operation for the network
Args:
data_parameters (dict): Dictionary containing relevant information for the datafiles.
training_parameters(dict): Dictionary containing relevant hyperparameters for training the network.
network_parameters (dict): Contains information relevant parameters
misc_parameters (dict): Dictionary of aditional hyperparameters
"""
train_data, validation_data = load_data(data_parameters)
train_loader = data.DataLoader(
......@@ -132,7 +145,8 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
)
if training_parameters['use_pre_trained']:
BrainMapperModel = torch.load(training_parameters['pre_trained_path'])
BrainMapperModel = torch.load(
training_parameters['pre_trained_path'])
else:
BrainMapperModel = BrainMapperUNet3D(network_parameters)
......@@ -153,10 +167,11 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
learning_rate_scheduler_gamma=training_parameters['learning_rate_scheduler_gamma'],
use_last_checkpoint=training_parameters['use_last_checkpoint'],
experiment_directory=misc_parameters['experiments_directory'],
logs_directory=misc_parameters['logs_directory']
logs_directory=misc_parameters['logs_directory'],
checkpoint_directory=misc_parameters['checkpoint_directory']
)
solver.train(train_loader, validation_loader)
validation_loss = solver.train(train_loader, validation_loader)
model_output_path = os.path.join(
misc_parameters['save_model_directory'], training_parameters['final_model_output_file'])
......@@ -167,6 +182,33 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
print("Final Model Saved in: {}".format(model_output_path))
del train_data, validation_data, train_loader, validation_loader, BrainMapperModel, solver
torch.cuda.empty_cache()
return validation_loss
if data_parameters['k_fold'] is None:
_ = _train_runner(data_parameters, training_parameters,
network_parameters, misc_parameters)
else:
for k in range of data_parameters['k_fold']:
k_fold_losses = []
data_parameters['train_list'] = os.path.join(
data_folder_name, 'train' + str(k+1)+'.txt')
data_parameters['validation_list'] = os.path.join(
data_folder_name, 'validation' + str(k+1)+'.txt')
training_parameters['final_model_output_file'])=final_model_output_file.replace(".pth.tar", str(k+1)+".pth.tar")
validation_loss=_train_runner(
data_parameters, training_parameters, network_parameters, misc_parameters)
k_fold_losses.append(validation_loss)
mean_k_fold_loss=k_fold_losses.mean()
def evaluate_score(training_parameters, network_parameters, misc_parameters, evaluation_parameters):
"""Mapping Score Evaluator
......@@ -205,17 +247,17 @@ def evaluate_score(training_parameters, network_parameters, misc_parameters, eva
# TODO - NEED TO UPDATE THE DATA FUNCTIONS!
logWriter = LogWriter(number_of_classes=network_parameters['number_of_classes'],
logs_directory=misc_parameters['logs_directory'],
experiment_name=training_parameters['experiment_name']
logWriter=LogWriter(number_of_classes = network_parameters['number_of_classes'],
logs_directory = misc_parameters['logs_directory'],
experiment_name = training_parameters['experiment_name']
)
prediction_output_path = os.path.join(misc_parameters['experiments_directory'],
prediction_output_path=os.path.join(misc_parameters['experiments_directory'],
training_parameters['experiment_name'],
evaluation_parameters['saved_predictions_directory']
)
_ = evaluations.evaluate_dice_score(trained_model_path=evaluation_parameters['trained_model_path'],
_=evaluations.evaluate_dice_score(trained_model_path = evaluation_parameters['trained_model_path'],
number_of_classes=network_parameters['number_of_classes'],
data_directory=evaluation_parameters['data_directory'],
targets_directory=evaluation_parameters[
......
......@@ -7,9 +7,9 @@ k_fold = None
data_split_flag = False
test_percentage = 5
subject_number = None
train_list = "train.txt"
validation_list = "validation.txt"
test_list = "test.txt"
train_list = "datasets/train.txt"
validation_list = "datasets/validation.txt"
test_list = "datasets/test.txt"
train_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
train_output_targets = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
validation_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
......
......@@ -23,7 +23,6 @@ from utils.data_logging_utils import LogWriter
from utils.early_stopping import EarlyStopping
from torch.optim import lr_scheduler
checkpoint_directory = 'checkpoints'
checkpoint_extension = 'path.tar'
......@@ -71,7 +70,8 @@ class Solver():
learning_rate_scheduler_gamma=0.5,
use_last_checkpoint=True,
experiment_directory='experiments',
logs_directory='logs'
logs_directory='logs',
checkpoint_directory = 'checkpoints'
):
self.model = model
......@@ -99,10 +99,12 @@ class Solver():
experiment_directory, experiment_name)
self.experiment_directory_path = experiment_directory_path
self.checkpoint_directory = checkpoint_directory
create_folder(experiment_directory)
create_folder(experiment_directory_path)
create_folder(os.path.join(
experiment_directory_path, checkpoint_directory))
experiment_directory_path, self.checkpoint_directory))
self.start_epoch = 1
self.start_iteration = 1
......@@ -220,6 +222,7 @@ class Solver():
early_stop, save_checkpoint = self.EarlyStopping(np.mean(losses))
self.early_stop = early_stop
if save_checkpoint == True:
validation_loss = np.mean(losses)
self.save_checkpoint(state={'epoch': epoch + 1,
'start_iteration': iteration + 1,
'arch': self.model_name,
......@@ -227,7 +230,7 @@ class Solver():
'optimizer': optimizer.state_dict(),
'scheduler': learning_rate_scheduler.state_dict()
},
filename=os.path.join(self.experiment_directory_path, checkpoint_directory,
filename=os.path.join(self.experiment_directory_path, self.checkpoint_directory,
'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
)
......@@ -251,6 +254,8 @@ class Solver():
print('Training Duration: {}'.format(end_time - start_time))
print('****************************************************************')
return validation_loss
def save_checkpoint(self, state, filename):
"""General Checkpoint Save
......@@ -273,11 +278,11 @@ class Solver():
if epoch is not None:
checkpoint_file_path = os.path.join(
self.experiment_directory_path, checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
self._checkpoint_reader(checkpoint_file_path)
else:
universal_path = os.path.join(
self.experiment_directory_path, checkpoint_directory, '*.' + checkpoint_extension)
self.experiment_directory_path, self.checkpoint_directory, '*.' + checkpoint_extension)
files_in_universal_path = glob.glob(universal_path)
# We will sort through all the files in path to see which one is most recent
......@@ -289,7 +294,7 @@ class Solver():
else:
self.LogWriter.log("No Checkpoint found at {}".format(
os.path.join(self.experiment_directory_path, checkpoint_directory)))
os.path.join(self.experiment_directory_path, self.checkpoint_directory)))
def _checkpoint_reader(self, checkpoint_file_path):
"""Checkpoint Reader
......
......@@ -81,7 +81,7 @@ def data_file_reader(data_file_path):
return subDirectoryList
def data_test_train_validation_split(data_folder_name, test_percentage, subject_number, data_directory= None, data_file= None, K_fold= None):
def data_test_train_validation_split(data_folder_name, test_percentage, subject_number, data_directory=None, data_file=None, K_fold=None):
"""Produces lists of train, test and validation data
This function looks at the list of all available directories and returns three lists of dsub-directories.
......@@ -110,30 +110,37 @@ def data_test_train_validation_split(data_folder_name, test_percentage, subject_
subDirectoryList = np.array(subDirectoryList)
create_folder(data_folder_name)
train_data, test = train_test_split(subDirectoryList, test_size= test_percentage/100, random_state= 42, shuffle= True)
train_data, test = train_test_split(
subDirectoryList, test_size=test_percentage/100, random_state=42, shuffle=True)
np.savetxt(os.path.join(data_folder_name, 'test.txt'), test, fmt='%s')
print("Test={}".format(test))
if K_fold is None:
train, validation = train_test_split(train_data, test_size= int(len(test)), random_state= 42, shuffle= True)
train, validation = train_test_split(
train_data, test_size=int(len(test)), random_state=42, shuffle=True)
np.savetxt(os.path.join(data_folder_name, 'train.txt'), train, fmt='%s')
np.savetxt(os.path.join(data_folder_name, 'validation.txt'), validation, fmt='%s')
np.savetxt(os.path.join(data_folder_name,
'train.txt'), train, fmt='%s')
np.savetxt(os.path.join(data_folder_name, 'validation.txt'),
validation, fmt='%s')
print("Train={}, Validation={}".format(train, validation))
else:
k_fold = KFold(n_splits= K_fold)
k_fold = KFold(n_splits=K_fold)
k = 0
for train_index, validation_index in k_fold.split(train_data):
train, validation = train_data[train_index], train_data[validation_index]
np.savetxt(os.path.join(data_folder_name, 'train'+str(k+1)+'.txt'), train, fmt='%s')
np.savetxt(os.path.join(data_folder_name, 'validation'+str(k+1)+'.txt'), validation, fmt='%s')
np.savetxt(os.path.join(data_folder_name, 'train' +
str(k+1)+'.txt'), train, fmt='%s')
np.savetxt(os.path.join(data_folder_name, 'validation' +
str(k+1)+'.txt'), validation, fmt='%s')
print("K={}, Train={}, Validation={}".format(k, train, validation))
k += 1
def update_shuffling_flag(file_name):
""" Update shuffling flag
......@@ -637,12 +644,10 @@ if __name__ == '__main__':
K_fold = None
subject_number = None
data_directory = "../well/win-biobank/projects/imaging/data/data3/subjectsAll/"
data_test_train_validation_split(data_folder_name, test_percentage, subject_number, data_directory = data_directory, K_fold= K_fold)
data_test_train_validation_split(
data_folder_name, test_percentage, subject_number, data_directory=data_directory, K_fold=K_fold)
# data_test_train_validation_split_Kfold_cross_validation(data_folder_name, K_fold, subject_number, data_directory = data_directory)
# data = np.arange(23)
# K = 10
# test_size = int(len(data)/K)
......@@ -655,7 +660,6 @@ if __name__ == '__main__':
# remainder[(k-1) * test_size: k *test_size], test_slice = test_slice, remainder[(k-1) * test_size: k * test_size].copy()
# print("k= {}, test_slice={}, remainder={}".format(k, test_slice, remainder))
# print('SKLEARN TIME!')
# from sklearn.model_selection import KFold, train_test_split
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment