Commit 8f3895ff authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added checkpoint reader private function

parent c14030f7
......@@ -279,9 +279,44 @@ class Solver():
"""
def _load_checkpoint_file(self):
# Name is private = can't be called outisde of this module
pass
if epoch is None:
checkpoint_file_path = os.path.join(self.experiment_directory_path, checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
self._checkpoint_reader(checkpoint_file_path)
else:
pizdetz!
def _checkpoint_reader(self, checkpoint_file_path):
"""Checkpoint Reader
This private function reads a checkpoint file and then loads the relevant variables
Args:
checkpoint_file_path (str): path to checkpoint file
Returns:
None
Raises:
None
"""
self.LogWriter.log("Loading Checkpoint {}".format(checkpoint_file_path))
checkpoint = torch.load(checkpoint_file_path)
self.start_epoch = checkpoint['epoch']
self.start_iteration = checkpoint['start_iteration']
# We are not loading the model_name as we might want to pre-train a model and then use it.
self.model.load_state_dict = checkpoint['state_dict']
self.optimizer.load_state_dict = checkpoint['optimizer']
self.scheduler.load_state_dict = checkpoint['scheduler']
for state in self.optimizer.state.values():
for key, value in state.items{}:
if torch.is_tensor(value):
state[key] = value.to(self.device)
self.LogWriter.log("Checkpoint Loaded {} - epoch {}".format(checkpoint_file_path, checkpoint['epoch']))
def save_model(self):
pass
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment