Commit 7da22503 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

Merge branch 'CrossDomain_autoencoder' into 'master'

Cross domain autoencoder

See merge request !4
parents aa74b23e 5e0551e3
......@@ -118,9 +118,16 @@ dmypy.json
.vscode/
datasets/
files.txt
jobscript.sge.sh
*.sge.sh
*.nii.gz
stuff/
test/*
.DS_Store
logs/
*.ini
experimentInputs/
experiments/
predictions/
*.sh
saved_models/
mock_job.py
......@@ -29,7 +29,7 @@ class BrainMapperAE3D(nn.Module):
Args:
parameters (dict): Contains information relevant parameters
parameters = {
'kernel_heigth': 5
'kernel_size': 5
'kernel_width': 5
'kernel_depth': 5
'kernel_classification': 1
......@@ -50,44 +50,61 @@ class BrainMapperAE3D(nn.Module):
def __init__(self, parameters):
super(BrainMapperAE3D, self).__init__()
self.cross_domain_x2y_flag = parameters['cross_domain_x2y_flag']
original_input_channels = parameters['input_channels']
original_output_channels = parameters['output_channels']
original_kernel_height = parameters['kernel_heigth']
original_kernel_size = parameters['kernel_size']
original_stride = parameters['convolution_stride']
# Encoder Path
parameters['kernel_heigth'] = 7
self.encoderBlock1 = modules.ResNetEncoderBlock3D(parameters)
parameters['kernel_size'] = parameters['first_kernel_size']
parameters['convolution_stride'] = parameters['first_convolution_stride']
self.encoderBlocks = nn.ModuleList([modules.ResNetEncoderBlock3D(parameters)])
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
parameters['kernel_heigth'] = original_kernel_height
parameters['convolution_stride'] = 2
self.encoderBlock2 = modules.ResNetEncoderBlock3D(parameters)
parameters['kernel_size'] = original_kernel_size
parameters['convolution_stride'] = original_stride
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
self.encoderBlock3 = modules.ResNetEncoderBlock3D(parameters)
equal_channels_blocks = 0
for _ in range(parameters['number_of_encoder_blocks']):
if parameters['output_channels'] < parameters['max_number_channels']:
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
else:
parameters['input_channels'] = parameters['output_channels']
equal_channels_blocks += 1
self.encoderBlocks.append(modules.ResNetEncoderBlock3D(parameters))
# Transformer
parameters['input_channels'] = parameters['output_channels']
parameters['convolution_stride'] = original_stride
parameters['convolution_stride'] = parameters['transformer_blocks_stride']
self.transformerBlocks = nn.ModuleList([modules.ResNetBlock3D(parameters) for i in range(parameters['number_of_transformer_blocks'])])
if self.cross_domain_x2y_flag == True:
self.featureMappingLayers = nn.ModuleList([modules.ResNetFeatureMappingBlock3D(parameters) for i in range(parameters['number_of_feature_mapping_blocks'])])
# Decoder
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock1 = modules.ResNetDecoderBlock3D(parameters)
if equal_channels_blocks != 0:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters) for i in range(equal_channels_blocks)])
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock2 = modules.ResNetDecoderBlock3D(parameters)
if equal_channels_blocks != 0:
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
else:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters)])
for _ in range(parameters['number_of_encoder_blocks'] - equal_channels_blocks - 1):
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
parameters['input_channels'] = parameters['output_channels']
self.decoderBlock3 = modules.ResNetClassifierBlock3D(parameters)
self.decoderBlocks.append(modules.ResNetClassifierBlock3D(parameters))
parameters['input_channels'] = original_input_channels
parameters['output_channels'] = original_output_channels
......@@ -107,24 +124,40 @@ class BrainMapperAE3D(nn.Module):
# Encoder
X = self.encoderBlock1.forward(X)
Y_encoder_1_size = X.size()
X = self.encoderBlock2.forward(X)
Y_encoder_2_size = X.size()
X = self.encoderBlock3.forward(X)
Y_encoder_sizes = []
for encoderBlock in self.encoderBlocks:
X = encoderBlock.forward(X)
Y_encoder_sizes.append(X.size())
Y_encoder_sizes = Y_encoder_sizes[:-1][::-1]
Y_encoder_sizes_lenght = len(Y_encoder_sizes)
# Transformer
for transformerBlock in self.transformerBlocks:
X = transformerBlock(X)
if self.cross_domain_x2y_flag == True:
for transformerBlock in self.transformerBlocks[:len(self.transformerBlocks)//2]:
X = transformerBlock(X)
for featureMappingLayer in self.featureMappingLayers:
X = featureMappingLayer(X)
for transformerBlock in self.transformerBlocks[len(self.transformerBlocks)//2:]:
X = transformerBlock(X)
else:
for transformerBlock in self.transformerBlocks:
X = transformerBlock(X)
# Decoder
X = self.decoderBlock1.forward(X, Y_encoder_2_size)
del Y_encoder_2_size
X = self.decoderBlock2.forward(X, Y_encoder_1_size)
del Y_encoder_1_size
X = self.decoderBlock3.forward(X)
for index, decoderBlock in enumerate(self.decoderBlocks):
if index < Y_encoder_sizes_lenght:
X = decoderBlock.forward(X, Y_encoder_sizes[index])
else:
X = decoderBlock.forward(X)
del Y_encoder_sizes, Y_encoder_sizes_lenght
return X
......@@ -226,3 +259,220 @@ class BrainMapperAE3D(nn.Module):
subsubmodule.weight.data.normal_(0, std)
print("Initialized network parameters!")
class AutoEncoder3D(nn.Module):
"""Architecture class for CycleGAN inspired BrainMapper 3D Autoencoder.
This class contains the pytorch implementation of the generator architecture underpinning the BrainMapper project.
Args:
parameters (dict): Contains information relevant parameters
parameters = {
'kernel_size': 5
'kernel_width': 5
'kernel_depth': 5
'kernel_classification': 1
'input_channels': 1
'output_channels': 64
'convolution_stride': 1
'dropout': 0.2
'pool_kernel_size': 2
'pool_stride': 2
'up_mode': 'upconv'
'number_of_classes': 1
}
Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block
"""
def __init__(self, parameters):
super(AutoEncoder3D, self).__init__()
self.cross_domain_x2y_flag = parameters['cross_domain_x2y_flag']
original_input_channels = parameters['input_channels']
original_output_channels = parameters['output_channels']
original_kernel_size = parameters['kernel_size']
original_stride = parameters['convolution_stride']
# Encoder Path
parameters['kernel_size'] = parameters['first_kernel_size']
parameters['convolution_stride'] = parameters['first_convolution_stride']
self.encoderBlocks = nn.ModuleList([modules.ResNetEncoderBlock3D(parameters)])
parameters['kernel_size'] = original_kernel_size
parameters['convolution_stride'] = original_stride
equal_channels_blocks = 0
for _ in range(parameters['number_of_encoder_blocks']):
if parameters['output_channels'] < parameters['max_number_channels']:
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
else:
parameters['input_channels'] = parameters['output_channels']
equal_channels_blocks += 1
self.encoderBlocks.append(modules.ResNetEncoderBlock3D(parameters))
# Decoder
parameters['input_channels'] = parameters['output_channels']
parameters['convolution_stride'] = parameters['transformer_blocks_stride']
if equal_channels_blocks != 0:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters) for i in range(equal_channels_blocks)])
parameters['output_channels'] = parameters['output_channels'] // 2
if equal_channels_blocks != 0:
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
else:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters)])
for _ in range(parameters['number_of_encoder_blocks'] - equal_channels_blocks - 1):
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlocks.append(modules.ResNetDecoderBlock3D(parameters))
parameters['input_channels'] = parameters['output_channels']
self.decoderBlocks.append(modules.ResNetClassifierBlock3D(parameters))
parameters['input_channels'] = original_input_channels
parameters['output_channels'] = original_output_channels
def forward(self, X):
"""Forward pass for 3D CGAN Autoencoder
Function computing the forward pass through the 3D generator
The input to the function is the dMRI map
Args:
X (torch.tensor): Input dMRI map, shape = (N x C x D x H x W)
Returns:
probability_map (torch.tensor): Output forward passed tensor through the CGAN Autoencoder
"""
# Encoder
Y_encoder_sizes = []
for encoderBlock in self.encoderBlocks:
X = encoderBlock.forward(X)
Y_encoder_sizes.append(X.size())
Y_encoder_sizes = Y_encoder_sizes[:-1][::-1]
Y_encoder_sizes_lenght = len(Y_encoder_sizes)
# Decoder
for index, decoderBlock in enumerate(self.decoderBlocks):
if index < Y_encoder_sizes_lenght:
X = decoderBlock.forward(X, Y_encoder_sizes[index])
else:
X = decoderBlock.forward(X)
del Y_encoder_sizes, Y_encoder_sizes_lenght
return X
def save(self, path):
"""Model Saver
Function saving the model with all its parameters to a given path.
The path must end with a *.model argument.
Args:
path (str): Path string
"""
print("Saving Model... {}".format(path))
torch.save(self, path)
@property
def test_if_cuda(self):
"""Cuda Test
This function tests if the model parameters are allocated to a CUDA enabled GPU.
Returns:
bool: Flag indicating True if the tensor is stored on the GPU and Flase otherwhise
"""
return next(self.parameters()).is_cuda
def predict(self, X, device=0):
"""Post-training Output Prediction
This function predicts the output of the of the U-net post-training
Args:
X (torch.tensor): input dMRI volume
device (int/str): Device type used for training (int - GPU id, str- CPU)
Returns:
prediction (ndarray): predicted output after training
"""
self.eval() # PyToch module setting network to evaluation mode
if type(X) is np.ndarray:
X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor)
elif type(X) is torch.Tensor and not X.is_cuda:
X = X.type(torch.FloatTensor).cuda(device, non_blocking=True)
# .cuda() call transfers the densor from the CPU to the GPU if that is the case.
# Non-blocking argument lets the caller bypas synchronization when necessary
with torch.no_grad(): # Causes operations to have no gradients
output = self.forward(X)
_, idx = torch.max(output, 1)
# We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray
idx = idx.data.cpu().numpy()
prediction = np.squeeze(idx)
del X, output, idx
return prediction
def reset_parameters(self, custom_weight_reset_flag):
"""Parameter Initialization
This function (re)initializes the parameters of the defined network.
This function is a wrapper for the reset_parameters() function defined for each module.
More information can be found here: https://discuss.pytorch.org/t/what-is-the-default-initialization-of-a-conv2d-layer-and-linear-layer/16055 + https://discuss.pytorch.org/t/how-to-reset-model-weights-to-effectively-implement-crossvalidation/53859
An alternative (re)initialization method is described here: https://discuss.pytorch.org/t/how-to-reset-variables-values-in-nn-modules/32639
Args:
custom_weight_reset_flag (bool): Flag indicating if the modified weight initialisation approach should be used.
"""
print("Initializing network parameters...")
for _, module in self.named_children():
for _, submodule in module.named_children():
if isinstance(submodule, (torch.nn.ConvTranspose3d, torch.nn.Conv3d, torch.nn.InstanceNorm3d)) == True:
submodule.reset_parameters()
if custom_weight_reset_flag == True:
if isinstance(submodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
fan, _ = calculate_fan(submodule.weight)
std = np.divide(gain, np.sqrt(fan))
submodule.weight.data.normal_(0, std)
for _, subsubmodule in submodule.named_children():
if isinstance(subsubmodule, (torch.nn.ConvTranspose3d, torch.nn.Conv3d, torch.nn.InstanceNorm3d)) == True:
subsubmodule.reset_parameters()
if custom_weight_reset_flag == True:
if isinstance(subsubmodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
fan, _ = calculate_fan(subsubmodule.weight)
std = np.divide(gain, np.sqrt(fan))
subsubmodule.weight.data.normal_(0, std)
print("Initialized network parameters!")
\ No newline at end of file
This diff is collapsed.
......@@ -4,12 +4,11 @@ input_data_train = "input_data_train.h5"
target_data_train = "target_data_train.h5"
input_data_validation = "input_data_validation.h5"
target_data_validation = "target_data_validation.h5"
crop_flag = False
crop_flag = True
[TRAINING]
experiment_name = "VA2-1"
pre_trained_path = "saved_models/VA2-1.pth.tar"
final_model_output_file = "VA2-1.pth.tar"
training_batch_size = 3
validation_batch_size = 3
use_pre_trained = False
......@@ -25,21 +24,28 @@ use_last_checkpoint = False
adam_w_flag = False
[NETWORK]
kernel_heigth = 3
kernel_width = 3
kernel_depth = 3
kernel_classification = 7
first_kernel_size = 9
first_convolution_stride = 1
input_channels = 1
output_channels = 32
convolution_stride = 1
dropout = 0
max_number_channels = 128
kernel_size = 3
convolution_stride = 2
number_of_encoder_blocks = 2
transformer_blocks_stride = 1
number_of_transformer_blocks = 6
kernel_classification = 9
pool_kernel_size = 3
pool_stride = 2
up_mode = "upconv"
dropout = 0
final_activation = 'tanh'
number_of_classes = 1
number_of_transformer_blocks = 6
custom_weight_reset_flag = False
cross_domain_flag = False
cross_domain_x2x_flag = False
cross_domain_y2y_flag = False
cross_domain_x2y_flag = False
number_of_feature_mapping_blocks = 1
[MISC]
save_model_directory = "saved_models"
......
[MAPPING]
trained_model_path = "saved_models/VA2-1.pth.tar"
prediction_output_path = "VA2-1_predictions"
prediction_output_database_name = "output_test_data.h5"
prediction_output_statistics_name = "output_statistics.csv"
data_directory = "/well/win-biobank/projects/imaging/data/data3/subjectsAll/"
mapping_data_file = "dMRI/autoptx_preproc/tractsNormSummed.nii.gz"
data_list = "datasets/test.txt"
mapping_targets_file = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
data_list_reduced = "datasets/test_reduced.txt"
data_list_all = "datasets/test_all.txt"
evaluate_all_data = False
output_database_flag = False
brain_mask_path = "utils/MNI152_T1_2mm_brain_mask.nii.gz"
rsfmri_mean_mask_path = "utils/mean_dr_stage2.nii.gz"
dmri_mean_mask_path = "utils/mean_tractsNormSummed_downsampled.nii.gz"
......@@ -21,4 +27,6 @@ shrinkage_flag = False
hard_shrinkage_flag = False
crop_flag = True
device = 0
exit_on_error = True
\ No newline at end of file
exit_on_error = True
cross_domain_x2x_flag = False
cross_domain_y2y_flag = False
\ No newline at end of file
......@@ -9,7 +9,6 @@ setup(
maintainer_email='andrei-claudiu.roibu@dtc.ox.ac.uk',
install_requires=[
'pip',
'matplotlib',
'numpy',
'pandas',
'torch==1.4',
......
......@@ -19,7 +19,6 @@ import glob
from fsl.data.image import Image
from fsl.utils.image.roi import roi
from datetime import datetime
from utils.losses import MSELoss
from utils.common_utils import create_folder
from utils.data_logging_utils import LogWriter
from utils.early_stopping import EarlyStopping
......@@ -64,9 +63,7 @@ class Solver():
experiment_name,
optimizer,
optimizer_arguments={},
loss_function=MSELoss(),
# loss_function=torch.nn.L1Loss(),
# loss_function=torch.nn.CosineEmbeddingLoss(),
loss_function=torch.nn.MSELoss(),
model_name='BrainMapper',
labels=None,
number_epochs=10,
......@@ -78,7 +75,6 @@ class Solver():
logs_directory='logs',
checkpoint_directory='checkpoints',
save_model_directory='saved_models',
final_model_output_file='finetuned_alldata.pth.tar',
crop_flag = False
):
......@@ -88,10 +84,10 @@ class Solver():
if torch.cuda.is_available():
self.loss_function = loss_function.cuda(device)
self.MSE = MSELoss().cuda(device)
self.MSE = torch.nn.MSELoss().cuda(device)
else:
self.loss_function = loss_function
self.MSE = MSELoss()
self.MSE = torch.nn.MSELoss()
self.model_name = model_name
self.labels = labels
......@@ -125,7 +121,6 @@ class Solver():
use_last_checkpoint=use_last_checkpoint,
labels=labels)
self.EarlyStopping = EarlyStopping(patience=10, min_delta=0)
self.early_stop = False
if crop_flag == False:
......@@ -134,10 +129,19 @@ class Solver():
self.MNI152_T1_2mm_brain_mask = torch.from_numpy(roi(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz'),((9,81),(10,100),(0,77))).data)
self.save_model_directory = save_model_directory
self.final_model_output_file = final_model_output_file
self.final_model_output_file = experiment_name + ".pth.tar"
self.best_score_early_stop = None
self.counter_early_stop = 0
self.previous_loss = None
self.previous_MSE = None
self.valid_epoch = None
if use_last_checkpoint:
self.load_checkpoint()
self.EarlyStopping = EarlyStopping(patience=5, min_delta=0, best_score=self.best_score_early_stop, counter=self.counter_early_stop)
else:
self.EarlyStopping = EarlyStopping(patience=5, min_delta=0)
def train(self, train_loader, validation_loader):
"""Training Function
......@@ -159,10 +163,6 @@ class Solver():
torch.cuda.empty_cache() # clear memory
model.cuda(self.device) # Moving the model to GPU
previous_checkpoint = None
previous_loss = None
previous_MSE = None
print('****************************************************************')
print('TRAINING IS STARTING!')
print('=====================')
......@@ -179,6 +179,11 @@ class Solver():
iteration = self.start_iteration
for epoch in range(self.start_epoch, self.number_epochs+1):
if self.early_stop == True:
print("ATTENTION!: Training stopped due to previous early stop flag!")
break
print("Epoch {}/{}".format(epoch, self.number_epochs))
for phase in ['train', 'validation']:
......@@ -253,36 +258,40 @@ class Solver():
self.LogWriter.MSE_per_epoch(MSEs, phase, epoch)
elif phase == 'validation':
self.LogWriter.loss_per_epoch(
losses, phase, epoch, previous_loss=previous_loss)
previous_loss = np.mean(losses)
losses, phase, epoch, previous_loss=self.previous_loss)
self.previous_loss = np.mean(losses)
self.LogWriter.MSE_per_epoch(
MSEs, phase, epoch, previous_loss=previous_MSE)
previous_MSE = np.mean(MSEs)
MSEs, phase, epoch, previous_loss=self.previous_MSE)
self.previous_MSE = np.mean(MSEs)
if phase == 'validation':
early_stop, save_checkpoint = self.EarlyStopping(
np.mean(losses))
self.early_stop = early_stop
if save_checkpoint == True:
validation_loss = np.mean(losses)
checkpoint_name = os.path.join(
self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
self.save_checkpoint(state={'epoch': epoch + 1,
'start_iteration': iteration + 1,
'arch': self.model_name,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': learning_rate_scheduler.state_dict()
},
filename=checkpoint_name
)
if previous_checkpoint != None:
os.remove(previous_checkpoint)
previous_checkpoint = checkpoint_name
else:
previous_checkpoint = checkpoint_name
early_stop, best_score_early_stop, counter_early_stop = self.EarlyStopping(np.mean(losses))
self.early_stop = early_stop
self.best_score_early_stop = best_score_early_stop
self.counter_early_stop = counter_early_stop
checkpoint_name = os.path.join(
self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
if self.counter_early_stop == 0:
self.valid_epoch = epoch
self.save_checkpoint(state={'epoch': epoch + 1,
'start_iteration': iteration + 1,
'arch': self.model_name,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': learning_rate_scheduler.state_dict(),
'best_score_early_stop': self.best_score_early_stop,
'counter_early_stop': self.counter_early_stop,
'previous_loss': self.previous_loss,
'previous_MSE': self.previous_MSE,
'early_stop': self.early_stop,
'valid_epoch': self.valid_epoch
},
filename=checkpoint_name