Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
12db841c
Commit
12db841c
authored
Mar 31, 2020
by
Andrei-Claudiu Roibu
🖥
Browse files
code cleaning & fixing semantic mistakes
parent
9f21c7b3
Changes
2
Hide whitespace changes
Inline
Side-by-side
solver.py
View file @
12db841c
...
...
@@ -169,7 +169,7 @@ class Solver():
for
batch_index
,
sampled_batch
in
enumerate
(
dataloaders
[
phase
]):
X
=
sampled_batch
[
0
].
type
(
torch
.
FloatTensor
)
y
=
sampled_batch
[
1
].
type
(
torch
.
Lon
d
Tensor
)
y
=
sampled_batch
[
1
].
type
(
torch
.
Lon
g
Tensor
)
if
model
.
is_cuda
():
X
=
X
.
cuda
(
self
.
device
,
non_blocking
=
True
)
...
...
@@ -186,7 +186,7 @@ class Solver():
if
batch_index
%
self
.
loss_log_period
==
0
:
self
.
LogWriter
.
loss_per_iteration
(
self
,
loss
.
item
(),
batch_index
,
iteration
)
self
.
LogWriter
.
loss_per_iteration
(
loss
.
item
(),
batch_index
,
iteration
)
iteration
+=
1
...
...
@@ -219,7 +219,7 @@ class Solver():
index
=
np
.
random
.
choice
(
len
(
dataloaders
[
phase
].
dataset
.
X
),
size
=
3
,
replace
=
False
)
self
.
LogWriter
.
sample_image_per_epoch
(
prediction
=
model
.
predict
(
dataloaders
[
phase
].
dataset
.
X
[
index
],
self
.
device
),
ground_truth
=
dataloaders
[
phase
].
dataset
.
y
[
index
],
phase
=
phase
phase
=
phase
,
epoch
=
epoch
)
print
(
"Epoch {}/{} DONE!"
.
format
(
epoch
,
self
.
number_epochs
))
...
...
@@ -318,10 +318,10 @@ class Solver():
# We are not loading the model_name as we might want to pre-train a model and then use it.
self
.
model
.
load_state_dict
=
checkpoint
[
'state_dict'
]
self
.
optimizer
.
load_state_dict
=
checkpoint
[
'optimizer'
]
self
.
scheduler
.
load_state_dict
=
checkpoint
[
'scheduler'
]
self
.
learning_rate_
scheduler
.
load_state_dict
=
checkpoint
[
'scheduler'
]
for
state
in
self
.
optimizer
.
state
.
values
():
for
key
,
value
in
state
.
items
{}
:
for
key
,
value
in
state
.
items
()
:
if
torch
.
is_tensor
(
value
):
state
[
key
]
=
value
.
to
(
self
.
device
)
...
...
utils/data_logging_utils.py
View file @
12db841c
...
...
@@ -20,6 +20,7 @@ import logging
import
numpy
as
np
import
re
from
textwrap
import
wrap
import
torch
# The SummaryWriter class provides a high-level API to create an event file in a given directory and add summaries and events to it.
# More here: https://tensorboardx.readthedocs.io/en/latest/tensorboard.html
...
...
@@ -139,7 +140,7 @@ class LogWriter():
"""
print
(
"Loss for Iteration {} is: {}"
.
format
(
batch_index
,
loss_per_iteration
))
self
.
log_writer
[
'train'
].
add_scalar
(
tag
=
'loss / iteration'
,
loss_per_iteration
,
iteration
)
self
.
log_writer
[
'train'
].
add_scalar
(
'loss / iteration'
,
loss_per_iteration
,
iteration
)
def
loss_per_epoch
(
self
,
losses
,
phase
,
epoch
):
"""Log function
...
...
@@ -164,7 +165,7 @@ class LogWriter():
loss
=
np
.
mean
(
losses
)
print
(
"Loss for Epoch {} of {} is: {}"
.
format
(
epoch
,
phase
,
loss
))
self
.
log_writer
[
phase
].
add_scalar
(
tag
=
'loss / iteration'
,
loss
,
epoch
)
self
.
log_writer
[
phase
].
add_scalar
(
'loss / iteration'
,
loss
,
epoch
)
# Currently, no confusion matrix is required
# TODO: add a confusion matrix per epoch and confusion matrix plot functions if required
...
...
@@ -190,12 +191,11 @@ class LogWriter():
print
(
"Dice Score is being calculated..."
,
end
=
''
,
flush
=
True
)
dice_score
=
evaluation
.
dice_score_calculator
(
outputs
,
correct_labels
,
self
.
number_of_classes
)
mean_dice_score
=
torch
.
mean
(
dice_score
)
self
.
plot_dice_score
(
dice_score
,
phase
,
plot_name
=
'dice_score_per_epoch'
,
title
=
'Dice Score'
,
epoch
)
self
.
log_writer
[
phase
].
add_scalar
(
tag
=
'loss / iteration'
,
loss
,
epoch
)
self
.
plot_dice_score
(
dice_score
,
phase
,
plot_name
=
'dice_score_per_epoch'
,
title
=
'Dice Score'
,
epochs
=
epoch
)
print
(
"Dice Score calculated successfully"
)
return
mean_dice_score
return
mean_dice_score
.
item
()
def
plot_dice_score
(
self
,
dice_score
,
phase
,
plot_name
,
title
,
epochs
):
def
plot_dice_score
(
self
,
dice_score
,
phase
,
plot_name
,
title
,
epochs
=
None
):
"""Function plotting dice score for multiple epochs
This function plots the dice score for each epoch.
...
...
@@ -223,7 +223,7 @@ class LogWriter():
ax
.
set_xticklabels
(
self
.
labels
)
ax
.
xaxis
.
tick_bottom
()
if
st
ep
:
if
ep
ochs
:
self
.
log_writer
[
phase
].
add_figure
(
plot_name
+
'/'
+
phase
,
figure
,
global_step
=
epochs
)
else
:
self
.
log_writer
[
phase
].
add_figure
(
plot_name
+
'/'
+
phase
,
figure
)
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment