Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
f72471db
Commit
f72471db
authored
Mar 25, 2020
by
Andrei-Claudiu Roibu
🖥
Browse files
created a post-training prediction function
parent
ea6e15b8
Changes
1
Show whitespace changes
Inline
Side-by-side
BrainMapperUNet.py
View file @
f72471db
...
...
@@ -141,23 +141,57 @@ class BrainMapperUNet(nn.Module):
def
predict
(
self
,
X
):
"""Post-training Output Prediction
Description
This function predicts the output of the of the U-net post-training
Args:
X (torch.tensor): input dMRI volume
Returns:
prediction (ndarray): predicted output after training
Raises:
None
"""
return
Non
e
self
.
eval
()
# PyToch module setting network to evaluation mod
e
if
__name__
==
'__main__'
:
if
type
(
X
)
is
np
.
ndarray
:
X
=
torch
.
tensor
(
X
,
requires_grad
=
False
).
type
(
torch
.
FloatTensor
)
elif
type
(
X
)
is
torch
.
Tensor
and
not
X
.
is_cuda
:
X
=
X
.
type
(
torch
.
FloatTensor
).
cuda
(
device
,
non_blocking
=
True
)
parameters
=
{
'kernel_heigth'
:
5
,
'kernel_width'
:
5
,
'kernel_classification'
:
1
,
'input_channels'
:
1
,
'output_channels'
:
64
,
'convolution_stride'
:
1
,
'dropout'
:
0.2
,
'pool_kernel_size'
:
2
,
'pool_stride'
:
2
,
'up_mode'
:
'upconv'
,
'number_of_classes'
:
1
}
network
=
BrainMapperUNet
(
parameters
)
# .cuda() call transfers the densor from the CPU to the GPU if that is the case.
# Non-blocking argument lets the caller bypas synchronization when necessary
with
torch
.
no_grad
():
# Causes operations to have no gradients
output
=
self
.
forward
(
X
)
_
,
idx
=
torch
.
max
(
output
,
1
)
idx
=
idx
.
data
.
cpu
().
numpy
()
# We retrieve the tensor held by idx (.data), and map it to a cpu as an ndarray
prediction
=
np
.
squeeze
(
idx
)
del
X
,
output
,
idx
return
prediction
# if __name__ == '__main__':
# # For debugging - To be deleted later! TODO
# parameters = {
# 'kernel_heigth': 5,
# 'kernel_width': 5,
# 'kernel_classification': 1,
# 'input_channels': 1,
# 'output_channels': 64,
# 'convolution_stride': 1,
# 'dropout': 0.2,
# 'pool_kernel_size': 2,
# 'pool_stride': 2,
# 'up_mode': 'upconv',
# 'number_of_classes': 1
# }
# network = BrainMapperUNet(parameters)
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment