Commit a868b271 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

fixed upsampling bugs; currently only upconv works!

parent 4bbd4184
......@@ -457,18 +457,18 @@ class DecoderBlock3D(ConvolutionalBlock3D):
def __init__(self, parameters):
super(DecoderBlock3D, self).__init__(parameters)
up_mode = parameters['up_mode']
if up_mode == 'upconv': # Attention - this will need to be checked to confirm that it is working!
self.up_mode = parameters['up_mode']
if self.up_mode == 'upconv': # Attention - this will need to be checked to confirm that it is working!
self.up = nn.ConvTranspose3d(
in_channels=parameters['input_channels'],
out_channels=parameters['output_channels'],
kernel_size=parameters['pool_kernel_size'],
stride=parameters['pool_stride'],
)
elif up_mode == 'upsample':
elif self.up_mode == 'upsample':
self.up = nn.Sequential(
nn.Upsample(
mode='bilinear',
mode='nearest',
scale_factor=2,
),
nn.Conv3d(
......@@ -477,7 +477,7 @@ class DecoderBlock3D(ConvolutionalBlock3D):
kernel_size=1,
)
)
elif up_mode == 'unpool':
elif self.up_mode == 'unpool':
self.up = nn.MaxUnpool3d(
kernel_size=parameters['pool_kernel_size'],
stride=parameters['pool_stride']
......@@ -500,7 +500,15 @@ class DecoderBlock3D(ConvolutionalBlock3D):
Y (torch.tensor): Output forward passed tensor through the decoder block
"""
upsampling = self.up(X, pool_indices)
# ATTENTION: As of this code version, only "upconv" works! Debugging is ongoing for upconv and upsample!
# It seems that errors are generated by variable filter sizes and the unorthodox input sizes 91x109x91.
if self.up_mode == 'upconv':
upsampling = self.up(X, pool_indices, output_size=Y_encoder.size())
elif self.up_mode == 'upsample':
upsampling = self.up(X)
elif self.up_mode == 'upconv':
upsampling = self.up(X, output_size=Y_encoder.size())
if Y_encoder is None:
concatenation = upsampling
......@@ -510,7 +518,7 @@ class DecoderBlock3D(ConvolutionalBlock3D):
Y = super(DecoderBlock3D, self).forward(concatenation)
if self.dropout_needed:
Y = self.dropout_needed(Y)
Y = self.dropout(Y)
return Y
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment