"""Brain Mapper U-Net Solver Description: This folder contains the Pytorch implementation of the core U-net solver, used for training the network. Usage: To use this module, import it and instantiate is as you wish: from solver import Solver """ import os import numpy as np import torch import glob from fsl.data.image import Image from fsl.utils.image.roi import roi from datetime import datetime from utils.common_utils import create_folder from utils.data_logging_utils import LogWriter from utils.early_stopping import EarlyStopping from torch.optim import lr_scheduler checkpoint_extension = 'path.tar' class Solver(): """Solver class for the BrainMapper U-net. This class contains the pytorch implementation of the U-net solver required for the BrainMapper project. Args: model (class): BrainMapper model class experiment_name (str): Name of the experiment device (int/str): Device type used for training (int - GPU id, str- CPU) number_of_classes (int): Number of classes optimizer (class): Pytorch class of desired optimizer optimizer_arguments (dict): Dictionary of arguments to be optimized loss_function (func): Function describing the desired loss function model_name (str): Name of the model labels (arr): Vector/Array of labels (if applicable) number_epochs (int): Number of training epochs loss_log_period (int): Period for writing loss value learning_rate_scheduler_step_size (int): Period of learning rate decay learning_rate_scheduler_gamma (int): Multiplicative factor of learning rate decay use_last_checkpoint (bool): Flag for loading the previous checkpoint experiment_directory (str): Experiment output directory name logs_directory (str): Directory for outputing training logs crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training Returns: trained model - working on this! """ def __init__(self, model, device, number_of_classes, experiment_name, optimizer, optimizer_arguments={}, loss_function=torch.nn.MSELoss(), model_name='BrainMapper', labels=None, number_epochs=10, loss_log_period=5, learning_rate_scheduler_step_size=5, learning_rate_scheduler_gamma=0.5, use_last_checkpoint=True, experiment_directory='experiments', logs_directory='logs', checkpoint_directory='checkpoints', save_model_directory='saved_models', crop_flag = False ): self.model = model self.device = device self.optimizer = optimizer(model.parameters(), **optimizer_arguments) if torch.cuda.is_available(): self.loss_function = loss_function.cuda(device) self.MSE = torch.nn.MSELoss().cuda(device) else: self.loss_function = loss_function self.MSE = torch.nn.MSELoss() self.model_name = model_name self.labels = labels self.number_epochs = number_epochs self.loss_log_period = loss_log_period # We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch. self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer, step_size=learning_rate_scheduler_step_size, gamma=learning_rate_scheduler_gamma) self.use_last_checkpoint = use_last_checkpoint experiment_directory_path = os.path.join( experiment_directory, experiment_name) self.experiment_directory_path = experiment_directory_path self.checkpoint_directory = checkpoint_directory create_folder(experiment_directory) create_folder(experiment_directory_path) create_folder(os.path.join( experiment_directory_path, self.checkpoint_directory)) self.start_epoch = 1 self.start_iteration = 1 self.LogWriter = LogWriter(number_of_classes=number_of_classes, logs_directory=logs_directory, experiment_name=experiment_name, use_last_checkpoint=use_last_checkpoint, labels=labels) self.early_stop = False if crop_flag == False: self.MNI152_T1_2mm_brain_mask = torch.from_numpy(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data) elif crop_flag == True: self.MNI152_T1_2mm_brain_mask = torch.from_numpy(roi(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz'),((9,81),(10,100),(0,77))).data) self.save_model_directory = save_model_directory self.final_model_output_file = experiment_name + ".pth.tar" self.best_score_early_stop = None self.counter_early_stop = 0 self.previous_loss = None self.previous_MSE = None self.valid_epoch = None if use_last_checkpoint: self.load_checkpoint() self.EarlyStopping = EarlyStopping(patience=5, min_delta=0, best_score=self.best_score_early_stop, counter=self.counter_early_stop) else: self.EarlyStopping = EarlyStopping(patience=5, min_delta=0) def train(self, train_loader, validation_loader): """Training Function This function trains a given model using the provided training data. Args: train_loader (class): Combined dataset and sampler, providing an iterable over the training dataset (torch.utils.data.DataLoader) validation_loader (class): Combined dataset and sampler, providing an iterable over the validationing dataset (torch.utils.data.DataLoader) Returns: trained model """ model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler dataloaders = {'train': train_loader, 'validation': validation_loader} if torch.cuda.is_available(): torch.cuda.empty_cache() # clear memory model.cuda(self.device) # Moving the model to GPU print('****************************************************************') print('TRAINING IS STARTING!') print('=====================') print('Model Name: {}'.format(self.model_name)) if torch.cuda.is_available(): print('Device Type: {}'.format( torch.cuda.get_device_name(self.device))) else: print('Device Type: {}'.format(self.device)) start_time = datetime.now() print('Started At: {}'.format(start_time)) print('----------------------------------------') iteration = self.start_iteration for epoch in range(self.start_epoch, self.number_epochs+1): if self.early_stop == True: print("ATTENTION!: Training stopped due to previous early stop flag!") break print("Epoch {}/{}".format(epoch, self.number_epochs)) for phase in ['train', 'validation']: print('-> Phase: {}'.format(phase)) losses = [] MSEs = [] if phase == 'train': model.train() else: model.eval() for batch_index, sampled_batch in enumerate(dataloaders[phase]): X = sampled_batch[0].type(torch.FloatTensor) y = sampled_batch[1].type(torch.FloatTensor) # We add an extra dimension (~ number of channels) for the 3D convolutions. X = torch.unsqueeze(X, dim=1) y = torch.unsqueeze(y, dim=1) MNI152_T1_2mm_brain_mask = torch.unsqueeze( torch.unsqueeze(self.MNI152_T1_2mm_brain_mask, dim=0), dim=0) if model.test_if_cuda: X = X.cuda(self.device, non_blocking=True) y = y.cuda(self.device, non_blocking=True) MNI152_T1_2mm_brain_mask = MNI152_T1_2mm_brain_mask.cuda( self.device, non_blocking=True) y_hat = model(X) # Forward pass & Masking y_hat = torch.mul(y_hat, MNI152_T1_2mm_brain_mask) loss = self.loss_function(y_hat, y) # Loss computation # loss = self.loss_function(y_hat+1e-4, y+1e-4, torch.tensor(1.0).cuda(self.device, non_blocking=True)) # We also calculate a separate MSE for cost function comparison! MSE = self.MSE(y_hat, y) MSEs.append(MSE.item()) if phase == 'train': optimizer.zero_grad() # Zero the parameter gradients loss.backward() # Backward propagation optimizer.step() if batch_index % self.loss_log_period == 0: self.LogWriter.loss_per_iteration( loss.item(), batch_index, iteration) iteration += 1 losses.append(loss.item()) # Clear the memory del X, y, y_hat, loss, MNI152_T1_2mm_brain_mask, MSE torch.cuda.empty_cache() if phase == 'validation': if batch_index != len(dataloaders[phase]) - 1: print("#", end='', flush=True) else: print("100%", flush=True) with torch.no_grad(): if phase == 'train': self.LogWriter.loss_per_epoch(losses, phase, epoch) self.LogWriter.MSE_per_epoch(MSEs, phase, epoch) elif phase == 'validation': self.LogWriter.loss_per_epoch( losses, phase, epoch, previous_loss=self.previous_loss) self.previous_loss = np.mean(losses) self.LogWriter.MSE_per_epoch( MSEs, phase, epoch, previous_loss=self.previous_MSE) self.previous_MSE = np.mean(MSEs) if phase == 'validation': early_stop, best_score_early_stop, counter_early_stop = self.EarlyStopping(np.mean(losses)) self.early_stop = early_stop self.best_score_early_stop = best_score_early_stop self.counter_early_stop = counter_early_stop checkpoint_name = os.path.join( self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension) if self.counter_early_stop == 0: self.valid_epoch = epoch self.save_checkpoint(state={'epoch': epoch + 1, 'start_iteration': iteration + 1, 'arch': self.model_name, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': learning_rate_scheduler.state_dict(), 'best_score_early_stop': self.best_score_early_stop, 'counter_early_stop': self.counter_early_stop, 'previous_loss': self.previous_loss, 'previous_MSE': self.previous_MSE, 'early_stop': self.early_stop, 'valid_epoch': self.valid_epoch }, filename=checkpoint_name ) if phase == 'train': learning_rate_scheduler.step() print("Epoch {}/{} DONE!".format(epoch, self.number_epochs)) # Early Stop Condition if self.early_stop == True: print("ATTENTION!: Training stopped early to prevent overfitting!") self.load_checkpoint(epoch=self.valid_epoch) break else: continue if self.early_stop == True: self.LogWriter.close() print('----------------------------------------') print('NO TRAINING DONE TO PREVENT OVERFITTING!') print('=====================') end_time = datetime.now() print('Completed At: {}'.format(end_time)) print('Training Duration: {}'.format(end_time - start_time)) print('****************************************************************') else: model_output_path = os.path.join( self.save_model_directory, self.final_model_output_file) create_folder(self.save_model_directory) model.save(model_output_path) self.LogWriter.close() print('----------------------------------------') print('TRAINING IS COMPLETE!') print('=====================') end_time = datetime.now() print('Completed At: {}'.format(end_time)) print('Training Duration: {}'.format(end_time - start_time)) print('Final Model Saved in: {}'.format(model_output_path)) print('****************************************************************') def save_checkpoint(self, state, filename): """General Checkpoint Save This function saves a general checkpoint for inference and/or resuming training Args: state (dict): Dictionary of all the relevant model components """ torch.save(state, filename) def load_checkpoint(self, epoch=None): """General Checkpoint Loader This function loads a previous checkpoint for inference and/or resuming training Args: epoch (int): Current epoch value """ if epoch is not None: checkpoint_file_path = os.path.join( self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension) self._checkpoint_reader(checkpoint_file_path) else: universal_path = os.path.join( self.experiment_directory_path, self.checkpoint_directory, '*.' + checkpoint_extension) files_in_universal_path = glob.glob(universal_path) # We will sort through all the files in path to see which one is most recent if len(files_in_universal_path) > 0: checkpoint_file_path = max( files_in_universal_path, key=os.path.getatime) self._checkpoint_reader(checkpoint_file_path) else: self.LogWriter.log("No Checkpoint found at {}".format( os.path.join(self.experiment_directory_path, self.checkpoint_directory))) def _checkpoint_reader(self, checkpoint_file_path): """Checkpoint Reader This private function reads a checkpoint file and then loads the relevant variables Args: checkpoint_file_path (str): path to checkpoint file """ self.LogWriter.log( "Loading Checkpoint {}".format(checkpoint_file_path)) checkpoint = torch.load(checkpoint_file_path) self.start_epoch = checkpoint['epoch'] self.start_iteration = checkpoint['start_iteration'] # We are not loading the model_name as we might want to pre-train a model and then use it. self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_score_early_stop = checkpoint['best_score_early_stop'] self.counter_early_stop = checkpoint['counter_early_stop'] self.previous_loss = checkpoint['previous_loss'] self.previous_MSE = checkpoint['previous_MSE'] self.early_stop = checkpoint['early_stop'] self.valid_epoch = checkpoint['valid_epoch'] for state in self.optimizer.state.values(): for key, value in state.items(): if torch.is_tensor(value): state[key] = value.to(self.device) self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler']) self.LogWriter.log( "Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))