Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
4f522872
Commit
4f522872
authored
Apr 29, 2020
by
Andrei-Claudiu Roibu
🖥
Browse files
added masking between network output and loss calculation
parent
d12bcf8e
Changes
2
Hide whitespace changes
Inline
Side-by-side
solver.py
View file @
4f522872
...
@@ -16,6 +16,7 @@ import numpy as np
...
@@ -16,6 +16,7 @@ import numpy as np
import
torch
import
torch
import
glob
import
glob
from
fsl.data.image
import
Image
from
datetime
import
datetime
from
datetime
import
datetime
from
utils.losses
import
MSELoss
from
utils.losses
import
MSELoss
from
utils.data_utils
import
create_folder
from
utils.data_utils
import
create_folder
...
@@ -123,6 +124,9 @@ class Solver():
...
@@ -123,6 +124,9 @@ class Solver():
if
use_last_checkpoint
:
if
use_last_checkpoint
:
self
.
load_checkpoint
()
self
.
load_checkpoint
()
self
.
MNI_152_2mm_mask
=
torch
.
from_numpy
(
Image
(
'utils/MNI152_T1_2mm_brain_mask.nii.gz'
).
data
)
def
train
(
self
,
train_loader
,
validation_loader
):
def
train
(
self
,
train_loader
,
validation_loader
):
"""Training Function
"""Training Function
...
@@ -189,6 +193,10 @@ class Solver():
...
@@ -189,6 +193,10 @@ class Solver():
y_hat
=
model
(
X
)
# Forward pass
y_hat
=
model
(
X
)
# Forward pass
### Masking goes here
y_hat
=
torch
.
mul
(
y_hat
,
self
.
MNI_152_2mm_mask
)
###
loss
=
self
.
loss_function
(
y_hat
,
y
)
# Loss computation
loss
=
self
.
loss_function
(
y_hat
,
y
)
# Loss computation
if
phase
==
'train'
:
if
phase
==
'train'
:
...
@@ -237,9 +245,9 @@ class Solver():
...
@@ -237,9 +245,9 @@ class Solver():
filename
=
os
.
path
.
join
(
self
.
experiment_directory_path
,
self
.
checkpoint_directory
,
filename
=
os
.
path
.
join
(
self
.
experiment_directory_path
,
self
.
checkpoint_directory
,
'checkpoint_epoch_'
+
str
(
epoch
)
+
'.'
+
checkpoint_extension
)
'checkpoint_epoch_'
+
str
(
epoch
)
+
'.'
+
checkpoint_extension
)
)
)
if
epoch
!=
self
.
start_epoch
:
#
if epoch != self.start_epoch:
os
.
remove
(
os
.
path
.
join
(
self
.
experiment_directory_path
,
self
.
checkpoint_directory
,
#
os.remove(os.path.join(self.experiment_directory_path, self.checkpoint_directory,
'checkpoint_epoch_'
+
str
(
epoch
-
1
)
+
'.'
+
checkpoint_extension
))
#
'checkpoint_epoch_' + str(epoch-1) + '.' + checkpoint_extension))
if
phase
==
'train'
:
if
phase
==
'train'
:
learning_rate_scheduler
.
step
()
learning_rate_scheduler
.
step
()
...
...
utils/MNI152_T1_2mm_brain_mask.nii.gz
0 → 100644
View file @
4f522872
File added
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment