Commit 7eacf841 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

updated docstrings, pep8 formated, fixed bugs

parent c442542d
"""Brain Mapper U-Net Architecture
Description:
-------------
This folder contains the Pytorch implementation of the core U-net architecture.
This arcitecture predicts functional connectivity rsfMRI from structural connectivity information from dMRI.
Usage
-------------
To use this module, import it and instantiate is as you wish:
This folder contains the Pytorch implementation of the core U-net architecture.
This arcitecture predicts functional connectivity rsfMRI from structural connectivity information from dMRI.
Usage:
To use this module, import it and instantiate is as you wish:
from BrainMapperUNet import BrainMapperUNet
deep_learning_model = BrainMapperUnet(parameters)
......@@ -19,6 +19,7 @@ import torch
import torch.nn as nn
import utils.modules as modules
class BrainMapperUNet(nn.Module):
"""Architecture class BrainMapper U-net.
......@@ -42,9 +43,6 @@ class BrainMapperUNet(nn.Module):
Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block
Raises:
None
"""
def __init__(self, parameters):
......@@ -80,23 +78,26 @@ class BrainMapperUNet(nn.Module):
Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block
Raises:
None
"""
Y_encoder_1, Y_np1, pool_indices1 = self.encoderBlock1.forward(X)
Y_encoder_2, Y_np2, pool_indices2 = self.encoderBlock2.forward(Y_encoder_1)
Y_encoder_3, Y_np3, pool_indices3 = self.encoderBlock3.forward(Y_encoder_2)
Y_encoder_4, Y_np4, pool_indices4 = self.encoderBlock4.forward(Y_encoder_3)
Y_encoder_2, Y_np2, pool_indices2 = self.encoderBlock2.forward(
Y_encoder_1)
Y_encoder_3, Y_np3, pool_indices3 = self.encoderBlock3.forward(
Y_encoder_2)
Y_encoder_4, Y_np4, pool_indices4 = self.encoderBlock4.forward(
Y_encoder_3)
Y_bottleNeck = self.bottleneck.forward(Y_encoder_4)
Y_decoder_1 = self.decoderBlock1.forward(Y_bottleNeck, Y_np4, pool_indices4)
Y_decoder_2 = self.decoderBlock2.forward(Y_decoder_1, Y_np3, pool_indices3)
Y_decoder_3 = self.decoderBlock3.forward(Y_decoder_2, Y_np2, pool_indices2)
Y_decoder_4 = self.decoderBlock4.forwrad(Y_decoder_3, Y_np1, pool_indices1)
Y_decoder_1 = self.decoderBlock1.forward(
Y_bottleNeck, Y_np4, pool_indices4)
Y_decoder_2 = self.decoderBlock2.forward(
Y_decoder_1, Y_np3, pool_indices3)
Y_decoder_3 = self.decoderBlock3.forward(
Y_decoder_2, Y_np2, pool_indices2)
Y_decoder_4 = self.decoderBlock4.forwrad(
Y_decoder_3, Y_np1, pool_indices1)
probability_map = self.classifier.forward(Y_decoder_4)
......@@ -110,12 +111,6 @@ class BrainMapperUNet(nn.Module):
Args:
path (str): Path string
Returns:
None
Raises:
None
"""
print("Saving Model... {}".format(path))
......@@ -127,19 +122,13 @@ class BrainMapperUNet(nn.Module):
This function tests if the model parameters are allocated to a CUDA enabled GPU.
Args:
None
Returns:
bool: Flag indicating True if the tensor is stored on the GPU and Flase otherwhise
Raises:
None
"""
return next(self.parameters()).is_cuda
def predict(self, X, device= 0):
def predict(self, X, device=0):
"""Post-training Output Prediction
This function predicts the output of the of the U-net post-training
......@@ -151,9 +140,6 @@ class BrainMapperUNet(nn.Module):
Returns:
prediction (ndarray): predicted output after training
Raises:
None
"""
self.eval() # PyToch module setting network to evaluation mode
......@@ -170,7 +156,8 @@ class BrainMapperUNet(nn.Module):
_, idx = torch.max(output, 1)
idx = idx.data.cpu().numpy() # We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray
# We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray
idx = idx.data.cpu().numpy()
prediction = np.squeeze(idx)
......
This diff is collapsed.
......@@ -64,9 +64,8 @@ def _parse_values(configurator):
settings_dictionary = {}
for section in configurator.sections():
settings_dictionary[section] = {}
for key, value in configurator[section].items()
for key, value in configurator[section].items():
# Safely evaluate an expression node or a Unicode or Latin-1 encoded string containing a Python expression
settings_dictionary[section][key] = ast.literal_eval(value)
return settings_dictionary
return settings_dictionary
......@@ -16,6 +16,5 @@ setup(
'torch',
'h5py',
'tensorboardX',
],
)
"""Brain Mapper U-Net Solver
Description:
-------------
This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
Usage
-------------
To use this module, import it and instantiate is as you wish:
This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
from solver import Solver
Usage:
To use this module, import it and instantiate is as you wish:
from solver import Solver
"""
import os
......@@ -26,6 +25,7 @@ from torch.optim import lr_scheduler
checkpoint_directory = 'checkpoints'
checkpoint_extension = 'path.tar'
class Solver():
"""Solver class for the BrainMapper U-net.
......@@ -52,8 +52,6 @@ class Solver():
Returns:
trained model(?) - working on this!
Raises:
None
"""
def __init__(self,
......@@ -61,18 +59,18 @@ class Solver():
device,
number_of_classes,
experiment_name,
optimizer = torch.optim.Adam,
optimizer_arguments = {},
loss_function = MSELoss(),
model_name = 'BrainMapper',
labels = None,
number_epochs = 10,
loss_log_period = 5,
learning_rate_scheduler_step_size = 5,
learning_rate_scheduler_gamma = 0.5,
use_last_checkpoint = True,
experiment_directory = 'experiments',
logs_directory = 'logs'
optimizer=torch.optim.Adam,
optimizer_arguments={},
loss_function=MSELoss(),
model_name='BrainMapper',
labels=None,
number_epochs=10,
loss_log_period=5,
learning_rate_scheduler_step_size=5,
learning_rate_scheduler_gamma=0.5,
use_last_checkpoint=True,
experiment_directory='experiments',
logs_directory='logs'
):
self.model = model
......@@ -91,16 +89,18 @@ class Solver():
# We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer,
step_size = learning_rate_scheduler_step_size,
gamma= learning_rate_scheduler_gamma)
step_size=learning_rate_scheduler_step_size,
gamma=learning_rate_scheduler_gamma)
self.use_last_checkpoint = use_last_checkpoint
experiment_directory_path = os.join.path(experiment_directory, experiment_name)
experiment_directory_path = os.join.path(
experiment_directory, experiment_name)
self.experiment_directory_path = experiment_directory_path
create_folder(experiment_directory_path)
create_folder(os.join.path(experiment_directory_path, checkpoint_directory))
create_folder(os.join.path(
experiment_directory_path, checkpoint_directory))
self.start_epoch = 1
self.start_iteration = 1
......@@ -110,12 +110,11 @@ class Solver():
if use_last_checkpoint:
self.load_checkpoint()
self.LogWriter = LogWriter(number_of_classes= number_of_classes,
logs_directory= logs_directory,
experiment_name= experiment_name,
use_last_checkpoint= use_last_checkpoint,
labels= labels)
self.LogWriter = LogWriter(number_of_classes=number_of_classes,
logs_directory=logs_directory,
experiment_name=experiment_name,
use_last_checkpoint=use_last_checkpoint,
labels=labels)
def train(self, train_loader, test_loader):
"""Training Function
......@@ -127,10 +126,7 @@ class Solver():
test_loader (class): Combined dataset and sampler, providing an iterable over the testing dataset (torch.utils.data.DataLoader)
Returns:
None: trained model
Raises:
None
trained model
"""
model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler
......@@ -172,8 +168,8 @@ class Solver():
y = sampled_batch[1].type(torch.LongTensor)
if model.is_cuda():
X = X.cuda(self.device, non_blocking= True)
y = y.cuda(self.device, non_blocking= True)
X = X.cuda(self.device, non_blocking=True)
y = y.cuda(self.device, non_blocking=True)
y_hat = model(X) # Forward pass
......@@ -186,7 +182,8 @@ class Solver():
if batch_index % self.loss_log_period == 0:
self.LogWriter.loss_per_iteration(loss.item(), batch_index, iteration)
self.LogWriter.loss_per_iteration(
loss.item(), batch_index, iteration)
iteration += 1
......@@ -206,25 +203,27 @@ class Solver():
print("100%", flush=True)
with torch.no_grad():
output_array, y_array = torch.cat(outputs), torch.cat(y_values)
output_array, y_array = torch.cat(
outputs), torch.cat(y_values)
self.LogWriter.loss_per_epoch(losses, phase, epoch)
dice_score_mean = self.LogWriter.dice_score_per_epoch(phase, output_array, y_array, epoch)
dice_score_mean = self.LogWriter.dice_score_per_epoch(
phase, output_array, y_array, epoch)
if phase == 'test':
if dice_score_mean > self.best_mean_score:
self.best_mean_score = dice_score_mean
self.best_mean_score_epoch = epoch
index = np.random.choice(len(dataloaders[phase].dataset.X), size=3, replace= False)
self.LogWriter.sample_image_per_epoch(prediction= model.predict(dataloaders[phase].dataset.X[index], self.device),
ground_truth= dataloaders[phase].dataset.y[index],
phase= phase,
epoch= epoch)
index = np.random.choice(
len(dataloaders[phase].dataset.X), size=3, replace=False)
self.LogWriter.sample_image_per_epoch(prediction=model.predict(dataloaders[phase].dataset.X[index], self.device),
ground_truth=dataloaders[phase].dataset.y[index],
phase=phase,
epoch=epoch)
print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
self.save_checkpoint(state={'epoch': epoch + 1,
'start_iteration': iteration + 1,
'arch': self.model_name,
......@@ -232,7 +231,8 @@ class Solver():
'optimizer': optimizer.state_dict(),
'scheduler': learning_rate_scheduler.state_dict()
},
filename= os.path.join(self.experiment_directory_path, checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
filename=os.path.join(self.experiment_directory_path, checkpoint_directory,
'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
)
self.LogWriter.close()
......@@ -252,48 +252,38 @@ class Solver():
Args:
state (dict): Dictionary of all the relevant model components
Returns:
None
Raises:
None
"""
torch.save(state, filename)
def load_checkpoint(self, epoch= None):
def load_checkpoint(self, epoch=None):
"""General Checkpoint Loader
This function loads a previous checkpoint for inference and/or resuming training
Args:
epoch (int): Current epoch value
Returns:
None
Raises:
None
"""
if epoch is None:
checkpoint_file_path = os.path.join(self.experiment_directory_path, checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
checkpoint_file_path = os.path.join(
self.experiment_directory_path, checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
self._checkpoint_reader(checkpoint_file_path)
else:
universal_path = os.path.join(self.experiment_directory_path, checkpoint_directory, '*.' + checkpoint_extension)
universal_path = os.path.join(
self.experiment_directory_path, checkpoint_directory, '*.' + checkpoint_extension)
files_in_universal_path = glob.glob(universal_path)
# We will sort through all the files in path to see which one is most recent
if len(files_in_universal_path) > 0:
checkpoint_file_path = max(files_in_universal_path, key= os.path.getatime)
checkpoint_file_path = max(
files_in_universal_path, key=os.path.getatime)
self._checkpoint_reader(checkpoint_file_path)
else:
self.LogWriter.log("No Checkpoint found at {}".format(os.path.join(self.experiment_directory_path, checkpoint_directory)))
self.LogWriter.log("No Checkpoint found at {}".format(
os.path.join(self.experiment_directory_path, checkpoint_directory)))
def _checkpoint_reader(self, checkpoint_file_path):
"""Checkpoint Reader
......@@ -302,15 +292,10 @@ class Solver():
Args:
checkpoint_file_path (str): path to checkpoint file
Returns:
None
Raises:
None
"""
self.LogWriter.log("Loading Checkpoint {}".format(checkpoint_file_path))
self.LogWriter.log(
"Loading Checkpoint {}".format(checkpoint_file_path))
checkpoint = torch.load(checkpoint_file_path)
self.start_epoch = checkpoint['epoch']
......@@ -325,4 +310,5 @@ class Solver():
if torch.is_tensor(value):
state[key] = value.to(self.device)
self.LogWriter.log("Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))
\ No newline at end of file
self.LogWriter.log(
"Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))
This diff is collapsed.
"""Data Logging Functions
Description:
-------------
This folder contains several functions which, either on their own or included in larger pieces of software, perform data logging tasks.
Usage
-------------
To use content from this folder, import the functions and instantiate them as you wish to use them:
This folder contains several functions which, either on their own or included in larger pieces of software, perform data logging tasks.
Usage:
To use content from this folder, import the functions and instantiate them as you wish to use them:
from utils.data_logging_utils import function_name
......@@ -30,6 +30,7 @@ import utils.data_evaluation_utils as evaluation
plt.axis('scaled')
class LogWriter():
"""Log Writer class for the BrainMapper U-net.
......@@ -44,19 +45,15 @@ class LogWriter():
use_last_checkpoint (bool): Flag for loading the previous checkpoint
labels (arr): Vector/Array of labels (if applicable)
confusion_matrix_cmap (class): Colour Map to be used for the Conusion Matrix
Returns:
None
Raises:
None
"""
def __init__(self, number_of_classes, logs_directory, experiment_name, use_last_checkpoint=False, labels=None, confusion_matrix_cmap= plt.cm.Blues):
def __init__(self, number_of_classes, logs_directory, experiment_name, use_last_checkpoint=False, labels=None, confusion_matrix_cmap=plt.cm.Blues):
self.number_of_classes = number_of_classes
training_logs_directory = os.path.join(logs_directory, experiment_name, "train")
testing_logs_directory = os.path.join(logs_directory, experiment_name, "test")
training_logs_directory = os.path.join(
logs_directory, experiment_name, "train")
testing_logs_directory = os.path.join(
logs_directory, experiment_name, "test")
# If the logs directory exist, we clear their contents to allow new logs to be created
if not use_last_checkpoint:
......@@ -66,8 +63,8 @@ class LogWriter():
shutil.rmtree(testing_logs_directory)
self.log_writer = {
'train': SummaryWriter(logdir= training_logs_directory),
'test:': SummaryWriter(logdir= testing_logs_directory)
'train': SummaryWriter(logdir=training_logs_directory),
'test:': SummaryWriter(logdir=testing_logs_directory)
}
self.confusion_matrix_color_map = confusion_matrix_cmap
......@@ -77,7 +74,8 @@ class LogWriter():
self.labels = self.labels_generator(labels)
self.logger = logging.getLogger()
file_handler = logging.FileHandler("{}/{}.log".format(os.path.join(logs_directory, experiment_name), "console_logs"))
file_handler = logging.FileHandler(
"{}/{}.log".format(os.path.join(logs_directory, experiment_name), "console_logs"))
self.logger.addHandler(file_handler)
def labels_generator(self, labels):
......@@ -90,17 +88,16 @@ class LogWriter():
Returns:
label_classes (list): List of processed labels
Raises:
None
"""
label_classes = []
for label in labels:
label_class = re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', label)
label_class = ['\n'.join(wrap(element, 40)) for element in label_class]
label_class = re.sub(
r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', label)
label_class = ['\n'.join(wrap(element, 40))
for element in label_class]
label_classes.append(label_class)
return label_classes
......@@ -112,15 +109,9 @@ class LogWriter():
Args:
message (str): Message to be logged
Returns:
None
Raises:
None
"""
self.logger.info(msg= message)
self.logger.info(msg=message)
def loss_per_iteration(self, loss_per_iteration, batch_index, iteration):
"""Log of loss / iteration
......@@ -131,16 +122,12 @@ class LogWriter():
loss_per_iteration (torch.tensor): Value of loss for every iteration step
batch_index (int): Index of current batch
iteration (int): Current iteration value
Returns:
None
Raises:
None
"""
print("Loss for Iteration {} is: {}".format(batch_index, loss_per_iteration))
self.log_writer['train'].add_scalar('loss / iteration', loss_per_iteration, iteration)
print("Loss for Iteration {} is: {}".format(
batch_index, loss_per_iteration))
self.log_writer['train'].add_scalar(
'loss / iteration', loss_per_iteration, iteration)
def loss_per_epoch(self, losses, phase, epoch):
"""Log function
......@@ -151,12 +138,6 @@ class LogWriter():
losses (list): Values of all the losses recorded during the training epoch
phase (str): Current run mode or phase
epoch (int): Current epoch value
Returns:
None
Raises:
None
"""
if phase == 'train':
......@@ -180,18 +161,14 @@ class LogWriter():
outputs (torch.tensor): Tensor of all the network outputs (Y-hat)
correct_labels (torch.tensor): Output ground-truth labelled data (Y)
epoch (int): Current epoch value
Returns:
mean_dice_score (torch.tensor): Mean dice score value
Raises
None
"""
print("Dice Score is being calculated...", end='', flush= True)
dice_score = evaluation.dice_score_calculator(outputs, correct_labels, self.number_of_classes)
print("Dice Score is being calculated...", end='', flush=True)
dice_score = evaluation.dice_score_calculator(
outputs, correct_labels, self.number_of_classes)
mean_dice_score = torch.mean(dice_score)
self.plot_dice_score(dice_score, phase, plot_name='dice_score_per_epoch', title='Dice Score', epochs=epoch)
self.plot_dice_score(
dice_score, phase, plot_name='dice_score_per_epoch', title='Dice Score', epochs=epoch)
print("Dice Score calculated successfully")
return mean_dice_score.item()
......@@ -206,12 +183,6 @@ class LogWriter():
plot_name (str): Caption name for later refference
title (str): Plot title
epoch (int): Current epoch value
Returns:
None
Raises
None
"""
figure = matplotlib.figure.Figure() # Might add some arguments here later
......@@ -224,7 +195,8 @@ class LogWriter():
ax.xaxis.tick_bottom()
if epochs:
self.log_writer[phase].add_figure(plot_name + '/' + phase, figure, global_step= epochs)
self.log_writer[phase].add_figure(
plot_name + '/' + phase, figure, global_step=epochs)
else:
self.log_writer[phase].add_figure(plot_name + '/' + phase, figure)
......@@ -240,16 +212,10 @@ class LogWriter():
ground_truth (torch.tensor): Labelled ground truth image
phase (str): Current run mode or phase
epoch (int): Current epoch value
Returns:
None
Raises
None
"""
print("Sample Image is being loaded...", end='', flush= True)
figure, ax = plt.subplots(nrows = len(prediction), ncols = 2)
print("Sample Image is being loaded...", end='', flush=True)
figure, ax = plt.subplots(nrows=len(prediction), ncols=2)
for i in range(len(prediction)):
ax[i][0].imshow(prediction[i])
......@@ -261,7 +227,8 @@ class LogWriter():
ax[i][1].axis('off')
figure.set_tight_layout()
self.log_writer[phase].add_figure('sample_prediction/'+phase, figure, epoch)
self.log_writer[phase].add_figure(
'sample_prediction/'+phase, figure, epoch)
print("Sample Image successfully loaded!")
......@@ -269,15 +236,6 @@ class LogWriter():