Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
74c4b384
Commit
74c4b384
authored
Mar 25, 2020
by
Andrei-Claudiu Roibu
🖥
Browse files
created constructor function for UNet
parent
328192ce
Changes
1
Hide whitespace changes
Inline
Side-by-side
BrainMapperUNet.py
View file @
74c4b384
...
@@ -17,14 +17,57 @@ To use this module, import it and instantiate is as you wish:
...
@@ -17,14 +17,57 @@ To use this module, import it and instantiate is as you wish:
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
utils.modules
as
modules
class
BrainMapperUnet
(
nn
.
Module
):
class
BrainMapperUNet
(
nn
.
Module
):
"""
"""Architecture class BrainMapper U-net.
Description
This class contains the pytorch implementation of the U-net architecture underpinning the BrainMapper project.
Args:
parameters (dict): Contains information relevant parameters
parameters = {
'kernel_heigth': 5
'kernel_width': 5
'kernel_classification': 1
'input_channels': 1
'output_channels': 64
'convolution_stride': 1
'dropout': 0.2
'pool_kernel_size': 2
'pool_stride': 2
'up_mode': 'upconv'
'number_of_classes': 1
}
Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block
Raises:
None
"""
"""
def
__init__
(
self
,
parameters
):
def
__init__
(
self
,
parameters
):
pass
super
(
BrainMapperUNet
,
self
).
__init__
()
# TODO: currently, architecture based on QuickNAT - need to adjust parameter values accordingly!
self
.
encoderBlock1
=
modules
.
EncoderBlock
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
self
.
encoderBlock2
=
modules
.
EncoderBlock
(
parameters
)
self
.
encoderBlock3
=
modules
.
EncoderBlock
(
parameters
)
self
.
encoderBlock4
=
modules
.
EncoderBlock
(
parameters
)
self
.
bottleneck
=
modules
.
ConvolutionalBlock
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
*
2.0
self
.
decoderBlock1
=
modules
.
DecoderBlock
(
parameters
)
self
.
decoderBlock2
=
modules
.
DecoderBlock
(
parameters
)
self
.
decoderBlock3
=
modules
.
DecoderBlock
(
parameters
)
self
.
decoderBlock4
=
modules
.
DecoderBlock
(
parameters
)
parameters
[
'input_channels'
]
=
parameters
[
'output_channels'
]
self
.
classifier
=
modules
.
ClassifierBlock
(
parameters
)
def
forward
(
self
,
X
):
def
forward
(
self
,
X
):
"""
"""
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment