Commit de062ac3 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

saving valid checkpoints, fixed bugs in overfitting prevention for multiple epoch patiece

parent 235db1fd
......@@ -287,7 +287,8 @@ class Solver():
'counter_early_stop': self.counter_early_stop,
'previous_loss': self.previous_loss,
'previous_MSE': self.previous_MSE,
'early_stop': self.early_stop
'early_stop': self.early_stop,
'valid_epoch': self.valid_epoch
},
filename=checkpoint_name
)
......@@ -305,23 +306,35 @@ class Solver():
else:
continue
model_output_path = os.path.join(
self.save_model_directory, self.final_model_output_file)
if self.early_stop == True:
self.LogWriter.close()
print('----------------------------------------')
print('NO TRAINING DONE TO PREVENT OVERFITTING!')
print('=====================')
end_time = datetime.now()
print('Completed At: {}'.format(end_time))
print('Training Duration: {}'.format(end_time - start_time))
print('****************************************************************')
else:
model_output_path = os.path.join(
self.save_model_directory, self.final_model_output_file)
create_folder(self.save_model_directory)
create_folder(self.save_model_directory)
model.save(model_output_path)
model.save(model_output_path)
self.LogWriter.close()
self.LogWriter.close()
print('----------------------------------------')
print('TRAINING IS COMPLETE!')
print('=====================')
end_time = datetime.now()
print('Completed At: {}'.format(end_time))
print('Training Duration: {}'.format(end_time - start_time))
print('Final Model Saved in: {}'.format(model_output_path))
print('****************************************************************')
print('----------------------------------------')
print('TRAINING IS COMPLETE!')
print('=====================')
end_time = datetime.now()
print('Completed At: {}'.format(end_time))
print('Training Duration: {}'.format(end_time - start_time))
print('Final Model Saved in: {}'.format(model_output_path))
print('****************************************************************')
def save_checkpoint(self, state, filename):
"""General Checkpoint Save
......@@ -386,6 +399,7 @@ class Solver():
self.previous_loss = checkpoint['previous_loss']
self.previous_MSE = checkpoint['previous_MSE']
self.early_stop = checkpoint['early_stop']
self.valid_epoch = checkpoint['valid_epoch']
for state in self.optimizer.state.values():
for key, value in state.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment