Commit 9c3d9d33 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

added early stopping counter and best score to the checkpoints

parent 7b7ee1b0
......@@ -121,7 +121,6 @@ class Solver():
use_last_checkpoint=use_last_checkpoint,
labels=labels)
self.EarlyStopping = EarlyStopping(patience=10, min_delta=0)
self.early_stop = False
if crop_flag == False:
......@@ -132,8 +131,14 @@ class Solver():
self.save_model_directory = save_model_directory
self.final_model_output_file = experiment_name + ".pth.tar"
self.best_score_early_stop = None
self.counter_early_stop = 0
if use_last_checkpoint:
self.load_checkpoint()
self.EarlyStopping = EarlyStopping(patience=2, min_delta=0, best_score=self.best_score_early_stop, counter=self.counter_early_stop)
else:
self.EarlyStopping = EarlyStopping(patience=2, min_delta=0)
def train(self, train_loader, validation_loader):
"""Training Function
......@@ -256,9 +261,11 @@ class Solver():
previous_MSE = np.mean(MSEs)
if phase == 'validation':
early_stop, save_checkpoint = self.EarlyStopping(
early_stop, save_checkpoint, best_score_early_stop, counter_early_stop = self.EarlyStopping(
np.mean(losses))
self.early_stop = early_stop
self.best_score_early_stop = best_score_early_stop
self.counter_early_stop = counter_early_stop
if save_checkpoint == True:
validation_loss = np.mean(losses)
checkpoint_name = os.path.join(
......@@ -268,7 +275,9 @@ class Solver():
'arch': self.model_name,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': learning_rate_scheduler.state_dict()
'scheduler': learning_rate_scheduler.state_dict(),
'best_score_early_stop': self.best_score_early_stop,
'counter_early_stop': self.counter_early_stop
},
filename=checkpoint_name
)
......@@ -374,6 +383,8 @@ class Solver():
# We are not loading the model_name as we might want to pre-train a model and then use it.
self.model.load_state_dict(checkpoint['state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.best_score_early_stop = checkpoint['best_score_early_stop']
self.counter_early_stop = checkpoint['counter_early_stop']
for state in self.optimizer.state.values():
for key, value in state.items():
......
......@@ -34,11 +34,11 @@ class EarlyStopping:
"""
def __init__(self, patience=5, min_delta=0):
def __init__(self, patience=5, min_delta=0, best_score=None, counter=0):
self.patience = patience
self.counter = 0
self.best_score = None
self.counter = counter
self.best_score = best_score
self.early_stop = False
self.save_checkpoint = False
self.min_delta = min_delta
......@@ -63,7 +63,7 @@ class EarlyStopping:
self.counter = 0
self.save_checkpoint = True
return self.early_stop, self.save_checkpoint
return self.early_stop, self.save_checkpoint, self.best_score, self.counter
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment