Commit ca61c821 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

added flag for tanh/sig/none change

parent b79d46ce
......@@ -309,8 +309,12 @@ class ResNetClassifierBlock3D(nn.Module):
self.normalization = nn.InstanceNorm3d(
self.activation = nn.Sigmoid()
# self.activation = nn.Tanh()
if parameters['final_activation'] == 'sigmoid':
self.activation = nn.Sigmoid()
elif parameters['final_activation'] == 'tanh':
self.activation = nn.Tanh()
self.activation = None
# TODO: Might be wworth looking at GANS for image generation, and adding padding
......@@ -329,6 +333,7 @@ class ResNetClassifierBlock3D(nn.Module):
logits = self.normalization(self.convolutional_layer(X))
logits = self.activation(logits)
if isinstance(self.activation, (nn.Sigmoid, nn.Tanh)):
logits = self.activation(logits)
return logits
return logits
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment