Commit 9f21c7b3 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added a complete load_checkpoint function

parent 8f3895ff
......@@ -15,6 +15,8 @@ To use this module, import it and instantiate is as you wish:
import os
import numpy as np
import torch
import glob
from datetime import datetime
from utils.losses import MSELoss
from utils.data_utils import create_folder
......@@ -243,8 +245,6 @@ class Solver():
print('Training Duration: {}'.format(end_time - start_time))
print('****************************************************************')
# TODO: MAKE SURE any log writer function is closed!
def save_checkpoint(self, state, filename):
"""General Checkpoint Save
......@@ -283,9 +283,18 @@ class Solver():
checkpoint_file_path = os.path.join(self.experiment_directory_path, checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
self._checkpoint_reader(checkpoint_file_path)
else:
pizdetz!
universal_path = os.path.join(self.experiment_directory_path, checkpoint_directory, '*.' + checkpoint_extension)
files_in_universal_path = glob.glob(universal_path)
# We will sort through all the files in path to see which one is most recent
if len(files_in_universal_path) > 0:
checkpoint_file_path = max(files_in_universal_path, key= os.path.getatime)
self._checkpoint_reader(checkpoint_file_path)
else:
self.LogWriter.log("No Checkpoint found at {}".format(os.path.join(self.experiment_directory_path, checkpoint_directory)))
def _checkpoint_reader(self, checkpoint_file_path):
"""Checkpoint Reader
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment