Commit d3b668b0 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

refactored train variable names to validation

parent 98ccda23
......@@ -52,7 +52,6 @@ def load_data(data_parameters):
"""Dataset Loader
This function loads the training and validation datasets.
TODO: Will need to define if all the training data is loaded as bulk or individually!
Args:
data_parameters (dict): Dictionary containing relevant information for the datafiles.
......@@ -86,7 +85,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
training_parameters(dict): Dictionary containing relevant hyperparameters for training the network.
training_parameters = {
'training_batch_size': 5
'test_batch_size: 5
'validation_batch_size: 5
'use_pre_trained': False
'pre_trained_path': 'pre_trained/path'
'experiment_name': 'experiment_name'
......@@ -114,7 +113,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
}
"""
train_data, test_data = load_data(data_parameters)
train_data, validation_data = load_data(data_parameters)
train_loader = data.DataLoader(
dataset=train_data,
......@@ -124,9 +123,9 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
pin_memory=True
)
test_loader = data.DataLoader(
dataset=test_data,
batch_size=training_parameters['test_batch_size'],
validation_loader = data.DataLoader(
dataset=validation_data,
batch_size=training_parameters['validation_batch_size'],
shuffle=False,
num_workers=4,
pin_memory=True
......@@ -157,7 +156,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
logs_directory=misc_parameters['logs_directory']
)
solver.train(train_loader, test_loader)
solver.train(train_loader, validation_loader)
model_output_path = os.path.join(
misc_parameters['save_model_directory'], training_parameters['final_model_output_file'])
......
......@@ -13,7 +13,7 @@ validation_target_file = "fMRI/rfMRI_25.dr/dr_stage2.nii.gz"
[TRAINING]
training_batch_size = 5
test_batch_size = 5
validation_batch_size = 5
use_pre_trained = False
pre_trained_path = "saved_models/preTrained.pth.tar"
experiment_name = "experiment_name"
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment