Commit 697cd299 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

updated run to import 3D network

parent 3c3b8f2a
...@@ -38,7 +38,7 @@ import torch ...@@ -38,7 +38,7 @@ import torch
import torch.utils.data as data import torch.utils.data as data
from solver import Solver from solver import Solver
from BrainMapperUNet import BrainMapperUNet from BrainMapperUNet import BrainMapperUNet3D
from utils.data_utils import get_datasets, data_test_train_validation_split, update_shuffling_flag from utils.data_utils import get_datasets, data_test_train_validation_split, update_shuffling_flag
import utils.data_evaluation_utils as evaluations import utils.data_evaluation_utils as evaluations
from utils.data_logging_utils import LogWriter from utils.data_logging_utils import LogWriter
...@@ -131,38 +131,39 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet ...@@ -131,38 +131,39 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
pin_memory=True pin_memory=True
) )
if training_parameters['use_pre_trained']: if training_parameters['use_pre_trained']:
BrainMapperModel = torch.load(training_parameters['pre_trained_path']) BrainMapperModel = torch.load(training_parameters['pre_trained_path'])
else: else:
BrainMapperModel = BrainMapperUNet(network_parameters) BrainMapperModel = BrainMapperUNet3D(network_parameters)
solver = Solver(model=BrainMapperModel, # solver = Solver(model=BrainMapperModel,
device=misc_parameters['device'], # device=misc_parameters['device'],
number_of_classes=network_parameters['number_of_classes'], # number_of_classes=network_parameters['number_of_classes'],
experiment_name=training_parameters['experiment_name'], # experiment_name=training_parameters['experiment_name'],
optimizer_arguments={'lr': training_parameters['learning_rate'], # optimizer_arguments={'lr': training_parameters['learning_rate'],
'betas': training_parameters['optimizer_beta'], # 'betas': training_parameters['optimizer_beta'],
'eps': training_parameters['optimizer_epsilon'], # 'eps': training_parameters['optimizer_epsilon'],
'weight_decay': training_parameters['optimizer_weigth_decay'] # 'weight_decay': training_parameters['optimizer_weigth_decay']
}, # },
model_name=misc_parameters['model_name'], # model_name=misc_parameters['model_name'],
number_epochs=training_parameters['number_of_epochs'], # number_epochs=training_parameters['number_of_epochs'],
loss_log_period=training_parameters['loss_log_period'], # loss_log_period=training_parameters['loss_log_period'],
learning_rate_scheduler_step_size=training_parameters[ # learning_rate_scheduler_step_size=training_parameters[
'learning_rate_scheduler_step_size'], # 'learning_rate_scheduler_step_size'],
learning_rate_scheduler_gamma=training_parameters['learning_rate_scheduler_gamma'], # learning_rate_scheduler_gamma=training_parameters['learning_rate_scheduler_gamma'],
use_last_checkpoint=training_parameters['use_last_checkpoint'], # use_last_checkpoint=training_parameters['use_last_checkpoint'],
experiment_directory=misc_parameters['experiments_directory'], # experiment_directory=misc_parameters['experiments_directory'],
logs_directory=misc_parameters['logs_directory'] # logs_directory=misc_parameters['logs_directory']
) # )
solver.train(train_loader, validation_loader) # solver.train(train_loader, validation_loader)
model_output_path = os.path.join( # model_output_path = os.path.join(
misc_parameters['save_model_directory'], training_parameters['final_model_output_file']) # misc_parameters['save_model_directory'], training_parameters['final_model_output_file'])
BrainMapperModel.save(model_output_path) # BrainMapperModel.save(model_output_path)
print("Final Model Saved in: {}".format(model_output_path)) # print("Final Model Saved in: {}".format(model_output_path))
def evaluate_score(training_parameters, network_parameters, misc_parameters, evaluation_parameters): def evaluate_score(training_parameters, network_parameters, misc_parameters, evaluation_parameters):
...@@ -351,18 +352,16 @@ if __name__ == '__main__': ...@@ -351,18 +352,16 @@ if __name__ == '__main__':
data_shuffling_flag = data_parameters['data_split_flag'] data_shuffling_flag = data_parameters['data_split_flag']
load_data(data_parameters) if data_shuffling_flag == True:
# Here we shuffle the data!
# if data_shuffling_flag == True: data_test_train_validation_split(data_parameters['data_directory'], data_parameters['train_percentage'], data_parameters['validation_percentage'])
# # Here we shuffle the data! update_shuffling_flag('settings.ini')
# data_test_train_validation_split(data_parameters['data_directory'], data_parameters['train_percentage'], data_parameters['validation_percentage']) # TODO: This might also be a very good point to add cross-validation later
# update_shuffling_flag('settings.ini') else:
# # TODO: This might also be a very good point to add cross-validation later
# else:
# if arguments.mode == 'train': if arguments.mode == 'train':
# train(data_parameters, training_parameters, train(data_parameters, training_parameters,
# network_parameters, misc_parameters) network_parameters, misc_parameters)
# elif arguments.mode == 'evaluate-score': # elif arguments.mode == 'evaluate-score':
# evaluate_score(training_parameters, # evaluate_score(training_parameters,
# network_parameters, misc_parameters, evaluation_parameters) # network_parameters, misc_parameters, evaluation_parameters)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment