Commit 8382f477 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

modified to allow dynamic architecture change + added AE class for receptive field test

parent 6cd5bba9
......@@ -29,7 +29,7 @@ class BrainMapperAE3D(nn.Module):
Args:
parameters (dict): Contains information relevant parameters
parameters = {
'kernel_heigth': 5
'kernel_size': 5
'kernel_width': 5
'kernel_depth': 5
'kernel_classification': 1
......@@ -54,28 +54,33 @@ class BrainMapperAE3D(nn.Module):
original_input_channels = parameters['input_channels']
original_output_channels = parameters['output_channels']
original_kernel_height = parameters['kernel_heigth']
original_kernel_size = parameters['kernel_size']
original_stride = parameters['convolution_stride']
# Encoder Path
parameters['kernel_heigth'] = 7
self.encoderBlock1 = modules.ResNetEncoderBlock3D(parameters)
parameters['kernel_size'] = parameters['first_kernel_size']
parameters['convolution_stride'] = parameters['first_convolution_stride']
self.encoderBlocks = nn.ModuleList([modules.ResNetEncoderBlock3D(parameters)])
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
parameters['kernel_heigth'] = original_kernel_height
parameters['convolution_stride'] = 2
self.encoderBlock2 = modules.ResNetEncoderBlock3D(parameters)
parameters['kernel_size'] = original_kernel_size
parameters['convolution_stride'] = original_stride
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
self.encoderBlock3 = modules.ResNetEncoderBlock3D(parameters)
equal_channels_blocks = 0
for _ in range(parameters['number_of_encoder_blocks']):
if parameters['output_channels'] < parameters['max_number_channels']:
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
else:
parameters['input_channels'] = parameters['output_channels']
equal_channels_blocks += 1
self.encoderBlocks.append(modules.ResNetEncoderBlock3D(parameters))
# Transformer
parameters['input_channels'] = parameters['output_channels']
parameters['convolution_stride'] = original_stride
parameters['convolution_stride'] = parameters['transformer_blocks_stride']
self.transformerBlocks = nn.ModuleList([modules.ResNetBlock3D(parameters) for i in range(parameters['number_of_transformer_blocks'])])
......@@ -84,15 +89,22 @@ class BrainMapperAE3D(nn.Module):
# Decoder
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock1 = modules.ResNetDecoderBlock3D(parameters)
if equal_channels_blocks != 0:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters) for i in range(equal_channels_blocks)])
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock2 = modules.ResNetDecoderBlock3D(parameters)
if equal_channels_blocks != 0:
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
else:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters)])
for _ in range(parameters['number_of_encoder_blocks'] - equal_channels_blocks - 1):
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
parameters['input_channels'] = parameters['output_channels']
self.decoderBlock3 = modules.ResNetClassifierBlock3D(parameters)
self.decoderBlocks.append(modules.ResNetClassifierBlock3D(parameters))
parameters['input_channels'] = original_input_channels
parameters['output_channels'] = original_output_channels
......@@ -112,11 +124,14 @@ class BrainMapperAE3D(nn.Module):
# Encoder
X = self.encoderBlock1.forward(X)
Y_encoder_1_size = X.size()
X = self.encoderBlock2.forward(X)
Y_encoder_2_size = X.size()
X = self.encoderBlock3.forward(X)
Y_encoder_sizes = []
for encoderBlock in self.encoderBlocks:
X = encoderBlock.forward(X)
Y_encoder_sizes.append(X.size())
Y_encoder_sizes = Y_encoder_sizes[:-1][::-1]
Y_encoder_sizes_lenght = len(Y_encoder_sizes)
# Transformer
......@@ -136,11 +151,13 @@ class BrainMapperAE3D(nn.Module):
# Decoder
X = self.decoderBlock1.forward(X, Y_encoder_2_size)
del Y_encoder_2_size
X = self.decoderBlock2.forward(X, Y_encoder_1_size)
del Y_encoder_1_size
X = self.decoderBlock3.forward(X)
for index, decoderBlock in enumerate(self.decoderBlocks):
if index < Y_encoder_sizes_lenght:
X = decoderBlock.forward(X, Y_encoder_sizes[index])
else:
X = decoderBlock.forward(X)
del Y_encoder_sizes, Y_encoder_sizes_lenght
return X
......@@ -242,3 +259,220 @@ class BrainMapperAE3D(nn.Module):
subsubmodule.weight.data.normal_(0, std)
print("Initialized network parameters!")
class AutoEncoder3D(nn.Module):
"""Architecture class for CycleGAN inspired BrainMapper 3D Autoencoder.
This class contains the pytorch implementation of the generator architecture underpinning the BrainMapper project.
Args:
parameters (dict): Contains information relevant parameters
parameters = {
'kernel_size': 5
'kernel_width': 5
'kernel_depth': 5
'kernel_classification': 1
'input_channels': 1
'output_channels': 64
'convolution_stride': 1
'dropout': 0.2
'pool_kernel_size': 2
'pool_stride': 2
'up_mode': 'upconv'
'number_of_classes': 1
}
Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block
"""
def __init__(self, parameters):
super(AutoEncoder3D, self).__init__()
self.cross_domain_x2y_flag = parameters['cross_domain_x2y_flag']
original_input_channels = parameters['input_channels']
original_output_channels = parameters['output_channels']
original_kernel_size = parameters['kernel_size']
original_stride = parameters['convolution_stride']
# Encoder Path
parameters['kernel_size'] = parameters['first_kernel_size']
parameters['convolution_stride'] = parameters['first_convolution_stride']
self.encoderBlocks = nn.ModuleList([modules.ResNetEncoderBlock3D(parameters)])
parameters['kernel_size'] = original_kernel_size
parameters['convolution_stride'] = original_stride
equal_channels_blocks = 0
for _ in range(parameters['number_of_encoder_blocks']):
if parameters['output_channels'] < parameters['max_number_channels']:
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
else:
parameters['input_channels'] = parameters['output_channels']
equal_channels_blocks += 1
self.encoderBlocks.append(modules.ResNetEncoderBlock3D(parameters))
# Decoder
if equal_channels_blocks != 0:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters) for i in range(equal_channels_blocks)])
parameters['output_channels'] = parameters['output_channels'] // 2
if equal_channels_blocks != 0:
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
else:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters)])
for _ in range(parameters['number_of_encoder_blocks'] - equal_channels_blocks - 1):
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
parameters['input_channels'] = parameters['output_channels']
self.decoderBlocks.append(modules.ResNetClassifierBlock3D(parameters))
parameters['input_channels'] = original_input_channels
parameters['output_channels'] = original_output_channels
def forward(self, X):
"""Forward pass for 3D CGAN Autoencoder
Function computing the forward pass through the 3D generator
The input to the function is the dMRI map
Args:
X (torch.tensor): Input dMRI map, shape = (N x C x D x H x W)
Returns:
probability_map (torch.tensor): Output forward passed tensor through the CGAN Autoencoder
"""
# Encoder
Y_encoder_sizes = []
for encoderBlock in self.encoderBlocks:
X = encoderBlock.forward(X)
Y_encoder_sizes.append(X.size())
Y_encoder_sizes = Y_encoder_sizes[:-1][::-1]
Y_encoder_sizes_lenght = len(Y_encoder_sizes)
# Decoder
parameters['input_channels'] = parameters['output_channels']
parameters['convolution_stride'] = parameters['transformer_blocks_stride']
for index, decoderBlock in enumerate(self.decoderBlocks):
if index < Y_encoder_sizes_lenght:
X = decoderBlock.forward(X, Y_encoder_sizes[index])
else:
X = decoderBlock.forward(X)
del Y_encoder_sizes, Y_encoder_sizes_lenght
return X
def save(self, path):
"""Model Saver
Function saving the model with all its parameters to a given path.
The path must end with a *.model argument.
Args:
path (str): Path string
"""
print("Saving Model... {}".format(path))
torch.save(self, path)
@property
def test_if_cuda(self):
"""Cuda Test
This function tests if the model parameters are allocated to a CUDA enabled GPU.
Returns:
bool: Flag indicating True if the tensor is stored on the GPU and Flase otherwhise
"""
return next(self.parameters()).is_cuda
def predict(self, X, device=0):
"""Post-training Output Prediction
This function predicts the output of the of the U-net post-training
Args:
X (torch.tensor): input dMRI volume
device (int/str): Device type used for training (int - GPU id, str- CPU)
Returns:
prediction (ndarray): predicted output after training
"""
self.eval() # PyToch module setting network to evaluation mode
if type(X) is np.ndarray:
X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor)
elif type(X) is torch.Tensor and not X.is_cuda:
X = X.type(torch.FloatTensor).cuda(device, non_blocking=True)
# .cuda() call transfers the densor from the CPU to the GPU if that is the case.
# Non-blocking argument lets the caller bypas synchronization when necessary
with torch.no_grad(): # Causes operations to have no gradients
output = self.forward(X)
_, idx = torch.max(output, 1)
# We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray
idx = idx.data.cpu().numpy()
prediction = np.squeeze(idx)
del X, output, idx
return prediction
def reset_parameters(self, custom_weight_reset_flag):
"""Parameter Initialization
This function (re)initializes the parameters of the defined network.
This function is a wrapper for the reset_parameters() function defined for each module.
More information can be found here: https://discuss.pytorch.org/t/what-is-the-default-initialization-of-a-conv2d-layer-and-linear-layer/16055 + https://discuss.pytorch.org/t/how-to-reset-model-weights-to-effectively-implement-crossvalidation/53859
An alternative (re)initialization method is described here: https://discuss.pytorch.org/t/how-to-reset-variables-values-in-nn-modules/32639
Args:
custom_weight_reset_flag (bool): Flag indicating if the modified weight initialisation approach should be used.
"""
print("Initializing network parameters...")
for _, module in self.named_children():
for _, submodule in module.named_children():
if isinstance(submodule, (torch.nn.ConvTranspose3d, torch.nn.Conv3d, torch.nn.InstanceNorm3d)) == True:
submodule.reset_parameters()
if custom_weight_reset_flag == True:
if isinstance(submodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
fan, _ = calculate_fan(submodule.weight)
std = np.divide(gain, np.sqrt(fan))
submodule.weight.data.normal_(0, std)
for _, subsubmodule in submodule.named_children():
if isinstance(subsubmodule, (torch.nn.ConvTranspose3d, torch.nn.Conv3d, torch.nn.InstanceNorm3d)) == True:
subsubmodule.reset_parameters()
if custom_weight_reset_flag == True:
if isinstance(subsubmodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
fan, _ = calculate_fan(subsubmodule.weight)
std = np.divide(gain, np.sqrt(fan))
subsubmodule.weight.data.normal_(0, std)
print("Initialized network parameters!")
\ No newline at end of file
......@@ -24,20 +24,22 @@ use_last_checkpoint = False
adam_w_flag = False
[NETWORK]
kernel_heigth = 3
kernel_width = 3
kernel_depth = 3
kernel_classification = 7
first_kernel_size = 10
first_convolution_stride = 1
input_channels = 1
output_channels = 32
convolution_stride = 1
dropout = 0
max_number_channels = 128
kernel_size = 3
convolution_stride = 2
number_of_encoder_blocks = 2
transformer_blocks_stride = 1
number_of_transformer_blocks = 6
kernel_classification = 10
pool_kernel_size = 3
pool_stride = 2
up_mode = "upconv"
dropout = 0
final_activation = 'tanh'
number_of_classes = 1
number_of_transformer_blocks = 6
custom_weight_reset_flag = False
cross_domain_flag = False
cross_domain_x2x_flag = False
......
......@@ -32,9 +32,7 @@ class ResNetEncoderBlock3D(nn.Module):
Args:
parameters (dict): Contains information on kernel size, number of channels, number of filters, and if convolution is strided.
parameters = {
'kernel_heigth': 5
'kernel_width': 5
'kernel_depth' : 5
'kernel_size': 5
'input_channels': 64
'output_channels': 64
'convolution_stride': 1
......@@ -49,15 +47,15 @@ class ResNetEncoderBlock3D(nn.Module):
super(ResNetEncoderBlock3D, self).__init__()
# We first calculate the amount of zero padding required (http://cs231n.github.io/convolutional-networks/)
padding_heigth = int((parameters['kernel_heigth'] - 1) / 2)
padding_width = int((parameters['kernel_heigth'] - 1) / 2)
padding_depth = int((parameters['kernel_heigth'] - 1) / 2)
padding_heigth = int((parameters['kernel_size'] - 1) / 2)
padding_width = int((parameters['kernel_size'] - 1) / 2)
padding_depth = int((parameters['kernel_size'] - 1) / 2)
self.convolutional_layer = nn.Sequential(
nn.Conv3d(
in_channels=parameters['input_channels'],
out_channels=parameters['output_channels'],
kernel_size=parameters['kernel_heigth'],
kernel_size=parameters['kernel_size'],
stride=parameters['convolution_stride'],
padding=(padding_depth, padding_heigth, padding_width)
),
......@@ -109,9 +107,7 @@ class ResNetFeatureMappingBlock3D(nn.Module):
Args:
parameters (dict): Contains information on kernel size, number of channels, number of filters, and if convolution is strided.
parameters = {
'kernel_heigth': 5
'kernel_width': 5
'kernel_depth' : 5
'kernel_size': 5
'input_channels': 64
'output_channels': 64
'convolution_stride': 1
......@@ -126,15 +122,15 @@ class ResNetFeatureMappingBlock3D(nn.Module):
super(ResNetFeatureMappingBlock3D, self).__init__()
# We first calculate the amount of zero padding required (http://cs231n.github.io/convolutional-networks/)
padding_heigth = int((parameters['kernel_heigth'] - 1) / 2)
padding_width = int((parameters['kernel_heigth'] - 1) / 2)
padding_depth = int((parameters['kernel_heigth'] - 1) / 2)
padding_heigth = int((parameters['kernel_size'] - 1) / 2)
padding_width = int((parameters['kernel_size'] - 1) / 2)
padding_depth = int((parameters['kernel_size'] - 1) / 2)
self.convolutional_layer = nn.Sequential(
nn.Conv3d(
in_channels=parameters['input_channels'],
out_channels=parameters['output_channels'],
kernel_size=parameters['kernel_heigth'],
kernel_size=parameters['kernel_size'],
stride=parameters['convolution_stride'],
padding=(padding_depth, padding_heigth, padding_width)
),
......@@ -179,9 +175,7 @@ class ResNetBlock3D(nn.Module):
Args:
parameters (dict): Contains information on kernel size, number of channels, number of filters, and if convolution is strided.
parameters = {
'kernel_heigth': 5
'kernel_width': 5
'kernel_depth' : 5
'kernel_size': 5
'input_channels': 64
'output_channels': 64
'convolution_stride': 1
......@@ -196,17 +190,15 @@ class ResNetBlock3D(nn.Module):
super(ResNetBlock3D, self).__init__()
# We first calculate the amount of zero padding required (http://cs231n.github.io/convolutional-networks/)
padding_heigth = int((parameters['kernel_heigth'] - 1) / 2)
padding_width = int((parameters['kernel_width'] - 1) / 2)
padding_depth = int((parameters['kernel_depth'] - 1) / 2)
padding_heigth = int((parameters['kernel_size'] - 1) / 2)
padding_width = int((parameters['kernel_size'] - 1) / 2)
padding_depth = int((parameters['kernel_size'] - 1) / 2)
self.convolutional_layer = nn.Sequential(
nn.Conv3d(
in_channels=parameters['input_channels'],
out_channels=parameters['output_channels'],
kernel_size=(parameters['kernel_depth'],
parameters['kernel_heigth'],
parameters['kernel_width']),
kernel_size=parameters['kernel_size'],
stride=parameters['convolution_stride'],
padding=(padding_depth, padding_heigth, padding_width)
),
......@@ -218,9 +210,7 @@ class ResNetBlock3D(nn.Module):
nn.Conv3d(
in_channels=parameters['input_channels'],
out_channels=parameters['output_channels'],
kernel_size=(parameters['kernel_depth'],
parameters['kernel_heigth'],
parameters['kernel_width']),
kernel_size=parameters['kernel_size'],
stride=parameters['convolution_stride'],
padding=(padding_depth, padding_heigth, padding_width)
),
......@@ -271,9 +261,7 @@ class ResNetDecoderBlock3D(nn.Module):
Args:
parameters (dict): Contains information relevant parameters
parameters = {
'kernel_heigth': 5
'kernel_width': 5
'kernel_depth': 5
'kernel_size': 5
'input_channels': 64
'output_channels': 64
'convolution_stride': 1
......@@ -343,7 +331,7 @@ class ResNetClassifierBlock3D(nn.Module):
Args:
parameters (dict): Contains information relevant parameters
parameters = {
'kernel_heigth': 5
'kernel_size': 5
'kernel_width': 5
'kernel_depth': 5
'kernel_classification': 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment