Commit 6cd5bba9 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

added tanh flag

parent 734d7acb
......@@ -379,8 +379,12 @@ class ResNetClassifierBlock3D(nn.Module):
self.normalization = nn.InstanceNorm3d(
num_features=parameters['number_of_classes'])
self.activation = nn.Sigmoid()
# self.activation = nn.Tanh()
if parameters['final_activation'] == 'sigmoid':
self.activation = nn.Sigmoid()
elif parameters['final_activation'] == 'tanh':
self.activation = nn.Tanh()
else:
self.activation = None
# TODO: Might be wworth looking at GANS for image generation, and adding padding
......@@ -399,6 +403,7 @@ class ResNetClassifierBlock3D(nn.Module):
logits = self.normalization(self.convolutional_layer(X))
logits = self.activation(logits)
if isinstance(self.activation, (nn.Sigmoid, nn.Tanh)):
logits = self.activation(logits)
return logits
return logits
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment