Commit a2b51e13 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

TRAIN WORKS! + code clean, pep8 formated, debug coms deleted

parent 98e87284
...@@ -123,7 +123,7 @@ class BrainMapperUNet3D(nn.Module): ...@@ -123,7 +123,7 @@ class BrainMapperUNet3D(nn.Module):
Y_bottleNeck, Y_np4) Y_bottleNeck, Y_np4)
del Y_bottleNeck, Y_np4 del Y_bottleNeck, Y_np4
Y_decoder_2 = self.decoderBlock2.forward( Y_decoder_2 = self.decoderBlock2.forward(
Y_decoder_1, Y_np3) Y_decoder_1, Y_np3)
...@@ -209,6 +209,7 @@ class BrainMapperUNet3D(nn.Module): ...@@ -209,6 +209,7 @@ class BrainMapperUNet3D(nn.Module):
# DEPRECATED ARCHITECTURES! # DEPRECATED ARCHITECTURES!
class BrainMapperUNet(nn.Module): class BrainMapperUNet(nn.Module):
"""Architecture class BrainMapper U-net. """Architecture class BrainMapper U-net.
...@@ -293,7 +294,7 @@ class BrainMapperUNet(nn.Module): ...@@ -293,7 +294,7 @@ class BrainMapperUNet(nn.Module):
Y_bottleNeck, Y_np4, pool_indices4) Y_bottleNeck, Y_np4, pool_indices4)
del Y_bottleNeck, Y_np4, pool_indices4 del Y_bottleNeck, Y_np4, pool_indices4
Y_decoder_2 = self.decoderBlock2.forward( Y_decoder_2 = self.decoderBlock2.forward(
Y_decoder_1, Y_np3, pool_indices3) Y_decoder_1, Y_np3, pool_indices3)
...@@ -377,6 +378,7 @@ class BrainMapperUNet(nn.Module): ...@@ -377,6 +378,7 @@ class BrainMapperUNet(nn.Module):
return prediction return prediction
class BrainMapperUNet3D_Simple(nn.Module): class BrainMapperUNet3D_Simple(nn.Module):
"""Architecture class BrainMapper 3D U-net. """Architecture class BrainMapper 3D U-net.
...@@ -462,7 +464,7 @@ class BrainMapperUNet3D_Simple(nn.Module): ...@@ -462,7 +464,7 @@ class BrainMapperUNet3D_Simple(nn.Module):
Y_bottleNeck, Y_np4, pool_indices4) Y_bottleNeck, Y_np4, pool_indices4)
del Y_bottleNeck, Y_np4, pool_indices4 del Y_bottleNeck, Y_np4, pool_indices4
Y_decoder_2 = self.decoderBlock2.forward( Y_decoder_2 = self.decoderBlock2.forward(
Y_decoder_1, Y_np3, pool_indices3) Y_decoder_1, Y_np3, pool_indices3)
......
...@@ -131,7 +131,6 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -131,7 +131,6 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
pin_memory=True pin_memory=True
) )
if training_parameters['use_pre_trained']: if training_parameters['use_pre_trained']:
BrainMapperModel = torch.load(training_parameters['pre_trained_path']) BrainMapperModel = torch.load(training_parameters['pre_trained_path'])
else: else:
...@@ -216,17 +215,17 @@ def evaluate_score(training_parameters, network_parameters, misc_parameters, eva ...@@ -216,17 +215,17 @@ def evaluate_score(training_parameters, network_parameters, misc_parameters, eva
evaluation_parameters['saved_predictions_directory'] evaluation_parameters['saved_predictions_directory']
) )
average_dice_score = evaluations.evaluate_dice_score(trained_model_path=evaluation_parameters['trained_model_path'], _ = evaluations.evaluate_dice_score(trained_model_path=evaluation_parameters['trained_model_path'],
number_of_classes=network_parameters['number_of_classes'], number_of_classes=network_parameters['number_of_classes'],
data_directory=evaluation_parameters['data_directory'], data_directory=evaluation_parameters['data_directory'],
targets_directory=evaluation_parameters[ targets_directory=evaluation_parameters[
'targets_directory'], 'targets_directory'],
data_list=evaluation_parameters['data_list'], data_list=evaluation_parameters['data_list'],
orientation=evaluation_parameters['orientation'], orientation=evaluation_parameters['orientation'],
prediction_output_path=prediction_output_path, prediction_output_path=prediction_output_path,
device=misc_parameters['device'], device=misc_parameters['device'],
LogWriter=logWriter LogWriter=logWriter
) )
logWriter.close() logWriter.close()
...@@ -357,35 +356,39 @@ if __name__ == '__main__': ...@@ -357,35 +356,39 @@ if __name__ == '__main__':
if data_shuffling_flag == True: if data_shuffling_flag == True:
# Here we shuffle the data! # Here we shuffle the data!
data_test_train_validation_split(data_parameters['data_directory'], data_parameters['train_percentage'], data_parameters['validation_percentage']) data_test_train_validation_split(
data_parameters['data_directory'], data_parameters['train_percentage'], data_parameters['validation_percentage'])
update_shuffling_flag('settings.ini') update_shuffling_flag('settings.ini')
# TODO: This might also be a very good point to add cross-validation later # TODO: This might also be a very good point to add cross-validation later
else: else:
if arguments.mode == 'train': if arguments.mode == 'train':
train(data_parameters, training_parameters, train(data_parameters, training_parameters,
network_parameters, misc_parameters) network_parameters, misc_parameters)
# elif arguments.mode == 'evaluate-score':
# evaluate_score(training_parameters, # NOTE: THE EVAL FUNCTIONS HAVE NOT YET BEEN DEBUGGED (16/04/20)
# network_parameters, misc_parameters, evaluation_parameters)
# elif arguments.mode == 'evaluate-mapping': elif arguments.mode == 'evaluate-score':
# logging.basicConfig(filename='evaluate-mapping-error.log') evaluate_score(training_parameters,
# if arguments.settings_path is not None: network_parameters, misc_parameters, evaluation_parameters)
# settings_evaluation = Settings(arguments.settings_path) elif arguments.mode == 'evaluate-mapping':
# else: logging.basicConfig(filename='evaluate-mapping-error.log')
# settings_evaluation = Settings('settings_evaluation.ini') if arguments.settings_path is not None:
# mapping_evaluation_parameters = settings_evaluation['MAPPING'] settings_evaluation = Settings(arguments.settings_path)
# evaluate_mapping(mapping_evaluation_parameters) else:
# elif arguments.mode == 'clear-experiments': settings_evaluation = Settings('settings_evaluation.ini')
# shutil.rmtree(os.path.join( mapping_evaluation_parameters = settings_evaluation['MAPPING']
# misc_parameters['experiments_directory'], training_parameters['experiment_name'])) evaluate_mapping(mapping_evaluation_parameters)
# shutil.rmtree(os.path.join( elif arguments.mode == 'clear-experiments':
# misc_parameters['logs_directory'], training_parameters['experiment_name'])) shutil.rmtree(os.path.join(
# print('Cleared the current experiments and logs directory successfully!') misc_parameters['experiments_directory'], training_parameters['experiment_name']))
# elif arguments.mode == 'clear-everything': shutil.rmtree(os.path.join(
# delete_files(misc_parameters['experiments_directory']) misc_parameters['logs_directory'], training_parameters['experiment_name']))
# delete_files(misc_parameters['logs_directory']) print('Cleared the current experiments and logs directory successfully!')
# print('Cleared the current experiments and logs directory successfully!') elif arguments.mode == 'clear-everything':
# else: delete_files(misc_parameters['experiments_directory'])
# raise ValueError( delete_files(misc_parameters['logs_directory'])
# 'Invalid mode value! Only supports: train, evaluate-score, evaluate-mapping, clear-experiments and clear-everything') print('Cleared the current experiments and logs directory successfully!')
else:
raise ValueError(
'Invalid mode value! Only supports: train, evaluate-score, evaluate-mapping, clear-experiments and clear-everything')
...@@ -117,7 +117,6 @@ class Solver(): ...@@ -117,7 +117,6 @@ class Solver():
if use_last_checkpoint: if use_last_checkpoint:
self.load_checkpoint() self.load_checkpoint()
def train(self, train_loader, test_loader): def train(self, train_loader, test_loader):
"""Training Function """Training Function
...@@ -143,7 +142,8 @@ class Solver(): ...@@ -143,7 +142,8 @@ class Solver():
print('=====================') print('=====================')
print('Model Name: {}'.format(self.model_name)) print('Model Name: {}'.format(self.model_name))
if torch.cuda.is_available(): if torch.cuda.is_available():
print('Device Type: {}'.format(torch.cuda.get_device_name(self.device))) print('Device Type: {}'.format(
torch.cuda.get_device_name(self.device)))
else: else:
print('Device Type: {}'.format(self.device)) print('Device Type: {}'.format(self.device))
start_time = datetime.now() start_time = datetime.now()
...@@ -159,8 +159,6 @@ class Solver(): ...@@ -159,8 +159,6 @@ class Solver():
print('-> Phase: {}'.format(phase)) print('-> Phase: {}'.format(phase))
losses = [] losses = []
# outputs = []
# y_values = []
if phase == 'train': if phase == 'train':
model.train() model.train()
...@@ -173,8 +171,8 @@ class Solver(): ...@@ -173,8 +171,8 @@ class Solver():
y = sampled_batch[1].type(torch.FloatTensor) y = sampled_batch[1].type(torch.FloatTensor)
# We add an extra dimension (~ number of channels) for the 3D convolutions. # We add an extra dimension (~ number of channels) for the 3D convolutions.
X = torch.unsqueeze(X, dim= 1) X = torch.unsqueeze(X, dim=1)
y = torch.unsqueeze(y, dim= 1) y = torch.unsqueeze(y, dim=1)
if model.test_if_cuda: if model.test_if_cuda:
X = X.cuda(self.device, non_blocking=True) X = X.cuda(self.device, non_blocking=True)
...@@ -196,11 +194,7 @@ class Solver(): ...@@ -196,11 +194,7 @@ class Solver():
iteration += 1 iteration += 1
losses.append(loss.item()) losses.append(loss.item())
# outputs.append(torch.max(y_hat, dim=1)[1].cpu())
# y_values.append(y.cpu())
# Clear the memory # Clear the memory
...@@ -214,26 +208,9 @@ class Solver(): ...@@ -214,26 +208,9 @@ class Solver():
print("100%", flush=True) print("100%", flush=True)
with torch.no_grad(): with torch.no_grad():
# output_array, y_array = torch.cat(
# outputs), torch.cat(y_values)
self.LogWriter.loss_per_epoch(losses, phase, epoch) self.LogWriter.loss_per_epoch(losses, phase, epoch)
# dice_score_mean = self.LogWriter.dice_score_per_epoch(
# phase, output_array, y_array, epoch)
# if phase == 'test':
# if dice_score_mean > self.best_mean_score:
# self.best_mean_score = dice_score_mean
# self.best_mean_score_epoch = epoch
# index = np.random.choice(
# len(dataloaders[phase].dataset.X), size=3, replace=False)
# self.LogWriter.sample_image_per_epoch(prediction=model.predict(dataloaders[phase].dataset.X[index], self.device),
# ground_truth=dataloaders[phase].dataset.y[index],
# phase=phase,
# epoch=epoch)
print("Epoch {}/{} DONE!".format(epoch, self.number_epochs)) print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
self.save_checkpoint(state={'epoch': epoch + 1, self.save_checkpoint(state={'epoch': epoch + 1,
......
...@@ -65,7 +65,7 @@ class LogWriter(): ...@@ -65,7 +65,7 @@ class LogWriter():
self.log_writer = { self.log_writer = {
'train': SummaryWriter(logdir=training_logs_directory), 'train': SummaryWriter(logdir=training_logs_directory),
'test': SummaryWriter(logdir=testing_logs_directory) 'test': SummaryWriter(logdir=testing_logs_directory)
} }
self.confusion_matrix_color_map = confusion_matrix_cmap self.confusion_matrix_color_map = confusion_matrix_cmap
......
...@@ -560,13 +560,3 @@ def get_datasetsHDF5(data_parameters): ...@@ -560,13 +560,3 @@ def get_datasetsHDF5(data_parameters):
training_labels['label'][()]), training_labels['label'][()]),
DataMapperHDF5(testing_data['data'][()], testing_labels['label'][()]) DataMapperHDF5(testing_data['data'][()], testing_labels['label'][()])
) )
if __name__ == "__main__":
pass
# folder_location = "../well/win-biobank/projects/imaging/data/data3/subjectsAll/"
# # data_test_train_validation_split(folder_location, 90, 5)
# subDirectoryList = directory_reader(folder_location, write_txt=True)
# print(subDirectoryList)
# tract_sum_generator(folder_location)
\ No newline at end of file
...@@ -78,9 +78,10 @@ class MSELoss(_WeightedLoss): ...@@ -78,9 +78,10 @@ class MSELoss(_WeightedLoss):
return self.loss(X, y) return self.loss(X, y)
# DEPRECATED LOSSES # DEPRECATED LOSSES
# NOTE: THESE LOSSES ARE USUALLY USED FOR CLASSIFICATION TASKS. # NOTE: THESE LOSSES ARE USUALLY USED FOR CLASSIFICATION TASKS.
# THIS IS NOT A CLASSIFICATION TASK. THUS, THESE ARE IGNORED FOR NOW! # THIS IS NOT A CLASSIFICATION TASK. THUS, THESE ARE IGNORED FOR NOW!
class CrossEntropyLoss(_WeightedLoss): class CrossEntropyLoss(_WeightedLoss):
"""Cross Entropy Loss """Cross Entropy Loss
......
...@@ -500,7 +500,7 @@ class DecoderBlock3D(ConvolutionalBlock3D): ...@@ -500,7 +500,7 @@ class DecoderBlock3D(ConvolutionalBlock3D):
Y (torch.tensor): Output forward passed tensor through the decoder block Y (torch.tensor): Output forward passed tensor through the decoder block
""" """
# ATTENTION: As of this code version, only "upconv" works! Debugging is ongoing for upconv and upsample! # ATTENTION: As of this code version, only "upconv" works! Debugging is ongoing for upconv and upsample!
# It seems that errors are generated by variable filter sizes and the unorthodox input sizes 91x109x91. # It seems that errors are generated by variable filter sizes and the unorthodox input sizes 91x109x91.
if self.up_mode == 'unpool': if self.up_mode == 'unpool':
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment