Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
41956de8
Commit
41956de8
authored
Jul 31, 2020
by
Andrei Roibu
Browse files
fixed if-bug in reset parameters
parent
9a385a85
Changes
1
Hide whitespace changes
Inline
Side-by-side
BrainMapperAE.py
View file @
41956de8
...
...
@@ -209,18 +209,20 @@ class BrainMapperAE3D(nn.Module):
for
_
,
submodule
in
module
.
named_children
():
if
isinstance
(
submodule
,
(
torch
.
nn
.
ConvTranspose3d
,
torch
.
nn
.
Conv3d
,
torch
.
nn
.
InstanceNorm3d
))
==
True
:
submodule
.
reset_parameters
()
if
custom_weight_reset_flag
==
True
&
isinstance
(
submodule
,
(
torch
.
nn
.
Conv3d
,
torch
.
nn
.
ConvTranspose3d
)):
gain
=
np
.
sqrt
(
np
.
divide
(
2
,
1
+
np
.
power
(
0.25
,
2
)))
fan
,
_
=
calculate_fan
(
submodule
.
weight
)
std
=
np
.
divide
(
gain
,
np
.
sqrt
(
fan
))
submodule
.
weight
.
data
.
normal_
(
0
,
std
)
if
custom_weight_reset_flag
==
True
:
if
isinstance
(
submodule
,
(
torch
.
nn
.
Conv3d
,
torch
.
nn
.
ConvTranspose3d
)):
gain
=
np
.
sqrt
(
np
.
divide
(
2
,
1
+
np
.
power
(
0.25
,
2
)))
fan
,
_
=
calculate_fan
(
submodule
.
weight
)
std
=
np
.
divide
(
gain
,
np
.
sqrt
(
fan
))
submodule
.
weight
.
data
.
normal_
(
0
,
std
)
for
_
,
subsubmodule
in
submodule
.
named_children
():
if
isinstance
(
subsubmodule
,
(
torch
.
nn
.
ConvTranspose3d
,
torch
.
nn
.
Conv3d
,
torch
.
nn
.
InstanceNorm3d
))
==
True
:
subsubmodule
.
reset_parameters
()
if
custom_weight_reset_flag
==
True
&
isinstance
(
subsubmodule
,
(
torch
.
nn
.
Conv3d
,
torch
.
nn
.
ConvTranspose3d
)):
gain
=
np
.
sqrt
(
np
.
divide
(
2
,
1
+
np
.
power
(
0.25
,
2
)))
fan
,
_
=
calculate_fan
(
subsubmodule
.
weight
)
std
=
np
.
divide
(
gain
,
np
.
sqrt
(
fan
))
subsubmodule
.
weight
.
data
.
normal_
(
0
,
std
)
if
custom_weight_reset_flag
==
True
:
if
isinstance
(
subsubmodule
,
(
torch
.
nn
.
Conv3d
,
torch
.
nn
.
ConvTranspose3d
)):
gain
=
np
.
sqrt
(
np
.
divide
(
2
,
1
+
np
.
power
(
0.25
,
2
)))
fan
,
_
=
calculate_fan
(
subsubmodule
.
weight
)
std
=
np
.
divide
(
gain
,
np
.
sqrt
(
fan
))
subsubmodule
.
weight
.
data
.
normal_
(
0
,
std
)
print
(
"Initialized network parameters!"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment