Commit 12ce71d4 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

shallow dense block + competitive dense blocks architectures

parent bd826338
...@@ -20,6 +20,379 @@ import torch.nn as nn ...@@ -20,6 +20,379 @@ import torch.nn as nn
import utils.modules as modules import utils.modules as modules
class BrainMapperCompResUNet3D(nn.Module):
"""Architecture class for Competitive Residual DenseBlock BrainMapper 3D U-net.
This class contains the pytorch implementation of the U-net architecture underpinning the BrainMapper project.
Args:
parameters (dict): Contains information relevant parameters
parameters = {
'kernel_heigth': 5
'kernel_width': 5
'kernel_depth': 5
'kernel_classification': 1
'input_channels': 1
'output_channels': 64
'convolution_stride': 1
'dropout': 0.2
'pool_kernel_size': 2
'pool_stride': 2
'up_mode': 'upconv'
'number_of_classes': 1
}
Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block
"""
def __init__(self, parameters):
super(BrainMapperCompResUNet3D, self).__init__()
original_input_channels = parameters['input_channels']
original_output_channels = parameters['output_channels']
self.encoderBlock1 = modules.InCompDensEncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
self.encoderBlock2 = modules.CompDensEncoderBlock3D(parameters)
self.encoderBlock3 = modules.CompDensEncoderBlock3D(parameters)
self.encoderBlock4 = modules.CompDensEncoderBlock3D(parameters)
self.bottleneck = modules.CompDensBlock3D(parameters)
self.decoderBlock1 = modules.CompDensDecoderBlock3D(parameters)
self.decoderBlock2 = modules.CompDensDecoderBlock3D(parameters)
self.decoderBlock3 = modules.CompDensDecoderBlock3D(parameters)
self.decoderBlock4 = modules.CompDensDecoderBlock3D(parameters)
self.classifier = modules.DensClassifierBlock3D(parameters)
parameters['input_channels'] = original_input_channels
parameters['output_channels'] = original_output_channels
def forward(self, X):
"""Forward pass for 3D U-net
Function computing the forward pass through the 3D U-Net
The input to the function is the dMRI map
Args:
X (torch.tensor): Input dMRI map, shape = (N x C x D x H x W)
Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block
"""
Y_encoder_1, Y_np1, _ = self.encoderBlock1.forward(X)
Y_encoder_2, Y_np2, _ = self.encoderBlock2.forward(
Y_encoder_1)
del Y_encoder_1
Y_encoder_3, Y_np3, _ = self.encoderBlock3.forward(
Y_encoder_2)
del Y_encoder_2
Y_encoder_4, Y_np4, _ = self.encoderBlock4.forward(
Y_encoder_3)
del Y_encoder_3
Y_bottleNeck = self.bottleneck.forward(Y_encoder_4)
del Y_encoder_4
Y_decoder_1 = self.decoderBlock1.forward(
Y_bottleNeck, Y_np4)
del Y_bottleNeck, Y_np4
Y_decoder_2 = self.decoderBlock2.forward(
Y_decoder_1, Y_np3)
del Y_decoder_1, Y_np3
Y_decoder_3 = self.decoderBlock3.forward(
Y_decoder_2, Y_np2)
del Y_decoder_2, Y_np2
Y_decoder_4 = self.decoderBlock4.forward(
Y_decoder_3, Y_np1)
del Y_decoder_3, Y_np1
probability_map = self.classifier.forward(Y_decoder_4)
del Y_decoder_4
return probability_map
def save(self, path):
"""Model Saver
Function saving the model with all its parameters to a given path.
The path must end with a *.model argument.
Args:
path (str): Path string
"""
print("Saving Model... {}".format(path))
torch.save(self, path)
@property
def test_if_cuda(self):
"""Cuda Test
This function tests if the model parameters are allocated to a CUDA enabled GPU.
Returns:
bool: Flag indicating True if the tensor is stored on the GPU and Flase otherwhise
"""
return next(self.parameters()).is_cuda
def predict(self, X, device=0):
"""Post-training Output Prediction
This function predicts the output of the of the U-net post-training
Args:
X (torch.tensor): input dMRI volume
device (int/str): Device type used for training (int - GPU id, str- CPU)
Returns:
prediction (ndarray): predicted output after training
"""
self.eval() # PyToch module setting network to evaluation mode
if type(X) is np.ndarray:
X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor)
elif type(X) is torch.Tensor and not X.is_cuda:
X = X.type(torch.FloatTensor).cuda(device, non_blocking=True)
# .cuda() call transfers the densor from the CPU to the GPU if that is the case.
# Non-blocking argument lets the caller bypas synchronization when necessary
with torch.no_grad(): # Causes operations to have no gradients
output = self.forward(X)
_, idx = torch.max(output, 1)
# We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray
idx = idx.data.cpu().numpy()
prediction = np.squeeze(idx)
del X, output, idx
return prediction
def reset_parameters(self):
"""Parameter Initialization
This function (re)initializes the parameters of the defined network.
This function is a wrapper for the reset_parameters() function defined for each module.
More information can be found here: https://discuss.pytorch.org/t/what-is-the-default-initialization-of-a-conv2d-layer-and-linear-layer/16055 + https://discuss.pytorch.org/t/how-to-reset-model-weights-to-effectively-implement-crossvalidation/53859
An alternative (re)initialization method is described here: https://discuss.pytorch.org/t/how-to-reset-variables-values-in-nn-modules/32639
"""
print("Initializing network parameters...")
for _, module in self.named_children():
for _, submodule in module.named_children():
for _, subsubmodule in submodule.named_children():
if isinstance(subsubmodule, (torch.nn.PReLU, torch.nn.Dropout3d, torch.nn.MaxPool3d)) == False:
subsubmodule.reset_parameters()
print("Initialized network parameters!")
class BrainMapperResUNet3Dshallow(nn.Module):
"""Architecture class for Residual DenseBlock BrainMapper 3D U-net.
This class contains the pytorch implementation of the U-net architecture underpinning the BrainMapper project.
Args:
parameters (dict): Contains information relevant parameters
parameters = {
'kernel_heigth': 5
'kernel_width': 5
'kernel_depth': 5
'kernel_classification': 1
'input_channels': 1
'output_channels': 64
'convolution_stride': 1
'dropout': 0.2
'pool_kernel_size': 2
'pool_stride': 2
'up_mode': 'upconv'
'number_of_classes': 1
}
Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block
"""
def __init__(self, parameters):
super(BrainMapperResUNet3Dshallow, self).__init__()
original_input_channels = parameters['input_channels']
original_output_channels = parameters['output_channels']
self.encoderBlock1 = modules.DensEncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
self.encoderBlock2 = modules.DensEncoderBlock3D(parameters)
self.encoderBlock3 = modules.DensEncoderBlock3D(parameters)
self.bottleneck = modules.DensBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] * 2
self.decoderBlock1 = modules.DensDecoderBlock3D(parameters)
self.decoderBlock2 = modules.DensDecoderBlock3D(parameters)
self.decoderBlock3 = modules.DensDecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
self.classifier = modules.DensClassifierBlock3D(parameters)
parameters['input_channels'] = original_input_channels
parameters['output_channels'] = original_output_channels
def forward(self, X):
"""Forward pass for 3D U-net
Function computing the forward pass through the 3D U-Net
The input to the function is the dMRI map
Args:
X (torch.tensor): Input dMRI map, shape = (N x C x D x H x W)
Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block
"""
Y_encoder_1, Y_np1, _ = self.encoderBlock1.forward(X)
Y_encoder_2, Y_np2, _ = self.encoderBlock2.forward(
Y_encoder_1)
del Y_encoder_1
Y_encoder_3, Y_np3, _ = self.encoderBlock3.forward(
Y_encoder_2)
del Y_encoder_2
Y_bottleNeck = self.bottleneck.forward(Y_encoder_3)
del Y_encoder_3
Y_decoder_1 = self.decoderBlock1.forward(
Y_bottleNeck, Y_np3)
del Y_bottleNeck, Y_np3
Y_decoder_2 = self.decoderBlock2.forward(
Y_decoder_1, Y_np2)
del Y_decoder_1, Y_np2
Y_decoder_3 = self.decoderBlock3.forward(
Y_decoder_2, Y_np1)
del Y_decoder_2, Y_np1
probability_map = self.classifier.forward(Y_decoder_3)
del Y_decoder_3
return probability_map
def save(self, path):
"""Model Saver
Function saving the model with all its parameters to a given path.
The path must end with a *.model argument.
Args:
path (str): Path string
"""
print("Saving Model... {}".format(path))
torch.save(self, path)
@property
def test_if_cuda(self):
"""Cuda Test
This function tests if the model parameters are allocated to a CUDA enabled GPU.
Returns:
bool: Flag indicating True if the tensor is stored on the GPU and Flase otherwhise
"""
return next(self.parameters()).is_cuda
def predict(self, X, device=0):
"""Post-training Output Prediction
This function predicts the output of the of the U-net post-training
Args:
X (torch.tensor): input dMRI volume
device (int/str): Device type used for training (int - GPU id, str- CPU)
Returns:
prediction (ndarray): predicted output after training
"""
self.eval() # PyToch module setting network to evaluation mode
if type(X) is np.ndarray:
X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor)
elif type(X) is torch.Tensor and not X.is_cuda:
X = X.type(torch.FloatTensor).cuda(device, non_blocking=True)
# .cuda() call transfers the densor from the CPU to the GPU if that is the case.
# Non-blocking argument lets the caller bypas synchronization when necessary
with torch.no_grad(): # Causes operations to have no gradients
output = self.forward(X)
_, idx = torch.max(output, 1)
# We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray
idx = idx.data.cpu().numpy()
prediction = np.squeeze(idx)
del X, output, idx
return prediction
def reset_parameters(self):
"""Parameter Initialization
This function (re)initializes the parameters of the defined network.
This function is a wrapper for the reset_parameters() function defined for each module.
More information can be found here: https://discuss.pytorch.org/t/what-is-the-default-initialization-of-a-conv2d-layer-and-linear-layer/16055 + https://discuss.pytorch.org/t/how-to-reset-model-weights-to-effectively-implement-crossvalidation/53859
An alternative (re)initialization method is described here: https://discuss.pytorch.org/t/how-to-reset-variables-values-in-nn-modules/32639
"""
print("Initializing network parameters...")
for _, module in self.named_children():
for _, submodule in module.named_children():
for _, subsubmodule in submodule.named_children():
if isinstance(subsubmodule, (torch.nn.PReLU, torch.nn.Dropout3d, torch.nn.MaxPool3d)) == False:
subsubmodule.reset_parameters()
print("Initialized network parameters!")
class BrainMapperResUNet3D(nn.Module): class BrainMapperResUNet3D(nn.Module):
"""Architecture class for Residual DenseBlock BrainMapper 3D U-net. """Architecture class for Residual DenseBlock BrainMapper 3D U-net.
......
...@@ -40,7 +40,7 @@ import torch.utils.data as data ...@@ -40,7 +40,7 @@ import torch.utils.data as data
import numpy as np import numpy as np
from solver import Solver from solver import Solver
from BrainMapperUNet import BrainMapperUNet3D, BrainMapperResUNet3D from BrainMapperUNet import BrainMapperUNet3D, BrainMapperResUNet3D, BrainMapperResUNet3Dshallow, BrainMapperCompResUNet3D
from utils.data_utils import get_datasets, data_test_train_validation_split, update_shuffling_flag, create_folder from utils.data_utils import get_datasets, data_test_train_validation_split, update_shuffling_flag, create_folder
import utils.data_evaluation_utils as evaluations import utils.data_evaluation_utils as evaluations
from utils.data_logging_utils import LogWriter from utils.data_logging_utils import LogWriter
...@@ -150,7 +150,9 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -150,7 +150,9 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
training_parameters['pre_trained_path']) training_parameters['pre_trained_path'])
else: else:
# BrainMapperModel = BrainMapperUNet3D(network_parameters) # BrainMapperModel = BrainMapperUNet3D(network_parameters)
BrainMapperModel = BrainMapperResUNet3D(network_parameters) # BrainMapperModel = BrainMapperResUNet3D(network_parameters)
# BrainMapperModel = BrainMapperResUNet3Dshallow(network_parameters)
BrainMapperModel = BrainMapperCompResUNet3D(network_parameters)
BrainMapperModel.reset_parameters() BrainMapperModel.reset_parameters()
......
This diff is collapsed.
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment