Commit 9338b510 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

reorganized deprecated nets, code clearning

parent f982dfaf
...@@ -20,8 +20,8 @@ import torch.nn as nn ...@@ -20,8 +20,8 @@ import torch.nn as nn
import utils.modules as modules import utils.modules as modules
class BrainMapperUNet(nn.Module): class BrainMapperUNet3D(nn.Module):
"""Architecture class BrainMapper U-net. """Architecture class BrainMapper 3D U-net.
This class contains the pytorch implementation of the U-net architecture underpinning the BrainMapper project. This class contains the pytorch implementation of the U-net architecture underpinning the BrainMapper project.
...@@ -30,6 +30,7 @@ class BrainMapperUNet(nn.Module): ...@@ -30,6 +30,7 @@ class BrainMapperUNet(nn.Module):
parameters = { parameters = {
'kernel_heigth': 5 'kernel_heigth': 5
'kernel_width': 5 'kernel_width': 5
'kernel_depth': 5
'kernel_classification': 1 'kernel_classification': 1
'input_channels': 1 'input_channels': 1
'output_channels': 64 'output_channels': 64
...@@ -46,35 +47,53 @@ class BrainMapperUNet(nn.Module): ...@@ -46,35 +47,53 @@ class BrainMapperUNet(nn.Module):
""" """
def __init__(self, parameters): def __init__(self, parameters):
super(BrainMapperUNet, self).__init__() super(BrainMapperUNet3D, self).__init__()
# TODO: currently, architecture based on QuickNAT - need to adjust parameter values accordingly! original_input_channels = parameters['input_channels']
original_output_channels = parameters['output_channels']
self.encoderBlock1 = modules.EncoderBlock(parameters) self.encoderBlock1 = modules.EncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] parameters['input_channels'] = parameters['output_channels']
self.encoderBlock2 = modules.EncoderBlock(parameters) parameters['output_channels'] = parameters['output_channels'] * 2
self.encoderBlock3 = modules.EncoderBlock(parameters) self.encoderBlock2 = modules.EncoderBlock3D(parameters)
self.encoderBlock4 = modules.EncoderBlock(parameters) parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
self.encoderBlock3 = modules.EncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
self.encoderBlock4 = modules.EncoderBlock3D(parameters)
self.bottleneck = modules.ConvolutionalBlock(parameters) parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
self.bottleneck = modules.ConvolutionalBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] * 2.0 parameters['input_channels'] = parameters['output_channels']
self.decoderBlock1 = modules.DecoderBlock(parameters) parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock2 = modules.DecoderBlock(parameters) self.decoderBlock1 = modules.DecoderBlock3D(parameters)
self.decoderBlock3 = modules.DecoderBlock(parameters) parameters['input_channels'] = parameters['output_channels']
self.decoderBlock4 = modules.DecoderBlock(parameters) parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock2 = modules.DecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock3 = modules.DecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock4 = modules.DecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] parameters['input_channels'] = parameters['output_channels']
self.classifier = modules.ClassifierBlock(parameters) self.classifier = modules.ClassifierBlock3D(parameters)
parameters['input_channels'] = original_input_channels
parameters['output_channels'] = original_output_channels
def forward(self, X): def forward(self, X):
"""Forward pass for U-net """Forward pass for 3D U-net
Function computing the forward pass through the U-Net Function computing the forward pass through the 3D U-Net
The input to the function is the dMRI map The input to the function is the dMRI map
Args: Args:
X (torch.tensor): Input dMRI map, shape = (N x C x H x W) X (torch.tensor): Input dMRI map, shape = (N x C x D x H x W)
Returns: Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block probability_map (torch.tensor): Output forward passed tensor through the U-net block
...@@ -188,8 +207,10 @@ class BrainMapperUNet(nn.Module): ...@@ -188,8 +207,10 @@ class BrainMapperUNet(nn.Module):
return prediction return prediction
class BrainMapperUNet3D(nn.Module): # DEPRECATED ARCHITECTURES!
"""Architecture class BrainMapper 3D U-net.
class BrainMapperUNet(nn.Module):
"""Architecture class BrainMapper U-net.
This class contains the pytorch implementation of the U-net architecture underpinning the BrainMapper project. This class contains the pytorch implementation of the U-net architecture underpinning the BrainMapper project.
...@@ -198,7 +219,6 @@ class BrainMapperUNet3D(nn.Module): ...@@ -198,7 +219,6 @@ class BrainMapperUNet3D(nn.Module):
parameters = { parameters = {
'kernel_heigth': 5 'kernel_heigth': 5
'kernel_width': 5 'kernel_width': 5
'kernel_depth': 5
'kernel_classification': 1 'kernel_classification': 1
'input_channels': 1 'input_channels': 1
'output_channels': 64 'output_channels': 64
...@@ -215,53 +235,35 @@ class BrainMapperUNet3D(nn.Module): ...@@ -215,53 +235,35 @@ class BrainMapperUNet3D(nn.Module):
""" """
def __init__(self, parameters): def __init__(self, parameters):
super(BrainMapperUNet3D, self).__init__() super(BrainMapperUNet, self).__init__()
original_input_channels = parameters['input_channels'] # TODO: currently, architecture based on QuickNAT - need to adjust parameter values accordingly!
original_output_channels = parameters['output_channels']
self.encoderBlock1 = modules.EncoderBlock3D(parameters) self.encoderBlock1 = modules.EncoderBlock(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
self.encoderBlock2 = modules.EncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2
self.encoderBlock3 = modules.EncoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] * 2 self.encoderBlock2 = modules.EncoderBlock(parameters)
self.encoderBlock4 = modules.EncoderBlock3D(parameters) self.encoderBlock3 = modules.EncoderBlock(parameters)
self.encoderBlock4 = modules.EncoderBlock(parameters)
parameters['input_channels'] = parameters['output_channels'] self.bottleneck = modules.ConvolutionalBlock(parameters)
parameters['output_channels'] = parameters['output_channels'] * 2
self.bottleneck = modules.ConvolutionalBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] parameters['input_channels'] = parameters['output_channels'] * 2.0
parameters['output_channels'] = parameters['output_channels'] // 2 self.decoderBlock1 = modules.DecoderBlock(parameters)
self.decoderBlock1 = modules.DecoderBlock3D(parameters) self.decoderBlock2 = modules.DecoderBlock(parameters)
parameters['input_channels'] = parameters['output_channels'] self.decoderBlock3 = modules.DecoderBlock(parameters)
parameters['output_channels'] = parameters['output_channels'] // 2 self.decoderBlock4 = modules.DecoderBlock(parameters)
self.decoderBlock2 = modules.DecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock3 = modules.DecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels']
parameters['output_channels'] = parameters['output_channels'] // 2
self.decoderBlock4 = modules.DecoderBlock3D(parameters)
parameters['input_channels'] = parameters['output_channels'] parameters['input_channels'] = parameters['output_channels']
self.classifier = modules.ClassifierBlock3D(parameters) self.classifier = modules.ClassifierBlock(parameters)
parameters['input_channels'] = original_input_channels
parameters['output_channels'] = original_output_channels
def forward(self, X): def forward(self, X):
"""Forward pass for 3D U-net """Forward pass for U-net
Function computing the forward pass through the 3D U-Net Function computing the forward pass through the U-Net
The input to the function is the dMRI map The input to the function is the dMRI map
Args: Args:
X (torch.tensor): Input dMRI map, shape = (N x C x D x H x W) X (torch.tensor): Input dMRI map, shape = (N x C x H x W)
Returns: Returns:
probability_map (torch.tensor): Output forward passed tensor through the U-net block probability_map (torch.tensor): Output forward passed tensor through the U-net block
...@@ -375,8 +377,6 @@ class BrainMapperUNet3D(nn.Module): ...@@ -375,8 +377,6 @@ class BrainMapperUNet3D(nn.Module):
return prediction return prediction
# DEPRECATED ARCHITECTURES!
class BrainMapperUNet3D_Simple(nn.Module): class BrainMapperUNet3D_Simple(nn.Module):
"""Architecture class BrainMapper 3D U-net. """Architecture class BrainMapper 3D U-net.
...@@ -548,23 +548,23 @@ class BrainMapperUNet3D_Simple(nn.Module): ...@@ -548,23 +548,23 @@ class BrainMapperUNet3D_Simple(nn.Module):
return prediction return prediction
if __name__ == '__main__': # if __name__ == '__main__':
# For debugging - To be deleted later! TODO # # For debugging - To be deleted later! TODO
parameters = { # parameters = {
'kernel_heigth': 5, # 'kernel_heigth': 5,
'kernel_width': 5, # 'kernel_width': 5,
'kernel_depth': 5, # 'kernel_depth': 5,
'kernel_classification': 1, # 'kernel_classification': 1,
'input_channels': 1, # 'input_channels': 1,
'output_channels': 64, # 'output_channels': 64,
'convolution_stride': 1, # 'convolution_stride': 1,
'dropout': 0.2, # 'dropout': 0.2,
'pool_kernel_size': 2, # 'pool_kernel_size': 2,
'pool_stride': 2, # 'pool_stride': 2,
'up_mode': 'upconv', # 'up_mode': 'upconv',
'number_of_classes': 1 # 'number_of_classes': 1
} # }
network = BrainMapperUNet3D(parameters) # network = BrainMapperUNet3D(parameters)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment