Commit 9c3d9d33 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

added early stopping counter and best score to the checkpoints

parent 7b7ee1b0
...@@ -121,7 +121,6 @@ class Solver(): ...@@ -121,7 +121,6 @@ class Solver():
use_last_checkpoint=use_last_checkpoint, use_last_checkpoint=use_last_checkpoint,
labels=labels) labels=labels)
self.EarlyStopping = EarlyStopping(patience=10, min_delta=0)
self.early_stop = False self.early_stop = False
if crop_flag == False: if crop_flag == False:
...@@ -132,8 +131,14 @@ class Solver(): ...@@ -132,8 +131,14 @@ class Solver():
self.save_model_directory = save_model_directory self.save_model_directory = save_model_directory
self.final_model_output_file = experiment_name + ".pth.tar" self.final_model_output_file = experiment_name + ".pth.tar"
self.best_score_early_stop = None
self.counter_early_stop = 0
if use_last_checkpoint: if use_last_checkpoint:
self.load_checkpoint() self.load_checkpoint()
self.EarlyStopping = EarlyStopping(patience=2, min_delta=0, best_score=self.best_score_early_stop, counter=self.counter_early_stop)
else:
self.EarlyStopping = EarlyStopping(patience=2, min_delta=0)
def train(self, train_loader, validation_loader): def train(self, train_loader, validation_loader):
"""Training Function """Training Function
...@@ -256,9 +261,11 @@ class Solver(): ...@@ -256,9 +261,11 @@ class Solver():
previous_MSE = np.mean(MSEs) previous_MSE = np.mean(MSEs)
if phase == 'validation': if phase == 'validation':
early_stop, save_checkpoint = self.EarlyStopping( early_stop, save_checkpoint, best_score_early_stop, counter_early_stop = self.EarlyStopping(
np.mean(losses)) np.mean(losses))
self.early_stop = early_stop self.early_stop = early_stop
self.best_score_early_stop = best_score_early_stop
self.counter_early_stop = counter_early_stop
if save_checkpoint == True: if save_checkpoint == True:
validation_loss = np.mean(losses) validation_loss = np.mean(losses)
checkpoint_name = os.path.join( checkpoint_name = os.path.join(
...@@ -268,7 +275,9 @@ class Solver(): ...@@ -268,7 +275,9 @@ class Solver():
'arch': self.model_name, 'arch': self.model_name,
'state_dict': model.state_dict(), 'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'scheduler': learning_rate_scheduler.state_dict() 'scheduler': learning_rate_scheduler.state_dict(),
'best_score_early_stop': self.best_score_early_stop,
'counter_early_stop': self.counter_early_stop
}, },
filename=checkpoint_name filename=checkpoint_name
) )
...@@ -374,6 +383,8 @@ class Solver(): ...@@ -374,6 +383,8 @@ class Solver():
# We are not loading the model_name as we might want to pre-train a model and then use it. # We are not loading the model_name as we might want to pre-train a model and then use it.
self.model.load_state_dict(checkpoint['state_dict']) self.model.load_state_dict(checkpoint['state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer']) self.optimizer.load_state_dict(checkpoint['optimizer'])
self.best_score_early_stop = checkpoint['best_score_early_stop']
self.counter_early_stop = checkpoint['counter_early_stop']
for state in self.optimizer.state.values(): for state in self.optimizer.state.values():
for key, value in state.items(): for key, value in state.items():
......
...@@ -34,11 +34,11 @@ class EarlyStopping: ...@@ -34,11 +34,11 @@ class EarlyStopping:
""" """
def __init__(self, patience=5, min_delta=0): def __init__(self, patience=5, min_delta=0, best_score=None, counter=0):
self.patience = patience self.patience = patience
self.counter = 0 self.counter = counter
self.best_score = None self.best_score = best_score
self.early_stop = False self.early_stop = False
self.save_checkpoint = False self.save_checkpoint = False
self.min_delta = min_delta self.min_delta = min_delta
...@@ -63,7 +63,7 @@ class EarlyStopping: ...@@ -63,7 +63,7 @@ class EarlyStopping:
self.counter = 0 self.counter = 0
self.save_checkpoint = True self.save_checkpoint = True
return self.early_stop, self.save_checkpoint return self.early_stop, self.save_checkpoint, self.best_score, self.counter
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment