Commit 91ecfd99 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

added settings call for number of res blocks and the use of custom weights

parent 8d308d99
......@@ -75,16 +75,7 @@ class BrainMapperAE3D(nn.Module):
parameters['input_channels'] = parameters['output_channels']
parameters['convolution_stride'] = original_stride
self.transformerBlock1 = modules.ResNetBlock3D(parameters)
self.transformerBlock2 = modules.ResNetBlock3D(parameters)
self.transformerBlock3 = modules.ResNetBlock3D(parameters)
self.transformerBlock4 = modules.ResNetBlock3D(parameters)
self.transformerBlock5 = modules.ResNetBlock3D(parameters)
self.transformerBlock6 = modules.ResNetBlock3D(parameters)
self.transformerBlock7 = modules.ResNetBlock3D(parameters)
self.transformerBlock8 = modules.ResNetBlock3D(parameters)
self.transformerBlock9 = modules.ResNetBlock3D(parameters)
self.transformerBlock10 = modules.ResNetBlock3D(parameters)
self.transformerBlocks = nn.ModuleList([modules.ResNetBlock3D(parameters) for i in range(parameters['number_of_transformer_blocks'])])
# Decoder
......@@ -124,16 +115,8 @@ class BrainMapperAE3D(nn.Module):
# Transformer
X = self.transformerBlock1.forward(X)
X = self.transformerBlock2.forward(X)
X = self.transformerBlock3.forward(X)
X = self.transformerBlock4.forward(X)
X = self.transformerBlock5.forward(X)
X = self.transformerBlock6.forward(X)
X = self.transformerBlock7.forward(X)
X = self.transformerBlock8.forward(X)
X = self.transformerBlock9.forward(X)
X = self.transformerBlock10.forward(X)
for transformerBlock in self.transformerBlocks:
X = transformerBlock(X)
# Decoder
......@@ -207,26 +190,37 @@ class BrainMapperAE3D(nn.Module):
return prediction
def reset_parameters(self):
def reset_parameters(self, custom_weight_reset_flag):
"""Parameter Initialization
This function (re)initializes the parameters of the defined network.
This function is a wrapper for the reset_parameters() function defined for each module.
More information can be found here: https://discuss.pytorch.org/t/what-is-the-default-initialization-of-a-conv2d-layer-and-linear-layer/16055 + https://discuss.pytorch.org/t/how-to-reset-model-weights-to-effectively-implement-crossvalidation/53859
An alternative (re)initialization method is described here: https://discuss.pytorch.org/t/how-to-reset-variables-values-in-nn-modules/32639
Args:
custom_weight_reset_flag (bool): Flag indicating if the modified weight initialisation approach should be used.
"""
print("Initializing network parameters...")
for _, module in self.named_children():
for _, submodule in module.named_children():
if isinstance(submodule, (torch.nn.ConvTranspose3d, torch.nn.Conv3d, torch.nn.InstanceNorm3d)) == True:
submodule.reset_parameters()
if custom_weight_reset_flag == True & isinstance(submodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
fan, _ = calculate_fan(submodule.weight)
std = np.divide(gain, np.sqrt(fan))
submodule.weight.data.normal_(0, std)
for _, subsubmodule in submodule.named_children():
if isinstance(subsubmodule, (torch.nn.PReLU, torch.nn.Dropout3d, torch.nn.MaxPool3d)) == False:
if isinstance(subsubmodule, (torch.nn.ConvTranspose3d, torch.nn.Conv3d, torch.nn.InstanceNorm3d)) == True:
subsubmodule.reset_parameters()
# if isinstance(subsubmodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
# gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
# fan, _ = calculate_fan(subsubmodule.weight)
# std = np.divide(gain, np.sqrt(fan))
# subsubmodule.weight.data.normal_(0, std)
if custom_weight_reset_flag == True & isinstance(subsubmodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
fan, _ = calculate_fan(subsubmodule.weight)
std = np.divide(gain, np.sqrt(fan))
subsubmodule.weight.data.normal_(0, std)
print("Initialized network parameters!")
......@@ -36,6 +36,8 @@ pool_kernel_size = 3
pool_stride = 2
up_mode = "upconv"
number_of_classes = 1
number_of_transformer_blocks = 10
custom_weight_reset_flag = False
[MISC]
save_model_directory = "saved_models"
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment