Commit 4c63df04 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added early stopping to prevent overfitting and long training

parent 0b7bd09d
......@@ -20,6 +20,7 @@ from datetime import datetime
from utils.losses import MSELoss
from utils.data_utils import create_folder
from utils.data_logging_utils import LogWriter
from utils.early_stopping import EarlyStopping
from torch.optim import lr_scheduler
checkpoint_directory = 'checkpoints'
......@@ -114,24 +115,27 @@ class Solver():
use_last_checkpoint=use_last_checkpoint,
labels=labels)
self.EarlyStopping = EarlyStopping()
self.early_stop = False
if use_last_checkpoint:
self.load_checkpoint()
def train(self, train_loader, test_loader):
def train(self, train_loader, validation_loader):
"""Training Function
This function trains a given model using the provided training data.
Args:
train_loader (class): Combined dataset and sampler, providing an iterable over the training dataset (torch.utils.data.DataLoader)
test_loader (class): Combined dataset and sampler, providing an iterable over the testing dataset (torch.utils.data.DataLoader)
validation_loader (class): Combined dataset and sampler, providing an iterable over the validationing dataset (torch.utils.data.DataLoader)
Returns:
trained model
"""
model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler
dataloaders = {'train': train_loader, 'test': test_loader}
dataloaders = {'train': train_loader, 'validation': validation_loader}
if torch.cuda.is_available():
torch.cuda.empty_cache() # clear memory
......@@ -155,7 +159,7 @@ class Solver():
for epoch in range(self.start_epoch, self.number_epochs+1):
print("Epoch {}/{}".format(epoch, self.number_epochs))
for phase in ['train', 'test']:
for phase in ['train', 'validation']:
print('-> Phase: {}'.format(phase))
losses = []
......@@ -194,14 +198,15 @@ class Solver():
iteration += 1
losses.append(loss.item())
losses.append(loss.item())
# Clear the memory
del X, y, y_hat, loss
torch.cuda.empty_cache()
if phase == 'test':
if phase == 'validation':
if batch_index != len(dataloaders[phase]) - 1:
print("#", end='', flush=True)
else:
......@@ -211,18 +216,34 @@ class Solver():
self.LogWriter.loss_per_epoch(losses, phase, epoch)
if phase == 'validation':
validation_loss = np.mean(losses)
early_stop, save_checkpoint = self.EarlyStopping(validation_loss)
self.early_stop = early_stop
if save_checkpoint == True:
self.save_checkpoint(state={'epoch': epoch + 1,
'start_iteration': iteration + 1,
'arch': self.model_name,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': learning_rate_scheduler.state_dict()
},
filename=os.path.join(self.experiment_directory_path, checkpoint_directory,
'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
)
del validation_loss, early_stop, save_checkpoint
torch.cuda.empty_cache()
print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
self.save_checkpoint(state={'epoch': epoch + 1,
'start_iteration': iteration + 1,
'arch': self.model_name,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': learning_rate_scheduler.state_dict()
},
filename=os.path.join(self.experiment_directory_path, checkpoint_directory,
'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
)
# Early Stop Condition
if self.early_stop == True:
print("ATTENTION!: Training stopped early to prevent overfitting!")
break
else:
continue
self.LogWriter.close()
......
"""Early Stopping Class
Description:
This file contains a class which is used to early stop the training of the network to avoid overfitting on the training dataset.
Usage:
To use content from this function, import the functions and instantiate them as you wish to use them:
from utils.early_stopping import EarlyStopping
early_stop_check = EarlyStopping(parameters)
early_stop = early_stop_check(parameters)
"""
import torch
import numpy as np
class EarlyStopping:
"""Early Stopping class
This class is used as a form of regularization used to avoid overfitting on the training dataset.
Early stopping keeps track of the validation loss.
If the loss stops decreasing for several epochs (represented as patience) in a row the training stop signal is transmitted.
Args:
patience (int): Metric for keeping track of the number of consecutive epochs the validation loss is allowed to decrease.
verbose (bool): Flag for prinding out useful information.
Returns:
early_stop (bool): Flag indicating if the training should be terminated
save_checkpoint (bool): Flag indicating if the checkpoint should be saved
"""
def __init__(self, patience=5):
self.patience = patience
self.counter = 0
self.best_score = None
self.early_stop = False
self.save_checkpoint = False
def __call__(self, validation_loss):
score = - validation_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint = True
elif score < self.best_score:
self.counter += 1
self.save_checkpoint = False
print("Early Stopping Counter: {}/{}".format(self.counter, self.patience))
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.counter = 0
self.save_checkpoint = True
return self.early_stop, self.save_checkpoint
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment