Commit 7da22503 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

Merge branch 'CrossDomain_autoencoder' into 'master'

Cross domain autoencoder

See merge request !4
parents aa74b23e 5e0551e3
......@@ -118,9 +118,16 @@ dmypy.json
.vscode/
datasets/
files.txt
jobscript.sge.sh
*.sge.sh
*.nii.gz
stuff/
test/*
.DS_Store
logs/
*.ini
experimentInputs/
experiments/
predictions/
*.sh
saved_models/
mock_job.py
......@@ -29,7 +29,7 @@ class BrainMapperAE3D(nn.Module):
Args:
parameters (dict): Contains information relevant parameters
parameters = {
'kernel_heigth': 5
'kernel_size': 5
'kernel_width': 5
'kernel_depth': 5
'kernel_classification': 1
......@@ -50,44 +50,61 @@ class BrainMapperAE3D(nn.Module):
def __init__(self, parameters):
super(BrainMapperAE3D, self).__init__()
self.cross_domain_x2y_flag = parameters['cross_domain_x2y_flag']
original_input_channels = parameters['input_channels']
original_output_channels = parameters['output_channels']
original_kernel_height = parameters['kernel_heigth']
original_kernel_size = parameters['kernel_size']
original_stride = parameters['convolution_stride']
# Encoder Path
parameters['kernel_heigth'] = 7
self.encoderBlock1 = modules.ResNetEncoderBlock3D(parameters)
parameters['kernel_size'] = parameters['first_kernel_size']
parameters['convolution_stride'] = parameters['first_convolution_stride']
self.encoderBlocks = nn.ModuleList([modules.ResNetEncoderBlock3D(parameters)])
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
parameters['kernel_heigth'] = original_kernel_height
parameters['convolution_stride'] = 2
self.encoderBlock2 = modules.ResNetEncoderBlock3D(parameters)
parameters['kernel_size'] = original_kernel_size
parameters['convolution_stride'] = original_stride
equal_channels_blocks = 0
for _ in range(parameters['number_of_encoder_blocks']):
if parameters['output_channels'] < parameters['max_number_channels']:
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
self.encoderBlock3 = modules.ResNetEncoderBlock3D(parameters)
else:
parameters['input_channels'] = parameters['output_channels']
equal_channels_blocks += 1
self.encoderBlocks.append(modules.ResNetEncoderBlock3D(parameters))
# Transformer
parameters['input_channels'] = parameters['output_channels']
parameters['convolution_stride'] = original_stride
parameters['convolution_stride'] = parameters['transformer_blocks_stride']
self.transformerBlocks = nn.ModuleList([modules.ResNetBlock3D(parameters) for i in range(parameters['number_of_transformer_blocks'])])
if self.cross_domain_x2y_flag == True:
self.featureMappingLayers = nn.ModuleList([modules.ResNetFeatureMappingBlock3D(parameters) for i in range(parameters['number_of_feature_mapping_blocks'])])
# Decoder
if equal_channels_blocks != 0:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters) for i in range(equal_channels_blocks)])
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock1 = modules.ResNetDecoderBlock3D(parameters)
if equal_channels_blocks != 0:
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
else:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters)])
for _ in range(parameters['number_of_encoder_blocks'] - equal_channels_blocks - 1):
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock2 = modules.ResNetDecoderBlock3D(parameters)
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
parameters['input_channels'] = parameters['output_channels']
self.decoderBlock3 = modules.ResNetClassifierBlock3D(parameters)
self.decoderBlocks.append(modules.ResNetClassifierBlock3D(parameters))
parameters['input_channels'] = original_input_channels
parameters['output_channels'] = original_output_channels
......@@ -107,24 +124,257 @@ class BrainMapperAE3D(nn.Module):
# Encoder
X = self.encoderBlock1.forward(X)
Y_encoder_1_size = X.size()
X = self.encoderBlock2.forward(X)
Y_encoder_2_size = X.size()
X = self.encoderBlock3.forward(X)
Y_encoder_sizes = []
for encoderBlock in self.encoderBlocks:
X = encoderBlock.forward(X)
Y_encoder_sizes.append(X.size())
Y_encoder_sizes = Y_encoder_sizes[:-1][::-1]
Y_encoder_sizes_lenght = len(Y_encoder_sizes)
# Transformer
if self.cross_domain_x2y_flag == True:
for transformerBlock in self.transformerBlocks[:len(self.transformerBlocks)//2]:
X = transformerBlock(X)
for featureMappingLayer in self.featureMappingLayers:
X = featureMappingLayer(X)
for transformerBlock in self.transformerBlocks[len(self.transformerBlocks)//2:]:
X = transformerBlock(X)
else:
for transformerBlock in self.transformerBlocks:
X = transformerBlock(X)
# Decoder
X = self.decoderBlock1.forward(X, Y_encoder_2_size)
del Y_encoder_2_size
X = self.decoderBlock2.forward(X, Y_encoder_1_size)
del Y_encoder_1_size
X = self.decoderBlock3.forward(X)
for index, decoderBlock in enumerate(self.decoderBlocks):
if index < Y_encoder_sizes_lenght:
X = decoderBlock.forward(X, Y_encoder_sizes[index])
else:
X = decoderBlock.forward(X)
del Y_encoder_sizes, Y_encoder_sizes_lenght
return X
def save(self, path):
"""Model Saver
Function saving the model with all its parameters to a given path.
The path must end with a *.model argument.
Args:
path (str): Path string
"""
print("Saving Model... {}".format(path))
torch.save(self, path)
@property
def test_if_cuda(self):
"""Cuda Test
This function tests if the model parameters are allocated to a CUDA enabled GPU.
Returns:
bool: Flag indicating True if the tensor is stored on the GPU and Flase otherwhise
"""
return next(self.parameters()).is_cuda
def predict(self, X, device=0):
"""Post-training Output Prediction
This function predicts the output of the of the U-net post-training
Args:
X (torch.tensor): input dMRI volume
device (int/str): Device type used for training (int - GPU id, str- CPU)
Returns:
prediction (ndarray): predicted output after training
"""
self.eval() # PyToch module setting network to evaluation mode
if type(X) is np.ndarray:
X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor)
elif type(X) is torch.Tensor and not X.is_cuda:
X = X.type(torch.FloatTensor).cuda(device, non_blocking=True)
# .cuda() call transfers the densor from the CPU to the GPU if that is the case.
# Non-blocking argument lets the caller bypas synchronization when necessary
with torch.no_grad(): # Causes operations to have no gradients
output = self.forward(X)
_, idx = torch.max(output, 1)
# We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray
idx = idx.data.cpu().numpy()
prediction = np.squeeze(idx)
del X, output, idx
return prediction
def reset_parameters(self, custom_weight_reset_flag):
"""Parameter Initialization
This function (re)initializes the parameters of the defined network.
This function is a wrapper for the reset_parameters() function defined for each module.
More information can be found here: https://discuss.pytorch.org/t/what-is-the-default-initialization-of-a-conv2d-layer-and-linear-layer/16055 + https://discuss.pytorch.org/t/how-to-reset-model-weights-to-effectively-implement-crossvalidation/53859
An alternative (re)initialization method is described here: https://discuss.pytorch.org/t/how-to-reset-variables-values-in-nn-modules/32639
Args:
custom_weight_reset_flag (bool): Flag indicating if the modified weight initialisation approach should be used.
"""
print("Initializing network parameters...")
for _, module in self.named_children():
for _, submodule in module.named_children():
if isinstance(submodule, (torch.nn.ConvTranspose3d, torch.nn.Conv3d, torch.nn.InstanceNorm3d)) == True:
submodule.reset_parameters()
if custom_weight_reset_flag == True:
if isinstance(submodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
fan, _ = calculate_fan(submodule.weight)
std = np.divide(gain, np.sqrt(fan))
submodule.weight.data.normal_(0, std)
for _, subsubmodule in submodule.named_children():
if isinstance(subsubmodule, (torch.nn.ConvTranspose3d, torch.nn.Conv3d, torch.nn.InstanceNorm3d)) == True:
subsubmodule.reset_parameters()
if custom_weight_reset_flag == True:
if isinstance(subsubmodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
fan, _ = calculate_fan(subsubmodule.weight)
std = np.divide(gain, np.sqrt(fan))
subsubmodule.weight.data.normal_(0, std)
print("Initialized network parameters!")
class AutoEncoder3D(nn.Module):
"""Architecture class for CycleGAN inspired BrainMapper 3D Autoencoder.
This class contains the pytorch implementation of the generator architecture underpinning the BrainMapper project.
Args:
parameters (dict): Contains information relevant parameters
parameters = {
'kernel_size': 5
'kernel_width': 5
'kernel_depth': 5
'kernel_classification': 1
'input_channels': 1
'output_channels': 64
'convolution_stride': 1
'dropout': 0.2
'pool_kernel_size': 2
'pool_stride': 2
'up_mode': 'upconv'
'number_of_classes': 1
}
Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block
"""
def __init__(self, parameters):
super(AutoEncoder3D, self).__init__()
self.cross_domain_x2y_flag = parameters['cross_domain_x2y_flag']
original_input_channels = parameters['input_channels']
original_output_channels = parameters['output_channels']
original_kernel_size = parameters['kernel_size']
original_stride = parameters['convolution_stride']
# Encoder Path
parameters['kernel_size'] = parameters['first_kernel_size']
parameters['convolution_stride'] = parameters['first_convolution_stride']
self.encoderBlocks = nn.ModuleList([modules.ResNetEncoderBlock3D(parameters)])
parameters['kernel_size'] = original_kernel_size
parameters['convolution_stride'] = original_stride
equal_channels_blocks = 0
for _ in range(parameters['number_of_encoder_blocks']):
if parameters['output_channels'] < parameters['max_number_channels']:
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
else:
parameters['input_channels'] = parameters['output_channels']
equal_channels_blocks += 1
self.encoderBlocks.append(modules.ResNetEncoderBlock3D(parameters))
# Decoder
parameters['input_channels'] = parameters['output_channels']
parameters['convolution_stride'] = parameters['transformer_blocks_stride']
if equal_channels_blocks != 0:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters) for i in range(equal_channels_blocks)])
parameters['output_channels'] = parameters['output_channels'] // 2
if equal_channels_blocks != 0:
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
else:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters)])
for _ in range(parameters['number_of_encoder_blocks'] - equal_channels_blocks - 1):
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
parameters['input_channels'] = parameters['output_channels']
self.decoderBlocks.append(modules.ResNetClassifierBlock3D(parameters))
parameters['input_channels'] = original_input_channels
parameters['output_channels'] = original_output_channels
def forward(self, X):
"""Forward pass for 3D CGAN Autoencoder
Function computing the forward pass through the 3D generator
The input to the function is the dMRI map
Args:
X (torch.tensor): Input dMRI map, shape = (N x C x D x H x W)
Returns:
probability_map (torch.tensor): Output forward passed tensor through the CGAN Autoencoder
"""
# Encoder
Y_encoder_sizes = []
for encoderBlock in self.encoderBlocks:
X = encoderBlock.forward(X)
Y_encoder_sizes.append(X.size())
Y_encoder_sizes = Y_encoder_sizes[:-1][::-1]
Y_encoder_sizes_lenght = len(Y_encoder_sizes)
# Decoder
for index, decoderBlock in enumerate(self.decoderBlocks):
if index < Y_encoder_sizes_lenght:
X = decoderBlock.forward(X, Y_encoder_sizes[index])
else:
X = decoderBlock.forward(X)
del Y_encoder_sizes, Y_encoder_sizes_lenght
return X
......
......@@ -39,11 +39,10 @@ import torch.utils.data as data
import numpy as np
from solver import Solver
from BrainMapperAE import BrainMapperAE3D
from BrainMapperAE import BrainMapperAE3D, AutoEncoder3D
from utils.data_utils import get_datasets
from utils.settings import Settings
import utils.data_evaluation_utils as evaluations
from utils.data_logging_utils import LogWriter
from utils.common_utils import create_folder
# Set the default floating point tensor type to FloatTensor
......@@ -51,13 +50,15 @@ from utils.common_utils import create_folder
torch.set_default_tensor_type(torch.FloatTensor)
def load_data(data_parameters):
def load_data(data_parameters, cross_domain_x2x_flag, cross_domain_y2y_flag):
"""Dataset Loader
This function loads the training and validation datasets.
Args:
data_parameters (dict): Dictionary containing relevant information for the datafiles.
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns:
train_data (dataset object): Pytorch map-style dataset object, mapping indices to training data samples.
......@@ -65,7 +66,7 @@ def load_data(data_parameters):
"""
print("Data is loading...")
train_data, validation_data = get_datasets(data_parameters)
train_data, validation_data = get_datasets(data_parameters, cross_domain_x2x_flag, cross_domain_y2y_flag)
print("Data has loaded!")
print("Training dataset size is {}".format(len(train_data)))
print("Validation dataset size is {}".format(len(validation_data)))
......@@ -116,7 +117,49 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
}
"""
def _train_runner(data_parameters, training_parameters, network_parameters, misc_parameters):
def _load_pretrained_cross_domain(x2y_model, save_model_directory, experiment_name):
""" Pretrained cross-domain loader
This function loads the pretrained X2X and Y2Y autuencoders.
After, it initializes the X2Y model's weights using the X2X encoder and teh Y2Y decoder weights.
Args:
x2y_model (class): Original x2y model initialised using the standard parameters.
save_model_directory (str): Name of the directory where the model is saved
experiment_name (str): Name of the experiment
Returns:
x2y_model (class): New x2y model with encoder and decoder paths weights reinitialised.
"""
x2y_model_state_dict = x2y_model.state_dict()
x2x_model_state_dict = torch.load(os.path.join(save_model_directory, experiment_name + '_x2x.pth.tar')).state_dict()
y2y_model_state_dict = torch.load(os.path.join(save_model_directory, experiment_name + '_y2y.pth.tar')).state_dict()
half_point = len(x2x_model_state_dict)//2 + 1
counter = 1
for key, _ in x2y_model_state_dict.items():
if counter <= half_point:
x2y_model_state_dict.update({key : x2x_model_state_dict[key]})
counter+=1
else:
if key in y2y_model_state_dict:
x2y_model_state_dict.update({key : y2y_model_state_dict[key]})
x2y_model.load_state_dict(x2y_model_state_dict)
return x2y_model
def _train_runner(data_parameters,
training_parameters,
network_parameters,
misc_parameters,
optimizer = torch.optim.Adam,
loss_function = torch.nn.MSELoss(),
):
"""Wrapper for the training operation
This function wraps the training operation for the network
......@@ -128,7 +171,11 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
misc_parameters (dict): Dictionary of aditional hyperparameters
"""
train_data, validation_data = load_data(data_parameters)
train_data, validation_data = load_data(data_parameters,
cross_domain_x2x_flag = network_parameters['cross_domain_x2x_flag'],
cross_domain_y2y_flag = network_parameters['cross_domain_y2y_flag']
)
train_loader = data.DataLoader(
dataset=train_data,
......@@ -145,19 +192,20 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
)
if training_parameters['use_pre_trained']:
BrainMapperModel = torch.load(
training_parameters['pre_trained_path'])
BrainMapperModel = torch.load(training_parameters['pre_trained_path'])
else:
BrainMapperModel = BrainMapperAE3D(network_parameters)
# BrainMapperModel = BrainMapperAE3D(network_parameters)
BrainMapperModel = AutoEncoder3D(network_parameters) # temprorary change for testing encoder-decoder effective receptive field
custom_weight_reset_flag = network_parameters['custom_weight_reset_flag']
BrainMapperModel.reset_parameters(custom_weight_reset_flag)
if training_parameters['adam_w_flag'] == True:
optimizer = torch.optim.AdamW
else:
optimizer = torch.optim.Adam
if network_parameters['cross_domain_x2y_flag'] == True:
BrainMapperModel = _load_pretrained_cross_domain(x2y_model=BrainMapperModel,
save_model_directory=misc_parameters['save_model_directory'],
experiment_name=training_parameters['experiment_name']
)
solver = Solver(model=BrainMapperModel,
device=misc_parameters['device'],
......@@ -169,6 +217,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
'eps': training_parameters['optimizer_epsilon'],
'weight_decay': training_parameters['optimizer_weigth_decay']
},
loss_function=loss_function,
model_name=training_parameters['experiment_name'],
number_epochs=training_parameters['number_of_epochs'],
loss_log_period=training_parameters['loss_log_period'],
......@@ -180,19 +229,79 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
logs_directory=misc_parameters['logs_directory'],
checkpoint_directory=misc_parameters['checkpoint_directory'],
save_model_directory=misc_parameters['save_model_directory'],
final_model_output_file=training_parameters['final_model_output_file'],
crop_flag = data_parameters['crop_flag']
)
validation_loss = solver.train(train_loader, validation_loader)
# _ = solver.train(train_loader, validation_loader)
solver.train(train_loader, validation_loader)
del train_data, validation_data, train_loader, validation_loader, BrainMapperModel, solver, optimizer
torch.cuda.empty_cache()
return validation_loss
# return None
_ = _train_runner(data_parameters, training_parameters, network_parameters, misc_parameters)
if training_parameters['adam_w_flag'] == True:
optimizer = torch.optim.AdamW
else:
optimizer = torch.optim.Adam
loss_function = torch.nn.MSELoss()
# loss_function=torch.nn.L1Loss()
# loss_function=torch.nn.CosineEmbeddingLoss()
if network_parameters['cross_domain_flag'] == False:
_train_runner(data_parameters,
training_parameters,
network_parameters,
misc_parameters,
optimizer=optimizer,
loss_function=loss_function
)
elif network_parameters['cross_domain_flag'] == True:
if network_parameters['cross_domain_x2x_flag'] == True:
training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_x2x"
data_parameters['target_data_train'] = data_parameters['input_data_train']
data_parameters['target_data_validation'] = data_parameters['input_data_validation']
# loss_function = torch.nn.L1Loss()
_train_runner(data_parameters,
training_parameters,
network_parameters,
misc_parameters,
optimizer=optimizer,
loss_function=loss_function
)
if network_parameters['cross_domain_y2y_flag'] == True:
training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_y2y"
data_parameters['input_data_train'] = data_parameters['target_data_train']
data_parameters['input_data_validation'] = data_parameters['target_data_validation']
# loss_function = torch.nn.L1Loss()
_train_runner(data_parameters,
training_parameters,
network_parameters,
misc_parameters,
optimizer=optimizer,
loss_function=loss_function
)
if network_parameters['cross_domain_x2y_flag'] == True:
_train_runner(data_parameters,
training_parameters,
network_parameters,
misc_parameters,
optimizer=optimizer,
loss_function=loss_function
)
def evaluate_mapping(mapping_evaluation_parameters):
......@@ -216,7 +325,10 @@ def evaluate_mapping(mapping_evaluation_parameters):
trained_model_path = mapping_evaluation_parameters['trained_model_path']
data_directory = mapping_evaluation_parameters['data_directory']
mapping_data_file = mapping_evaluation_parameters['mapping_data_file']
data_list = mapping_evaluation_parameters['data_list']
mapping_targets_file = mapping_evaluation_parameters['mapping_targets_file']
data_list = mapping_evaluation_parameters['data_list_reduced']
prediction_output_path = mapping_evaluation_parameters['prediction_output_path']
dmri_mean_mask_path = mapping_evaluation_parameters['dmri_mean_mask_path']
rsfmri_mean_mask_path = mapping_evaluation_parameters['rsfmri_mean_mask_path']
......@@ -235,12 +347,97 @@ def evaluate_mapping(mapping_evaluation_parameters):
shrinkage_flag = mapping_evaluation_parameters['shrinkage_flag']
hard_shrinkage_flag = mapping_evaluation_parameters['hard_shrinkage_flag']
crop_flag = mapping_evaluation_parameters['crop_flag']
cross_domain_x2x_flag = mapping_evaluation_parameters['cross_domain_x2x_flag']
cross_domain_y2y_flag = mapping_evaluation_parameters['cross_domain_y2y_flag']
evaluations.evaluate_mapping(trained_model_path,
data_directory,
mapping_data_file,
mapping_targets_file,
data_list,
prediction_output_path,
brain_mask_path,
dmri_mean_mask_path,
rsfmri_mean_mask_path,
regression_factors,
mean_regression_flag,
mean_regression_all_flag,
mean_subtraction_flag,