Commit de062ac3 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

saving valid checkpoints, fixed bugs in overfitting prevention for multiple epoch patiece

parent 235db1fd
...@@ -287,7 +287,8 @@ class Solver(): ...@@ -287,7 +287,8 @@ class Solver():
'counter_early_stop': self.counter_early_stop, 'counter_early_stop': self.counter_early_stop,
'previous_loss': self.previous_loss, 'previous_loss': self.previous_loss,
'previous_MSE': self.previous_MSE, 'previous_MSE': self.previous_MSE,
'early_stop': self.early_stop 'early_stop': self.early_stop,
'valid_epoch': self.valid_epoch
}, },
filename=checkpoint_name filename=checkpoint_name
) )
...@@ -305,23 +306,35 @@ class Solver(): ...@@ -305,23 +306,35 @@ class Solver():
else: else:
continue continue
model_output_path = os.path.join( if self.early_stop == True:
self.save_model_directory, self.final_model_output_file)
self.LogWriter.close()
print('----------------------------------------')
print('NO TRAINING DONE TO PREVENT OVERFITTING!')
print('=====================')
end_time = datetime.now()
print('Completed At: {}'.format(end_time))
print('Training Duration: {}'.format(end_time - start_time))
print('****************************************************************')
else:
model_output_path = os.path.join(
self.save_model_directory, self.final_model_output_file)
create_folder(self.save_model_directory) create_folder(self.save_model_directory)
model.save(model_output_path) model.save(model_output_path)
self.LogWriter.close() self.LogWriter.close()
print('----------------------------------------') print('----------------------------------------')
print('TRAINING IS COMPLETE!') print('TRAINING IS COMPLETE!')
print('=====================') print('=====================')
end_time = datetime.now() end_time = datetime.now()
print('Completed At: {}'.format(end_time)) print('Completed At: {}'.format(end_time))
print('Training Duration: {}'.format(end_time - start_time)) print('Training Duration: {}'.format(end_time - start_time))
print('Final Model Saved in: {}'.format(model_output_path)) print('Final Model Saved in: {}'.format(model_output_path))
print('****************************************************************') print('****************************************************************')
def save_checkpoint(self, state, filename): def save_checkpoint(self, state, filename):
"""General Checkpoint Save """General Checkpoint Save
...@@ -386,6 +399,7 @@ class Solver(): ...@@ -386,6 +399,7 @@ class Solver():
self.previous_loss = checkpoint['previous_loss'] self.previous_loss = checkpoint['previous_loss']
self.previous_MSE = checkpoint['previous_MSE'] self.previous_MSE = checkpoint['previous_MSE']
self.early_stop = checkpoint['early_stop'] self.early_stop = checkpoint['early_stop']
self.valid_epoch = checkpoint['valid_epoch']
for state in self.optimizer.state.values(): for state in self.optimizer.state.values():
for key, value in state.items(): for key, value in state.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment