"""Brain Mapper U-Net Solver Description: This folder contains the Pytorch implementation of the core U-net solver, used for training the network. Usage: To use this module, import it and instantiate is as you wish: from solver import Solver """ import os import numpy as np import torch import glob from fsl.data.image import Image from fsl.utils.image.roi import roi from datetime import datetime from utils.losses import MSELoss from utils.common_utils import create_folder from utils.data_logging_utils import LogWriter from utils.early_stopping import EarlyStopping from torch.optim import lr_scheduler checkpoint_extension = 'path.tar' class Solver(): """Solver class for the BrainMapper U-net. This class contains the pytorch implementation of the U-net solver required for the BrainMapper project. Args: model (class): BrainMapper model class experiment_name (str): Name of the experiment device (int/str): Device type used for training (int - GPU id, str- CPU) number_of_classes (int): Number of classes optimizer (class): Pytorch class of desired optimizer optimizer_arguments (dict): Dictionary of arguments to be optimized loss_function (func): Function describing the desired loss function model_name (str): Name of the model labels (arr): Vector/Array of labels (if applicable) number_epochs (int): Number of training epochs loss_log_period (int): Period for writing loss value learning_rate_scheduler_step_size (int): Period of learning rate decay learning_rate_scheduler_gamma (int): Multiplicative factor of learning rate decay use_last_checkpoint (bool): Flag for loading the previous checkpoint experiment_directory (str): Experiment output directory name logs_directory (str): Directory for outputing training logs crop_flag (bool): Flag indicating if the volumes should be cropped from 91x109x91 to 72x90x77 to reduce storage space and speed-up training Returns: trained model - working on this! """ def __init__(self, model, device, number_of_classes, experiment_name, optimizer, optimizer_arguments={}, loss_function=MSELoss(), # loss_function=torch.nn.L1Loss(), # loss_function=torch.nn.CosineEmbeddingLoss(), model_name='BrainMapper', labels=None, number_epochs=10, loss_log_period=5, learning_rate_scheduler_step_size=5, learning_rate_scheduler_gamma=0.5, use_last_checkpoint=True, experiment_directory='experiments', logs_directory='logs', checkpoint_directory='checkpoints', save_model_directory='saved_models', final_model_output_file='finetuned_alldata.pth.tar', crop_flag = False ): self.model = model self.device = device self.optimizer = optimizer(model.parameters(), **optimizer_arguments) if torch.cuda.is_available(): self.loss_function = loss_function.cuda(device) self.MSE = MSELoss().cuda(device) else: self.loss_function = loss_function self.MSE = MSELoss() self.model_name = model_name self.labels = labels self.number_epochs = number_epochs self.loss_log_period = loss_log_period # We use a learning rate scheduler, that decays the LR of each paramter group by gamma every step_size epoch. self.learning_rate_scheduler = lr_scheduler.StepLR(self.optimizer, step_size=learning_rate_scheduler_step_size, gamma=learning_rate_scheduler_gamma) self.use_last_checkpoint = use_last_checkpoint experiment_directory_path = os.path.join( experiment_directory, experiment_name) self.experiment_directory_path = experiment_directory_path self.checkpoint_directory = checkpoint_directory create_folder(experiment_directory) create_folder(experiment_directory_path) create_folder(os.path.join( experiment_directory_path, self.checkpoint_directory)) self.start_epoch = 1 self.start_iteration = 1 self.LogWriter = LogWriter(number_of_classes=number_of_classes, logs_directory=logs_directory, experiment_name=experiment_name, use_last_checkpoint=use_last_checkpoint, labels=labels) self.EarlyStopping = EarlyStopping(patience=10, min_delta=0) self.early_stop = False if crop_flag == False: self.MNI152_T1_2mm_brain_mask = torch.from_numpy(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz').data) elif crop_flag == True: self.MNI152_T1_2mm_brain_mask = torch.from_numpy(roi(Image('utils/MNI152_T1_2mm_brain_mask.nii.gz'),((9,81),(10,100),(0,77))).data) self.save_model_directory = save_model_directory self.final_model_output_file = final_model_output_file if use_last_checkpoint: self.load_checkpoint() def train(self, train_loader, validation_loader): """Training Function This function trains a given model using the provided training data. Args: train_loader (class): Combined dataset and sampler, providing an iterable over the training dataset (torch.utils.data.DataLoader) validation_loader (class): Combined dataset and sampler, providing an iterable over the validationing dataset (torch.utils.data.DataLoader) Returns: trained model """ model, optimizer, learning_rate_scheduler = self.model, self.optimizer, self.learning_rate_scheduler dataloaders = {'train': train_loader, 'validation': validation_loader} if torch.cuda.is_available(): torch.cuda.empty_cache() # clear memory model.cuda(self.device) # Moving the model to GPU previous_checkpoint = None previous_loss = None previous_MSE = None print('****************************************************************') print('TRAINING IS STARTING!') print('=====================') print('Model Name: {}'.format(self.model_name)) if torch.cuda.is_available(): print('Device Type: {}'.format( torch.cuda.get_device_name(self.device))) else: print('Device Type: {}'.format(self.device)) start_time = datetime.now() print('Started At: {}'.format(start_time)) print('----------------------------------------') iteration = self.start_iteration for epoch in range(self.start_epoch, self.number_epochs+1): print("Epoch {}/{}".format(epoch, self.number_epochs)) for phase in ['train', 'validation']: print('-> Phase: {}'.format(phase)) losses = [] MSEs = [] if phase == 'train': model.train() else: model.eval() for batch_index, sampled_batch in enumerate(dataloaders[phase]): X = sampled_batch[0].type(torch.FloatTensor) y = sampled_batch[1].type(torch.FloatTensor) # We add an extra dimension (~ number of channels) for the 3D convolutions. X = torch.unsqueeze(X, dim=1) y = torch.unsqueeze(y, dim=1) MNI152_T1_2mm_brain_mask = torch.unsqueeze( torch.unsqueeze(self.MNI152_T1_2mm_brain_mask, dim=0), dim=0) if model.test_if_cuda: X = X.cuda(self.device, non_blocking=True) y = y.cuda(self.device, non_blocking=True) MNI152_T1_2mm_brain_mask = MNI152_T1_2mm_brain_mask.cuda( self.device, non_blocking=True) y_hat = model(X) # Forward pass & Masking y_hat = torch.mul(y_hat, MNI152_T1_2mm_brain_mask) loss = self.loss_function(y_hat, y) # Loss computation # loss = self.loss_function(y_hat+1e-4, y+1e-4, torch.tensor(1.0).cuda(self.device, non_blocking=True)) # We also calculate a separate MSE for cost function comparison! MSE = self.MSE(y_hat, y) MSEs.append(MSE.item()) if phase == 'train': optimizer.zero_grad() # Zero the parameter gradients loss.backward() # Backward propagation optimizer.step() if batch_index % self.loss_log_period == 0: self.LogWriter.loss_per_iteration( loss.item(), batch_index, iteration) iteration += 1 losses.append(loss.item()) # Clear the memory del X, y, y_hat, loss, MNI152_T1_2mm_brain_mask, MSE torch.cuda.empty_cache() if phase == 'validation': if batch_index != len(dataloaders[phase]) - 1: print("#", end='', flush=True) else: print("100%", flush=True) with torch.no_grad(): if phase == 'train': self.LogWriter.loss_per_epoch(losses, phase, epoch) self.LogWriter.MSE_per_epoch(MSEs, phase, epoch) elif phase == 'validation': self.LogWriter.loss_per_epoch( losses, phase, epoch, previous_loss=previous_loss) previous_loss = np.mean(losses) self.LogWriter.MSE_per_epoch( MSEs, phase, epoch, previous_loss=previous_MSE) previous_MSE = np.mean(MSEs) if phase == 'validation': early_stop, save_checkpoint = self.EarlyStopping( np.mean(losses)) self.early_stop = early_stop if save_checkpoint == True: validation_loss = np.mean(losses) checkpoint_name = os.path.join( self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension) self.save_checkpoint(state={'epoch': epoch + 1, 'start_iteration': iteration + 1, 'arch': self.model_name, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'scheduler': learning_rate_scheduler.state_dict() }, filename=checkpoint_name ) if previous_checkpoint != None: os.remove(previous_checkpoint) previous_checkpoint = checkpoint_name else: previous_checkpoint = checkpoint_name if phase == 'train': learning_rate_scheduler.step() print("Epoch {}/{} DONE!".format(epoch, self.number_epochs)) # Early Stop Condition if self.early_stop == True: print("ATTENTION!: Training stopped early to prevent overfitting!") self.load_checkpoint() break else: continue model_output_path = os.path.join( self.save_model_directory, self.final_model_output_file) create_folder(self.save_model_directory) model.save(model_output_path) self.LogWriter.close() print('----------------------------------------') print('TRAINING IS COMPLETE!') print('=====================') end_time = datetime.now() print('Completed At: {}'.format(end_time)) print('Training Duration: {}'.format(end_time - start_time)) print('Final Model Saved in: {}'.format(model_output_path)) print('****************************************************************') if self.start_epoch >= self.number_epochs+1: validation_loss = None return validation_loss def save_checkpoint(self, state, filename): """General Checkpoint Save This function saves a general checkpoint for inference and/or resuming training Args: state (dict): Dictionary of all the relevant model components """ torch.save(state, filename) def load_checkpoint(self, epoch=None): """General Checkpoint Loader This function loads a previous checkpoint for inference and/or resuming training Args: epoch (int): Current epoch value """ if epoch is not None: checkpoint_file_path = os.path.join( self.experiment_directory_path, self.checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension) self._checkpoint_reader(checkpoint_file_path) else: universal_path = os.path.join( self.experiment_directory_path, self.checkpoint_directory, '*.' + checkpoint_extension) files_in_universal_path = glob.glob(universal_path) # We will sort through all the files in path to see which one is most recent if len(files_in_universal_path) > 0: checkpoint_file_path = max( files_in_universal_path, key=os.path.getatime) self._checkpoint_reader(checkpoint_file_path) else: self.LogWriter.log("No Checkpoint found at {}".format( os.path.join(self.experiment_directory_path, self.checkpoint_directory))) def _checkpoint_reader(self, checkpoint_file_path): """Checkpoint Reader This private function reads a checkpoint file and then loads the relevant variables Args: checkpoint_file_path (str): path to checkpoint file """ self.LogWriter.log( "Loading Checkpoint {}".format(checkpoint_file_path)) checkpoint = torch.load(checkpoint_file_path) self.start_epoch = checkpoint['epoch'] self.start_iteration = checkpoint['start_iteration'] # We are not loading the model_name as we might want to pre-train a model and then use it. self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) for state in self.optimizer.state.values(): for key, value in state.items(): if torch.is_tensor(value): state[key] = value.to(self.device) self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler']) self.LogWriter.log( "Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))