Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
9c3d9d33
Commit
9c3d9d33
authored
Aug 20, 2020
by
Andrei Roibu
Browse files
added early stopping counter and best score to the checkpoints
parent
7b7ee1b0
Changes
2
Hide whitespace changes
Inline
Side-by-side
solver.py
View file @
9c3d9d33
...
...
@@ -121,7 +121,6 @@ class Solver():
use_last_checkpoint
=
use_last_checkpoint
,
labels
=
labels
)
self
.
EarlyStopping
=
EarlyStopping
(
patience
=
10
,
min_delta
=
0
)
self
.
early_stop
=
False
if
crop_flag
==
False
:
...
...
@@ -132,8 +131,14 @@ class Solver():
self
.
save_model_directory
=
save_model_directory
self
.
final_model_output_file
=
experiment_name
+
".pth.tar"
self
.
best_score_early_stop
=
None
self
.
counter_early_stop
=
0
if
use_last_checkpoint
:
self
.
load_checkpoint
()
self
.
EarlyStopping
=
EarlyStopping
(
patience
=
2
,
min_delta
=
0
,
best_score
=
self
.
best_score_early_stop
,
counter
=
self
.
counter_early_stop
)
else
:
self
.
EarlyStopping
=
EarlyStopping
(
patience
=
2
,
min_delta
=
0
)
def
train
(
self
,
train_loader
,
validation_loader
):
"""Training Function
...
...
@@ -256,9 +261,11 @@ class Solver():
previous_MSE
=
np
.
mean
(
MSEs
)
if
phase
==
'validation'
:
early_stop
,
save_checkpoint
=
self
.
EarlyStopping
(
early_stop
,
save_checkpoint
,
best_score_early_stop
,
counter_early_stop
=
self
.
EarlyStopping
(
np
.
mean
(
losses
))
self
.
early_stop
=
early_stop
self
.
best_score_early_stop
=
best_score_early_stop
self
.
counter_early_stop
=
counter_early_stop
if
save_checkpoint
==
True
:
validation_loss
=
np
.
mean
(
losses
)
checkpoint_name
=
os
.
path
.
join
(
...
...
@@ -268,7 +275,9 @@ class Solver():
'arch'
:
self
.
model_name
,
'state_dict'
:
model
.
state_dict
(),
'optimizer'
:
optimizer
.
state_dict
(),
'scheduler'
:
learning_rate_scheduler
.
state_dict
()
'scheduler'
:
learning_rate_scheduler
.
state_dict
(),
'best_score_early_stop'
:
self
.
best_score_early_stop
,
'counter_early_stop'
:
self
.
counter_early_stop
},
filename
=
checkpoint_name
)
...
...
@@ -374,6 +383,8 @@ class Solver():
# We are not loading the model_name as we might want to pre-train a model and then use it.
self
.
model
.
load_state_dict
(
checkpoint
[
'state_dict'
])
self
.
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
self
.
best_score_early_stop
=
checkpoint
[
'best_score_early_stop'
]
self
.
counter_early_stop
=
checkpoint
[
'counter_early_stop'
]
for
state
in
self
.
optimizer
.
state
.
values
():
for
key
,
value
in
state
.
items
():
...
...
utils/early_stopping.py
View file @
9c3d9d33
...
...
@@ -34,11 +34,11 @@ class EarlyStopping:
"""
def
__init__
(
self
,
patience
=
5
,
min_delta
=
0
):
def
__init__
(
self
,
patience
=
5
,
min_delta
=
0
,
best_score
=
None
,
counter
=
0
):
self
.
patience
=
patience
self
.
counter
=
0
self
.
best_score
=
Non
e
self
.
counter
=
counter
self
.
best_score
=
best_scor
e
self
.
early_stop
=
False
self
.
save_checkpoint
=
False
self
.
min_delta
=
min_delta
...
...
@@ -63,7 +63,7 @@ class EarlyStopping:
self
.
counter
=
0
self
.
save_checkpoint
=
True
return
self
.
early_stop
,
self
.
save_checkpoint
return
self
.
early_stop
,
self
.
save_checkpoint
,
self
.
best_score
,
self
.
counter
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment