Commit d492260e authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

bug fixes for better cross validation

parent 09367ded
......@@ -153,10 +153,13 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
BrainMapperModel.reset_parameters()
optimizer = torch.optim.Adam
solver = Solver(model=BrainMapperModel,
device=misc_parameters['device'],
number_of_classes=network_parameters['number_of_classes'],
experiment_name=training_parameters['experiment_name'],
optimizer= optimizer,
optimizer_arguments={'lr': training_parameters['learning_rate'],
'betas': training_parameters['optimizer_beta'],
'eps': training_parameters['optimizer_epsilon'],
......@@ -185,7 +188,7 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
print("Final Model Saved in: {}".format(model_output_path))
del train_data, validation_data, train_loader, validation_loader, BrainMapperModel, solver
del train_data, validation_data, train_loader, validation_loader, BrainMapperModel, solver, optimizer
torch.cuda.empty_cache()
return validation_loss
......
......@@ -59,7 +59,7 @@ class Solver():
device,
number_of_classes,
experiment_name,
optimizer=torch.optim.Adam,
optimizer,
optimizer_arguments={},
loss_function=MSELoss(),
model_name='BrainMapper',
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment