Commit 852cdef8 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

removed last checkpoint del, added code for keeping track of early stops

parent 9d940084
......@@ -133,9 +133,9 @@ class Solver():
self.best_score_early_stop = None
self.counter_early_stop = 0
self.previous_checkpoint = None
self.previous_loss = None
self.previous_MSE = None
self.valid_epoch = None
if use_last_checkpoint:
self.load_checkpoint()
......@@ -143,7 +143,6 @@ class Solver():
else:
self.EarlyStopping = EarlyStopping(patience=2, min_delta=0)
def train(self, train_loader, validation_loader):
"""Training Function
......@@ -180,6 +179,11 @@ class Solver():
iteration = self.start_iteration
for epoch in range(self.start_epoch, self.number_epochs+1):
if self.early_stop == True:
print("ATTENTION!: Training stopped due to previous early stop flag!")
break
print("Epoch {}/{}".format(epoch, self.number_epochs))
for phase in ['train', 'validation']:
......@@ -261,36 +265,32 @@ class Solver():
self.previous_MSE = np.mean(MSEs)
if phase == 'validation':
early_stop, save_checkpoint, best_score_early_stop, counter_early_stop = self.EarlyStopping(
np.mean(losses))
early_stop, best_score_early_stop, counter_early_stop = self.EarlyStopping(np.mean(losses))
self.early_stop = early_stop
self.best_score_early_stop = best_score_early_stop
self.counter_early_stop = counter_early_stop
if save_checkpoint == True:
validation_loss = np.mean(losses)
checkpoint_name = os.path.join(
self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
self.save_checkpoint(state={'epoch': epoch + 1,
'start_iteration': iteration + 1,
'arch': self.model_name,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': learning_rate_scheduler.state_dict(),
'best_score_early_stop': self.best_score_early_stop,
'counter_early_stop': self.counter_early_stop,
'previous_checkpoint': self.previous_checkpoint,
'previous_loss': self.previous_loss,
'previous_MSE': self.previous_MSE,
},
filename=checkpoint_name
)
if self.previous_checkpoint != None:
os.remove(self.previous_checkpoint)
self.previous_checkpoint = checkpoint_name
else:
self.previous_checkpoint = checkpoint_name
checkpoint_name = os.path.join(
self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
if self.counter_early_stop == 0:
self.valid_epoch = epoch
self.save_checkpoint(state={'epoch': epoch + 1,
'start_iteration': iteration + 1,
'arch': self.model_name,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': learning_rate_scheduler.state_dict(),
'best_score_early_stop': self.best_score_early_stop,
'counter_early_stop': self.counter_early_stop,
'previous_loss': self.previous_loss,
'previous_MSE': self.previous_MSE,
'early_stop': self.early_stop
},
filename=checkpoint_name
)
if phase == 'train':
learning_rate_scheduler.step()
......@@ -300,7 +300,7 @@ class Solver():
if self.early_stop == True:
print("ATTENTION!: Training stopped early to prevent overfitting!")
self.load_checkpoint()
self.load_checkpoint(epoch=self.valid_epoch)
break
else:
continue
......@@ -323,11 +323,6 @@ class Solver():
print('Final Model Saved in: {}'.format(model_output_path))
print('****************************************************************')
if self.start_epoch >= self.number_epochs+1:
validation_loss = None
return validation_loss
def save_checkpoint(self, state, filename):
"""General Checkpoint Save
......@@ -388,9 +383,9 @@ class Solver():
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.best_score_early_stop = checkpoint['best_score_early_stop']
self.counter_early_stop = checkpoint['counter_early_stop']
self.previous_checkpoint = checkpoint['previous_checkpoint']
self.previous_loss = checkpoint['previous_loss']
self.previous_MSE = checkpoint['previous_MSE']
self.early_stop = checkpoint['early_stop']
for state in self.optimizer.state.values():
for key, value in state.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment