Commit a897de36 authored by Andrei Roibu's avatar Andrei Roibu
Browse files

fixed checkpoint loader

parent 91ecfd99
......@@ -376,14 +376,14 @@ class Solver():
self.start_epoch = checkpoint['epoch']
self.start_iteration = checkpoint['start_iteration']
# We are not loading the model_name as we might want to pre-train a model and then use it.
self.model.load_state_dict = checkpoint['state_dict']
self.optimizer.load_state_dict = checkpoint['optimizer']
self.model.load_state_dict(checkpoint['state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
for state in self.optimizer.state.values():
for state in self.optimizer.state. ():
for key, value in state.items():
if torch.is_tensor(value):
state[key] = value.to(self.device)
self.learning_rate_scheduler.load_state_dict = checkpoint['scheduler']
self.learning_rate_scheduler.load_state_dict(checkpoint['scheduler'])
self.LogWriter.log(
"Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment