Commit f72471db authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

created a post-training prediction function

parent ea6e15b8
......@@ -141,23 +141,57 @@ class BrainMapperUNet(nn.Module):
def predict(self, X):
"""Post-training Output Prediction
Description
This function predicts the output of the of the U-net post-training
Args:
X (torch.tensor): input dMRI volume
Returns:
prediction (ndarray): predicted output after training
Raises:
None
"""
return None
if __name__ == '__main__':
parameters = {
'kernel_heigth': 5,
'kernel_width': 5,
'kernel_classification': 1,
'input_channels': 1,
'output_channels': 64,
'convolution_stride': 1,
'dropout': 0.2,
'pool_kernel_size': 2,
'pool_stride': 2,
'up_mode': 'upconv',
'number_of_classes': 1
}
network = BrainMapperUNet(parameters)
self.eval() # PyToch module setting network to evaluation mode
if type(X) is np.ndarray:
X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor)
elif type(X) is torch.Tensor and not X.is_cuda:
X = X.type(torch.FloatTensor).cuda(device, non_blocking=True)
# .cuda() call transfers the densor from the CPU to the GPU if that is the case.
# Non-blocking argument lets the caller bypas synchronization when necessary
with torch.no_grad(): # Causes operations to have no gradients
output = self.forward(X)
_, idx = torch.max(output, 1)
idx = idx.data.cpu().numpy() # We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray
prediction = np.squeeze(idx)
del X, output, idx
return prediction
# if __name__ == '__main__':
# # For debugging - To be deleted later! TODO
# parameters = {
# 'kernel_heigth': 5,
# 'kernel_width': 5,
# 'kernel_classification': 1,
# 'input_channels': 1,
# 'output_channels': 64,
# 'convolution_stride': 1,
# 'dropout': 0.2,
# 'pool_kernel_size': 2,
# 'pool_stride': 2,
# 'up_mode': 'upconv',
# 'number_of_classes': 1
# }
# network = BrainMapperUNet(parameters)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment