Commit 7da22503 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

Merge branch 'CrossDomain_autoencoder' into 'master'

Cross domain autoencoder

See merge request !4
parents aa74b23e 5e0551e3
...@@ -118,9 +118,16 @@ dmypy.json ...@@ -118,9 +118,16 @@ dmypy.json
.vscode/ .vscode/
datasets/ datasets/
files.txt files.txt
jobscript.sge.sh *.sge.sh
*.nii.gz *.nii.gz
stuff/ stuff/
test/* test/*
.DS_Store .DS_Store
logs/ logs/
*.ini
experimentInputs/
experiments/
predictions/
*.sh
saved_models/
mock_job.py
...@@ -29,7 +29,7 @@ class BrainMapperAE3D(nn.Module): ...@@ -29,7 +29,7 @@ class BrainMapperAE3D(nn.Module):
Args: Args:
parameters (dict): Contains information relevant parameters parameters (dict): Contains information relevant parameters
parameters = { parameters = {
'kernel_heigth': 5 'kernel_size': 5
'kernel_width': 5 'kernel_width': 5
'kernel_depth': 5 'kernel_depth': 5
'kernel_classification': 1 'kernel_classification': 1
...@@ -50,44 +50,61 @@ class BrainMapperAE3D(nn.Module): ...@@ -50,44 +50,61 @@ class BrainMapperAE3D(nn.Module):
def __init__(self, parameters): def __init__(self, parameters):
super(BrainMapperAE3D, self).__init__() super(BrainMapperAE3D, self).__init__()
self.cross_domain_x2y_flag = parameters['cross_domain_x2y_flag']
original_input_channels = parameters['input_channels'] original_input_channels = parameters['input_channels']
original_output_channels = parameters['output_channels'] original_output_channels = parameters['output_channels']
original_kernel_height = parameters['kernel_heigth'] original_kernel_size = parameters['kernel_size']
original_stride = parameters['convolution_stride'] original_stride = parameters['convolution_stride']
# Encoder Path # Encoder Path
parameters['kernel_heigth'] = 7 parameters['kernel_size'] = parameters['first_kernel_size']
self.encoderBlock1 = modules.ResNetEncoderBlock3D(parameters) parameters['convolution_stride'] = parameters['first_convolution_stride']
self.encoderBlocks = nn.ModuleList([modules.ResNetEncoderBlock3D(parameters)])
parameters['input_channels'] = parameters['output_channels'] parameters['kernel_size'] = original_kernel_size
parameters['output_channels'] = parameters['output_channels'] * 2 parameters['convolution_stride'] = original_stride
parameters['kernel_heigth'] = original_kernel_height
parameters['convolution_stride'] = 2
self.encoderBlock2 = modules.ResNetEncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] equal_channels_blocks = 0
parameters['output_channels'] = parameters['output_channels'] * 2
self.encoderBlock3 = modules.ResNetEncoderBlock3D(parameters) for _ in range(parameters['number_of_encoder_blocks']):
if parameters['output_channels'] < parameters['max_number_channels']:
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
else:
parameters['input_channels'] = parameters['output_channels']
equal_channels_blocks += 1
self.encoderBlocks.append(modules.ResNetEncoderBlock3D(parameters))
# Transformer # Transformer
parameters['input_channels'] = parameters['output_channels'] parameters['input_channels'] = parameters['output_channels']
parameters['convolution_stride'] = original_stride parameters['convolution_stride'] = parameters['transformer_blocks_stride']
self.transformerBlocks = nn.ModuleList([modules.ResNetBlock3D(parameters) for i in range(parameters['number_of_transformer_blocks'])]) self.transformerBlocks = nn.ModuleList([modules.ResNetBlock3D(parameters) for i in range(parameters['number_of_transformer_blocks'])])
if self.cross_domain_x2y_flag == True:
self.featureMappingLayers = nn.ModuleList([modules.ResNetFeatureMappingBlock3D(parameters) for i in range(parameters['number_of_feature_mapping_blocks'])])
# Decoder # Decoder
parameters['output_channels'] = parameters['output_channels'] // 2 if equal_channels_blocks != 0:
self.decoderBlock1 = modules.ResNetDecoderBlock3D(parameters) self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters) for i in range(equal_channels_blocks)])
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2 parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock2 = modules.ResNetDecoderBlock3D(parameters) if equal_channels_blocks != 0:
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
else:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters)])
for _ in range(parameters['number_of_encoder_blocks'] - equal_channels_blocks - 1):
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
parameters['input_channels'] = parameters['output_channels'] parameters['input_channels'] = parameters['output_channels']
self.decoderBlock3 = modules.ResNetClassifierBlock3D(parameters) self.decoderBlocks.append(modules.ResNetClassifierBlock3D(parameters))
parameters['input_channels'] = original_input_channels parameters['input_channels'] = original_input_channels
parameters['output_channels'] = original_output_channels parameters['output_channels'] = original_output_channels
...@@ -107,24 +124,40 @@ class BrainMapperAE3D(nn.Module): ...@@ -107,24 +124,40 @@ class BrainMapperAE3D(nn.Module):
# Encoder # Encoder
X = self.encoderBlock1.forward(X) Y_encoder_sizes = []
Y_encoder_1_size = X.size()
X = self.encoderBlock2.forward(X) for encoderBlock in self.encoderBlocks:
Y_encoder_2_size = X.size() X = encoderBlock.forward(X)
X = self.encoderBlock3.forward(X) Y_encoder_sizes.append(X.size())
Y_encoder_sizes = Y_encoder_sizes[:-1][::-1]
Y_encoder_sizes_lenght = len(Y_encoder_sizes)
# Transformer # Transformer
for transformerBlock in self.transformerBlocks: if self.cross_domain_x2y_flag == True:
X = transformerBlock(X) for transformerBlock in self.transformerBlocks[:len(self.transformerBlocks)//2]:
X = transformerBlock(X)
for featureMappingLayer in self.featureMappingLayers:
X = featureMappingLayer(X)
for transformerBlock in self.transformerBlocks[len(self.transformerBlocks)//2:]:
X = transformerBlock(X)
else:
for transformerBlock in self.transformerBlocks:
X = transformerBlock(X)
# Decoder # Decoder
X = self.decoderBlock1.forward(X, Y_encoder_2_size) for index, decoderBlock in enumerate(self.decoderBlocks):
del Y_encoder_2_size if index < Y_encoder_sizes_lenght:
X = self.decoderBlock2.forward(X, Y_encoder_1_size) X = decoderBlock.forward(X, Y_encoder_sizes[index])
del Y_encoder_1_size else:
X = self.decoderBlock3.forward(X) X = decoderBlock.forward(X)
del Y_encoder_sizes, Y_encoder_sizes_lenght
return X return X
...@@ -226,3 +259,220 @@ class BrainMapperAE3D(nn.Module): ...@@ -226,3 +259,220 @@ class BrainMapperAE3D(nn.Module):
subsubmodule.weight.data.normal_(0, std) subsubmodule.weight.data.normal_(0, std)
print("Initialized network parameters!") print("Initialized network parameters!")
class AutoEncoder3D(nn.Module):
"""Architecture class for CycleGAN inspired BrainMapper 3D Autoencoder.
This class contains the pytorch implementation of the generator architecture underpinning the BrainMapper project.
Args:
parameters (dict): Contains information relevant parameters
parameters = {
'kernel_size': 5
'kernel_width': 5
'kernel_depth': 5
'kernel_classification': 1
'input_channels': 1
'output_channels': 64
'convolution_stride': 1
'dropout': 0.2
'pool_kernel_size': 2
'pool_stride': 2
'up_mode': 'upconv'
'number_of_classes': 1
}
Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block
"""
def __init__(self, parameters):
super(AutoEncoder3D, self).__init__()
self.cross_domain_x2y_flag = parameters['cross_domain_x2y_flag']
original_input_channels = parameters['input_channels']
original_output_channels = parameters['output_channels']
original_kernel_size = parameters['kernel_size']
original_stride = parameters['convolution_stride']
# Encoder Path
parameters['kernel_size'] = parameters['first_kernel_size']
parameters['convolution_stride'] = parameters['first_convolution_stride']
self.encoderBlocks = nn.ModuleList([modules.ResNetEncoderBlock3D(parameters)])
parameters['kernel_size'] = original_kernel_size
parameters['convolution_stride'] = original_stride
equal_channels_blocks = 0
for _ in range(parameters['number_of_encoder_blocks']):
if parameters['output_channels'] < parameters['max_number_channels']:
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
else:
parameters['input_channels'] = parameters['output_channels']
equal_channels_blocks += 1
self.encoderBlocks.append(modules.ResNetEncoderBlock3D(parameters))
# Decoder
parameters['input_channels'] = parameters['output_channels']
parameters['convolution_stride'] = parameters['transformer_blocks_stride']
if equal_channels_blocks != 0:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters) for i in range(equal_channels_blocks)])
parameters['output_channels'] = parameters['output_channels'] // 2
if equal_channels_blocks != 0:
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
else:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters)])
for _ in range(parameters['number_of_encoder_blocks'] - equal_channels_blocks - 1):
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
parameters['input_channels'] = parameters['output_channels']
self.decoderBlocks.append(modules.ResNetClassifierBlock3D(parameters))
parameters['input_channels'] = original_input_channels
parameters['output_channels'] = original_output_channels
def forward(self, X):
"""Forward pass for 3D CGAN Autoencoder
Function computing the forward pass through the 3D generator
The input to the function is the dMRI map
Args:
X (torch.tensor): Input dMRI map, shape = (N x C x D x H x W)
Returns:
probability_map (torch.tensor): Output forward passed tensor through the CGAN Autoencoder
"""
# Encoder
Y_encoder_sizes = []
for encoderBlock in self.encoderBlocks:
X = encoderBlock.forward(X)
Y_encoder_sizes.append(X.size())
Y_encoder_sizes = Y_encoder_sizes[:-1][::-1]
Y_encoder_sizes_lenght = len(Y_encoder_sizes)
# Decoder
for index, decoderBlock in enumerate(self.decoderBlocks):
if index < Y_encoder_sizes_lenght:
X = decoderBlock.forward(X, Y_encoder_sizes[index])
else:
X = decoderBlock.forward(X)
del Y_encoder_sizes, Y_encoder_sizes_lenght
return X
def save(self, path):
"""Model Saver
Function saving the model with all its parameters to a given path.
The path must end with a *.model argument.
Args:
path (str): Path string
"""
print("Saving Model... {}".format(path))
torch.save(self, path)
@property
def test_if_cuda(self):
"""Cuda Test
This function tests if the model parameters are allocated to a CUDA enabled GPU.
Returns:
bool: Flag indicating True if the tensor is stored on the GPU and Flase otherwhise
"""
return next(self.parameters()).is_cuda
def predict(self, X, device=0):
"""Post-training Output Prediction
This function predicts the output of the of the U-net post-training
Args:
X (torch.tensor): input dMRI volume
device (int/str): Device type used for training (int - GPU id, str- CPU)
Returns:
prediction (ndarray): predicted output after training
"""
self.eval() # PyToch module setting network to evaluation mode
if type(X) is np.ndarray:
X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor)
elif type(X) is torch.Tensor and not X.is_cuda:
X = X.type(torch.FloatTensor).cuda(device, non_blocking=True)
# .cuda() call transfers the densor from the CPU to the GPU if that is the case.
# Non-blocking argument lets the caller bypas synchronization when necessary
with torch.no_grad(): # Causes operations to have no gradients
output = self.forward(X)
_, idx = torch.max(output, 1)
# We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray
idx = idx.data.cpu().numpy()
prediction = np.squeeze(idx)
del X, output, idx
return prediction
def reset_parameters(self, custom_weight_reset_flag):
"""Parameter Initialization
This function (re)initializes the parameters of the defined network.
This function is a wrapper for the reset_parameters() function defined for each module.
More information can be found here: https://discuss.pytorch.org/t/what-is-the-default-initialization-of-a-conv2d-layer-and-linear-layer/16055 + https://discuss.pytorch.org/t/how-to-reset-model-weights-to-effectively-implement-crossvalidation/53859
An alternative (re)initialization method is described here: https://discuss.pytorch.org/t/how-to-reset-variables-values-in-nn-modules/32639
Args:
custom_weight_reset_flag (bool): Flag indicating if the modified weight initialisation approach should be used.
"""
print("Initializing network parameters...")
for _, module in self.named_children():
for _, submodule in module.named_children():
if isinstance(submodule, (torch.nn.ConvTranspose3d, torch.nn.Conv3d, torch.nn.InstanceNorm3d)) == True:
submodule.reset_parameters()
if custom_weight_reset_flag == True:
if isinstance(submodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
fan, _ = calculate_fan(submodule.weight)
std = np.divide(gain, np.sqrt(fan))
submodule.weight.data.normal_(0, std)
for _, subsubmodule in submodule.named_children():
if isinstance(subsubmodule, (torch.nn.ConvTranspose3d, torch.nn.Conv3d, torch.nn.InstanceNorm3d)) == True:
subsubmodule.reset_parameters()
if custom_weight_reset_flag == True:
if isinstance(subsubmodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
fan, _ = calculate_fan(subsubmodule.weight)
std = np.divide(gain, np.sqrt(fan))
subsubmodule.weight.data.normal_(0, std)
print("Initialized network parameters!")
\ No newline at end of file
This diff is collapsed.
...@@ -4,12 +4,11 @@ input_data_train = "input_data_train.h5" ...@@ -4,12 +4,11 @@ input_data_train = "input_data_train.h5"
target_data_train = "target_data_train.h5" target_data_train = "target_data_train.h5"
input_data_validation = "input_data_validation.h5" input_data_validation = "input_data_validation.h5"
target_data_validation = "target_data_validation.h5" target_data_validation = "target_data_validation.h5"
crop_flag = False crop_flag = True
[TRAINING] [TRAINING]
experiment_name = "VA2-1" experiment_name = "VA2-1"
pre_trained_path = "saved_models/VA2-1.pth.tar" pre_trained_path = "saved_models/VA2-1.pth.tar"
final_model_output_file = "VA2-1.pth.tar"
training_batch_size = 3 training_batch_size = 3
validation_batch_size = 3 validation_batch_size = 3
use_pre_trained = False use_pre_trained = False
...@@ -25,21 +24,28 @@ use_last_checkpoint = False ...@@ -25,21 +24,28 @@ use_last_checkpoint = False
adam_w_flag = False adam_w_flag = False
[NETWORK] [NETWORK]
kernel_heigth = 3 first_kernel_size = 9
kernel_width = 3 first_convolution_stride = 1
kernel_depth = 3
kernel_classification = 7
input_channels = 1 input_channels = 1
output_channels = 32 output_channels = 32
convolution_stride = 1 max_number_channels = 128
dropout = 0 kernel_size = 3
convolution_stride = 2
number_of_encoder_blocks = 2
transformer_blocks_stride = 1
number_of_transformer_blocks = 6
kernel_classification = 9
pool_kernel_size = 3 pool_kernel_size = 3
pool_stride = 2 pool_stride = 2
up_mode = "upconv" dropout = 0
final_activation = 'tanh' final_activation = 'tanh'
number_of_classes = 1 number_of_classes = 1
number_of_transformer_blocks = 6
custom_weight_reset_flag = False custom_weight_reset_flag = False
cross_domain_flag = False
cross_domain_x2x_flag = False
cross_domain_y2y_flag = False
cross_domain_x2y_flag = False
number_of_feature_mapping_blocks = 1
[MISC] [MISC]
save_model_directory = "saved_models" save_model_directory = "saved_models"
......
[MAPPING] [MAPPING]
trained_model_path = "saved_models/VA2-1.pth.tar" trained_model_path = "saved_models/VA2-1.pth.tar"
prediction_output_path = "VA2-1_predictions" prediction_output_path = "VA2-1_predictions"
prediction_output_database_name = "output_test_data.h5"
prediction_output_statistics_name = "output_statistics.csv"
data_directory = "/well/win-biobank/projects/imaging/data/data3/subjectsAll/" data_directory = "/well/win-biobank/projects/imaging/data/data3/subjectsAll/"
mapping_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz" mapping_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
data_list = "datasets/test.txt" mapping_targets_file = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
data_list_reduced = "datasets/test_reduced.txt"
data_list_all = "datasets/test_all.txt"
evaluate_all_data = False
output_database_flag = False
brain_mask_path = "utils/MNI152_T1_2mm_brain_mask.nii.gz" brain_mask_path = "utils/MNI152_T1_2mm_brain_mask.nii.gz"
rsfmri_mean_mask_path = "utils/mean_dr_stage2.nii.gz" rsfmri_mean_mask_path = "utils/mean_dr_stage2.nii.gz"
dmri_mean_mask_path = "utils/mean_tractsNormSummed_downsampled.nii.gz" dmri_mean_mask_path = "utils/mean_tractsNormSummed_downsampled.nii.gz"
...@@ -21,4 +27,6 @@ shrinkage_flag = False ...@@ -21,4 +27,6 @@ shrinkage_flag = False
hard_shrinkage_flag = False hard_shrinkage_flag = False
crop_flag = True crop_flag = True
device = 0 device = 0
exit_on_error = True exit_on_error = True
\ No newline at end of file cross_domain_x2x_flag = False
cross_domain_y2y_flag = False
\ No newline at end of file
...@@ -9,7 +9,6 @@ setup( ...@@ -9,7 +9,6 @@ setup(
maintainer_email='andrei-claudiu.roibu@dtc.ox.ac.uk', maintainer_email='andrei-claudiu.roibu@dtc.ox.ac.uk',
install_requires=[ install_requires=[
'pip', 'pip',
'matplotlib',
'numpy', 'numpy',
'pandas', 'pandas',
'torch==1.4', 'torch==1.4',
......
...@@ -19,7 +19,6 @@ import glob ...@@ -19,7 +19,6 @@ import glob
from fsl.data.image import Image from fsl.data.image import Image
from fsl.utils.image.roi import roi from fsl.utils.image.roi import roi
from datetime import datetime from datetime import datetime
from utils.losses import MSELoss
from utils.common_utils import create_folder from utils.common_utils import create_folder
from utils.data_logging_utils import LogWriter from utils.data_logging_utils import LogWriter
from utils.early_stopping import EarlyStopping from utils.early_stopping import EarlyStopping
...@@ -64,9 +63,7 @@ class Solver(): ...@@ -64,9 +63,7 @@ class Solver():
experiment_name, experiment_name,
optimizer, optimizer,
optimizer_arguments={}, optimizer_arguments={},
loss_function=MSELoss(), loss_function=torch.nn.MSELoss(),
# loss_function=torch.nn.L1Loss(),
# loss_function=torch.nn.CosineEmbeddingLoss(),
model_name='BrainMapper', model_name='BrainMapper',
labels=None, labels=None,
number_epochs=10, number_epochs=10,
...@@ -78,7 +75,6 @@ class Solver(): ...@@ -78,7 +75,6 @@ class Solver():
logs_directory='logs', logs_directory='logs',
checkpoint_directory='checkpoints', checkpoint_directory='checkpoints',
save_model_directory='saved_models', save_model_directory='saved_models',
final_model_output_file='finetuned_alldata.pth.tar',
crop_flag = False crop_flag = False
): ):
...@@ -88,10 +84,10 @@ class Solver(): ...@@ -88,10 +84,10 @@ class Solver():
if torch.cuda.is_available(): if torch.cuda.is_available():
self.loss_function = loss_function.cuda(device) self.loss_function = loss_function.cuda(device)
self.MSE = MSELoss().cuda(device) self.MSE = torch.nn.MSELoss().cuda(device)
else: else:
self.loss_function = loss_function self.loss_function = loss_function
self.MSE = MSELoss() self.MSE = torch.nn.MSELoss()
self.model_name = model_name self.model_name = model_name
self.labels = labels self.labels = labels
...@@ -125,7 +121,6 @@ class Solver(): ...@@ -125,7 +121,6 @@ class Solver():
use_last_checkpoint=use_last_checkpoint, use_last_checkpoint=use_last_checkpoint,
labels=labels) labels=labels)
self.EarlyStopping = EarlyStopping(patience=10, min_delta=0)
self.early_stop = False self.early_stop = False
if crop_flag == False: if crop_flag == False:
...@@ -134,10 +129,19 @@ class Solver(): ...@@ -134,10 +129,19 @@ class Solver():
self.MNI152_T1_2mm_brain_mask = torch.from_numpy(roi(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz'),((9,81),(10,100),(0,77))).data) self.MNI152_T1_2mm_brain_mask = torch.from_numpy(roi(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz'),((9,81),(10,100),(0,77))).data)
self.save_model_directory = save_model_directory self.save_model_directory = save_model_directory
self.final_model_output_file = final_model_output_file self.final_model_output_file = experiment_name + ".pth.tar"
self.best_score_early_stop = None
self.counter_early_stop = 0
self.previous_loss = None
self.previous_MSE = None
self.valid_epoch = None