Commit f72471db authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

created a post-training prediction function

parent ea6e15b8
...@@ -141,23 +141,57 @@ class BrainMapperUNet(nn.Module): ...@@ -141,23 +141,57 @@ class BrainMapperUNet(nn.Module):
def predict(self, X): def predict(self, X):
"""Post-training Output Prediction """Post-training Output Prediction
Description
This function predicts the output of the of the U-net post-training
Args:
X (torch.tensor): input dMRI volume
Returns:
prediction (ndarray): predicted output after training
Raises:
None
""" """
return None self.eval() # PyToch module setting network to evaluation mode
if __name__ == '__main__': if type(X) is np.ndarray:
X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor)
parameters = { elif type(X) is torch.Tensor and not X.is_cuda:
'kernel_heigth': 5, X = X.type(torch.FloatTensor).cuda(device, non_blocking=True)
'kernel_width': 5,
'kernel_classification': 1, # .cuda() call transfers the densor from the CPU to the GPU if that is the case.
'input_channels': 1, # Non-blocking argument lets the caller bypas synchronization when necessary
'output_channels': 64,
'convolution_stride': 1, with torch.no_grad(): # Causes operations to have no gradients
'dropout': 0.2, output = self.forward(X)
'pool_kernel_size': 2,
'pool_stride': 2, _, idx = torch.max(output, 1)
'up_mode': 'upconv',
'number_of_classes': 1 idx = idx.data.cpu().numpy() # We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray
}
network = BrainMapperUNet(parameters) prediction = np.squeeze(idx)
del X, output, idx
return prediction
# if __name__ == '__main__':
# # For debugging - To be deleted later! TODO
# parameters = {
# 'kernel_heigth': 5,
# 'kernel_width': 5,
# 'kernel_classification': 1,
# 'input_channels': 1,
# 'output_channels': 64,
# 'convolution_stride': 1,
# 'dropout': 0.2,
# 'pool_kernel_size': 2,
# 'pool_stride': 2,
# 'up_mode': 'upconv',
# 'number_of_classes': 1
# }
# network = BrainMapperUNet(parameters)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment