Commit 7eacf841 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

updated docstrings, pep8 formated, fixed bugs

parent c442542d
"""Brain Mapper U-Net Architecture
Description:
-------------
This folder contains the Pytorch implementation of the core U-net architecture.
This arcitecture predicts functional connectivity rsfMRI from structural connectivity information from dMRI.
Usage
-------------
To use this module, import it and instantiate is as you wish:
This folder contains the Pytorch implementation of the core U-net architecture.
This arcitecture predicts functional connectivity rsfMRI from structural connectivity information from dMRI.
from BrainMapperUNet import BrainMapperUNet
deep_learning_model = BrainMapperUnet(parameters)
Usage:
To use this module, import it and instantiate is as you wish:
from BrainMapperUNet import BrainMapperUNet
deep_learning_model = BrainMapperUnet(parameters)
"""
......@@ -19,6 +19,7 @@ import torch
import torch.nn as nn
import utils.modules as modules
class BrainMapperUNet(nn.Module):
"""Architecture class BrainMapper U-net.
......@@ -39,14 +40,11 @@ class BrainMapperUNet(nn.Module):
'up_mode': 'upconv'
'number_of_classes': 1
}
Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block
Raises:
None
"""
def __init__(self, parameters):
super(BrainMapperUNet, self).__init__()
......@@ -80,44 +78,41 @@ class BrainMapperUNet(nn.Module):
Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block
Raises:
None
"""
Y_encoder_1, Y_np1, pool_indices1 = self.encoderBlock1.forward(X)
Y_encoder_2, Y_np2, pool_indices2 = self.encoderBlock2.forward(Y_encoder_1)
Y_encoder_3, Y_np3, pool_indices3 = self.encoderBlock3.forward(Y_encoder_2)
Y_encoder_4, Y_np4, pool_indices4 = self.encoderBlock4.forward(Y_encoder_3)
Y_encoder_2, Y_np2, pool_indices2 = self.encoderBlock2.forward(
Y_encoder_1)
Y_encoder_3, Y_np3, pool_indices3 = self.encoderBlock3.forward(
Y_encoder_2)
Y_encoder_4, Y_np4, pool_indices4 = self.encoderBlock4.forward(
Y_encoder_3)
Y_bottleNeck = self.bottleneck.forward(Y_encoder_4)
Y_decoder_1 = self.decoderBlock1.forward(Y_bottleNeck, Y_np4, pool_indices4)
Y_decoder_2 = self.decoderBlock2.forward(Y_decoder_1, Y_np3, pool_indices3)
Y_decoder_3 = self.decoderBlock3.forward(Y_decoder_2, Y_np2, pool_indices2)
Y_decoder_4 = self.decoderBlock4.forwrad(Y_decoder_3, Y_np1, pool_indices1)
Y_decoder_1 = self.decoderBlock1.forward(
Y_bottleNeck, Y_np4, pool_indices4)
Y_decoder_2 = self.decoderBlock2.forward(
Y_decoder_1, Y_np3, pool_indices3)
Y_decoder_3 = self.decoderBlock3.forward(
Y_decoder_2, Y_np2, pool_indices2)
Y_decoder_4 = self.decoderBlock4.forwrad(
Y_decoder_3, Y_np1, pool_indices1)
probability_map = self.classifier.forward(Y_decoder_4)
return probability_map
def save(self, path):
"""Model Saver
Function saving the model with all its parameters to a given path.
The path must end with a *.model argument.
Args:
path (str): Path string
Returns:
None
Raises:
None
"""
print("Saving Model... {}".format(path))
torch.save(self, path)
......@@ -127,21 +122,15 @@ class BrainMapperUNet(nn.Module):
This function tests if the model parameters are allocated to a CUDA enabled GPU.
Args:
None
Returns:
bool: Flag indicating True if the tensor is stored on the GPU and Flase otherwhise
Raises:
None
"""
return next(self.parameters()).is_cuda
def predict(self, X, device= 0):
def predict(self, X, device=0):
"""Post-training Output Prediction
This function predicts the output of the of the U-net post-training
Args:
......@@ -151,11 +140,8 @@ class BrainMapperUNet(nn.Module):
Returns:
prediction (ndarray): predicted output after training
Raises:
None
"""
self.eval() # PyToch module setting network to evaluation mode
self.eval() # PyToch module setting network to evaluation mode
if type(X) is np.ndarray:
X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor)
......@@ -165,13 +151,14 @@ class BrainMapperUNet(nn.Module):
# .cuda() call transfers the densor from the CPU to the GPU if that is the case.
# Non-blocking argument lets the caller bypas synchronization when necessary
with torch.no_grad(): # Causes operations to have no gradients
with torch.no_grad(): # Causes operations to have no gradients
output = self.forward(X)
_, idx = torch.max(output, 1)
idx = idx.data.cpu().numpy() # We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray
# We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray
idx = idx.data.cpu().numpy()
prediction = np.squeeze(idx)
del X, output, idx
......
This diff is collapsed.
......@@ -64,9 +64,8 @@ def _parse_values(configurator):
settings_dictionary = {}
for section in configurator.sections():
settings_dictionary[section] = {}
for key, value in configurator[section].items()
# Safely evaluate an expression node or a Unicode or Latin-1 encoded string containing a Python expression
settings_dictionary[section][key] = ast.literal_eval(value)
return settings_dictionary
for key, value in configurator[section].items():
# Safely evaluate an expression node or a Unicode or Latin-1 encoded string containing a Python expression
settings_dictionary[section][key] = ast.literal_eval(value)
return settings_dictionary
......@@ -16,6 +16,5 @@ setup(
'torch',
'h5py',
'tensorboardX',
],
)
"""Brain Mapper U-Net Solver
Description:
-------------
This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
Usage
-------------
To use this module, import it and instantiate is as you wish:
This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
from solver import Solver
Usage:
To use this module, import it and instantiate is as you wish:
from solver import Solver
"""
import os
......@@ -26,6 +25,7 @@ from torch.optim import lr_scheduler
checkpoint_directory = 'checkpoints'
checkpoint_extension = 'path.tar'
class Solver():
"""Solver class for the BrainMapper U-net.
......@@ -52,28 +52,26 @@ class Solver():
Returns:
trained model(?) - working on this!
Raises:
None
"""
def __init__(self,
model,
device,
number_of_classes,
experiment_name,
optimizer = torch.optim.Adam,
optimizer_arguments = {},
loss_function = MSELoss(),
model_name = 'BrainMapper',
labels = None,
number_epochs = 10,
loss_log_period = 5,
learning_rate_scheduler_step_size = 5,
learning_rate_scheduler_gamma = 0.5,
use_last_checkpoint = True,
experiment_directory = 'experiments',
logs_directory = 'logs'
):
model,
device,
number_of_classes,
experiment_name,
optimizer=torch.optim.Adam,
optimizer_arguments={},
loss_function=MSELoss(),
model_name='BrainMapper',
labels=None,
number_epochs=10,
loss_log_period=5,
learning_rate_scheduler_step_size=5,
learning_rate_scheduler_gamma=0.5,
use_last_checkpoint=True,
experiment_directory='experiments',
logs_directory='logs'
):
self.model = model
self.device = device
......@@ -88,19 +86,21 @@ class Solver():
self.labels = labels
self.number_epochs = number_epochs
self.loss_log_period = loss_log_period
# We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer,
step_size = learning_rate_scheduler_step_size,
gamma= learning_rate_scheduler_gamma)
step_size=learning_rate_scheduler_step_size,
gamma=learning_rate_scheduler_gamma)
self.use_last_checkpoint = use_last_checkpoint
experiment_directory_path = os.join.path(experiment_directory, experiment_name)
experiment_directory_path = os.join.path(
experiment_directory, experiment_name)
self.experiment_directory_path = experiment_directory_path
create_folder(experiment_directory_path)
create_folder(os.join.path(experiment_directory_path, checkpoint_directory))
create_folder(os.join.path(
experiment_directory_path, checkpoint_directory))
self.start_epoch = 1
self.start_iteration = 1
......@@ -110,12 +110,11 @@ class Solver():
if use_last_checkpoint:
self.load_checkpoint()
self.LogWriter = LogWriter(number_of_classes= number_of_classes,
logs_directory= logs_directory,
experiment_name= experiment_name,
use_last_checkpoint= use_last_checkpoint,
labels= labels)
self.LogWriter = LogWriter(number_of_classes=number_of_classes,
logs_directory=logs_directory,
experiment_name=experiment_name,
use_last_checkpoint=use_last_checkpoint,
labels=labels)
def train(self, train_loader, test_loader):
"""Training Function
......@@ -127,18 +126,15 @@ class Solver():
test_loader (class): Combined dataset and sampler, providing an iterable over the testing dataset (torch.utils.data.DataLoader)
Returns:
None: trained model
Raises:
None
trained model
"""
model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler
dataloaders = {'train': train_loader, 'test': test_loader}
if torch.cuda.is_available():
torch.cuda.empty_cache() # clear memory
model.cuda(self.device) # Moving the model to GPU
torch.cuda.empty_cache() # clear memory
model.cuda(self.device) # Moving the model to GPU
print('****************************************************************')
print('TRAINING IS STARTING!')
......@@ -172,25 +168,26 @@ class Solver():
y = sampled_batch[1].type(torch.LongTensor)
if model.is_cuda():
X = X.cuda(self.device, non_blocking= True)
y = y.cuda(self.device, non_blocking= True)
X = X.cuda(self.device, non_blocking=True)
y = y.cuda(self.device, non_blocking=True)
y_hat = model(X) # Forward pass
loss = self.loss_function(y_hat, y) # Loss computation
loss = self.loss_function(y_hat, y) # Loss computation
if phase == 'train':
optimizer.zero_grad() # Zero the parameter gradients
loss.backward() # Backward propagation
optimizer.zero_grad() # Zero the parameter gradients
loss.backward() # Backward propagation
optimizer.step()
if batch_index % self.loss_log_period == 0:
self.LogWriter.loss_per_iteration(loss.item(), batch_index, iteration)
iteration += 1
self.LogWriter.loss_per_iteration(
loss.item(), batch_index, iteration)
losses.append(loss.item())
iteration += 1
losses.append(loss.item())
outputs.append(torch.max(y_hat, dim=1)[1].cpu())
y_values.append(y.cpu())
......@@ -206,25 +203,27 @@ class Solver():
print("100%", flush=True)
with torch.no_grad():
output_array, y_array = torch.cat(outputs), torch.cat(y_values)
output_array, y_array = torch.cat(
outputs), torch.cat(y_values)
self.LogWriter.loss_per_epoch(losses, phase, epoch)
dice_score_mean = self.LogWriter.dice_score_per_epoch(phase, output_array, y_array, epoch)
dice_score_mean = self.LogWriter.dice_score_per_epoch(
phase, output_array, y_array, epoch)
if phase == 'test':
if dice_score_mean > self.best_mean_score:
self.best_mean_score = dice_score_mean
self.best_mean_score_epoch = epoch
index = np.random.choice(len(dataloaders[phase].dataset.X), size=3, replace= False)
self.LogWriter.sample_image_per_epoch(prediction= model.predict(dataloaders[phase].dataset.X[index], self.device),
ground_truth= dataloaders[phase].dataset.y[index],
phase= phase,
epoch= epoch)
index = np.random.choice(
len(dataloaders[phase].dataset.X), size=3, replace=False)
self.LogWriter.sample_image_per_epoch(prediction=model.predict(dataloaders[phase].dataset.X[index], self.device),
ground_truth=dataloaders[phase].dataset.y[index],
phase=phase,
epoch=epoch)
print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
self.save_checkpoint(state={'epoch': epoch + 1,
'start_iteration': iteration + 1,
'arch': self.model_name,
......@@ -232,8 +231,9 @@ class Solver():
'optimizer': optimizer.state_dict(),
'scheduler': learning_rate_scheduler.state_dict()
},
filename= os.path.join(self.experiment_directory_path, checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
)
filename=os.path.join(self.experiment_directory_path, checkpoint_directory,
'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
)
self.LogWriter.close()
......@@ -252,48 +252,38 @@ class Solver():
Args:
state (dict): Dictionary of all the relevant model components
Returns:
None
Raises:
None
"""
torch.save(state, filename)
def load_checkpoint(self, epoch= None):
def load_checkpoint(self, epoch=None):
"""General Checkpoint Loader
This function loads a previous checkpoint for inference and/or resuming training
Args:
epoch (int): Current epoch value
Returns:
None
Raises:
None
"""
if epoch is None:
checkpoint_file_path = os.path.join(self.experiment_directory_path, checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
checkpoint_file_path = os.path.join(
self.experiment_directory_path, checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
self._checkpoint_reader(checkpoint_file_path)
else:
universal_path = os.path.join(self.experiment_directory_path, checkpoint_directory, '*.' + checkpoint_extension)
universal_path = os.path.join(
self.experiment_directory_path, checkpoint_directory, '*.' + checkpoint_extension)
files_in_universal_path = glob.glob(universal_path)
# We will sort through all the files in path to see which one is most recent
if len(files_in_universal_path) > 0:
checkpoint_file_path = max(files_in_universal_path, key= os.path.getatime)
checkpoint_file_path = max(
files_in_universal_path, key=os.path.getatime)
self._checkpoint_reader(checkpoint_file_path)
else:
self.LogWriter.log("No Checkpoint found at {}".format(os.path.join(self.experiment_directory_path, checkpoint_directory)))
self.LogWriter.log("No Checkpoint found at {}".format(
os.path.join(self.experiment_directory_path, checkpoint_directory)))
def _checkpoint_reader(self, checkpoint_file_path):
"""Checkpoint Reader
......@@ -302,20 +292,15 @@ class Solver():
Args:
checkpoint_file_path (str): path to checkpoint file
Returns:
None
Raises:
None
"""
self.LogWriter.log("Loading Checkpoint {}".format(checkpoint_file_path))
self.LogWriter.log(
"Loading Checkpoint {}".format(checkpoint_file_path))
checkpoint = torch.load(checkpoint_file_path)
self.start_epoch = checkpoint['epoch']
self.start_iteration = checkpoint['start_iteration']
# We are not loading the model_name as we might want to pre-train a model and then use it.
# We are not loading the model_name as we might want to pre-train a model and then use it.
self.model.load_state_dict = checkpoint['state_dict']
self.optimizer.load_state_dict = checkpoint['optimizer']
self.learning_rate_scheduler.load_state_dict = checkpoint['scheduler']
......@@ -325,4 +310,5 @@ class Solver():
if torch.is_tensor(value):
state[key] = value.to(self.device)
self.LogWriter.log("Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))
\ No newline at end of file
self.LogWriter.log(
"Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))
This diff is collapsed.
"""Data Logging Functions
Description:
-------------
This folder contains several functions which, either on their own or included in larger pieces of software, perform data logging tasks.
Usage
-------------
To use content from this folder, import the functions and instantiate them as you wish to use them:
This folder contains several functions which, either on their own or included in larger pieces of software, perform data logging tasks.
from utils.data_logging_utils import function_name
Usage:
To use content from this folder, import the functions and instantiate them as you wish to use them:
from utils.data_logging_utils import function_name
"""
......@@ -30,6 +30,7 @@ import utils.data_evaluation_utils as evaluation
plt.axis('scaled')
class LogWriter():
"""Log Writer class for the BrainMapper U-net.
......@@ -44,19 +45,15 @@ class LogWriter():
use_last_checkpoint (bool): Flag for loading the previous checkpoint
labels (arr): Vector/Array of labels (if applicable)
confusion_matrix_cmap (class): Colour Map to be used for the Conusion Matrix
Returns:
None
Raises:
None
"""
def __init__(self, number_of_classes, logs_directory, experiment_name, use_last_checkpoint=False, labels=None, confusion_matrix_cmap= plt.cm.Blues):
def __init__(self, number_of_classes, logs_directory, experiment_name, use_last_checkpoint=False, labels=None, confusion_matrix_cmap=plt.cm.Blues):
self.number_of_classes = number_of_classes
training_logs_directory = os.path.join(logs_directory, experiment_name, "train")
testing_logs_directory = os.path.join(logs_directory, experiment_name, "test")
training_logs_directory = os.path.join(
logs_directory, experiment_name, "train")
testing_logs_directory = os.path.join(
logs_directory, experiment_name, "test")
# If the logs directory exist, we clear their contents to allow new logs to be created
if not use_last_checkpoint:
......@@ -66,8 +63,8 @@ class LogWriter():
shutil.rmtree(testing_logs_directory)
self.log_writer = {
'train': SummaryWriter(logdir= training_logs_directory),
'test:': SummaryWriter(logdir= testing_logs_directory)
'train': SummaryWriter(logdir=training_logs_directory),
'test:': SummaryWriter(logdir=testing_logs_directory)
}
self.confusion_matrix_color_map = confusion_matrix_cmap
......@@ -77,7 +74,8 @@ class LogWriter():
self.labels = self.labels_generator(labels)
self.logger = logging.getLogger()
file_handler = logging.FileHandler("{}/{}.log".format(os.path.join(logs_directory, experiment_name), "console_logs"))
file_handler = logging.FileHandler(
"{}/{}.log".format(os.path.join(logs_directory, experiment_name), "console_logs"))
self.logger.addHandler(file_handler)
def labels_generator(self, labels):
......@@ -90,17 +88,16 @@ class LogWriter():
Returns:
label_classes (list): List of processed labels
Raises:
None
"""
label_classes = []
for label in labels:
label_class = re.sub(r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', label)
label_class = ['\n'.join(wrap(element, 40)) for element in label_class]
label_class = re.sub(
r'([a-z](?=[A-Z])|[A-Z](?=[A-Z][a-z]))', r'\1 ', label)
label_class = ['\n'.join(wrap(element, 40))
for element in label_class]
label_classes.append(label_class)
return label_classes
......@@ -112,15 +109,9 @@ class LogWriter():
Args:
message (str): Message to be logged
Returns:
None
Raises:
None
"""
self.logger.info(msg= message)
self.logger.info(msg=message)
def loss_per_iteration(self, loss_per_iteration, batch_index, iteration):
"""Log of loss / iteration
......@@ -131,16 +122,12 @@ class LogWriter():
loss_per_iteration (torch.tensor): Value of loss for every iteration step
batch_index (int): Index of current batch
iteration (int): Current iteration value
Returns:
None
Raises:
None
"""
print("Loss for Iteration {} is: {}".format(batch_index, loss_per_iteration))