Commit 40295fb1 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

added modifications required for cross-domain training

parent 41956de8
......@@ -50,6 +50,8 @@ class BrainMapperAE3D(nn.Module):
def __init__(self, parameters):
super(BrainMapperAE3D, self).__init__()
self.cross_domain_x2y_flag = parameters['cross_domain_x2y_flag']
original_input_channels = parameters['input_channels']
original_output_channels = parameters['output_channels']
original_kernel_height = parameters['kernel_heigth']
......@@ -77,6 +79,9 @@ class BrainMapperAE3D(nn.Module):
self.transformerBlocks = nn.ModuleList([modules.ResNetBlock3D(parameters) for i in range(parameters['number_of_transformer_blocks'])])
if self.cross_domain_x2y_flag == True:
self.featureMappingLayers = nn.ModuleList([modules.ResNetFeatureMappingBlock3D(parameters) for i in range(parameters['number_of_feature_mapping_blocks'])])
# Decoder
parameters['output_channels'] = parameters['output_channels'] // 2
......@@ -115,8 +120,19 @@ class BrainMapperAE3D(nn.Module):
# Transformer
for transformerBlock in self.transformerBlocks:
X = transformerBlock(X)
if self.cross_domain_x2y_flag == True:
for transformerBlock in self.transformerBlocks[:len(self.transformerBlocks)//2]:
X = transformerBlock(X)
for featureMappingLayer in self.featureMappingLayers:
X = featureMappingLayer(X)
for transformerBlock in self.transformerBlocks[len(self.transformerBlocks)//2:]:
X = transformerBlock(X)
else:
for transformerBlock in self.transformerBlocks:
X = transformerBlock(X)
# Decoder
......
......@@ -51,13 +51,15 @@ from utils.common_utils import create_folder
torch.set_default_tensor_type(torch.FloatTensor)
def load_data(data_parameters):
def load_data(data_parameters, cross_domain_x2x_flag, cross_domain_y2y_flag):
"""Dataset Loader
This function loads the training and validation datasets.
Args:
data_parameters (dict): Dictionary containing relevant information for the datafiles.
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns:
train_data (dataset object): Pytorch map-style dataset object, mapping indices to training data samples.
......@@ -65,7 +67,7 @@ def load_data(data_parameters):
"""
print("Data is loading...")
train_data, validation_data = get_datasets(data_parameters)
train_data, validation_data = get_datasets(data_parameters, cross_domain_x2x_flag, cross_domain_y2y_flag)
print("Data has loaded!")
print("Training dataset size is {}".format(len(train_data)))
print("Validation dataset size is {}".format(len(validation_data)))
......@@ -116,7 +118,49 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
}
"""
def _train_runner(data_parameters, training_parameters, network_parameters, misc_parameters):
def _load_pretrained_cross_domain(x2y_model, save_model_directory, experiment_name):
""" Pretrained cross-domain loader
This function loads the pretrained X2X and Y2Y autuencoders.
After, it initializes the X2Y model's weights using the X2X encoder and teh Y2Y decoder weights.
Args:
x2y_model (class): Original x2y model initialised using the standard parameters.
save_model_directory (str): Name of the directory where the model is saved
experiment_name (str): Name of the experiment
Returns:
x2y_model (class): New x2y model with encoder and decoder paths weights reinitialised.
"""
x2y_model_state_dict = x2y_model.state_dict()
x2x_model_state_dict = torch.load(os.path.join(save_model_directory, experiment_name + '_x2x.pth.tar')).state_dict()
y2y_model_state_dict = torch.load(os.path.join(save_model_directory, experiment_name + '_y2y.pth.tar')).state_dict()
half_point = len(x2x_model_state_dict)//2 + 1
counter = 1
for key, _ in x2y_model_state_dict.items():
if counter <= half_point:
x2y_model_state_dict.update({key : x2x_model_state_dict[key]})
counter+=1
else:
if key in y2y_model_state_dict:
x2y_model_state_dict.update({key : y2y_model_state_dict[key]})
x2y_model.load_state_dict(x2y_model_state_dict)
return x2y_model
def _train_runner(data_parameters,
training_parameters,
network_parameters,
misc_parameters,
optimizer = torch.optim.Adam,
loss_function = torch.nn.MSELoss(),
):
"""Wrapper for the training operation
This function wraps the training operation for the network
......@@ -128,7 +172,11 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
misc_parameters (dict): Dictionary of aditional hyperparameters
"""
train_data, validation_data = load_data(data_parameters)
train_data, validation_data = load_data(data_parameters,
cross_domain_x2x_flag = network_parameters['cross_domain_x2x_flag'],
cross_domain_y2y_flag = network_parameters['cross_domain_y2y_flag']
)
train_loader = data.DataLoader(
dataset=train_data,
......@@ -145,8 +193,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
)
if training_parameters['use_pre_trained']:
BrainMapperModel = torch.load(
training_parameters['pre_trained_path'])
BrainMapperModel = torch.load(training_parameters['pre_trained_path'])
else:
BrainMapperModel = BrainMapperAE3D(network_parameters)
......@@ -154,8 +201,11 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
BrainMapperModel.reset_parameters(custom_weight_reset_flag)
optimizer = torch.optim.Adam
# optimizer = torch.optim.AdamW
if network_parameters['cross_domain_x2y_flag'] == True:
BrainMapperModel = _load_pretrained_cross_domain(x2y_model=BrainMapperModel,
save_model_directory=misc_parameters['save_model_directory'],
experiment_name=training_parameters['experiment_name']
)
solver = Solver(model=BrainMapperModel,
device=misc_parameters['device'],
......@@ -167,6 +217,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
'eps': training_parameters['optimizer_epsilon'],
'weight_decay': training_parameters['optimizer_weigth_decay']
},
loss_function=loss_function,
model_name=training_parameters['experiment_name'],
number_epochs=training_parameters['number_of_epochs'],
loss_log_period=training_parameters['loss_log_period'],
......@@ -178,7 +229,6 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
logs_directory=misc_parameters['logs_directory'],
checkpoint_directory=misc_parameters['checkpoint_directory'],
save_model_directory=misc_parameters['save_model_directory'],
final_model_output_file=training_parameters['final_model_output_file'],
crop_flag = data_parameters['crop_flag']
)
......@@ -190,7 +240,64 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
return validation_loss
_ = _train_runner(data_parameters, training_parameters, network_parameters, misc_parameters)
optimizer = torch.optim.Adam
# optimizer = torch.optim.AdamW
loss_function = torch.nn.MSELoss()
# loss_function=torch.nn.L1Loss()
# loss_function=torch.nn.CosineEmbeddingLoss()
if network_parameters['cross_domain_flag'] == False:
_ = _train_runner(data_parameters,
training_parameters,
network_parameters,
misc_parameters,
optimizer=optimizer,
loss_function=loss_function
)
elif network_parameters['cross_domain_flag'] == True:
if network_parameters['cross_domain_x2x_flag'] == True:
training_parameters['experiment_name'] = training_parameters['experiment_name'] + '_x2x'
data_parameters['target_data_train'] = data_parameters['input_data_train']
data_parameters['target_data_validation'] = data_parameters['input_data_validation']
loss_function = torch.nn.L1Loss()
_ = _train_runner(data_parameters,
training_parameters,
network_parameters,
misc_parameters,
optimizer=optimizer,
loss_function=loss_function
)
if network_parameters['cross_domain_y2y_flag'] == True:
training_parameters['experiment_name'] = training_parameters['experiment_name'] + '_y2y'
data_parameters['input_data_train'] = data_parameters['target_data_train']
data_parameters['input_data_validation'] = data_parameters['target_data_validation']
loss_function = torch.nn.L1Loss()
_ = _train_runner(data_parameters,
training_parameters,
network_parameters,
misc_parameters,
optimizer=optimizer,
loss_function=loss_function
)
if network_parameters['cross_domain_x2y_flag'] == True:
_ = _train_runner(data_parameters,
training_parameters,
network_parameters,
misc_parameters,
optimizer=optimizer,
loss_function=loss_function
)
def evaluate_mapping(mapping_evaluation_parameters):
......
......@@ -4,12 +4,11 @@ input_data_train = "input_data_train.h5"
target_data_train = "target_data_train.h5"
input_data_validation = "input_data_validation.h5"
target_data_validation = "target_data_validation.h5"
crop_flag = False
crop_flag = True
[TRAINING]
experiment_name = "VA2-1"
pre_trained_path = "saved_models/VA2-1.pth.tar"
final_model_output_file = "VA2-1.pth.tar"
training_batch_size = 5
validation_batch_size = 5
use_pre_trained = False
......@@ -29,15 +28,20 @@ kernel_width = 3
kernel_depth = 3
kernel_classification = 7
input_channels = 1
output_channels = 64
output_channels = 32
convolution_stride = 1
dropout = 0
pool_kernel_size = 3
pool_stride = 2
up_mode = "upconv"
number_of_classes = 1
number_of_transformer_blocks = 10
number_of_transformer_blocks = 6
custom_weight_reset_flag = False
cross_domain_flag = False
cross_domain_x2x_flag = False
cross_domain_y2y_flag = False
cross_domain_x2y_flag = False
number_of_feature_mapping_blocks = 1
[MISC]
save_model_directory = "saved_models"
......
......@@ -65,8 +65,6 @@ class Solver():
optimizer,
optimizer_arguments={},
loss_function=MSELoss(),
# loss_function=torch.nn.L1Loss(),
# loss_function=torch.nn.CosineEmbeddingLoss(),
model_name='BrainMapper',
labels=None,
number_epochs=10,
......@@ -78,7 +76,6 @@ class Solver():
logs_directory='logs',
checkpoint_directory='checkpoints',
save_model_directory='saved_models',
final_model_output_file='finetuned_alldata.pth.tar',
crop_flag = False
):
......@@ -88,10 +85,10 @@ class Solver():
if torch.cuda.is_available():
self.loss_function = loss_function.cuda(device)
self.MSE = MSELoss().cuda(device)
self.MSE = torch.nn.MSELoss().cuda(device)
else:
self.loss_function = loss_function
self.MSE = MSELoss()
self.MSE = torch.nn.MSELoss()
self.model_name = model_name
self.labels = labels
......@@ -134,7 +131,7 @@ class Solver():
self.MNI152_T1_2mm_brain_mask = torch.from_numpy(roi(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz'),((9,81),(10,100),(0,77))).data)
self.save_model_directory = save_model_directory
self.final_model_output_file = final_model_output_file
self.final_model_output_file = experiment_name + ".pth.tar"
if use_last_checkpoint:
self.load_checkpoint()
......
......@@ -53,7 +53,7 @@ class DataMapper(data.Dataset):
return len(self.y)
def get_datasets(data_parameters):
def get_datasets(data_parameters, cross_domain_x2x_flag, cross_domain_y2y_flag):
"""Data Loader Function.
This function loads the various data file and returns the relevand mapped datasets.
......@@ -67,20 +67,32 @@ def get_datasets(data_parameters):
input_data_validation = "input_data_validation.h5"
target_data_validation = "target_data_validation.h5"
}
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns:
touple: the relevant train and validation datasets
"""
key_X = 'input'
key_y = 'target'
X_train_data = h5py.File(os.path.join(data_parameters["data_folder_name"], data_parameters["input_data_train"]), 'r')
y_train_data = h5py.File(os.path.join(data_parameters["data_folder_name"], data_parameters["target_data_train"]), 'r')
X_validation_data = h5py.File(os.path.join(data_parameters["data_folder_name"], data_parameters["input_data_validation"]), 'r')
y_validation_data = h5py.File(os.path.join(data_parameters["data_folder_name"], data_parameters["target_data_validation"]), 'r')
if cross_domain_x2x_flag == True:
key_X = 'input'
key_y = 'input'
elif cross_domain_y2y_flag == True:
key_X = 'target'
key_y = 'target'
return (
DataMapper( X_train_data['input'][()], y_train_data['target'][()] ),
DataMapper( X_validation_data['input'][()], y_validation_data['target'][()] )
DataMapper( X_train_data[key_X][()], y_train_data[key_y][()] ),
DataMapper( X_validation_data[key_X][()], y_validation_data[key_y][()] )
)
......
......@@ -100,6 +100,76 @@ class ResNetEncoderBlock3D(nn.Module):
return X
class ResNetFeatureMappingBlock3D(nn.Module):
"""Parent class for a 3D convolutional feature mapping block.
This class represents a generic parent class for a convolutional 3D feature mapping block.
The class represents a subclass/child class of nn.Module, inheriting its functionality.
Args:
parameters (dict): Contains information on kernel size, number of channels, number of filters, and if convolution is strided.
parameters = {
'kernel_heigth': 5
'kernel_width': 5
'kernel_depth' : 5
'input_channels': 64
'output_channels': 64
'convolution_stride': 1
'dropout': 0.2
}
Returns:
torch.tensor: Output forward passed tensor
"""
def __init__(self, parameters):
super(ResNetFeatureMappingBlock3D, self).__init__()
# We first calculate the amount of zero padding required (http://cs231n.github.io/convolutional-networks/)
padding_heigth = int((parameters['kernel_heigth'] - 1) / 2)
padding_width = int((parameters['kernel_heigth'] - 1) / 2)
padding_depth = int((parameters['kernel_heigth'] - 1) / 2)
self.convolutional_layer = nn.Sequential(
nn.Conv3d(
in_channels=parameters['input_channels'],
out_channels=parameters['output_channels'],
kernel_size=parameters['kernel_heigth'],
stride=parameters['convolution_stride'],
padding=(padding_depth, padding_heigth, padding_width)
),
nn.InstanceNorm3d(num_features=parameters['output_channels']),
)
# Instance normalisation is used to the the small batch size, and as it has shown promise during the experiments with the simple network.
if parameters['dropout'] > 0:
self.dropout_needed = True
self.dropout = nn.Dropout3d(parameters['dropout'])
else:
self.dropout_needed = False
def forward(self, X):
"""Forward pass
Function computing the forward pass through the convolutional layer.
The input to the function is a torch tensor of shape N (batch size) x C (number of channels) x D (input depth) x H (input heigth) x W (input width)
Args:
X (torch.tensor): Input tensor, shape = (N x C x D x H x W)
Returns:
torch.tensor: Output forward passed tensor
"""
X = self.convolutional_layer(X)
if self.dropout_needed:
X = self.dropout(X)
return X
class ResNetBlock3D(nn.Module):
"""Parent class for a 3D ResNet convolutional block.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment