Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
40295fb1
Commit
40295fb1
authored
Aug 05, 2020
by
Andrei Roibu
Browse files
added modifications required for cross-domain training
parent
41956de8
Changes
6
Hide whitespace changes
Inline
Side-by-side
BrainMapperAE.py
View file @
40295fb1
...
...
@@ -50,6 +50,8 @@ class BrainMapperAE3D(nn.Module):
def
__init__
(
self
,
parameters
):
super
(
BrainMapperAE3D
,
self
).
__init__
()
self
.
cross_domain_x2y_flag
=
parameters
[
'cross_domain_x2y_flag'
]
original_input_channels
=
parameters
[
'input_channels'
]
original_output_channels
=
parameters
[
'output_channels'
]
original_kernel_height
=
parameters
[
'kernel_heigth'
]
...
...
@@ -77,6 +79,9 @@ class BrainMapperAE3D(nn.Module):
self
.
transformerBlocks
=
nn
.
ModuleList
([
modules
.
ResNetBlock3D
(
parameters
)
for
i
in
range
(
parameters
[
'number_of_transformer_blocks'
])])
if
self
.
cross_domain_x2y_flag
==
True
:
self
.
featureMappingLayers
=
nn
.
ModuleList
([
modules
.
ResNetFeatureMappingBlock3D
(
parameters
)
for
i
in
range
(
parameters
[
'number_of_feature_mapping_blocks'
])])
# Decoder
parameters
[
'output_channels'
]
=
parameters
[
'output_channels'
]
//
2
...
...
@@ -115,8 +120,19 @@ class BrainMapperAE3D(nn.Module):
# Transformer
for
transformerBlock
in
self
.
transformerBlocks
:
X
=
transformerBlock
(
X
)
if
self
.
cross_domain_x2y_flag
==
True
:
for
transformerBlock
in
self
.
transformerBlocks
[:
len
(
self
.
transformerBlocks
)
//
2
]:
X
=
transformerBlock
(
X
)
for
featureMappingLayer
in
self
.
featureMappingLayers
:
X
=
featureMappingLayer
(
X
)
for
transformerBlock
in
self
.
transformerBlocks
[
len
(
self
.
transformerBlocks
)
//
2
:]:
X
=
transformerBlock
(
X
)
else
:
for
transformerBlock
in
self
.
transformerBlocks
:
X
=
transformerBlock
(
X
)
# Decoder
...
...
run.py
View file @
40295fb1
...
...
@@ -51,13 +51,15 @@ from utils.common_utils import create_folder
torch
.
set_default_tensor_type
(
torch
.
FloatTensor
)
def
load_data
(
data_parameters
):
def
load_data
(
data_parameters
,
cross_domain_x2x_flag
,
cross_domain_y2y_flag
):
"""Dataset Loader
This function loads the training and validation datasets.
Args:
data_parameters (dict): Dictionary containing relevant information for the datafiles.
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns:
train_data (dataset object): Pytorch map-style dataset object, mapping indices to training data samples.
...
...
@@ -65,7 +67,7 @@ def load_data(data_parameters):
"""
print
(
"Data is loading..."
)
train_data
,
validation_data
=
get_datasets
(
data_parameters
)
train_data
,
validation_data
=
get_datasets
(
data_parameters
,
cross_domain_x2x_flag
,
cross_domain_y2y_flag
)
print
(
"Data has loaded!"
)
print
(
"Training dataset size is {}"
.
format
(
len
(
train_data
)))
print
(
"Validation dataset size is {}"
.
format
(
len
(
validation_data
)))
...
...
@@ -116,7 +118,49 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
}
"""
def
_train_runner
(
data_parameters
,
training_parameters
,
network_parameters
,
misc_parameters
):
def
_load_pretrained_cross_domain
(
x2y_model
,
save_model_directory
,
experiment_name
):
""" Pretrained cross-domain loader
This function loads the pretrained X2X and Y2Y autuencoders.
After, it initializes the X2Y model's weights using the X2X encoder and teh Y2Y decoder weights.
Args:
x2y_model (class): Original x2y model initialised using the standard parameters.
save_model_directory (str): Name of the directory where the model is saved
experiment_name (str): Name of the experiment
Returns:
x2y_model (class): New x2y model with encoder and decoder paths weights reinitialised.
"""
x2y_model_state_dict
=
x2y_model
.
state_dict
()
x2x_model_state_dict
=
torch
.
load
(
os
.
path
.
join
(
save_model_directory
,
experiment_name
+
'_x2x.pth.tar'
)).
state_dict
()
y2y_model_state_dict
=
torch
.
load
(
os
.
path
.
join
(
save_model_directory
,
experiment_name
+
'_y2y.pth.tar'
)).
state_dict
()
half_point
=
len
(
x2x_model_state_dict
)
//
2
+
1
counter
=
1
for
key
,
_
in
x2y_model_state_dict
.
items
():
if
counter
<=
half_point
:
x2y_model_state_dict
.
update
({
key
:
x2x_model_state_dict
[
key
]})
counter
+=
1
else
:
if
key
in
y2y_model_state_dict
:
x2y_model_state_dict
.
update
({
key
:
y2y_model_state_dict
[
key
]})
x2y_model
.
load_state_dict
(
x2y_model_state_dict
)
return
x2y_model
def
_train_runner
(
data_parameters
,
training_parameters
,
network_parameters
,
misc_parameters
,
optimizer
=
torch
.
optim
.
Adam
,
loss_function
=
torch
.
nn
.
MSELoss
(),
):
"""Wrapper for the training operation
This function wraps the training operation for the network
...
...
@@ -128,7 +172,11 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
misc_parameters (dict): Dictionary of aditional hyperparameters
"""
train_data
,
validation_data
=
load_data
(
data_parameters
)
train_data
,
validation_data
=
load_data
(
data_parameters
,
cross_domain_x2x_flag
=
network_parameters
[
'cross_domain_x2x_flag'
],
cross_domain_y2y_flag
=
network_parameters
[
'cross_domain_y2y_flag'
]
)
train_loader
=
data
.
DataLoader
(
dataset
=
train_data
,
...
...
@@ -145,8 +193,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
)
if
training_parameters
[
'use_pre_trained'
]:
BrainMapperModel
=
torch
.
load
(
training_parameters
[
'pre_trained_path'
])
BrainMapperModel
=
torch
.
load
(
training_parameters
[
'pre_trained_path'
])
else
:
BrainMapperModel
=
BrainMapperAE3D
(
network_parameters
)
...
...
@@ -154,8 +201,11 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
BrainMapperModel
.
reset_parameters
(
custom_weight_reset_flag
)
optimizer
=
torch
.
optim
.
Adam
# optimizer = torch.optim.AdamW
if
network_parameters
[
'cross_domain_x2y_flag'
]
==
True
:
BrainMapperModel
=
_load_pretrained_cross_domain
(
x2y_model
=
BrainMapperModel
,
save_model_directory
=
misc_parameters
[
'save_model_directory'
],
experiment_name
=
training_parameters
[
'experiment_name'
]
)
solver
=
Solver
(
model
=
BrainMapperModel
,
device
=
misc_parameters
[
'device'
],
...
...
@@ -167,6 +217,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
'eps'
:
training_parameters
[
'optimizer_epsilon'
],
'weight_decay'
:
training_parameters
[
'optimizer_weigth_decay'
]
},
loss_function
=
loss_function
,
model_name
=
training_parameters
[
'experiment_name'
],
number_epochs
=
training_parameters
[
'number_of_epochs'
],
loss_log_period
=
training_parameters
[
'loss_log_period'
],
...
...
@@ -178,7 +229,6 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
logs_directory
=
misc_parameters
[
'logs_directory'
],
checkpoint_directory
=
misc_parameters
[
'checkpoint_directory'
],
save_model_directory
=
misc_parameters
[
'save_model_directory'
],
final_model_output_file
=
training_parameters
[
'final_model_output_file'
],
crop_flag
=
data_parameters
[
'crop_flag'
]
)
...
...
@@ -190,7 +240,64 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
return
validation_loss
_
=
_train_runner
(
data_parameters
,
training_parameters
,
network_parameters
,
misc_parameters
)
optimizer
=
torch
.
optim
.
Adam
# optimizer = torch.optim.AdamW
loss_function
=
torch
.
nn
.
MSELoss
()
# loss_function=torch.nn.L1Loss()
# loss_function=torch.nn.CosineEmbeddingLoss()
if
network_parameters
[
'cross_domain_flag'
]
==
False
:
_
=
_train_runner
(
data_parameters
,
training_parameters
,
network_parameters
,
misc_parameters
,
optimizer
=
optimizer
,
loss_function
=
loss_function
)
elif
network_parameters
[
'cross_domain_flag'
]
==
True
:
if
network_parameters
[
'cross_domain_x2x_flag'
]
==
True
:
training_parameters
[
'experiment_name'
]
=
training_parameters
[
'experiment_name'
]
+
'_x2x'
data_parameters
[
'target_data_train'
]
=
data_parameters
[
'input_data_train'
]
data_parameters
[
'target_data_validation'
]
=
data_parameters
[
'input_data_validation'
]
loss_function
=
torch
.
nn
.
L1Loss
()
_
=
_train_runner
(
data_parameters
,
training_parameters
,
network_parameters
,
misc_parameters
,
optimizer
=
optimizer
,
loss_function
=
loss_function
)
if
network_parameters
[
'cross_domain_y2y_flag'
]
==
True
:
training_parameters
[
'experiment_name'
]
=
training_parameters
[
'experiment_name'
]
+
'_y2y'
data_parameters
[
'input_data_train'
]
=
data_parameters
[
'target_data_train'
]
data_parameters
[
'input_data_validation'
]
=
data_parameters
[
'target_data_validation'
]
loss_function
=
torch
.
nn
.
L1Loss
()
_
=
_train_runner
(
data_parameters
,
training_parameters
,
network_parameters
,
misc_parameters
,
optimizer
=
optimizer
,
loss_function
=
loss_function
)
if
network_parameters
[
'cross_domain_x2y_flag'
]
==
True
:
_
=
_train_runner
(
data_parameters
,
training_parameters
,
network_parameters
,
misc_parameters
,
optimizer
=
optimizer
,
loss_function
=
loss_function
)
def
evaluate_mapping
(
mapping_evaluation_parameters
):
...
...
settings.ini
View file @
40295fb1
...
...
@@ -4,12 +4,11 @@ input_data_train = "input_data_train.h5"
target_data_train
=
"target_data_train.h5"
input_data_validation
=
"input_data_validation.h5"
target_data_validation
=
"target_data_validation.h5"
crop_flag
=
Fals
e
crop_flag
=
Tru
e
[TRAINING]
experiment_name
=
"VA2-1"
pre_trained_path
=
"saved_models/VA2-1.pth.tar"
final_model_output_file
=
"VA2-1.pth.tar"
training_batch_size
=
5
validation_batch_size
=
5
use_pre_trained
=
False
...
...
@@ -29,15 +28,20 @@ kernel_width = 3
kernel_depth
=
3
kernel_classification
=
7
input_channels
=
1
output_channels
=
64
output_channels
=
32
convolution_stride
=
1
dropout
=
0
pool_kernel_size
=
3
pool_stride
=
2
up_mode
=
"upconv"
number_of_classes
=
1
number_of_transformer_blocks
=
10
number_of_transformer_blocks
=
6
custom_weight_reset_flag
=
False
cross_domain_flag
=
False
cross_domain_x2x_flag
=
False
cross_domain_y2y_flag
=
False
cross_domain_x2y_flag
=
False
number_of_feature_mapping_blocks
=
1
[MISC]
save_model_directory
=
"saved_models"
...
...
solver.py
View file @
40295fb1
...
...
@@ -65,8 +65,6 @@ class Solver():
optimizer
,
optimizer_arguments
=
{},
loss_function
=
MSELoss
(),
# loss_function=torch.nn.L1Loss(),
# loss_function=torch.nn.CosineEmbeddingLoss(),
model_name
=
'BrainMapper'
,
labels
=
None
,
number_epochs
=
10
,
...
...
@@ -78,7 +76,6 @@ class Solver():
logs_directory
=
'logs'
,
checkpoint_directory
=
'checkpoints'
,
save_model_directory
=
'saved_models'
,
final_model_output_file
=
'finetuned_alldata.pth.tar'
,
crop_flag
=
False
):
...
...
@@ -88,10 +85,10 @@ class Solver():
if
torch
.
cuda
.
is_available
():
self
.
loss_function
=
loss_function
.
cuda
(
device
)
self
.
MSE
=
MSELoss
().
cuda
(
device
)
self
.
MSE
=
torch
.
nn
.
MSELoss
().
cuda
(
device
)
else
:
self
.
loss_function
=
loss_function
self
.
MSE
=
MSELoss
()
self
.
MSE
=
torch
.
nn
.
MSELoss
()
self
.
model_name
=
model_name
self
.
labels
=
labels
...
...
@@ -134,7 +131,7 @@ class Solver():
self
.
MNI152_T1_2mm_brain_mask
=
torch
.
from_numpy
(
roi
(
Image
(
'utils/MNI152_T1_2mm_brain_mask.nii.gz'
),((
9
,
81
),(
10
,
100
),(
0
,
77
))).
data
)
self
.
save_model_directory
=
save_model_directory
self
.
final_model_output_file
=
final_model_output_file
self
.
final_model_output_file
=
experiment_name
+
".pth.tar"
if
use_last_checkpoint
:
self
.
load_checkpoint
()
...
...
utils/data_utils.py
View file @
40295fb1
...
...
@@ -53,7 +53,7 @@ class DataMapper(data.Dataset):
return
len
(
self
.
y
)
def
get_datasets
(
data_parameters
):
def
get_datasets
(
data_parameters
,
cross_domain_x2x_flag
,
cross_domain_y2y_flag
):
"""Data Loader Function.
This function loads the various data file and returns the relevand mapped datasets.
...
...
@@ -67,20 +67,32 @@ def get_datasets(data_parameters):
input_data_validation = "input_data_validation.h5"
target_data_validation = "target_data_validation.h5"
}
cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs
cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets
Returns:
touple: the relevant train and validation datasets
"""
key_X
=
'input'
key_y
=
'target'
X_train_data
=
h5py
.
File
(
os
.
path
.
join
(
data_parameters
[
"data_folder_name"
],
data_parameters
[
"input_data_train"
]),
'r'
)
y_train_data
=
h5py
.
File
(
os
.
path
.
join
(
data_parameters
[
"data_folder_name"
],
data_parameters
[
"target_data_train"
]),
'r'
)
X_validation_data
=
h5py
.
File
(
os
.
path
.
join
(
data_parameters
[
"data_folder_name"
],
data_parameters
[
"input_data_validation"
]),
'r'
)
y_validation_data
=
h5py
.
File
(
os
.
path
.
join
(
data_parameters
[
"data_folder_name"
],
data_parameters
[
"target_data_validation"
]),
'r'
)
if
cross_domain_x2x_flag
==
True
:
key_X
=
'input'
key_y
=
'input'
elif
cross_domain_y2y_flag
==
True
:
key_X
=
'target'
key_y
=
'target'
return
(
DataMapper
(
X_train_data
[
'input'
][()],
y_train_data
[
'target'
][()]
),
DataMapper
(
X_validation_data
[
'input'
][()],
y_validation_data
[
'target'
][()]
)
DataMapper
(
X_train_data
[
key_X
][()],
y_train_data
[
key_y
][()]
),
DataMapper
(
X_validation_data
[
key_X
][()],
y_validation_data
[
key_y
][()]
)
)
...
...
utils/modules.py
View file @
40295fb1
...
...
@@ -100,6 +100,76 @@ class ResNetEncoderBlock3D(nn.Module):
return
X
class
ResNetFeatureMappingBlock3D
(
nn
.
Module
):
"""Parent class for a 3D convolutional feature mapping block.
This class represents a generic parent class for a convolutional 3D feature mapping block.
The class represents a subclass/child class of nn.Module, inheriting its functionality.
Args:
parameters (dict): Contains information on kernel size, number of channels, number of filters, and if convolution is strided.
parameters = {
'kernel_heigth': 5
'kernel_width': 5
'kernel_depth' : 5
'input_channels': 64
'output_channels': 64
'convolution_stride': 1
'dropout': 0.2
}
Returns:
torch.tensor: Output forward passed tensor
"""
def
__init__
(
self
,
parameters
):
super
(
ResNetFeatureMappingBlock3D
,
self
).
__init__
()
# We first calculate the amount of zero padding required (http://cs231n.github.io/convolutional-networks/)
padding_heigth
=
int
((
parameters
[
'kernel_heigth'
]
-
1
)
/
2
)
padding_width
=
int
((
parameters
[
'kernel_heigth'
]
-
1
)
/
2
)
padding_depth
=
int
((
parameters
[
'kernel_heigth'
]
-
1
)
/
2
)
self
.
convolutional_layer
=
nn
.
Sequential
(
nn
.
Conv3d
(
in_channels
=
parameters
[
'input_channels'
],
out_channels
=
parameters
[
'output_channels'
],
kernel_size
=
parameters
[
'kernel_heigth'
],
stride
=
parameters
[
'convolution_stride'
],
padding
=
(
padding_depth
,
padding_heigth
,
padding_width
)
),
nn
.
InstanceNorm3d
(
num_features
=
parameters
[
'output_channels'
]),
)
# Instance normalisation is used to the the small batch size, and as it has shown promise during the experiments with the simple network.
if
parameters
[
'dropout'
]
>
0
:
self
.
dropout_needed
=
True
self
.
dropout
=
nn
.
Dropout3d
(
parameters
[
'dropout'
])
else
:
self
.
dropout_needed
=
False
def
forward
(
self
,
X
):
"""Forward pass
Function computing the forward pass through the convolutional layer.
The input to the function is a torch tensor of shape N (batch size) x C (number of channels) x D (input depth) x H (input heigth) x W (input width)
Args:
X (torch.tensor): Input tensor, shape = (N x C x D x H x W)
Returns:
torch.tensor: Output forward passed tensor
"""
X
=
self
.
convolutional_layer
(
X
)
if
self
.
dropout_needed
:
X
=
self
.
dropout
(
X
)
return
X
class
ResNetBlock3D
(
nn
.
Module
):
"""Parent class for a 3D ResNet convolutional block.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment