Commit 74c4b384 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

created constructor function for UNet

parent 328192ce
...@@ -17,14 +17,57 @@ To use this module, import it and instantiate is as you wish: ...@@ -17,14 +17,57 @@ To use this module, import it and instantiate is as you wish:
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import utils.modules as modules
class BrainMapperUnet(nn.Module): class BrainMapperUNet(nn.Module):
""" """Architecture class BrainMapper U-net.
Description
This class contains the pytorch implementation of the U-net architecture underpinning the BrainMapper project.
Args:
parameters (dict): Contains information relevant parameters
parameters = {
'kernel_heigth': 5
'kernel_width': 5
'kernel_classification': 1
'input_channels': 1
'output_channels': 64
'convolution_stride': 1
'dropout': 0.2
'pool_kernel_size': 2
'pool_stride': 2
'up_mode': 'upconv'
'number_of_classes': 1
}
Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block
Raises:
None
""" """
def __init__(self, parameters): def __init__(self, parameters):
pass super(BrainMapperUNet, self).__init__()
# TODO: currently, architecture based on QuickNAT - need to adjust parameter values accordingly!
self.encoderBlock1 = modules.EncoderBlock(parameters)
parameters['input_channels'] = parameters['output_channels']
self.encoderBlock2 = modules.EncoderBlock(parameters)
self.encoderBlock3 = modules.EncoderBlock(parameters)
self.encoderBlock4 = modules.EncoderBlock(parameters)
self.bottleneck = modules.ConvolutionalBlock(parameters)
parameters['input_channels'] = parameters['output_channels'] * 2.0
self.decoderBlock1 = modules.DecoderBlock(parameters)
self.decoderBlock2 = modules.DecoderBlock(parameters)
self.decoderBlock3 = modules.DecoderBlock(parameters)
self.decoderBlock4 = modules.DecoderBlock(parameters)
parameters['input_channels'] = parameters['output_channels']
self.classifier = modules.ClassifierBlock(parameters)
def forward(self, X): def forward(self, X):
""" """
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment