Commit 1de05ff6 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

corrected weight initiation for PReLU

parent 9f6e9cb1
......@@ -18,6 +18,7 @@ import numpy as np
import torch
import torch.nn as nn
import utils.modules as modules
from torch.nn.init import _calculate_fan_in_and_fan_out as calculate_fan
class BrainMapperAE3D(nn.Module):
......@@ -214,5 +215,10 @@ class BrainMapperAE3D(nn.Module):
for _, subsubmodule in submodule.named_children():
if isinstance(subsubmodule, (torch.nn.PReLU, torch.nn.Dropout3d, torch.nn.MaxPool3d)) == False:
subsubmodule.reset_parameters()
if isinstance(subsubmodule, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
gain = np.sqrt(np.divide(2, 1 + np.power(0.25, 2)))
fan, _ = calculate_fan(subsubmodule.weight)
std = np.divide(gain, np.sqrt(fan))
subsubmodule.weight.data.normal_(0, std)
print("Initialized network parameters!")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment