Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
852cdef8
Commit
852cdef8
authored
Aug 21, 2020
by
Andrei Roibu
Browse files
removed last checkpoint del, added code for keeping track of early stops
parent
9d940084
Changes
1
Hide whitespace changes
Inline
Side-by-side
solver.py
View file @
852cdef8
...
...
@@ -133,9 +133,9 @@ class Solver():
self
.
best_score_early_stop
=
None
self
.
counter_early_stop
=
0
self
.
previous_checkpoint
=
None
self
.
previous_loss
=
None
self
.
previous_MSE
=
None
self
.
valid_epoch
=
None
if
use_last_checkpoint
:
self
.
load_checkpoint
()
...
...
@@ -143,7 +143,6 @@ class Solver():
else
:
self
.
EarlyStopping
=
EarlyStopping
(
patience
=
2
,
min_delta
=
0
)
def
train
(
self
,
train_loader
,
validation_loader
):
"""Training Function
...
...
@@ -180,6 +179,11 @@ class Solver():
iteration
=
self
.
start_iteration
for
epoch
in
range
(
self
.
start_epoch
,
self
.
number_epochs
+
1
):
if
self
.
early_stop
==
True
:
print
(
"ATTENTION!: Training stopped due to previous early stop flag!"
)
break
print
(
"Epoch {}/{}"
.
format
(
epoch
,
self
.
number_epochs
))
for
phase
in
[
'train'
,
'validation'
]:
...
...
@@ -261,36 +265,32 @@ class Solver():
self
.
previous_MSE
=
np
.
mean
(
MSEs
)
if
phase
==
'validation'
:
early_stop
,
save_checkpoint
,
best_score_early_stop
,
counter_early_stop
=
self
.
EarlyStopping
(
np
.
mean
(
losses
))
early_stop
,
best_score_early_stop
,
counter_early_stop
=
self
.
EarlyStopping
(
np
.
mean
(
losses
))
self
.
early_stop
=
early_stop
self
.
best_score_early_stop
=
best_score_early_stop
self
.
counter_early_stop
=
counter_early_stop
if
save_checkpoint
==
True
:
validation_loss
=
np
.
mean
(
losses
)
checkpoint_name
=
os
.
path
.
join
(
self
.
experiment_directory_path
,
self
.
checkpoint_directory
,
'checkpoint_epoch_'
+
str
(
epoch
)
+
'.'
+
checkpoint_extension
)
self
.
save_checkpoint
(
state
=
{
'epoch'
:
epoch
+
1
,
'start_iteration'
:
iteration
+
1
,
'arch'
:
self
.
model_name
,
'state_dict'
:
model
.
state_dict
(),
'optimizer'
:
optimizer
.
state_dict
(),
'scheduler'
:
learning_rate_scheduler
.
state_dict
(),
'best_score_early_stop'
:
self
.
best_score_early_stop
,
'counter_early_stop'
:
self
.
counter_early_stop
,
'previous_checkpoint'
:
self
.
previous_checkpoint
,
'previous_loss'
:
self
.
previous_loss
,
'previous_MSE'
:
self
.
previous_MSE
,
},
filename
=
checkpoint_name
)
if
self
.
previous_checkpoint
!=
None
:
os
.
remove
(
self
.
previous_checkpoint
)
self
.
previous_checkpoint
=
checkpoint_name
else
:
self
.
previous_checkpoint
=
checkpoint_name
checkpoint_name
=
os
.
path
.
join
(
self
.
experiment_directory_path
,
self
.
checkpoint_directory
,
'checkpoint_epoch_'
+
str
(
epoch
)
+
'.'
+
checkpoint_extension
)
if
self
.
counter_early_stop
==
0
:
self
.
valid_epoch
=
epoch
self
.
save_checkpoint
(
state
=
{
'epoch'
:
epoch
+
1
,
'start_iteration'
:
iteration
+
1
,
'arch'
:
self
.
model_name
,
'state_dict'
:
model
.
state_dict
(),
'optimizer'
:
optimizer
.
state_dict
(),
'scheduler'
:
learning_rate_scheduler
.
state_dict
(),
'best_score_early_stop'
:
self
.
best_score_early_stop
,
'counter_early_stop'
:
self
.
counter_early_stop
,
'previous_loss'
:
self
.
previous_loss
,
'previous_MSE'
:
self
.
previous_MSE
,
'early_stop'
:
self
.
early_stop
},
filename
=
checkpoint_name
)
if
phase
==
'train'
:
learning_rate_scheduler
.
step
()
...
...
@@ -300,7 +300,7 @@ class Solver():
if
self
.
early_stop
==
True
:
print
(
"ATTENTION!: Training stopped early to prevent overfitting!"
)
self
.
load_checkpoint
()
self
.
load_checkpoint
(
epoch
=
self
.
valid_epoch
)
break
else
:
continue
...
...
@@ -323,11 +323,6 @@ class Solver():
print
(
'Final Model Saved in: {}'
.
format
(
model_output_path
))
print
(
'****************************************************************'
)
if
self
.
start_epoch
>=
self
.
number_epochs
+
1
:
validation_loss
=
None
return
validation_loss
def
save_checkpoint
(
self
,
state
,
filename
):
"""General Checkpoint Save
...
...
@@ -388,9 +383,9 @@ class Solver():
self
.
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
self
.
best_score_early_stop
=
checkpoint
[
'best_score_early_stop'
]
self
.
counter_early_stop
=
checkpoint
[
'counter_early_stop'
]
self
.
previous_checkpoint
=
checkpoint
[
'previous_checkpoint'
]
self
.
previous_loss
=
checkpoint
[
'previous_loss'
]
self
.
previous_MSE
=
checkpoint
[
'previous_MSE'
]
self
.
early_stop
=
checkpoint
[
'early_stop'
]
for
state
in
self
.
optimizer
.
state
.
values
():
for
key
,
value
in
state
.
items
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment