# .cuda() call transfers the densor from the CPU to the GPU if that is the case.
# Non-blocking argument lets the caller bypas synchronization when necessary
withtorch.no_grad():# Causes operations to have no gradients
output=self.forward(X)
_,idx=torch.max(output,1)
# We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray
idx=idx.data.cpu().numpy()
prediction=np.squeeze(idx)
delX,output,idx
returnprediction
defreset_parameters(self):
"""Parameter Initialization
This function (re)initializes the parameters of the defined network.
This function is a wrapper for the reset_parameters() function defined for each module.
More information can be found here: https://discuss.pytorch.org/t/what-is-the-default-initialization-of-a-conv2d-layer-and-linear-layer/16055 + https://discuss.pytorch.org/t/how-to-reset-model-weights-to-effectively-implement-crossvalidation/53859
An alternative (re)initialization method is described here: https://discuss.pytorch.org/t/how-to-reset-variables-values-in-nn-modules/32639
@@ -19,6 +19,335 @@ import torch.nn.functional as F
# TODO: Currently, it appears that we are using constant size filters. We will need to adjust this in the network architecture, to allow it to encode/decode information!
# ResBlock 3D UNet:
classDensBlock3D(nn.Module):
"""Parent class for a 3D convolutional residual block.
This class represents a generic parent class for a convolutional residual 3D encoder or decoder block.
The class represents a subclass/child class of nn.Module, inheriting its functionality.
Args:
parameters (dict): Contains information on kernel size, number of channels, number of filters, and if convolution is strided.
parameters = {
'kernel_heigth': 5
'kernel_width': 5
'kernel_depth' : 5
'input_channels': 64
'output_channels': 64
'convolution_stride': 1
'dropout': 0.2
}
Returns:
torch.tensor: Output forward passed tensor
"""
def__init__(self,parameters):
super(DensBlock3D,self).__init__()
# We first calculate the amount of zero padding required (http://cs231n.github.io/convolutional-networks/)