Commit 1a6a8939 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added logs for loss per iteration and epoch

parent 2d53f299
...@@ -16,6 +16,7 @@ import os ...@@ -16,6 +16,7 @@ import os
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import shutil import shutil
import logging import logging
import numpy as np
# The SummaryWriter class provides a high-level API to create an event file in a given directory and add summaries and events to it. # The SummaryWriter class provides a high-level API to create an event file in a given directory and add summaries and events to it.
# More here: https://tensorboardx.readthedocs.io/en/latest/tensorboard.html # More here: https://tensorboardx.readthedocs.io/en/latest/tensorboard.html
...@@ -90,15 +91,68 @@ class LogWriter(): ...@@ -90,15 +91,68 @@ class LogWriter():
self.logger.info(msg= message) self.logger.info(msg= message)
def loss_per_iteration(self): def loss_per_iteration(self, loss_per_iteration, batch_index, current_iteration):
pass """Log of loss / iteration
def loss_per_epoch(self): This function records the loss for every iteration.
pass
Args:
loss_per_iteration (torch.tensor): Value of loss for every iteration step
batch_index (int): Index of current batch
current_iteartion (int): Current iteration value
Returns:
None
Raises:
None
"""
print("Loss for Iteration {} is: {}".format(batch_index, loss_per_iteration))
self.log_writer['train'].add_scalar(tag= 'loss / iteration', loss_per_iteration, current_iteration)
def loss_per_epoch(self, losses, phase, epoch):
"""Log function
This function records the loss for every epoch.
Args:
losses (list): Values of all the losses recorded during the training epoch
phase (str): Current run mode or phase
epoch (int): Current epoch value
Returns:
None
Raises:
None
"""
if phase == 'train':
loss = losses[-1]
else:
loss = np.mean(losses)
print("Loss for Epoch {} of {} is: {}".format(epoch, phase, loss))
self.log_writer[phase].add_scalar(tag= 'loss / iteration', loss, epoch)
def confusion_matrix_per_epoch(self): def confusion_matrix_per_epoch(self):
pass """Log function
This function logs a message in the logger.
Args:
message (str): Message to be logged
Returns:
None
Raises:
None
"""
def plot_confusion_matrix(self): def plot_confusion_matrix(self):
pass pass
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment