Commit 697cd299 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

updated run to import 3D network

parent 3c3b8f2a
......@@ -38,7 +38,7 @@ import torch
import torch.utils.data as data
from solver import Solver
from BrainMapperUNet import BrainMapperUNet
from BrainMapperUNet import BrainMapperUNet3D
from utils.data_utils import get_datasets, data_test_train_validation_split, update_shuffling_flag
import utils.data_evaluation_utils as evaluations
from utils.data_logging_utils import LogWriter
......@@ -131,38 +131,39 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
pin_memory=True
)
if training_parameters['use_pre_trained']:
BrainMapperModel = torch.load(training_parameters['pre_trained_path'])
else:
BrainMapperModel = BrainMapperUNet(network_parameters)
solver = Solver(model=BrainMapperModel,
device=misc_parameters['device'],
number_of_classes=network_parameters['number_of_classes'],
experiment_name=training_parameters['experiment_name'],
optimizer_arguments={'lr': training_parameters['learning_rate'],
'betas': training_parameters['optimizer_beta'],
'eps': training_parameters['optimizer_epsilon'],
'weight_decay': training_parameters['optimizer_weigth_decay']
},
model_name=misc_parameters['model_name'],
number_epochs=training_parameters['number_of_epochs'],
loss_log_period=training_parameters['loss_log_period'],
learning_rate_scheduler_step_size=training_parameters[
'learning_rate_scheduler_step_size'],
learning_rate_scheduler_gamma=training_parameters['learning_rate_scheduler_gamma'],
use_last_checkpoint=training_parameters['use_last_checkpoint'],
experiment_directory=misc_parameters['experiments_directory'],
logs_directory=misc_parameters['logs_directory']
)
solver.train(train_loader, validation_loader)
model_output_path = os.path.join(
misc_parameters['save_model_directory'], training_parameters['final_model_output_file'])
BrainMapperModel.save(model_output_path)
print("Final Model Saved in: {}".format(model_output_path))
BrainMapperModel = BrainMapperUNet3D(network_parameters)
# solver = Solver(model=BrainMapperModel,
# device=misc_parameters['device'],
# number_of_classes=network_parameters['number_of_classes'],
# experiment_name=training_parameters['experiment_name'],
# optimizer_arguments={'lr': training_parameters['learning_rate'],
# 'betas': training_parameters['optimizer_beta'],
# 'eps': training_parameters['optimizer_epsilon'],
# 'weight_decay': training_parameters['optimizer_weigth_decay']
# },
# model_name=misc_parameters['model_name'],
# number_epochs=training_parameters['number_of_epochs'],
# loss_log_period=training_parameters['loss_log_period'],
# learning_rate_scheduler_step_size=training_parameters[
# 'learning_rate_scheduler_step_size'],
# learning_rate_scheduler_gamma=training_parameters['learning_rate_scheduler_gamma'],
# use_last_checkpoint=training_parameters['use_last_checkpoint'],
# experiment_directory=misc_parameters['experiments_directory'],
# logs_directory=misc_parameters['logs_directory']
# )
# solver.train(train_loader, validation_loader)
# model_output_path = os.path.join(
# misc_parameters['save_model_directory'], training_parameters['final_model_output_file'])
# BrainMapperModel.save(model_output_path)
# print("Final Model Saved in: {}".format(model_output_path))
def evaluate_score(training_parameters, network_parameters, misc_parameters, evaluation_parameters):
......@@ -351,18 +352,16 @@ if __name__ == '__main__':
data_shuffling_flag = data_parameters['data_split_flag']
load_data(data_parameters)
# if data_shuffling_flag == True:
# # Here we shuffle the data!
# data_test_train_validation_split(data_parameters['data_directory'], data_parameters['train_percentage'], data_parameters['validation_percentage'])
# update_shuffling_flag('settings.ini')
# # TODO: This might also be a very good point to add cross-validation later
# else:
if data_shuffling_flag == True:
# Here we shuffle the data!
data_test_train_validation_split(data_parameters['data_directory'], data_parameters['train_percentage'], data_parameters['validation_percentage'])
update_shuffling_flag('settings.ini')
# TODO: This might also be a very good point to add cross-validation later
else:
# if arguments.mode == 'train':
# train(data_parameters, training_parameters,
# network_parameters, misc_parameters)
if arguments.mode == 'train':
train(data_parameters, training_parameters,
network_parameters, misc_parameters)
# elif arguments.mode == 'evaluate-score':
# evaluate_score(training_parameters,
# network_parameters, misc_parameters, evaluation_parameters)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment