Commit 171d95ec authored by Andrei Roibu's avatar Andrei Roibu
Browse files

added previous loss, MSE and checkpoint to checkpoint

parent 9c3d9d33
......@@ -139,6 +139,9 @@ class Solver():
self.EarlyStopping = EarlyStopping(patience=2, min_delta=0, best_score=self.best_score_early_stop, counter=self.counter_early_stop)
else:
self.EarlyStopping = EarlyStopping(patience=2, min_delta=0)
self.previous_checkpoint = None
self.previous_loss = None
self.previous_MSE = None
def train(self, train_loader, validation_loader):
"""Training Function
......@@ -160,10 +163,6 @@ class Solver():
torch.cuda.empty_cache() # clear memory
model.cuda(self.device) # Moving the model to GPU
previous_checkpoint = None
previous_loss = None
previous_MSE = None
print('****************************************************************')
print('TRAINING IS STARTING!')
print('=====================')
......@@ -254,11 +253,11 @@ class Solver():
self.LogWriter.MSE_per_epoch(MSEs, phase, epoch)
elif phase == 'validation':
self.LogWriter.loss_per_epoch(
losses, phase, epoch, previous_loss=previous_loss)
previous_loss = np.mean(losses)
losses, phase, epoch, previous_loss=self.previous_loss)
self.previous_loss = np.mean(losses)
self.LogWriter.MSE_per_epoch(
MSEs, phase, epoch, previous_loss=previous_MSE)
previous_MSE = np.mean(MSEs)
MSEs, phase, epoch, previous_loss=self.previous_MSE)
self.previous_MSE = np.mean(MSEs)
if phase == 'validation':
early_stop, save_checkpoint, best_score_early_stop, counter_early_stop = self.EarlyStopping(
......@@ -277,16 +276,19 @@ class Solver():
'optimizer': optimizer.state_dict(),
'scheduler': learning_rate_scheduler.state_dict(),
'best_score_early_stop': self.best_score_early_stop,
'counter_early_stop': self.counter_early_stop
'counter_early_stop': self.counter_early_stop,
'previous_checkpoint': self.previous_checkpoint,
'previous_loss': self.previous_loss,
'previous_MSE': self.previous_MSE,
},
filename=checkpoint_name
)
if previous_checkpoint != None:
os.remove(previous_checkpoint)
previous_checkpoint = checkpoint_name
if self.previous_checkpoint != None:
os.remove(self.previous_checkpoint)
self.previous_checkpoint = checkpoint_name
else:
previous_checkpoint = checkpoint_name
self.previous_checkpoint = checkpoint_name
if phase == 'train':
learning_rate_scheduler.step()
......@@ -385,6 +387,9 @@ class Solver():
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.best_score_early_stop = checkpoint['best_score_early_stop']
self.counter_early_stop = checkpoint['counter_early_stop']
self.previous_checkpoint = checkpoint['previous_checkpoint']
self.previous_loss = checkpoint['previous_loss']
self.previous_MSE = checkpoint['previous_MSE']
for state in self.optimizer.state.values():
for key, value in state.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment