Commit 7eacf841 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

updated docstrings, pep8 formated, fixed bugs

parent c442542d
"""Brain Mapper U-Net Architecture """Brain Mapper U-Net Architecture
Description: Description:
-------------
This folder contains the Pytorch implementation of the core U-net architecture.
This arcitecture predicts functional connectivity rsfMRI from structural connectivity information from dMRI.
Usage This folder contains the Pytorch implementation of the core U-net architecture.
------------- This arcitecture predicts functional connectivity rsfMRI from structural connectivity information from dMRI.
To use this module, import it and instantiate is as you wish:
from BrainMapperUNet import BrainMapperUNet Usage:
deep_learning_model = BrainMapperUnet(parameters)
To use this module, import it and instantiate is as you wish:
from BrainMapperUNet import BrainMapperUNet
deep_learning_model = BrainMapperUnet(parameters)
""" """
...@@ -19,6 +19,7 @@ import torch ...@@ -19,6 +19,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import utils.modules as modules import utils.modules as modules
class BrainMapperUNet(nn.Module): class BrainMapperUNet(nn.Module):
"""Architecture class BrainMapper U-net. """Architecture class BrainMapper U-net.
...@@ -39,14 +40,11 @@ class BrainMapperUNet(nn.Module): ...@@ -39,14 +40,11 @@ class BrainMapperUNet(nn.Module):
'up_mode': 'upconv' 'up_mode': 'upconv'
'number_of_classes': 1 'number_of_classes': 1
} }
Returns: Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block probability_map (torch.tensor): Output forward passed tensor through the U-net block
Raises:
None
""" """
def __init__(self, parameters): def __init__(self, parameters):
super(BrainMapperUNet, self).__init__() super(BrainMapperUNet, self).__init__()
...@@ -80,44 +78,41 @@ class BrainMapperUNet(nn.Module): ...@@ -80,44 +78,41 @@ class BrainMapperUNet(nn.Module):
Returns: Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block probability_map (torch.tensor): Output forward passed tensor through the U-net block
Raises:
None
""" """
Y_encoder_1, Y_np1, pool_indices1 = self.encoderBlock1.forward(X) Y_encoder_1, Y_np1, pool_indices1 = self.encoderBlock1.forward(X)
Y_encoder_2, Y_np2, pool_indices2 = self.encoderBlock2.forward(Y_encoder_1) Y_encoder_2, Y_np2, pool_indices2 = self.encoderBlock2.forward(
Y_encoder_3, Y_np3, pool_indices3 = self.encoderBlock3.forward(Y_encoder_2) Y_encoder_1)
Y_encoder_4, Y_np4, pool_indices4 = self.encoderBlock4.forward(Y_encoder_3) Y_encoder_3, Y_np3, pool_indices3 = self.encoderBlock3.forward(
Y_encoder_2)
Y_encoder_4, Y_np4, pool_indices4 = self.encoderBlock4.forward(
Y_encoder_3)
Y_bottleNeck = self.bottleneck.forward(Y_encoder_4) Y_bottleNeck = self.bottleneck.forward(Y_encoder_4)
Y_decoder_1 = self.decoderBlock1.forward(Y_bottleNeck, Y_np4, pool_indices4) Y_decoder_1 = self.decoderBlock1.forward(
Y_decoder_2 = self.decoderBlock2.forward(Y_decoder_1, Y_np3, pool_indices3) Y_bottleNeck, Y_np4, pool_indices4)
Y_decoder_3 = self.decoderBlock3.forward(Y_decoder_2, Y_np2, pool_indices2) Y_decoder_2 = self.decoderBlock2.forward(
Y_decoder_4 = self.decoderBlock4.forwrad(Y_decoder_3, Y_np1, pool_indices1) Y_decoder_1, Y_np3, pool_indices3)
Y_decoder_3 = self.decoderBlock3.forward(
Y_decoder_2, Y_np2, pool_indices2)
Y_decoder_4 = self.decoderBlock4.forwrad(
Y_decoder_3, Y_np1, pool_indices1)
probability_map = self.classifier.forward(Y_decoder_4) probability_map = self.classifier.forward(Y_decoder_4)
return probability_map return probability_map
def save(self, path): def save(self, path):
"""Model Saver """Model Saver
Function saving the model with all its parameters to a given path. Function saving the model with all its parameters to a given path.
The path must end with a *.model argument. The path must end with a *.model argument.
Args: Args:
path (str): Path string path (str): Path string
Returns:
None
Raises:
None
""" """
print("Saving Model... {}".format(path)) print("Saving Model... {}".format(path))
torch.save(self, path) torch.save(self, path)
...@@ -127,21 +122,15 @@ class BrainMapperUNet(nn.Module): ...@@ -127,21 +122,15 @@ class BrainMapperUNet(nn.Module):
This function tests if the model parameters are allocated to a CUDA enabled GPU. This function tests if the model parameters are allocated to a CUDA enabled GPU.
Args:
None
Returns: Returns:
bool: Flag indicating True if the tensor is stored on the GPU and Flase otherwhise bool: Flag indicating True if the tensor is stored on the GPU and Flase otherwhise
Raises:
None
""" """
return next(self.parameters()).is_cuda return next(self.parameters()).is_cuda
def predict(self, X, device= 0): def predict(self, X, device=0):
"""Post-training Output Prediction """Post-training Output Prediction
This function predicts the output of the of the U-net post-training This function predicts the output of the of the U-net post-training
Args: Args:
...@@ -151,11 +140,8 @@ class BrainMapperUNet(nn.Module): ...@@ -151,11 +140,8 @@ class BrainMapperUNet(nn.Module):
Returns: Returns:
prediction (ndarray): predicted output after training prediction (ndarray): predicted output after training
Raises:
None
""" """
self.eval() # PyToch module setting network to evaluation mode self.eval() # PyToch module setting network to evaluation mode
if type(X) is np.ndarray: if type(X) is np.ndarray:
X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor) X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor)
...@@ -165,13 +151,14 @@ class BrainMapperUNet(nn.Module): ...@@ -165,13 +151,14 @@ class BrainMapperUNet(nn.Module):
# .cuda() call transfers the densor from the CPU to the GPU if that is the case. # .cuda() call transfers the densor from the CPU to the GPU if that is the case.
# Non-blocking argument lets the caller bypas synchronization when necessary # Non-blocking argument lets the caller bypas synchronization when necessary
with torch.no_grad(): # Causes operations to have no gradients with torch.no_grad(): # Causes operations to have no gradients
output = self.forward(X) output = self.forward(X)
_, idx = torch.max(output, 1) _, idx = torch.max(output, 1)
idx = idx.data.cpu().numpy() # We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray # We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray
idx = idx.data.cpu().numpy()
prediction = np.squeeze(idx) prediction = np.squeeze(idx)
del X, output, idx del X, output, idx
......
This diff is collapsed.
...@@ -64,9 +64,8 @@ def _parse_values(configurator): ...@@ -64,9 +64,8 @@ def _parse_values(configurator):
settings_dictionary = {} settings_dictionary = {}
for section in configurator.sections(): for section in configurator.sections():
settings_dictionary[section] = {} settings_dictionary[section] = {}
for key, value in configurator[section].items() for key, value in configurator[section].items():
# Safely evaluate an expression node or a Unicode or Latin-1 encoded string containing a Python expression # Safely evaluate an expression node or a Unicode or Latin-1 encoded string containing a Python expression
settings_dictionary[section][key] = ast.literal_eval(value) settings_dictionary[section][key] = ast.literal_eval(value)
return settings_dictionary
return settings_dictionary
...@@ -16,6 +16,5 @@ setup( ...@@ -16,6 +16,5 @@ setup(
'torch', 'torch',
'h5py', 'h5py',
'tensorboardX', 'tensorboardX',
], ],
) )
"""Brain Mapper U-Net Solver """Brain Mapper U-Net Solver
Description: Description:
-------------
This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
Usage This folder contains the Pytorch implementation of the core U-net solver, used for training the network.
-------------
To use this module, import it and instantiate is as you wish:
from solver import Solver Usage:
To use this module, import it and instantiate is as you wish:
from solver import Solver
""" """
import os import os
...@@ -26,6 +25,7 @@ from torch.optim import lr_scheduler ...@@ -26,6 +25,7 @@ from torch.optim import lr_scheduler
checkpoint_directory = 'checkpoints' checkpoint_directory = 'checkpoints'
checkpoint_extension = 'path.tar' checkpoint_extension = 'path.tar'
class Solver(): class Solver():
"""Solver class for the BrainMapper U-net. """Solver class for the BrainMapper U-net.
...@@ -52,28 +52,26 @@ class Solver(): ...@@ -52,28 +52,26 @@ class Solver():
Returns: Returns:
trained model(?) - working on this! trained model(?) - working on this!
Raises:
None
""" """
def __init__(self, def __init__(self,
model, model,
device, device,
number_of_classes, number_of_classes,
experiment_name, experiment_name,
optimizer = torch.optim.Adam, optimizer=torch.optim.Adam,
optimizer_arguments = {}, optimizer_arguments={},
loss_function = MSELoss(), loss_function=MSELoss(),
model_name = 'BrainMapper', model_name='BrainMapper',
labels = None, labels=None,
number_epochs = 10, number_epochs=10,
loss_log_period = 5, loss_log_period=5,
learning_rate_scheduler_step_size = 5, learning_rate_scheduler_step_size=5,
learning_rate_scheduler_gamma = 0.5, learning_rate_scheduler_gamma=0.5,
use_last_checkpoint = True, use_last_checkpoint=True,
experiment_directory = 'experiments', experiment_directory='experiments',
logs_directory = 'logs' logs_directory='logs'
): ):
self.model = model self.model = model
self.device = device self.device = device
...@@ -88,19 +86,21 @@ class Solver(): ...@@ -88,19 +86,21 @@ class Solver():
self.labels = labels self.labels = labels
self.number_epochs = number_epochs self.number_epochs = number_epochs
self.loss_log_period = loss_log_period self.loss_log_period = loss_log_period
# We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch. # We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch.
self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer, self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer,
step_size = learning_rate_scheduler_step_size, step_size=learning_rate_scheduler_step_size,
gamma= learning_rate_scheduler_gamma) gamma=learning_rate_scheduler_gamma)
self.use_last_checkpoint = use_last_checkpoint self.use_last_checkpoint = use_last_checkpoint
experiment_directory_path = os.join.path(experiment_directory, experiment_name) experiment_directory_path = os.join.path(
experiment_directory, experiment_name)
self.experiment_directory_path = experiment_directory_path self.experiment_directory_path = experiment_directory_path
create_folder(experiment_directory_path) create_folder(experiment_directory_path)
create_folder(os.join.path(experiment_directory_path, checkpoint_directory)) create_folder(os.join.path(
experiment_directory_path, checkpoint_directory))
self.start_epoch = 1 self.start_epoch = 1
self.start_iteration = 1 self.start_iteration = 1
...@@ -110,12 +110,11 @@ class Solver(): ...@@ -110,12 +110,11 @@ class Solver():
if use_last_checkpoint: if use_last_checkpoint:
self.load_checkpoint() self.load_checkpoint()
self.LogWriter = LogWriter(number_of_classes= number_of_classes, self.LogWriter = LogWriter(number_of_classes=number_of_classes,
logs_directory= logs_directory, logs_directory=logs_directory,
experiment_name= experiment_name, experiment_name=experiment_name,
use_last_checkpoint= use_last_checkpoint, use_last_checkpoint=use_last_checkpoint,
labels= labels) labels=labels)
def train(self, train_loader, test_loader): def train(self, train_loader, test_loader):
"""Training Function """Training Function
...@@ -127,18 +126,15 @@ class Solver(): ...@@ -127,18 +126,15 @@ class Solver():
test_loader (class): Combined dataset and sampler, providing an iterable over the testing dataset (torch.utils.data.DataLoader) test_loader (class): Combined dataset and sampler, providing an iterable over the testing dataset (torch.utils.data.DataLoader)
Returns: Returns:
None: trained model trained model
Raises:
None
""" """
model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler
dataloaders = {'train': train_loader, 'test': test_loader} dataloaders = {'train': train_loader, 'test': test_loader}
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() # clear memory torch.cuda.empty_cache() # clear memory
model.cuda(self.device) # Moving the model to GPU model.cuda(self.device) # Moving the model to GPU
print('****************************************************************') print('****************************************************************')
print('TRAINING IS STARTING!') print('TRAINING IS STARTING!')
...@@ -172,25 +168,26 @@ class Solver(): ...@@ -172,25 +168,26 @@ class Solver():
y = sampled_batch[1].type(torch.LongTensor) y = sampled_batch[1].type(torch.LongTensor)
if model.is_cuda(): if model.is_cuda():
X = X.cuda(self.device, non_blocking= True) X = X.cuda(self.device, non_blocking=True)
y = y.cuda(self.device, non_blocking= True) y = y.cuda(self.device, non_blocking=True)
y_hat = model(X) # Forward pass y_hat = model(X) # Forward pass
loss = self.loss_function(y_hat, y) # Loss computation loss = self.loss_function(y_hat, y) # Loss computation
if phase == 'train': if phase == 'train':
optimizer.zero_grad() # Zero the parameter gradients optimizer.zero_grad() # Zero the parameter gradients
loss.backward() # Backward propagation loss.backward() # Backward propagation
optimizer.step() optimizer.step()
if batch_index % self.loss_log_period == 0: if batch_index % self.loss_log_period == 0:
self.LogWriter.loss_per_iteration(loss.item(), batch_index, iteration)
iteration += 1 self.LogWriter.loss_per_iteration(
loss.item(), batch_index, iteration)
losses.append(loss.item()) iteration += 1
losses.append(loss.item())
outputs.append(torch.max(y_hat, dim=1)[1].cpu()) outputs.append(torch.max(y_hat, dim=1)[1].cpu())
y_values.append(y.cpu()) y_values.append(y.cpu())
...@@ -206,25 +203,27 @@ class Solver(): ...@@ -206,25 +203,27 @@ class Solver():
print("100%", flush=True) print("100%", flush=True)
with torch.no_grad(): with torch.no_grad():
output_array, y_array = torch.cat(outputs), torch.cat(y_values) output_array, y_array = torch.cat(
outputs), torch.cat(y_values)
self.LogWriter.loss_per_epoch(losses, phase, epoch) self.LogWriter.loss_per_epoch(losses, phase, epoch)
dice_score_mean = self.LogWriter.dice_score_per_epoch(phase, output_array, y_array, epoch) dice_score_mean = self.LogWriter.dice_score_per_epoch(
phase, output_array, y_array, epoch)
if phase == 'test': if phase == 'test':
if dice_score_mean > self.best_mean_score: if dice_score_mean > self.best_mean_score:
self.best_mean_score = dice_score_mean self.best_mean_score = dice_score_mean
self.best_mean_score_epoch = epoch self.best_mean_score_epoch = epoch
index = np.random.choice(len(dataloaders[phase].dataset.X), size=3, replace= False) index = np.random.choice(
self.LogWriter.sample_image_per_epoch(prediction= model.predict(dataloaders[phase].dataset.X[index], self.device), len(dataloaders[phase].dataset.X), size=3, replace=False)
ground_truth= dataloaders[phase].dataset.y[index], self.LogWriter.sample_image_per_epoch(prediction=model.predict(dataloaders[phase].dataset.X[index], self.device),
phase= phase, ground_truth=dataloaders[phase].dataset.y[index],
epoch= epoch) phase=phase,
epoch=epoch)
print("Epoch {}/{} DONE!".format(epoch, self.number_epochs)) print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
self.save_checkpoint(state={'epoch': epoch + 1, self.save_checkpoint(state={'epoch': epoch + 1,
'start_iteration': iteration + 1, 'start_iteration': iteration + 1,
'arch': self.model_name, 'arch': self.model_name,
...@@ -232,8 +231,9 @@ class Solver(): ...@@ -232,8 +231,9 @@ class Solver():
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'scheduler': learning_rate_scheduler.state_dict() 'scheduler': learning_rate_scheduler.state_dict()
}, },
filename= os.path.join(self.experiment_directory_path, checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension) filename=os.path.join(self.experiment_directory_path, checkpoint_directory,
) 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
)
self.LogWriter.close() self.LogWriter.close()
...@@ -252,48 +252,38 @@ class Solver(): ...@@ -252,48 +252,38 @@ class Solver():
Args: Args:
state (dict): Dictionary of all the relevant model components state (dict): Dictionary of all the relevant model components
Returns:
None
Raises:
None
""" """
torch.save(state, filename) torch.save(state, filename)
def load_checkpoint(self, epoch= None): def load_checkpoint(self, epoch=None):
"""General Checkpoint Loader """General Checkpoint Loader
This function loads a previous checkpoint for inference and/or resuming training This function loads a previous checkpoint for inference and/or resuming training
Args: Args:
epoch (int): Current epoch value epoch (int): Current epoch value
Returns:
None
Raises:
None
""" """
if epoch is None: if epoch is None:
checkpoint_file_path = os.path.join(self.experiment_directory_path, checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension) checkpoint_file_path = os.path.join(
self.experiment_directory_path, checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
self._checkpoint_reader(checkpoint_file_path) self._checkpoint_reader(checkpoint_file_path)
else: else:
universal_path = os.path.join(self.experiment_directory_path, checkpoint_directory, '*.' + checkpoint_extension) universal_path = os.path.join(
self.experiment_directory_path, checkpoint_directory, '*.' + checkpoint_extension)
files_in_universal_path = glob.glob(universal_path) files_in_universal_path = glob.glob(universal_path)
# We will sort through all the files in path to see which one is most recent # We will sort through all the files in path to see which one is most recent
if len(files_in_universal_path) > 0: if len(files_in_universal_path) > 0:
checkpoint_file_path = max(files_in_universal_path, key= os.path.getatime) checkpoint_file_path = max(
files_in_universal_path, key=os.path.getatime)
self._checkpoint_reader(checkpoint_file_path) self._checkpoint_reader(checkpoint_file_path)
else: else:
self.LogWriter.log("No Checkpoint found at {}".format(os.path.join(self.experiment_directory_path, checkpoint_directory))) self.LogWriter.log("No Checkpoint found at {}".format(
os.path.join(self.experiment_directory_path, checkpoint_directory)))
def _checkpoint_reader(self, checkpoint_file_path): def _checkpoint_reader(self, checkpoint_file_path):
"""Checkpoint Reader """Checkpoint Reader
...@@ -302,20 +292,15 @@ class Solver(): ...@@ -302,20 +292,15 @@ class Solver():
Args: Args:
checkpoint_file_path (str): path to checkpoint file checkpoint_file_path (str): path to checkpoint file
Returns: