Commit 7da22503 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

Merge branch 'CrossDomain_autoencoder' into 'master'

Cross domain autoencoder

See merge request !4
parents aa74b23e 5e0551e3
...@@ -118,9 +118,16 @@ dmypy.json ...@@ -118,9 +118,16 @@ dmypy.json
.vscode/ .vscode/
datasets/ datasets/
files.txt files.txt
jobscript.sge.sh *.sge.sh
*.nii.gz *.nii.gz
stuff/ stuff/
test/* test/*
.DS_Store .DS_Store
logs/ logs/
*.ini
experimentInputs/
experiments/
predictions/
*.sh
saved_models/
mock_job.py
...@@ -29,7 +29,7 @@ class BrainMapperAE3D(nn.Module): ...@@ -29,7 +29,7 @@ class BrainMapperAE3D(nn.Module):
Args: Args:
parameters (dict): Contains information relevant parameters parameters (dict): Contains information relevant parameters
parameters = { parameters = {
'kernel_heigth': 5 'kernel_size': 5
'kernel_width': 5 'kernel_width': 5
'kernel_depth': 5 'kernel_depth': 5
'kernel_classification': 1 'kernel_classification': 1
...@@ -50,44 +50,61 @@ class BrainMapperAE3D(nn.Module): ...@@ -50,44 +50,61 @@ class BrainMapperAE3D(nn.Module):
def __init__(self, parameters): def __init__(self, parameters):
super(BrainMapperAE3D, self).__init__() super(BrainMapperAE3D, self).__init__()
self.cross_domain_x2y_flag = parameters['cross_domain_x2y_flag']
original_input_channels = parameters['input_channels'] original_input_channels = parameters['input_channels']
original_output_channels = parameters['output_channels'] original_output_channels = parameters['output_channels']
original_kernel_height = parameters['kernel_heigth'] original_kernel_size = parameters['kernel_size']
original_stride = parameters['convolution_stride'] original_stride = parameters['convolution_stride']
# Encoder Path # Encoder Path
parameters['kernel_heigth'] = 7 parameters['kernel_size'] = parameters['first_kernel_size']
self.encoderBlock1 = modules.ResNetEncoderBlock3D(parameters) parameters['convolution_stride'] = parameters['first_convolution_stride']
self.encoderBlocks = nn.ModuleList([modules.ResNetEncoderBlock3D(parameters)])
parameters['input_channels'] = parameters['output_channels'] parameters['kernel_size'] = original_kernel_size
parameters['output_channels'] = parameters['output_channels'] * 2 parameters['convolution_stride'] = original_stride
parameters['kernel_heigth'] = original_kernel_height
parameters['convolution_stride'] = 2
self.encoderBlock2 = modules.ResNetEncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] equal_channels_blocks = 0
parameters['output_channels'] = parameters['output_channels'] * 2
self.encoderBlock3 = modules.ResNetEncoderBlock3D(parameters) for _ in range(parameters['number_of_encoder_blocks']):
if parameters['output_channels'] < parameters['max_number_channels']:
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
else:
parameters['input_channels'] = parameters['output_channels']
equal_channels_blocks += 1
self.encoderBlocks.append(modules.ResNetEncoderBlock3D(parameters))
# Transformer # Transformer
parameters['input_channels'] = parameters['output_channels'] parameters['input_channels'] = parameters['output_channels']
parameters['convolution_stride'] = original_stride parameters['convolution_stride'] = parameters['transformer_blocks_stride']
self.transformerBlocks = nn.ModuleList([modules.ResNetBlock3D(parameters) for i in range(parameters['number_of_transformer_blocks'])]) self.transformerBlocks = nn.ModuleList([modules.ResNetBlock3D(parameters) for i in range(parameters['number_of_transformer_blocks'])])
if self.cross_domain_x2y_flag == True:
self.featureMappingLayers = nn.ModuleList([modules.ResNetFeatureMappingBlock3D(parameters) for i in range(parameters['number_of_feature_mapping_blocks'])])
# Decoder # Decoder
parameters['output_channels'] = parameters['output_channels'] // 2 if equal_channels_blocks != 0:
self.decoderBlock1 = modules.ResNetDecoderBlock3D(parameters) self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters) for i in range(equal_channels_blocks)])
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2 parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock2 = modules.ResNetDecoderBlock3D(parameters) if equal_channels_blocks != 0:
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
else:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters)])
for _ in range(parameters['number_of_encoder_blocks'] - equal_channels_blocks - 1):
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
parameters['input_channels'] = parameters['output_channels'] parameters['input_channels'] = parameters['output_channels']
self.decoderBlock3 = modules.ResNetClassifierBlock3D(parameters) self.decoderBlocks.append(modules.ResNetClassifierBlock3D(parameters))
parameters['input_channels'] = original_input_channels parameters['input_channels'] = original_input_channels
parameters['output_channels'] = original_output_channels parameters['output_channels'] = original_output_channels
...@@ -107,24 +124,40 @@ class BrainMapperAE3D(nn.Module): ...@@ -107,24 +124,40 @@ class BrainMapperAE3D(nn.Module):
# Encoder # Encoder
X = self.encoderBlock1.forward(X) Y_encoder_sizes = []
Y_encoder_1_size = X.size()
X = self.encoderBlock2.forward(X) for encoderBlock in self.encoderBlocks:
Y_encoder_2_size = X.size() X = encoderBlock.forward(X)
X = self.encoderBlock3.forward(X) Y_encoder_sizes.append(X.size())
Y_encoder_sizes = Y_encoder_sizes[:-1][::-1]
Y_encoder_sizes_lenght = len(Y_encoder_sizes)
# Transformer # Transformer
for transformerBlock in self.transformerBlocks: if self.cross_domain_x2y_flag == True:
X = transformerBlock(X) for transformerBlock in self.transformerBlocks[:len(self.transformerBlocks)//2]:
X = transformerBlock(X)
for featureMappingLayer in self.featureMappingLayers:
X = featureMappingLayer(X)
for transformerBlock in self.transformerBlocks[len(self.transformerBlocks)//2:]:
X = transformerBlock(X)
else:
for transformerBlock in self.transformerBlocks:
X = transformerBlock(X)
# Decoder # Decoder
X = self.decoderBlock1.forward(X, Y_encoder_2_size) for index, decoderBlock in enumerate(self.decoderBlocks):
del Y_encoder_2_size if index < Y_encoder_sizes_lenght:
X = self.decoderBlock2.forward(X, Y_encoder_1_size) X = decoderBlock.forward(X, Y_encoder_sizes[index])
del Y_encoder_1_size else:
X = self.decoderBlock3.forward(X) X = decoderBlock.forward(X)
del Y_encoder_sizes, Y_encoder_sizes_lenght
return X return X
...@@ -226,3 +259,220 @@ class BrainMapperAE3D(nn.Module): ...@@ -226,3 +259,220 @@ class BrainMapperAE3D(nn.Module):
subsubmodule.weight.data.normal_(0, std) subsubmodule.weight.data.normal_(0, std)
print("Initialized network parameters!") print("Initialized network parameters!")
class AutoEncoder3D(nn.Module):
"""Architecture class for CycleGAN inspired BrainMapper 3D Autoencoder.
This class contains the pytorch implementation of the generator architecture underpinning the BrainMapper project.
Args:
parameters (dict): Contains information relevant parameters
parameters = {
'kernel_size': 5
'kernel_width': 5
'kernel_depth': 5
'kernel_classification': 1
'input_channels': 1
'output_channels': 64
'convolution_stride': 1
'dropout': 0.2
'pool_kernel_size': 2
'pool_stride': 2
'up_mode': 'upconv'
'number_of_classes': 1
}
Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block
"""
def __init__(self, parameters):
super(AutoEncoder3D, self).__init__()
self.cross_domain_x2y_flag = parameters['cross_domain_x2y_flag']
original_input_channels = parameters['input_channels']
original_output_channels = parameters['output_channels']
original_kernel_size = parameters['kernel_size']
original_stride = parameters['convolution_stride']
# Encoder Path
parameters['kernel_size'] = parameters['first_kernel_size']
parameters['convolution_stride'] = parameters['first_convolution_stride']
self.encoderBlocks = nn.ModuleList([modules.ResNetEncoderBlock3D(parameters)])
parameters['kernel_size'] = original_kernel_size
parameters['convolution_stride'] = original_stride
equal_channels_blocks = 0
for _ in range(parameters['number_of_encoder_blocks']):
if parameters['output_channels'] < parameters['max_number_channels']:
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
else:
parameters['input_channels'] = parameters['output_channels']
equal_channels_blocks += 1
self.encoderBlocks.append(modules.ResNetEncoderBlock3D(parameters))
# Decoder
parameters['input_channels'] = parameters['output_channels']
parameters['convolution_stride'] = parameters['transformer_blocks_stride']
if equal_channels_blocks != 0:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters) for i in range(equal_channels_blocks)])
parameters['output_channels'] = parameters['output_channels'] // 2
if equal_channels_blocks != 0:
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
else:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters)])
for _ in range(parameters['number_of_encoder_blocks'] - equal_channels_blocks - 1):
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
parameters['input_channels'] = parameters['output_channels']
self.decoderBlocks.append(modules.ResNetClassifierBlock3D(parameters))
parameters['input_channels'] = original_input_channels
parameters['output_channels'] = original_output_channels
def forward(self, X):
"""Forward pass for 3D CGAN Autoencoder
Function computing the forward pass through the 3D generator
The input to the function is the dMRI map
Args:
X (torch.tensor): Input dMRI map, shape = (N x C x D x H x W)
Returns:
probability_map (torch.tensor): Output forward passed tensor through the CGAN Autoencoder
"""
# Encoder
Y_encoder_sizes = []
for encoderBlock in self.encoderBlocks:
X = encoderBlock.forward(X)
Y_encoder_sizes.append(X.size())
Y_encoder_sizes = Y_encoder_sizes[:-1][::-1]
Y_encoder_sizes_lenght = len(Y_encoder_sizes)
# Decoder
for index, decoderBlock in enumerate(self.decoderBlocks):
if index < Y_encoder_sizes_lenght:
X = decoderBlock.forward(X, Y_encoder_sizes[index])
else:
X = decoderBlock.forward(X)
del Y_encoder_sizes, Y_encoder_sizes_lenght
return X
def save(self, path):
"""Model Saver
Function saving the model with all its parameters to a given path.
The path must end with a *.model argument.
Args:
path (str): Path string
"""
print("Saving Model... {}".format(path))
torch.save(self, path)
@property
def test_if_cuda(self):
"""Cuda Test
This function tests if the model parameters are allocated to a CUDA enabled GPU.
Returns:
bool: Flag indicating True if the tensor is stored on the GPU and Flase otherwhise
"""
return next(self.parameters()).is_cuda
def predict(self, X, device=0):
"""Post-training Output Prediction
This function predicts the output of the of the U-net post-training
Args:
X (torch.tensor): input dMRI volume
device (int/str): Device type used for training (int - GPU id, str- CPU)
Returns:
prediction (ndarray): predicted output after training
"""
self.eval() # PyToch module setting network to evaluation mode
if type(X) is np.ndarray:
X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor)
elif type(X) is torch.Tensor and not X.is_cuda:
X = X.type(torch.FloatTensor).cuda(device, non_blocking=True)
# .cuda() call transfers the densor from the CPU to the GPU if that is the case.
# Non-blocking argument lets the caller bypas synchronization when necessary
with torch.no_grad(): # Causes operations to have no gradients
output = self.forward(X)
_, idx = torch.max(output, 1)
# We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray
idx = idx.data.cpu().numpy()
prediction = np.squeeze(idx)
del X, output, idx
return prediction
def reset_parameters(self, custom_weight_reset_flag):
"""Parameter Initialization
This function (re)initializes the parameters of the defined network.
This function is a wrapper for the reset_parameters() function defined for each module.
More information can be found here: https://discuss.pytorch.org/t/what-is-the-default-initialization-of-a-conv2d-layer-and-linear-layer/16055 + https://discuss.pytorch.org/t/how-to-reset-model-weights-to-effectively-implement-crossvalidation/53859
An alternative (re)initialization method is described here: https://discuss.pytorch.org/t/how-to-reset-variables-values-in-nn-modules/32639
Args:
custom_weight_reset_flag (bool): Flag indicating if the modified weight initialisation approach should be used.
"""
print("Initializing network parameters...")
for _, module in self.named_children():
for _, submodule in module.named_children():
if isinstance(submodule, (torch.nn.ConvTranspose3d, torch.nn.Conv3d, torch.nn.InstanceNorm3d)) == True:
submodule.reset_parameters()
if custom_weight_reset_flag == True:
if isinstance(submodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
fan, _ = calculate_fan(submodule.weight)
std = np.divide(gain, np.sqrt(fan))
submodule.weight.data.normal_(0, std)
for _, subsubmodule in submodule.named_children():
if isinstance(subsubmodule, (torch.nn.ConvTranspose3d, torch.nn.Conv3d, torch.nn.InstanceNorm3d)) == True:
subsubmodule.reset_parameters()
if custom_weight_reset_flag == True:
if isinstance(subsubmodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
fan, _ = calculate_fan(subsubmodule.weight)
std = np.divide(gain, np.sqrt(fan))
subsubmodule.weight.data.normal_(0, std)
print("Initialized network parameters!")
\ No newline at end of file
...@@ -39,11 +39,10 @@ import torch.utils.data as data ...@@ -39,11 +39,10 @@ import torch.utils.data as data
import numpy as np import numpy as np
from solver import Solver from solver import Solver
from BrainMapperAE import BrainMapperAE3D from BrainMapperAE import BrainMapperAE3D, AutoEncoder3D
from utils.data_utils import get_datasets from utils.data_utils import get_datasets
from utils.settings import Settings from utils.settings import Settings
import utils.data_evaluation_utils as evaluations import utils.data_evaluation_utils as evaluations
from utils.data_logging_utils import LogWriter
from utils.common_utils import create_folder from utils.common_utils import create_folder
# Set the default floating point tensor type to FloatTensor # Set the default floating point tensor type to FloatTensor
...@@ -51,13 +50,15 @@ from utils.common_utils import create_folder ...@@ -51,13 +50,15 @@ from utils.common_utils import create_folder
torch.set_default_tensor_type(torch.FloatTensor) torch.set_default_tensor_type(torch.FloatTensor)
def load_data(data_parameters): def load_data(data_parameters, cross_domain_x2x_flag, cross_domain_y2y_flag):
"""Dataset Loader """Dataset Loader
This function loads the training and validation datasets. This function loads the training and validation datasets.
Args: Args:
data_parameters (dict): Dictionary containing relevant information for the datafiles. data_parameters (dict): Dictionary containing relevant information for the datafiles.
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns: Returns:
train_data (dataset object): Pytorch map-style dataset object, mapping indices to training data samples. train_data (dataset object): Pytorch map-style dataset object, mapping indices to training data samples.
...@@ -65,7 +66,7 @@ def load_data(data_parameters): ...@@ -65,7 +66,7 @@ def load_data(data_parameters):
""" """
print("Data is loading...") print("Data is loading...")
train_data, validation_data = get_datasets(data_parameters) train_data, validation_data = get_datasets(data_parameters, cross_domain_x2x_flag, cross_domain_y2y_flag)
print("Data has loaded!") print("Data has loaded!")
print("Training dataset size is {}".format(len(train_data))) print("Training dataset size is {}".format(len(train_data)))
print("Validation dataset size is {}".format(len(validation_data))) print("Validation dataset size is {}".format(len(validation_data)))
...@@ -116,7 +117,49 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -116,7 +117,49 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
} }
""" """
def _train_runner(data_parameters, training_parameters, network_parameters, misc_parameters):
def _load_pretrained_cross_domain(x2y_model, save_model_directory, experiment_name):
""" Pretrained cross-domain loader
This function loads the pretrained X2X and Y2Y autuencoders.
After, it initializes the X2Y model's weights using the X2X encoder and teh Y2Y decoder weights.
Args:
x2y_model (class): Original x2y model initialised using the standard parameters.
save_model_directory (str): Name of the directory where the model is saved
experiment_name (str): Name of the experiment
Returns:
x2y_model (class): New x2y model with encoder and decoder paths weights reinitialised.
"""
x2y_model_state_dict = x2y_model.state_dict()
x2x_model_state_dict = torch.load(os.path.join(save_model_directory, experiment_name + '_x2x.pth.tar')).state_dict()
y2y_model_state_dict = torch.load(os.path.join(save_model_directory, experiment_name + '_y2y.pth.tar')).state_dict()
half_point = len(x2x_model_state_dict)//2 + 1
counter = 1
for key, _ in x2y_model_state_dict.items():
if counter <= half_point:
x2y_model_state_dict.update({key : x2x_model_state_dict[key]})
counter+=1
else:
if key in y2y_model_state_dict:
x2y_model_state_dict.update({key : y2y_model_state_dict[key]})
x2y_model.load_state_dict(x2y_model_state_dict)
return x2y_model
def _train_runner(data_parameters,
training_parameters,
network_parameters,
misc_parameters,
optimizer = torch.optim.Adam,
loss_function = torch.nn.MSELoss(),
):
"""Wrapper for the training operation """Wrapper for the training operation
This function wraps the training operation for the network This function wraps the training operation for the network
...@@ -128,7 +171,11 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -128,7 +171,11 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
misc_parameters (dict): Dictionary of aditional hyperparameters misc_parameters (dict): Dictionary of aditional hyperparameters
""" """
train_data, validation_data = load_data(data_parameters)
train_data, validation_data = load_data(data_parameters,
cross_domain_x2x_flag = network_parameters['cross_domain_x2x_flag'],
cross_domain_y2y_flag = network_parameters['cross_domain_y2y_flag']
)
train_loader = data.DataLoader( train_loader = data.DataLoader(
dataset=train_data, dataset=train_data,
...@@ -145,19 +192,20 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -145,19 +192,20 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
) )
if training_parameters['use_pre_trained']: if training_parameters['use_pre_trained']:
BrainMapperModel = torch.load( BrainMapperModel = torch.load(training_parameters['pre_trained_path'])
training_parameters['pre_trained_path'])
else: else:
BrainMapperModel = BrainMapperAE3D(network_parameters) # BrainMapperModel = BrainMapperAE3D(network_parameters)
BrainMapperModel = AutoEncoder3D(network_parameters) # temprorary change for testing encoder-decoder effective receptive field
custom_weight_reset_flag = network_parameters['custom_weight_reset_flag'] custom_weight_reset_flag = network_parameters['custom_weight_reset_flag']
BrainMapperModel.reset_parameters(custom_weight_reset_flag) BrainMapperModel.reset_parameters(custom_weight_reset_flag)
if training_parameters['adam_w_flag'] == True: if network_parameters['cross_domain_x2y_flag'] == True:
optimizer = torch.optim.AdamW BrainMapperModel = _load_pretrained_cross_domain(x2y_model=BrainMapperModel,
else: save_model_directory=misc_parameters['save_model_directory'],
optimizer = torch.optim.Adam experiment_name=training_parameters['experiment_name']
)
solver = Solver(model=BrainMapperModel, solver = Solver(model=BrainMapperModel,
device=misc_parameters['device'], device=misc_parameters['device'],
...@@ -169,6 +217,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -169,6 +217,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
'eps': training_parameters['optimizer_epsilon'], 'eps': training_parameters['optimizer_epsilon'],
'weight_decay': training_parameters['optimizer_weigth_decay'] 'weight_decay': training_parameters['optimizer_weigth_decay']
}, },
loss_function=loss_function,
model_name=training_parameters['experiment_name'], model_name=training_parameters['experiment_name'],
number_epochs=training_parameters['number_of_epochs'], number_epochs=training_parameters['number_of_epochs'],
loss_log_period=training_parameters['loss_log_period'], loss_log_period=training_parameters['loss_log_period'],
...@@ -180,19 +229,79 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -180,19 +229,79 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
logs_directory=misc_parameters['logs_directory'], logs_directory=misc_parameters['logs_directory'],
checkpoint_directory=misc_parameters['checkpoint_directory'], checkpoint_directory=misc_parameters['checkpoint_directory'],
save_model_directory=misc_parameters['save_model_directory'], save_model_directory=misc_parameters['save_model_directory'],
final_model_output_file=training_parameters['final_model_output_file'],
crop_flag = data_parameters['crop_flag'] crop_flag = data_parameters['crop_flag']
) )
validation_loss = solver.train(train_loader, validation_loader) # _ = solver.train(train_loader, validation_loader)
solver.train(train_loader, validation_loader)
del train_data, validation_data, train_loader, validation_loader, BrainMapperModel, solver, optimizer del train_data, validation_data, train_loader, validation_loader, BrainMapperModel, solver, optimizer
torch.cuda.empty_cache() torch.cuda.empty_cache()