Commit bd826338 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added dense block Unet architecture

parent 7e359259
...@@ -20,6 +20,199 @@ import torch.nn as nn ...@@ -20,6 +20,199 @@ import torch.nn as nn
import utils.modules as modules import utils.modules as modules
class BrainMapperResUNet3D(nn.Module):
"""Architecture class for Residual DenseBlock BrainMapper 3D U-net.
This class contains the pytorch implementation of the U-net architecture underpinning the BrainMapper project.
Args:
parameters (dict): Contains information relevant parameters
parameters = {
'kernel_heigth': 5
'kernel_width': 5
'kernel_depth': 5
'kernel_classification': 1
'input_channels': 1
'output_channels': 64
'convolution_stride': 1
'dropout': 0.2
'pool_kernel_size': 2
'pool_stride': 2
'up_mode': 'upconv'
'number_of_classes': 1
}
Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block
"""
def __init__(self, parameters):
super(BrainMapperResUNet3D, self).__init__()
original_input_channels = parameters['input_channels']
original_output_channels = parameters['output_channels']
self.encoderBlock1 = modules.DensEncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
self.encoderBlock2 = modules.DensEncoderBlock3D(parameters)
self.encoderBlock3 = modules.DensEncoderBlock3D(parameters)
self.encoderBlock4 = modules.DensEncoderBlock3D(parameters)
self.bottleneck = modules.DensBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] * 2
self.decoderBlock1 = modules.DensDecoderBlock3D(parameters)
self.decoderBlock2 = modules.DensDecoderBlock3D(parameters)
self.decoderBlock3 = modules.DensDecoderBlock3D(parameters)
self.decoderBlock4 = modules.DensDecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
self.classifier = modules.DensClassifierBlock3D(parameters)
parameters['input_channels'] = original_input_channels
parameters['output_channels'] = original_output_channels
def forward(self, X):
"""Forward pass for 3D U-net
Function computing the forward pass through the 3D U-Net
The input to the function is the dMRI map
Args:
X (torch.tensor): Input dMRI map, shape = (N x C x D x H x W)
Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block
"""
Y_encoder_1, Y_np1, _ = self.encoderBlock1.forward(X)
Y_encoder_2, Y_np2, _ = self.encoderBlock2.forward(
Y_encoder_1)
del Y_encoder_1
Y_encoder_3, Y_np3, _ = self.encoderBlock3.forward(
Y_encoder_2)
del Y_encoder_2
Y_encoder_4, Y_np4, _ = self.encoderBlock4.forward(
Y_encoder_3)
del Y_encoder_3
Y_bottleNeck = self.bottleneck.forward(Y_encoder_4)
del Y_encoder_4
Y_decoder_1 = self.decoderBlock1.forward(
Y_bottleNeck, Y_np4)
del Y_bottleNeck, Y_np4
Y_decoder_2 = self.decoderBlock2.forward(
Y_decoder_1, Y_np3)
del Y_decoder_1, Y_np3
Y_decoder_3 = self.decoderBlock3.forward(
Y_decoder_2, Y_np2)
del Y_decoder_2, Y_np2
Y_decoder_4 = self.decoderBlock4.forward(
Y_decoder_3, Y_np1)
del Y_decoder_3, Y_np1
probability_map = self.classifier.forward(Y_decoder_4)
del Y_decoder_4
return probability_map
def save(self, path):
"""Model Saver
Function saving the model with all its parameters to a given path.
The path must end with a *.model argument.
Args:
path (str): Path string
"""
print("Saving Model... {}".format(path))
torch.save(self, path)
@property
def test_if_cuda(self):
"""Cuda Test
This function tests if the model parameters are allocated to a CUDA enabled GPU.
Returns:
bool: Flag indicating True if the tensor is stored on the GPU and Flase otherwhise
"""
return next(self.parameters()).is_cuda
def predict(self, X, device=0):
"""Post-training Output Prediction
This function predicts the output of the of the U-net post-training
Args:
X (torch.tensor): input dMRI volume
device (int/str): Device type used for training (int - GPU id, str- CPU)
Returns:
prediction (ndarray): predicted output after training
"""
self.eval() # PyToch module setting network to evaluation mode
if type(X) is np.ndarray:
X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor)
elif type(X) is torch.Tensor and not X.is_cuda:
X = X.type(torch.FloatTensor).cuda(device, non_blocking=True)
# .cuda() call transfers the densor from the CPU to the GPU if that is the case.
# Non-blocking argument lets the caller bypas synchronization when necessary
with torch.no_grad(): # Causes operations to have no gradients
output = self.forward(X)
_, idx = torch.max(output, 1)
# We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray
idx = idx.data.cpu().numpy()
prediction = np.squeeze(idx)
del X, output, idx
return prediction
def reset_parameters(self):
"""Parameter Initialization
This function (re)initializes the parameters of the defined network.
This function is a wrapper for the reset_parameters() function defined for each module.
More information can be found here: https://discuss.pytorch.org/t/what-is-the-default-initialization-of-a-conv2d-layer-and-linear-layer/16055 + https://discuss.pytorch.org/t/how-to-reset-model-weights-to-effectively-implement-crossvalidation/53859
An alternative (re)initialization method is described here: https://discuss.pytorch.org/t/how-to-reset-variables-values-in-nn-modules/32639
"""
print("Initializing network parameters...")
for _, module in self.named_children():
for _, submodule in module.named_children():
for _, subsubmodule in submodule.named_children():
if isinstance(subsubmodule, (torch.nn.PReLU, torch.nn.Dropout3d, torch.nn.MaxPool3d)) == False:
subsubmodule.reset_parameters()
print("Initialized network parameters!")
class BrainMapperUNet3D(nn.Module): class BrainMapperUNet3D(nn.Module):
"""Architecture class for Traditional BrainMapper 3D U-net. """Architecture class for Traditional BrainMapper 3D U-net.
......
...@@ -40,7 +40,7 @@ import torch.utils.data as data ...@@ -40,7 +40,7 @@ import torch.utils.data as data
import numpy as np import numpy as np
from solver import Solver from solver import Solver
from BrainMapperUNet import BrainMapperUNet3D from BrainMapperUNet import BrainMapperUNet3D, BrainMapperResUNet3D
from utils.data_utils import get_datasets, data_test_train_validation_split, update_shuffling_flag, create_folder from utils.data_utils import get_datasets, data_test_train_validation_split, update_shuffling_flag, create_folder
import utils.data_evaluation_utils as evaluations import utils.data_evaluation_utils as evaluations
from utils.data_logging_utils import LogWriter from utils.data_logging_utils import LogWriter
...@@ -149,7 +149,9 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -149,7 +149,9 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
BrainMapperModel = torch.load( BrainMapperModel = torch.load(
training_parameters['pre_trained_path']) training_parameters['pre_trained_path'])
else: else:
BrainMapperModel = BrainMapperUNet3D(network_parameters) # BrainMapperModel = BrainMapperUNet3D(network_parameters)
BrainMapperModel = BrainMapperResUNet3D(network_parameters)
BrainMapperModel.reset_parameters() BrainMapperModel.reset_parameters()
......
...@@ -19,6 +19,335 @@ import torch.nn.functional as F ...@@ -19,6 +19,335 @@ import torch.nn.functional as F
# TODO: Currently, it appears that we are using constant size filters. We will need to adjust this in the network architecture, to allow it to encode/decode information! # TODO: Currently, it appears that we are using constant size filters. We will need to adjust this in the network architecture, to allow it to encode/decode information!
# ResBlock 3D UNet:
class DensBlock3D(nn.Module):
"""Parent class for a 3D convolutional residual block.
This class represents a generic parent class for a convolutional residual 3D encoder or decoder block.
The class represents a subclass/child class of nn.Module, inheriting its functionality.
Args:
parameters (dict): Contains information on kernel size, number of channels, number of filters, and if convolution is strided.
parameters = {
'kernel_heigth': 5
'kernel_width': 5
'kernel_depth' : 5
'input_channels': 64
'output_channels': 64
'convolution_stride': 1
'dropout': 0.2
}
Returns:
torch.tensor: Output forward passed tensor
"""
def __init__(self, parameters):
super(DensBlock3D, self).__init__()
# We first calculate the amount of zero padding required (http://cs231n.github.io/convolutional-networks/)
padding_heigth = int((parameters['kernel_heigth'] - 1) / 2)
padding_width = int((parameters['kernel_width'] - 1) / 2)
padding_depth = int((parameters['kernel_depth'] - 1) / 2)
convolutional_layer2_input = int(parameters['input_channels'] + parameters['output_channels'])
convolutional_layer3_input = int(convolutional_layer2_input + parameters['output_channels'])
self.convolutional_layer1 = nn.Sequential(
nn.InstanceNorm3d(num_features=parameters['input_channels']),
nn.PReLU(),
nn.Conv3d(
in_channels=parameters['input_channels'],
out_channels=parameters['output_channels'],
kernel_size=(parameters['kernel_depth'],
parameters['kernel_heigth'],
parameters['kernel_width']),
stride=parameters['convolution_stride'],
padding=(padding_depth, padding_heigth, padding_width)
),
)
self.convolutional_layer2 = nn.Sequential(
nn.InstanceNorm3d(num_features=convolutional_layer2_input),
nn.PReLU(),
nn.Conv3d(
in_channels=convolutional_layer2_input,
out_channels=parameters['output_channels'],
kernel_size=(parameters['kernel_depth'],
parameters['kernel_heigth'],
parameters['kernel_width']),
stride=parameters['convolution_stride'],
padding=(padding_depth, padding_heigth, padding_width)
),
)
self.convolutional_layer3 = nn.Sequential(
nn.InstanceNorm3d(num_features=convolutional_layer3_input),
nn.PReLU(),
nn.Conv3d(
in_channels=convolutional_layer3_input,
out_channels=parameters['output_channels'],
kernel_size=(parameters['kernel_classification'],
parameters['kernel_classification'],
parameters['kernel_classification']),
stride=parameters['convolution_stride'],
padding=(0, 0, 0)
),
)
# Other activation functions which might be interesting to test:
# More reading: https://arxiv.org/abs/1706.02515 ; https://mlfromscratch.com/activation-functions-explained/#/
# self.activation = nn.SELU()
# self.activation = nn.ELU()
# self.activation = nn.ReLU()
# Instance normalisation is used to the the small batch size, and as it has shown promise during the experiments with the simple network.
if parameters['dropout'] > 0:
self.dropout_needed = True
self.dropout = nn.Dropout3d(parameters['dropout'])
else:
self.dropout_needed = False
def forward(self, X):
"""Forward pass
Function computing the forward pass through the convolutional layer.
The input to the function is a torch tensor of shape N (batch size) x C (number of channels) x D (input depth) x H (input heigth) x W (input width)
Args:
X (torch.tensor): Input tensor, shape = (N x C x D x H x W)
Returns:
torch.tensor: Output forward passed tensor
"""
feature_map1 = self.convolutional_layer1(X)
feature_map2 = self.convolutional_layer2(torch.cat((X, feature_map1), dim= 1))
feature_map3 = self.convolutional_layer3(torch.cat((X, feature_map1, feature_map2), dim= 1))
return feature_map3
class DensEncoderBlock3D(DensBlock3D):
"""Forward 3D encoder path block for a U-net.
This class creates a dense encoder block following the architecture:
DensBlock -> Non-linear Activation -> Batch Normalisation -> MaxPool
Args:
parameters (dict): Contains information relevant parameters
parameters = {
'kernel_heigth': 5
'kernel_width': 5
'kernel_depth': 5
'input_channels': 64
'output_channels': 64
'convolution_stride': 1
'dropout': 0.2
'pool_kernel_size': 2
'pool_stride': 2
}
Returns:
Y (torch.tensor): Output forward passed tensor through the encoder block, with maxpool
Y_np (torch.tensor): Output forward passed tensor through the encoder block, with no pooling
pool_indices (torch.tensor): Indices for unpooling
"""
def __init__(self, parameters):
super(DensEncoderBlock3D, self).__init__(parameters)
self.maxpool = nn.MaxPool3d(
kernel_size=parameters['pool_kernel_size'],
stride=parameters['pool_stride'],
# This option returns the max index along with the outputs, useful for MaxUnpool2D
return_indices=True
)
def forward(self, X):
"""Forward pass for U-net 3D encoder block
Function computing the forward pass through the encoder block.
The input to the function is a torch tensor of shape N (batch size) x C (number of channels) x D (depth) x H (input heigth) x W (input width)
Args:
X (torch.tensor): Input tensor, shape = (N x C x D x H x W)
Returns:
Y (torch.tensor): Output forward passed tensor through the encoder block, with maxpool
Y_np (torch.tensor): Output forward passed tensor through the encoder block, with no pooling
pool_indices (torch.tensor): Indices for unpooling
"""
# Two convolutions are used, as per the original UNet Paper
Y_np = super(DensEncoderBlock3D, self).forward(X)
if self.dropout_needed:
Y_np = self.dropout(Y_np)
Y, pool_indices = self.maxpool(Y_np)
return Y, Y_np, pool_indices
class DensDecoderBlock3D(DensBlock3D):
"""Forward 3D decoder path block for a U-net.
This class creates a simple encoder block following the architecture:
Strided Convolution (or) MaxUnpool -> Convolution -> Non-linear Activation -> Batch Normalisation
Args:
parameters (dict): Contains information relevant parameters
parameters = {
'kernel_heigth': 5
'kernel_width': 5
'kernel_depth': 5
'input_channels': 64
'output_channels': 64
'convolution_stride': 1
'dropout': 0.2
'pool_kernel_size': 2
'pool_stride': 2
'up_mode': 'upconv'
}
Returns:
Y (torch.tensor): Output forward passed tensor through the decoder block
"""
def __init__(self, parameters):
super(DensDecoderBlock3D, self).__init__(parameters)
self.up_mode = parameters['up_mode']
if self.up_mode == 'upconv': # Attention - this will need to be checked to confirm that it is working!
self.up = nn.ConvTranspose3d(
in_channels=int(parameters['input_channels']/2),
out_channels=parameters['output_channels'],
kernel_size=parameters['pool_kernel_size'],
stride=parameters['pool_stride'],
)
elif self.up_mode == 'upsample':
self.up = nn.Sequential(
nn.Upsample(
mode='nearest',
scale_factor=2,
),
nn.Conv3d(
in_channels=parameters['input_channels'],
out_channels=parameters['output_channels'],
kernel_size=1,
)
)
elif self.up_mode == 'unpool':
self.up = nn.MaxUnpool3d(
kernel_size=parameters['pool_kernel_size'],
stride=parameters['pool_stride']
)
def forward(self, X, Y_encoder=None, pool_indices=None):
"""Forward pass for U-net decoder block
Function computing the forward pass through the decoder block.
The input to the function is a torch tensor of shape N (batch size) x C (number of channels) x D (input depth) x H (input heigth) x W (input width).
A second input is a tensor for the skip connection, of shape (N x C x D x H x W); that defaults to None.
The function also takes the previous pool indices, for the unpooling operation; they aslo default to None.
Args:
X (torch.tensor): Input tensor, shape = (N x C x D x H x W)
Y_encoder (torch.tensor): Skip-connection tensor, shape = (N x C x H x W)
pool_indices (torch.tensor): Indices for unpooling
Returns:
Y (torch.tensor): Output forward passed tensor through the decoder block
"""
# ATTENTION: As of this code version, only "upconv" works! Debugging is ongoing for upconv and upsample!
# It seems that errors are generated by variable filter sizes and the unorthodox input sizes 91x109x91.
if self.up_mode == 'unpool':
upsampling = self.up(X, pool_indices, output_size=Y_encoder.size())
elif self.up_mode == 'upsample':
upsampling = self.up(X)
elif self.up_mode == 'upconv':
upsampling = self.up(X, output_size=Y_encoder.size())
if Y_encoder is None:
concatenation = upsampling
else:
concatenation = torch.cat((Y_encoder, upsampling), dim=1)
Y = super(DensDecoderBlock3D, self).forward(concatenation)
if self.dropout_needed:
Y = self.dropout(Y)
return Y
class DensClassifierBlock3D(nn.Module):
"""Classifier block for a U-net.
This class creates a simple classifier block following the architecture:
Args:
parameters (dict): Contains information relevant parameters
parameters = {
'kernel_heigth': 5
'kernel_width': 5
'kernel_depth': 5
'kernel_classification': 1
'input_channels': 1
'output_channels': 1
'convolution_stride': 1
'dropout': 0.2
'pool_kernel_size': 2
'pool_stride': 2
'up_mode': 'upconv'
'number_of_classes': 1
}
Returns:
Y (torch.tensor): Output forward passed tensor through the decoder block
"""
def __init__(self, parameters):
super(DensClassifierBlock3D, self).__init__()
self.convolutional_layer = nn.Conv3d(
in_channels=parameters['input_channels'],
out_channels=parameters['number_of_classes'],
kernel_size=parameters['kernel_classification'],
stride=parameters['convolution_stride'],
)
# TODO: Might be wworth looking at GANS for image generation, and adding padding
def forward(self, X):
"""Forward pass for U-net classifier block
Function computing the forward pass through the classifier block.
The input to the function is a torch tensor of shape N (batch size) x C (number of channels) x D (input depth) x H (input heigth) x W (input width).
Args:
X (torch.tensor): Input tensor, shape = (N x C x D x H x W)
Returns:
logits (torch.tensor): Output logits from forward pass tensor through the classifier block
"""
logits = self.convolutional_layer(X)
# TODO: Currently, this has no activation function. Might be worth considering adding a tanh activation function, similar to GANs
# For refernece : https://machinelearningmastery.com/how-to-implement-pix2pix-gan-models-from-scratch-with-keras/
# For refernece 2: https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
return logits
# Classical 3D UNet: # Classical 3D UNet:
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment