Commit ee7d936a authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

completed training function

parent 5a96100e
......@@ -20,6 +20,8 @@ import torch
from utils.data_utils import get_datasets
from BrainMapperUNet import BrainMapperUNet
import torch.utils.data as data
from solver import Solver
import os
# Set the default floating point tensor type to FloatTensor
......@@ -34,11 +36,11 @@ def load_data(data_parameters):
Args:
data_parameters (dict): Dictionary containing relevant information for the datafiles.
data_parameters = {
data_directory: 'path/to/directory'
train_data_file: 'training_data'
train_output_targets: 'training_targets'
test_data_file: 'testing_data'
test_target_file: 'testing_targets'
'data_directory': 'path/to/directory'
'train_data_file': 'training_data'
'train_output_targets': 'training_targets'
'test_data_file': 'testing_data'
'test_target_file': 'testing_targets'
}
Returns:
......@@ -57,23 +59,46 @@ def load_data(data_parameters):
return train_data, test_data
def train(data_parameters, training_parameters, network_parameters):
"""Name
def train(data_parameters, training_parameters, network_parameters, misc_parameters):
"""Training Function
Desc
This function trains a given model using the provided training data.
Currently, the data loaded is set to have multiple sub-processes.
A high enough number of workers assures that CPU computations are efficiently managed, i.e. that the bottleneck is indeed the neural network's forward and backward operations on the GPU (and not data generation)
Loader memory is also pinned, to speed up data transfer from CPU to GPU by using the page-locked memory.
Train data is also re-shuffled at each training epoch.
Args:
data_parameters(dict):{
paramters
data_parameters (dict): Dictionary containing relevant information for the datafiles.
data_parameters = {
'data_directory': 'path/to/directory'
'train_data_file': 'training_data'
'train_output_targets': 'training_targets'
'test_data_file': 'testing_data'
'test_target_file': 'testing_targets'
}
training_parameters(dict):{
paraters
training_parameters(dict): Dictionary containing relevant hyperparameters for training the network.
training_parameters = {
'training_batch_size': 5
'test_batch_size: 5
'use_pre_trained': False
'pre_trained_path': 'pre_trained/path'
'experiment_name': 'experiment_name'
'learning_rate': 1e-4
'optimizer_beta': (0.9, 0.999)
'optimizer_epsilon': 1e-8
'optimizer_weigth_decay': 1e-5
'number_of_epochs': 10
'loss_log_period': 50
'learning_rate_scheduler_step_size': 3
'learning_rate_scheduler_gamma': 1e-1
'use_last_checkpoint': True
'final_model_output_file': 'path/to/model'
}
network_parameters (dict): Contains information relevant parameters = {
network_parameters (dict): Contains information relevant parameters
network_parameters= {
'kernel_heigth': 5
'kernel_width': 5
'kernel_classification': 1
......@@ -87,6 +112,15 @@ def train(data_parameters, training_parameters, network_parameters):
'number_of_classes': 1
}
misc_parameters (dict): Dictionary of aditional hyperparameters
misc_parameters = {
'save_model_directory': 'directory_name'
'model_name': 'BrainMapper'
'logs_directory': 'log-directory'
'device': 1
'experiments_directory': 'experiments-directory'
}
Returns:
None
......@@ -98,7 +132,7 @@ def train(data_parameters, training_parameters, network_parameters):
train_loader = data.DataLoader(
dataset= train_data,
batch_size= training_parameters['train_batch_size'],
batch_size= training_parameters['training_batch_size'],
shuffle= True,
num_workers= 4,
pin_memory= True
......@@ -117,9 +151,31 @@ def train(data_parameters, training_parameters, network_parameters):
else:
BrainMapperModel = BrainMapperUNet(network_parameters)
solver = Solver(
# TODO - need to write the solver !
)
solver = Solver(model= BrainMapperModel,
device= misc_parameters['device'],
number_of_classes= network_parameters['number_of_classes'],
experiment_name= training_parameters['experiment_name'],
optimizer_arguments = {'lr': training_parameters['learning_rate'],
'betas': training_parameters['optimizer_beta'],
'eps': training_parameters['optimizer_epsilon'],
'weight_decay': training_parameters['optimizer_weigth_decay']
},
model_name = misc_parameters['model_name'],
number_epochs = training_parameters['number_of_epochs'],
loss_log_period = training_parameters['loss_log_period'],
learning_rate_scheduler_step_size = training_parameters['learning_rate_scheduler_step_size'],
learning_rate_scheduler_gamma = training_parameters['learning_rate_scheduler_gamma'],
use_last_checkpoint = training_parameters['use_last_checkpoint'],
experiment_directory = misc_parameters['experiments_directory'],
logs_directory = misc_parameters['logs_directory']
)
solver.train(train_loader, test_loader)
model_output_path = os.path.join(misc_parameters['save_model_directory'], training_parameters['final_model_output_file'])
BrainMapperModel.save(model_output_path)
print("Final Model Saved in: {}".format(model_output_path))
def evaluate_path():
pass
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment