Commit c14030f7 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added save_checkpoint to solver train

parent 8ab6cb09
......@@ -209,21 +209,30 @@ class Solver():
self.LogWriter.loss_per_epoch(losses, phase, epoch)
dice_score_mean = self.LogWriter.dice_score_per_epoch(phase, output_array, y_array, epoch)
if phase === 'test':
if phase == 'test':
if dice_score_mean > self.best_mean_score:
self.best_mean_score = dice_score_mean
self.best_mean_score_epoch = epoch
index = np.random.choice(len(dataloaders[phase].dataset.X), size=3, replace= False)
self.LogWriter.sample_image_per_epoch(prediction= model.predict(dataloaders[phase].dataset.X[index], self.device)
ground_truth= dataloaders[phase].dataset.y[index],\
self.LogWriter.sample_image_per_epoch(prediction= model.predict(dataloaders[phase].dataset.X[index], self.device),
ground_truth= dataloaders[phase].dataset.y[index],
phase= phase
epoch= epoch)
print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
self.save_checkpoint() # TODO - write function and save the checkpoint!
self.save_checkpoint(state={'epoch': epoch + 1,
'start_iteration': iteration + 1,
'arch': self.model_name,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': learning_rate_scheduler.state_dict()
},
filename= os.path.join(self.experiment_directory_path, checkpoint_directory, 'checkpoint_epoch_' + str(epoch) + '.' + checkpoint_extension)
)
self.LogWriter.close()
print('----------------------------------------')
......@@ -236,9 +245,6 @@ class Solver():
# TODO: MAKE SURE any log writer function is closed!
def save_model(self):
pass
def save_checkpoint(self, state, filename):
"""General Checkpoint Save
......@@ -257,10 +263,25 @@ class Solver():
torch.save(state, filename)
def load_checkpoint(self):
pass
def load_checkpoint(self, epoch= None):
"""General Checkpoint Loader
This function loads a previous checkpoint for inference and/or resuming training
Args:
epoch (int): Current epoch value
Returns:
None
Raises:
None
"""
def _load_checkpoint_file(self):
# Name is private = can't be called outisde of this module
pass
def save_model(self):
pass
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment