Commit 11b9a3b8 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

added flag for AdamW use

parent ab7a80b4
......@@ -154,8 +154,10 @@ def train(data_parameters, training_parameters, network_parameters, misc_paramet
BrainMapperModel.reset_parameters(custom_weight_reset_flag)
optimizer = torch.optim.Adam
# optimizer = torch.optim.AdamW
if training_parameters['adam_w_flag'] == True:
optimizer = torch.optim.AdamW
else:
optimizer = torch.optim.Adam
solver = Solver(model=BrainMapperModel,
device=misc_parameters['device'],
......
......@@ -22,6 +22,7 @@ loss_log_period = 50
learning_rate_scheduler_step_size = 5
learning_rate_scheduler_gamma = 1e-1
use_last_checkpoint = False
adam_w_flag = False
[NETWORK]
kernel_heigth = 3
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment