Commit d8b4134c authored by Andrei Roibu's avatar Andrei Roibu
Browse files

constructed classifier block

parent f79ee18a
......@@ -17,6 +17,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
# TODO: Currently, it appears that we are using constant size filters. We will need to adjust this in the network architecture, to allow it to encode/decode information!
class ConvolutionalBlock(nn.Module):
"""Parent class for a convolutional block.
......@@ -241,16 +243,72 @@ class DecoderBlock(ConvolutionalBlock):
Y = super(DecoderBlock, self).forward(concatenation)
if self.dropout_needed:ß
if self.dropout_needed:
Y = self.dropout_needed(Y)
return Y
class ClassifierBlock(ConvolutionalBlock):
"""Classifier block for a U-net.
This class creates a simple classifier block following the architecture:
parameters (dict): Contains information relevant parameters
parameters = {
'kernel_heigth': 5
'kernel_width': 5
'kernel_classification': 1
'input_channels': 1
'output_channels': 1
'convolution_stride': 1
'dropout': 0.2
'pool_kernel_size': 2
'pool_stride': 2
'up_mode': 'upconv'
'number_of_classes': 1
Y (torch.tensor): Output forward passed tensor through the decoder block
def __init__(self, parameters):
super(ClassifierBlock, self).__init__()
self.convolutional_layer = nn.Conv2d(
in_channels= parameters['input_channels'],
out_channels= parameters['number_of_classes'],
kernel_size= parameters['kernel_classification'],
stride= parameters['convolution_stride'],
# TODO: Might be wworth looking at GANS for image generation, and adding padding
def forward(self, X):
"""Forward pass for U-net classifier block
Function computing the forward pass through the classifier block.
The input to the function is a torch tensor of shape N (batch size) x C (number of channels) x H (input heigth) x W (input width).
X (torch.tensor): Input tensor, shape = (N x C x H x W)
logits (torch.tensor): Output logits from forward pass tensor through the classifier block
logits = self.convolutional_layer(X)
# TODO: Currently, this has no activation function. Might be worth considering adding a tanh activation function, similar to GANs
# For refernece :
# For refernece 2:
return logits
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment