"""Brain Mapper U-Net Solver Description: This folder contains the Pytorch implementation of the core U-net solver, used for training the network. Usage: To use this module, import it and instantiate is as you wish: from solver import Solver """ import os import numpy as np import torch import glob from datetime import datetime from utils.losses import MSELoss from utils.data_utils import create_folder from utils.data_logging_utils import LogWriter from utils.early_stopping import EarlyStopping from torch.optim import lr_scheduler checkpoint_extension = 'path.tar' class Solver(): """Solver class for the BrainMapper U-net. This class contains the pytorch implementation of the U-net solver required for the BrainMapper project. Args: model (class): BrainMapper model class experiment_name (str): Name of the experiment device (int/str): Device type used for training (int - GPU id, str- CPU) number_of_classes (int): Number of classes optimizer (class): Pytorch class of desired optimizer optimizer_arguments (dict): Dictionary of arguments to be optimized loss_function (func): Function describing the desired loss function model_name (str): Name of the model labels (arr): Vector/Array of labels (if applicable) number_epochs (int): Number of training epochs loss_log_period (int): Period for writing loss value learning_rate_scheduler_step_size (int): Period of learning rate decay learning_rate_scheduler_gamma (int): Multiplicative factor of learning rate decay use_last_checkpoint (bool): Flag for loading the previous checkpoint experiment_directory (str): Experiment output directory name logs_directory (str): Directory for outputing training logs Returns: trained model - working on this! """ def __init__(self, model, device, number_of_classes, experiment_name, optimizer, optimizer_arguments={}, loss_function=MSELoss(), model_name='BrainMapper', labels=None, number_epochs=10, loss_log_period=5, learning_rate_scheduler_step_size=5, learning_rate_scheduler_gamma=0.5, use_last_checkpoint=True, experiment_directory='experiments', logs_directory='logs', checkpoint_directory = 'checkpoints' ): self.model = model self.device = device self.optimizer = optimizer(model.parameters(), **optimizer_arguments) if torch.cuda.is_available(): self.loss_function = loss_function.cuda(device) else: self.loss_function = loss_function self.model_name = model_name self.labels = labels self.number_epochs = number_epochs self.loss_log_period = loss_log_period # We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch. self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer, step_size=learning_rate_scheduler_step_size, gamma=learning_rate_scheduler_gamma) self.use_last_checkpoint = use_last_checkpoint experiment_directory_path = os.path.join( experiment_directory, experiment_name) self.experiment_directory_path = experiment_directory_path self.checkpoint_directory = checkpoint_directory create_folder(experiment_directory) create_folder(experiment_directory_path) create_folder(os.path.join( experiment_directory_path, self.checkpoint_directory)) self.start_epoch = 1 self.start_iteration = 1 # self.best_mean_score = 0 # self.best_mean_score_epoch = 0 self.LogWriter = LogWriter(number_of_classes=number_of_classes, logs_directory=logs_directory, experiment_name=experiment_name, use_last_checkpoint=use_last_checkpoint, labels=labels) self.EarlyStopping = EarlyStopping() self.early_stop = False if use_last_checkpoint: self.load_checkpoint() def train(self, train_loader, validation_loader): """Training Function This function trains a given model using the provided training data. Args: train_loader (class): Combined dataset and sampler, providing an iterable over the training dataset (torch.utils.data.DataLoader) validation_loader (class): Combined dataset and sampler, providing an iterable over the validationing dataset (torch.utils.data.DataLoader) Returns: trained model """ model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler dataloaders = {'train': train_loader, 'validation': validation_loader} if torch.cuda.is_available(): torch.cuda.empty_cache() # clear memory model.cuda(self.device) # Moving the model to GPU print('****************************************************************') print('TRAINING IS STARTING!') print('=====================') print('Model Name: {}'.format(self.model_name)) if torch.cuda.is_available(): print('Device Type: {}'.format( torch.cuda.get_device_name(self.device))) else: print('Device Type: {}'.format(self.device)) start_time = datetime.now() print('Started At: {}'.format(start_time)) print('----------------------------------------') iteration = self.start_iteration for epoch in range(self.start_epoch, self.number_epochs+1): print("Epoch {}/{}".format(epoch, self.number_epochs)) for phase in ['train', 'validation']: print('-> Phase: {}'.format(phase)) losses = [] if phase == 'train': model.train() else: model.eval() for batch_index, sampled_batch in enumerate(dataloaders[phase]): X = sampled_batch[0].type(torch.FloatTensor) # X = ( X - X.min() ) / ( X.max() - X.min() ) # X = ( X - X.mean() ) / X.std() y = sampled_batch[1].type(torch.FloatTensor) # We add an extra dimension (~ number of channels) for the 3D convolutions. X = torch.unsqueeze(X, dim=1) y = torch.unsqueeze(y, dim=1) if model.test_if_cuda: X = X.cuda(self.device, non_blocking=True) y = y.cuda(self.device, non_blocking=True) y_hat = model(X) # Forward pass loss = self.loss_function(y_hat, y) # Loss computation if phase == 'train': optimizer.zero_grad() # Zero the parameter gradients loss.backward() # Backward propagation optimizer.step() if batch_index % self.loss_log_period == 0: self.LogWriter.loss_per_iteration( loss.item(), batch_index, iteration) iteration += 1 losses.append(loss.item()) # Clear the memory del X, y, y_hat, loss torch.cuda.empty_cache() if phase == 'validation': if batch_index != len(dataloaders[phase]) - 1: print("#", end='', flush=True) else: print("100%", flush=True) with torch.no_grad(): self.LogWriter.loss_per_epoch(losses, phase, epoch) if phase == 'validation': early_stop, save_checkpoint = self.EarlyStopping(np.mean(losses)) self.early_stop = early_stop if save_checkpoint == True: validation_loss = np.mean(losses) self.save_checkpoint(state={'epoch': epoch + 1, 'start_iteration': iteration + 1, 'arch': self.model_name, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': learning_rate_scheduler.state_dict() }, filename=os.path.join(self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension) ) if epoch != self.start_epoch: os.remove(os.path.join(self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch-1) + '.' + checkpoint_extension)) if phase == 'train': learning_rate_scheduler.step() print("Epoch {}/{} DONE!".format(epoch, self.number_epochs)) # Early Stop Condition if self.early_stop == True: print("ATTENTION!: Training stopped early to prevent overfitting!") break else: continue self.LogWriter.close() print('----------------------------------------') print('TRAINING IS COMPLETE!') print('=====================') end_time = datetime.now() print('Completed At: {}'.format(end_time)) print('Training Duration: {}'.format(end_time - start_time)) print('****************************************************************') return validation_loss def save_checkpoint(self, state, filename): """General Checkpoint Save This function saves a general checkpoint for inference and/or resuming training Args: state (dict): Dictionary of all the relevant model components """ torch.save(state, filename) def load_checkpoint(self, epoch=None): """General Checkpoint Loader This function loads a previous checkpoint for inference and/or resuming training Args: epoch (int): Current epoch value """ if epoch is not None: checkpoint_file_path = os.path.join( self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension) self._checkpoint_reader(checkpoint_file_path) else: universal_path = os.path.join( self.experiment_directory_path, self.checkpoint_directory, '*.' + checkpoint_extension) files_in_universal_path = glob.glob(universal_path) # We will sort through all the files in path to see which one is most recent if len(files_in_universal_path) > 0: checkpoint_file_path = max( files_in_universal_path, key=os.path.getatime) self._checkpoint_reader(checkpoint_file_path) else: self.LogWriter.log("No Checkpoint found at {}".format( os.path.join(self.experiment_directory_path, self.checkpoint_directory))) def _checkpoint_reader(self, checkpoint_file_path): """Checkpoint Reader This private function reads a checkpoint file and then loads the relevant variables Args: checkpoint_file_path (str): path to checkpoint file """ self.LogWriter.log( "Loading Checkpoint {}".format(checkpoint_file_path)) checkpoint = torch.load(checkpoint_file_path) self.start_epoch = checkpoint['epoch'] self.start_iteration = checkpoint['start_iteration'] # We are not loading the model_name as we might want to pre-train a model and then use it. self.model.load_state_dict = checkpoint['state_dict'] self.optimizer.load_state_dict = checkpoint['optimizer'] self.learning_rate_scheduler.load_state_dict = checkpoint['scheduler'] for state in self.optimizer.state.values(): for key, value in state.items(): if torch.is_tensor(value): state[key] = value.to(self.device) self.LogWriter.log( "Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))