Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
c14030f7
Commit
c14030f7
authored
Mar 31, 2020
by
Andrei-Claudiu Roibu
🖥
Browse files
added save_checkpoint to solver train
parent
8ab6cb09
Changes
1
Hide whitespace changes
Inline
Side-by-side
solver.py
View file @
c14030f7
...
...
@@ -209,21 +209,30 @@ class Solver():
self
.
LogWriter
.
loss_per_epoch
(
losses
,
phase
,
epoch
)
dice_score_mean
=
self
.
LogWriter
.
dice_score_per_epoch
(
phase
,
output_array
,
y_array
,
epoch
)
if
phase
==
=
'test'
:
if
phase
==
'test'
:
if
dice_score_mean
>
self
.
best_mean_score
:
self
.
best_mean_score
=
dice_score_mean
self
.
best_mean_score_epoch
=
epoch
index
=
np
.
random
.
choice
(
len
(
dataloaders
[
phase
].
dataset
.
X
),
size
=
3
,
replace
=
False
)
self
.
LogWriter
.
sample_image_per_epoch
(
prediction
=
model
.
predict
(
dataloaders
[
phase
].
dataset
.
X
[
index
],
self
.
device
)
ground_truth
=
dataloaders
[
phase
].
dataset
.
y
[
index
],
\
self
.
LogWriter
.
sample_image_per_epoch
(
prediction
=
model
.
predict
(
dataloaders
[
phase
].
dataset
.
X
[
index
],
self
.
device
)
,
ground_truth
=
dataloaders
[
phase
].
dataset
.
y
[
index
],
phase
=
phase
epoch
=
epoch
)
print
(
"Epoch {}/{} DONE!"
.
format
(
epoch
,
self
.
number_epochs
))
self
.
save_checkpoint
()
# TODO - write function and save the checkpoint!
self
.
save_checkpoint
(
state
=
{
'epoch'
:
epoch
+
1
,
'start_iteration'
:
iteration
+
1
,
'arch'
:
self
.
model_name
,
'state_dict'
:
model
.
state_dict
(),
'optimizer'
:
optimizer
.
state_dict
(),
'scheduler'
:
learning_rate_scheduler
.
state_dict
()
},
filename
=
os
.
path
.
join
(
self
.
experiment_directory_path
,
checkpoint_directory
,
'checkpoint_epoch_'
+
str
(
epoch
)
+
'.'
+
checkpoint_extension
)
)
self
.
LogWriter
.
close
()
print
(
'----------------------------------------'
)
...
...
@@ -236,9 +245,6 @@ class Solver():
# TODO: MAKE SURE any log writer function is closed!
def
save_model
(
self
):
pass
def
save_checkpoint
(
self
,
state
,
filename
):
"""General Checkpoint Save
...
...
@@ -257,10 +263,25 @@ class Solver():
torch
.
save
(
state
,
filename
)
def
load_checkpoint
(
self
):
pass
def
load_checkpoint
(
self
,
epoch
=
None
):
"""General Checkpoint Loader
This function loads a previous checkpoint for inference and/or resuming training
Args:
epoch (int): Current epoch value
Returns:
None
Raises:
None
"""
def
_load_checkpoint_file
(
self
):
# Name is private = can't be called outisde of this module
pass
def
save_model
(
self
):
pass
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment