"""Brain Mapper Run File Description: This file contains all the relevant functions for running BrainMapper. The network can be ran in one of these modes: - train - evaluate path - evaluate whole TODO: Might be worth adding some information on uncertaintiy estimation, later down the line Usage: In order to run the network, in the terminal, the user needs to pass it relevant arguments: $ ./setup.sh $ source env/bin/activate $ python run.py --mode ... The arguments for mode are the following: mode=train # For training the model mode=evaluate-score # For evaluating the model score mode=evaluate-mapping # For evaluating the model mapping # For clearning the experiments and logs directories of the last experiment mode=clear-experiment mode=clear-all # For clearing all the files from the experiments and logs directories/ """ import os import shutil import argparse import logging import torch import torch.utils.data as data import numpy as np from solver import Solver from BrainMapperAE import BrainMapperAE3D, AutoEncoder3D from utils.data_utils import get_datasets from utils.settings import Settings import utils.data_evaluation_utils as evaluations from utils.common_utils import create_folder # Set the default floating point tensor type to FloatTensor torch.set_default_tensor_type(torch.FloatTensor) def load_data(data_parameters, cross_domain_x2x_flag, cross_domain_y2y_flag): """Dataset Loader This function loads the training and validation datasets. Args: data_parameters (dict): Dictionary containing relevant information for the datafiles. cross_domain_x2x_flag (bool): Flag indicating if cross-domain training is occuring between the inputs cross_domain_y2y_flag (bool): Flag indicating if cross-domain training is occuring between the targets Returns: train_data (dataset object): Pytorch map-style dataset object, mapping indices to training data samples. validation_data (dataset object): Pytorch map-style dataset object, mapping indices to testing data samples. """ print("Data is loading...") train_data, validation_data = get_datasets(data_parameters, cross_domain_x2x_flag, cross_domain_y2y_flag) print("Data has loaded!") print("Training dataset size is {}".format(len(train_data))) print("Validation dataset size is {}".format(len(validation_data))) return train_data, validation_data def train(data_parameters, training_parameters, network_parameters, misc_parameters): """Training Function This function trains a given model using the provided training data. Currently, the data loaded is set to have multiple sub-processes. A high enough number of workers assures that CPU computations are efficiently managed, i.e. that the bottleneck is indeed the neural network's forward and backward operations on the GPU (and not data generation) Loader memory is also pinned, to speed up data transfer from CPU to GPU by using the page-locked memory. Train data is also re-shuffled at each training epoch. Args: data_parameters (dict): Dictionary containing relevant information for the datafiles. training_parameters(dict): Dictionary containing relevant hyperparameters for training the network. training_parameters = { 'training_batch_size': 5 'validation_batch_size: 5 'use_pre_trained': False 'pre_trained_path': 'pre_trained/path' 'experiment_name': 'experiment_name' 'learning_rate': 1e-4 'optimizer_beta': (0.9, 0.999) 'optimizer_epsilon': 1e-8 'optimizer_weigth_decay': 1e-5 'number_of_epochs': 10 'loss_log_period': 50 'learning_rate_scheduler_step_size': 3 'learning_rate_scheduler_gamma': 1e-1 'use_last_checkpoint': True 'final_model_output_file': 'path/to/model' } network_parameters (dict): Contains information relevant parameters misc_parameters (dict): Dictionary of aditional hyperparameters misc_parameters = { 'save_model_directory': 'directory_name' 'model_name': 'BrainMapper' 'logs_directory': 'log-directory' 'device': 1 'experiments_directory': 'experiments-directory' } """ def _load_pretrained_cross_domain(x2y_model, save_model_directory, experiment_name): """ Pretrained cross-domain loader This function loads the pretrained X2X and Y2Y autuencoders. After, it initializes the X2Y model's weights using the X2X encoder and teh Y2Y decoder weights. Args: x2y_model (class): Original x2y model initialised using the standard parameters. save_model_directory (str): Name of the directory where the model is saved experiment_name (str): Name of the experiment Returns: x2y_model (class): New x2y model with encoder and decoder paths weights reinitialised. """ x2y_model_state_dict = x2y_model.state_dict() x2x_model_state_dict = torch.load(os.path.join(save_model_directory, experiment_name + '_x2x.pth.tar')).state_dict() y2y_model_state_dict = torch.load(os.path.join(save_model_directory, experiment_name + '_y2y.pth.tar')).state_dict() half_point = len(x2x_model_state_dict)//2 + 1 counter = 1 for key, _ in x2y_model_state_dict.items(): if counter <= half_point: x2y_model_state_dict.update({key : x2x_model_state_dict[key]}) counter+=1 else: if key in y2y_model_state_dict: x2y_model_state_dict.update({key : y2y_model_state_dict[key]}) x2y_model.load_state_dict(x2y_model_state_dict) return x2y_model def _train_runner(data_parameters, training_parameters, network_parameters, misc_parameters, optimizer = torch.optim.Adam, loss_function = torch.nn.MSELoss(), ): """Wrapper for the training operation This function wraps the training operation for the network Args: data_parameters (dict): Dictionary containing relevant information for the datafiles. training_parameters(dict): Dictionary containing relevant hyperparameters for training the network. network_parameters (dict): Contains information relevant parameters misc_parameters (dict): Dictionary of aditional hyperparameters """ train_data, validation_data = load_data(data_parameters, cross_domain_x2x_flag = network_parameters['cross_domain_x2x_flag'], cross_domain_y2y_flag = network_parameters['cross_domain_y2y_flag'] ) train_loader = data.DataLoader( dataset=train_data, batch_size=training_parameters['training_batch_size'], shuffle=True, pin_memory=True ) validation_loader = data.DataLoader( dataset=validation_data, batch_size=training_parameters['validation_batch_size'], shuffle=False, pin_memory=True ) if training_parameters['use_pre_trained']: BrainMapperModel = torch.load(training_parameters['pre_trained_path']) else: # BrainMapperModel = BrainMapperAE3D(network_parameters) BrainMapperModel = AutoEncoder3D(network_parameters) # temprorary change for testing encoder-decoder effective receptive field custom_weight_reset_flag = network_parameters['custom_weight_reset_flag'] BrainMapperModel.reset_parameters(custom_weight_reset_flag) if network_parameters['cross_domain_x2y_flag'] == True: BrainMapperModel = _load_pretrained_cross_domain(x2y_model=BrainMapperModel, save_model_directory=misc_parameters['save_model_directory'], experiment_name=training_parameters['experiment_name'] ) solver = Solver(model=BrainMapperModel, device=misc_parameters['device'], number_of_classes=network_parameters['number_of_classes'], experiment_name=training_parameters['experiment_name'], optimizer=optimizer, optimizer_arguments={'lr': training_parameters['learning_rate'], 'betas': training_parameters['optimizer_beta'], 'eps': training_parameters['optimizer_epsilon'], 'weight_decay': training_parameters['optimizer_weigth_decay'] }, loss_function=loss_function, model_name=training_parameters['experiment_name'], number_epochs=training_parameters['number_of_epochs'], loss_log_period=training_parameters['loss_log_period'], learning_rate_scheduler_step_size=training_parameters[ 'learning_rate_scheduler_step_size'], learning_rate_scheduler_gamma=training_parameters['learning_rate_scheduler_gamma'], use_last_checkpoint=training_parameters['use_last_checkpoint'], experiment_directory=misc_parameters['experiments_directory'], logs_directory=misc_parameters['logs_directory'], checkpoint_directory=misc_parameters['checkpoint_directory'], save_model_directory=misc_parameters['save_model_directory'], crop_flag = data_parameters['crop_flag'] ) # _ = solver.train(train_loader, validation_loader) solver.train(train_loader, validation_loader) del train_data, validation_data, train_loader, validation_loader, BrainMapperModel, solver, optimizer torch.cuda.empty_cache() # return None if training_parameters['adam_w_flag'] == True: optimizer = torch.optim.AdamW else: optimizer = torch.optim.Adam loss_function = torch.nn.MSELoss() # loss_function=torch.nn.L1Loss() # loss_function=torch.nn.CosineEmbeddingLoss() if network_parameters['cross_domain_flag'] == False: _train_runner(data_parameters, training_parameters, network_parameters, misc_parameters, optimizer=optimizer, loss_function=loss_function ) elif network_parameters['cross_domain_flag'] == True: if network_parameters['cross_domain_x2x_flag'] == True: training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_x2x" data_parameters['target_data_train'] = data_parameters['input_data_train'] data_parameters['target_data_validation'] = data_parameters['input_data_validation'] # loss_function = torch.nn.L1Loss() _train_runner(data_parameters, training_parameters, network_parameters, misc_parameters, optimizer=optimizer, loss_function=loss_function ) if network_parameters['cross_domain_y2y_flag'] == True: training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_y2y" data_parameters['input_data_train'] = data_parameters['target_data_train'] data_parameters['input_data_validation'] = data_parameters['target_data_validation'] # loss_function = torch.nn.L1Loss() _train_runner(data_parameters, training_parameters, network_parameters, misc_parameters, optimizer=optimizer, loss_function=loss_function ) if network_parameters['cross_domain_x2y_flag'] == True: _train_runner(data_parameters, training_parameters, network_parameters, misc_parameters, optimizer=optimizer, loss_function=loss_function ) def evaluate_mapping(mapping_evaluation_parameters): """Mapping Evaluator This function passes through the network an input and generates the rsfMRI outputs. Args: mapping_evaluation_parameters (dict): Dictionary of parameters useful during mapping evaluation. mapping_evaluation_parameters = { 'trained_model_path': 'path/to/model' 'data_directory': 'path/to/data' 'data_list': 'path/to/datalist.txt/ 'prediction_output_path': 'directory-of-saved-predictions' 'batch_size': 2 'device': 0 'exit_on_error': True } """ trained_model_path = mapping_evaluation_parameters['trained_model_path'] data_directory = mapping_evaluation_parameters['data_directory'] mapping_data_file = mapping_evaluation_parameters['mapping_data_file'] mapping_targets_file = mapping_evaluation_parameters['mapping_targets_file'] data_list = mapping_evaluation_parameters['data_list_reduced'] prediction_output_path = mapping_evaluation_parameters['prediction_output_path'] dmri_mean_mask_path = mapping_evaluation_parameters['dmri_mean_mask_path'] rsfmri_mean_mask_path = mapping_evaluation_parameters['rsfmri_mean_mask_path'] device = mapping_evaluation_parameters['device'] exit_on_error = mapping_evaluation_parameters['exit_on_error'] brain_mask_path = mapping_evaluation_parameters['brain_mask_path'] regression_factors = mapping_evaluation_parameters['regression_factors'] mean_regression_flag = mapping_evaluation_parameters['mean_regression_flag'] mean_regression_all_flag = mapping_evaluation_parameters['mean_regression_all_flag'] mean_subtraction_flag = mapping_evaluation_parameters['mean_subtraction_flag'] scale_volumes_flag = mapping_evaluation_parameters['scale_volumes_flag'] normalize_flag = mapping_evaluation_parameters['normalize_flag'] minus_one_scaling_flag = mapping_evaluation_parameters['minus_one_scaling_flag'] negative_flag = mapping_evaluation_parameters['negative_flag'] outlier_flag = mapping_evaluation_parameters['outlier_flag'] shrinkage_flag = mapping_evaluation_parameters['shrinkage_flag'] hard_shrinkage_flag = mapping_evaluation_parameters['hard_shrinkage_flag'] crop_flag = mapping_evaluation_parameters['crop_flag'] cross_domain_x2x_flag = mapping_evaluation_parameters['cross_domain_x2x_flag'] cross_domain_y2y_flag = mapping_evaluation_parameters['cross_domain_y2y_flag'] evaluations.evaluate_mapping(trained_model_path, data_directory, mapping_data_file, mapping_targets_file, data_list, prediction_output_path, brain_mask_path, dmri_mean_mask_path, rsfmri_mean_mask_path, regression_factors, mean_regression_flag, mean_regression_all_flag, mean_subtraction_flag, scale_volumes_flag, normalize_flag, minus_one_scaling_flag, negative_flag, outlier_flag, shrinkage_flag, hard_shrinkage_flag, crop_flag, device, exit_on_error, cross_domain_x2x_flag, cross_domain_y2y_flag ) def evaluate_data(mapping_evaluation_parameters): """Mapping Evaluator This function passes through the network an input and generates the rsfMRI outputs. Args: mapping_evaluation_parameters (dict): Dictionary of parameters useful during mapping evaluation. mapping_evaluation_parameters = { 'trained_model_path': 'path/to/model' 'data_directory': 'path/to/data' 'data_list': 'path/to/datalist.txt/ 'prediction_output_path': 'directory-of-saved-predictions' 'batch_size': 2 'device': 0 'exit_on_error': True } """ trained_model_path = mapping_evaluation_parameters['trained_model_path'] data_directory = mapping_evaluation_parameters['data_directory'] mapping_data_file = mapping_evaluation_parameters['mapping_data_file'] mapping_targets_file = mapping_evaluation_parameters['mapping_targets_file'] if mapping_evaluation_parameters['evaluate_all_data'] == False: data_list = mapping_evaluation_parameters['data_list_reduced'] elif mapping_evaluation_parameters['evaluate_all_data'] == True: data_list = mapping_evaluation_parameters['data_list_all'] prediction_output_path = mapping_evaluation_parameters['prediction_output_path'] prediction_output_database_name = mapping_evaluation_parameters['prediction_output_database_name'] prediction_output_statistics_name = mapping_evaluation_parameters['prediction_output_statistics_name'] dmri_mean_mask_path = mapping_evaluation_parameters['dmri_mean_mask_path'] rsfmri_mean_mask_path = mapping_evaluation_parameters['rsfmri_mean_mask_path'] device = mapping_evaluation_parameters['device'] exit_on_error = mapping_evaluation_parameters['exit_on_error'] brain_mask_path = mapping_evaluation_parameters['brain_mask_path'] regression_factors = mapping_evaluation_parameters['regression_factors'] mean_regression_flag = mapping_evaluation_parameters['mean_regression_flag'] mean_regression_all_flag = mapping_evaluation_parameters['mean_regression_all_flag'] mean_subtraction_flag = mapping_evaluation_parameters['mean_subtraction_flag'] scale_volumes_flag = mapping_evaluation_parameters['scale_volumes_flag'] normalize_flag = mapping_evaluation_parameters['normalize_flag'] minus_one_scaling_flag = mapping_evaluation_parameters['minus_one_scaling_flag'] negative_flag = mapping_evaluation_parameters['negative_flag'] outlier_flag = mapping_evaluation_parameters['outlier_flag'] shrinkage_flag = mapping_evaluation_parameters['shrinkage_flag'] hard_shrinkage_flag = mapping_evaluation_parameters['hard_shrinkage_flag'] crop_flag = mapping_evaluation_parameters['crop_flag'] output_database_flag = mapping_evaluation_parameters['output_database_flag'] cross_domain_x2x_flag = mapping_evaluation_parameters['cross_domain_x2x_flag'] cross_domain_y2y_flag = mapping_evaluation_parameters['cross_domain_y2y_flag'] evaluations.evaluate_data(trained_model_path, data_directory, mapping_data_file, mapping_targets_file, data_list, prediction_output_path, prediction_output_database_name, prediction_output_statistics_name, brain_mask_path, dmri_mean_mask_path, rsfmri_mean_mask_path, regression_factors, mean_regression_flag, mean_regression_all_flag, mean_subtraction_flag, scale_volumes_flag, normalize_flag, minus_one_scaling_flag, negative_flag, outlier_flag, shrinkage_flag, hard_shrinkage_flag, crop_flag, device, exit_on_error, output_database_flag, cross_domain_x2x_flag, cross_domain_y2y_flag ) def delete_files(folder): """ Clear Folder Contents Function which clears contents (like experiments or logs) Args: folder (str): Name of folders whose conents is to be deleted """ for object_name in os.listdir(folder): file_path = os.path.join(folder, object_name) try: if os.path.isfile(file_path): os.unlink(file_path) elif os.path.isdir(file_path): shutil.rmtree(file_path) except Exception as exception: print(exception) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--mode', '-m', required=True, help='run mode, valid values are train or evaluate') parser.add_argument('--model_name', '-n', required=True, help='model name, required for identifying the settings file modelName.ini & modelName_eval.ini') parser.add_argument('--use_last_checkpoint', '-c', required=False, help='flag indicating if the last checkpoint should be used if 1; useful when wanting to time-limit jobs.') parser.add_argument('--number_of_epochs', '-e', required=False, help='flag indicating how many epochs the network will train for; should be limited to ~3 hours or 2/3 epochs') arguments = parser.parse_args() settings_file_name = arguments.model_name + '.ini' evaluation_settings_file_name = arguments.model_name + '_eval.ini' settings = Settings(settings_file_name) data_parameters = settings['DATA'] training_parameters = settings['TRAINING'] network_parameters = settings['NETWORK'] misc_parameters = settings['MISC'] evaluation_parameters = settings['EVALUATION'] if arguments.use_last_checkpoint == '1': training_parameters['use_last_checkpoint'] = True elif arguments.use_last_checkpoint == '0': training_parameters['use_last_checkpoint'] = False if arguments.number_of_epochs is not None: training_parameters['number_of_epochs'] = int(arguments.number_of_epochs) if arguments.mode == 'train': train(data_parameters, training_parameters, network_parameters, misc_parameters) elif arguments.mode == 'evaluate-mapping': logging.basicConfig(filename='evaluate-mapping-error.log') settings_evaluation = Settings(evaluation_settings_file_name) mapping_evaluation_parameters = settings_evaluation['MAPPING'] evaluate_mapping(mapping_evaluation_parameters) elif arguments.mode == 'evaluate-data': logging.basicConfig(filename='evaluate-data-error.log') settings_evaluation = Settings(evaluation_settings_file_name) mapping_evaluation_parameters = settings_evaluation['MAPPING'] evaluate_data(mapping_evaluation_parameters) elif arguments.mode == 'clear-checkpoints': if network_parameters['cross_domain_flag'] == True: if network_parameters['cross_domain_x2x_flag'] == True: training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_x2x" if network_parameters['cross_domain_y2y_flag'] == True: training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_y2y" shutil.rmtree(os.path.join(misc_parameters['experiments_directory'], training_parameters['experiment_name'])) print('Cleared the current experiment checkpoints successfully!') elif arguments.mode == 'clear-logs': if network_parameters['cross_domain_flag'] == True: if network_parameters['cross_domain_x2x_flag'] == True: training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_x2x" if network_parameters['cross_domain_y2y_flag'] == True: training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_y2y" shutil.rmtree(os.path.join(misc_parameters['logs_directory'], training_parameters['experiment_name'])) print('Cleared the current experiment logs directory successfully!') elif arguments.mode == 'clear-experiment': if network_parameters['cross_domain_flag'] == True: if network_parameters['cross_domain_x2x_flag'] == True: training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_x2x" if network_parameters['cross_domain_y2y_flag'] == True: training_parameters['experiment_name'] = training_parameters['experiment_name'] + "_y2y" shutil.rmtree(os.path.join(misc_parameters['experiments_directory'], training_parameters['experiment_name'])) shutil.rmtree(os.path.join(misc_parameters['logs_directory'], training_parameters['experiment_name'])) print('Cleared the current experiment checkpoints and logs directory successfully!') # elif arguments.mode == 'clear-everything': # delete_files(misc_parameters['experiments_directory']) # delete_files(misc_parameters['logs_directory']) # print('Cleared the all the checkpoints and logs directory successfully!') elif arguments.mode == 'train-and-evaluate-mapping': settings_evaluation = Settings(evaluation_settings_file_name) mapping_evaluation_parameters = settings_evaluation['MAPPING'] train(data_parameters, training_parameters, network_parameters, misc_parameters) logging.basicConfig(filename='evaluate-mapping-error.log') evaluate_mapping(mapping_evaluation_parameters) else: raise ValueError( 'Invalid mode value! Only supports: train, evaluate-data, evaluate-mapping, train-and-evaluate-mapping, clear-checkpoints, clear-logs, clear-experiment and clear-everything (req uncomment for safety!)')