Commit 7b808175 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

bug fixes

parent 8382f477
......@@ -319,6 +319,9 @@ class AutoEncoder3D(nn.Module):
# Decoder
parameters['input_channels'] = parameters['output_channels']
parameters['convolution_stride'] = parameters['transformer_blocks_stride']
if equal_channels_blocks != 0:
self.decoderBlocks = nn.ModuleList([modules.ResNetDecoderBlock3D(parameters) for i in range(equal_channels_blocks)])
......@@ -365,9 +368,6 @@ class AutoEncoder3D(nn.Module):
# Decoder
parameters['input_channels'] = parameters['output_channels']
parameters['convolution_stride'] = parameters['transformer_blocks_stride']
for index, decoderBlock in enumerate(self.decoderBlocks):
if index < Y_encoder_sizes_lenght:
X = decoderBlock.forward(X, Y_encoder_sizes[index])
......
......@@ -24,7 +24,7 @@ use_last_checkpoint = False
adam_w_flag = False
[NETWORK]
first_kernel_size = 10
first_kernel_size = 9
first_convolution_stride = 1
input_channels = 1
output_channels = 32
......@@ -34,7 +34,7 @@ convolution_stride = 2
number_of_encoder_blocks = 2
transformer_blocks_stride = 1
number_of_transformer_blocks = 6
kernel_classification = 10
kernel_classification = 9
pool_kernel_size = 3
pool_stride = 2
dropout = 0
......
......@@ -24,7 +24,7 @@ negative_flag = False
outlier_flag = True
shrinkage_flag = False
hard_shrinkage_flag = False
crop_flag = False
crop_flag = True
input_data_train = "input_data_train.h5"
target_data_train = "target_data_train.h5"
input_data_validation = "input_data_validation.h5"
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment