Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
31c05c0c
Commit
31c05c0c
authored
Apr 29, 2020
by
Andrei-Claudiu Roibu
🖥
Browse files
fixed bug with loading mask onto gpu
parent
d189eaa8
Changes
1
Hide whitespace changes
Inline
Side-by-side
solver.py
View file @
31c05c0c
...
...
@@ -109,8 +109,6 @@ class Solver():
self
.
start_epoch
=
1
self
.
start_iteration
=
1
# self.best_mean_score = 0
# self.best_mean_score_epoch = 0
self
.
LogWriter
=
LogWriter
(
number_of_classes
=
number_of_classes
,
logs_directory
=
logs_directory
,
...
...
@@ -124,8 +122,7 @@ class Solver():
if
use_last_checkpoint
:
self
.
load_checkpoint
()
self
.
MNI_152_2mm_mask
=
torch
.
from_numpy
(
Image
(
'utils/MNI152_T1_2mm_brain_mask.nii.gz'
).
data
)
self
.
MNI152_T1_2mm_brain_mask
=
torch
.
from_numpy
(
Image
(
'utils/MNI152_T1_2mm_brain_mask.nii.gz'
).
data
)
def
train
(
self
,
train_loader
,
validation_loader
):
"""Training Function
...
...
@@ -186,16 +183,17 @@ class Solver():
# We add an extra dimension (~ number of channels) for the 3D convolutions.
X
=
torch
.
unsqueeze
(
X
,
dim
=
1
)
y
=
torch
.
unsqueeze
(
y
,
dim
=
1
)
MNI152_T1_2mm_brain_mask
=
self
.
MNI152_T1_2mm_brain_mask
if
model
.
test_if_cuda
:
X
=
X
.
cuda
(
self
.
device
,
non_blocking
=
True
)
y
=
y
.
cuda
(
self
.
device
,
non_blocking
=
True
)
MNI152_T1_2mm_brain_mask
=
MNI152_T1_2mm_brain_mask
.
cuda
(
self
.
device
,
non_blocking
=
True
)
y_hat
=
model
(
X
)
# Forward pass
y_hat
=
model
(
X
)
# Forward pass
& Masking
### Masking goes here
y_hat
=
torch
.
mul
(
y_hat
,
self
.
MNI_152_2mm_mask
)
###
y_hat
=
torch
.
mul
(
y_hat
,
MNI152_T1_2mm_brain_mask
)
loss
=
self
.
loss_function
(
y_hat
,
y
)
# Loss computation
...
...
@@ -215,7 +213,7 @@ class Solver():
# Clear the memory
del
X
,
y
,
y_hat
,
loss
del
X
,
y
,
y_hat
,
loss
,
MNI152_T1_2mm_brain_mask
torch
.
cuda
.
empty_cache
()
if
phase
==
'validation'
:
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment