Commit 0a75ffb5 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added loging functionality to the train function

parent 05d17c56
......@@ -18,7 +18,7 @@ import torch
from datetime import datetime
from utils.losses import MSELoss
from utils.data_utils import create_folder
from utils.data_logging_utils import #BLA - need to write something first
from utils.data_logging_utils import LogWriter
from torch.optim import lr_scheduler
checkpoint_directory = 'checkpoints'
......@@ -103,11 +103,17 @@ class Solver():
self.start_epoch = 1
self.start_iteration = 1
self.best_mean_score = 0
self.best_mean_epoch = 0
self.best_mean_score_epoch = 0
if use_last_checkpoint:
self.load_checkpoint()
self.LogWriter = LogWriter(number_of_classes= number_of_classes,
logs_directory= logs_directory,
experiment_name= experiment_name,
use_last_checkpoint= use_last_checkpoint,
labels= labels)
def train(self, train_loader, test_loader):
"""Training Function
......@@ -178,8 +184,7 @@ class Solver():
if batch_index % self.loss_log_period == 0:
# TODO: NEED A FUNCTION that logs outputs for debugging!
# Here, I need it to log the loss, batch id and iteration number\
self.LogWriter.loss_per_iteration(self, loss.item(), batch_index, iteration)
iteration += 1
......@@ -201,13 +206,26 @@ class Solver():
with torch.no_grad():
output_array, y_array = torch.cat(outputs), torch.cat(y_values)
# TODO - using log functions, record loss per epoch, maybe generated images per epoch, dice score and any other relevant metrics?
self.LogWriter.loss_per_epoch(losses, phase, epoch)
dice_score_mean = self.LogWriter.dice_score_per_epoch(phase, output_array, y_array, epoch)
if phase === 'test':
if dice_score_mean > self.best_mean_score:
self.best_mean_score = dice_score_mean
self.best_mean_score_epoch = epoch
index = np.random.choice(len(dataloaders[phase].dataset.X), size=3, replace= False)
self.LogWriter.sample_image_per_epoch(prediction= model.predict(dataloaders[phase].dataset.X[index], self.device)
ground_truth= dataloaders[phase].dataset.y[index],\
phase= phase
epoch= epoch)
print("Epoch {}/{} DONE!".format(epoch, self.number_epochs))
self.save_checkpoint() # TODO - write function and save the checkpoint!
self.LogWriter.close()
print('----------------------------------------')
print('TRAINING IS COMPLETE!')
print('=====================')
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment