Commit a998549f authored by Andrei Roibu's avatar Andrei Roibu
Browse files

fixed module import error, saved original in/out parameters

parent df0d197a
...@@ -217,36 +217,42 @@ class BrainMapperUNet3D(nn.Module): ...@@ -217,36 +217,42 @@ class BrainMapperUNet3D(nn.Module):
def __init__(self, parameters): def __init__(self, parameters):
super(BrainMapperUNet3D, self).__init__() super(BrainMapperUNet3D, self).__init__()
self.encoderBlock1 = EncoderBlock3D(parameters) original_input_channels = parameters['input_channels']
original_output_channels = parameters['output_channels']
self.encoderBlock1 = modules.EncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2 parameters['output_channels'] = parameters['output_channels'] * 2
self.encoderBlock2 = EncoderBlock3D(parameters) self.encoderBlock2 = modules.EncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2 parameters['output_channels'] = parameters['output_channels'] * 2
self.encoderBlock3 = EncoderBlock3D(parameters) self.encoderBlock3 = modules.EncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2 parameters['output_channels'] = parameters['output_channels'] * 2
self.encoderBlock4 = EncoderBlock3D(parameters) self.encoderBlock4 = modules.EncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2 parameters['output_channels'] = parameters['output_channels'] * 2
self.bottleneck = ConvolutionalBlock3D(parameters) self.bottleneck = modules.ConvolutionalBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2 parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock1 = DecoderBlock3D(parameters) self.decoderBlock1 = modules.DecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2 parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock2 = DecoderBlock3D(parameters) self.decoderBlock2 = modules.DecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2 parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock3 = DecoderBlock3D(parameters) self.decoderBlock3 = modules.DecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2 parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock4 = DecoderBlock3D(parameters) self.decoderBlock4 = modules.DecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] parameters['input_channels'] = parameters['output_channels']
self.classifier = ClassifierBlock3D(parameters) self.classifier = modules.ClassifierBlock3D(parameters)
parameters['input_channels'] = original_input_channels
parameters['output_channels'] = original_output_channels
def forward(self, X): def forward(self, X):
"""Forward pass for 3D U-net """Forward pass for 3D U-net
...@@ -398,7 +404,7 @@ class BrainMapperUNet3D_Simple(nn.Module): ...@@ -398,7 +404,7 @@ class BrainMapperUNet3D_Simple(nn.Module):
""" """
def __init__(self, parameters): def __init__(self, parameters):
super(BrainMapperUNet3D, self).__init__() super(BrainMapperUNet3D_Simple, self).__init__()
# TODO: currently, architecture based on QuickNAT - need to adjust parameter values accordingly! # TODO: currently, architecture based on QuickNAT - need to adjust parameter values accordingly!
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment