Commit a1a5c934 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added dice score calculation and plotting

parent 1a6a8939
......@@ -13,6 +13,7 @@ To use content from this folder, import the functions and instantiate them as yo
"""
import os
import matplotlib
import matplotlib.pyplot as plt
import shutil
import logging
......@@ -22,6 +23,8 @@ import numpy as np
# More here: https://tensorboardx.readthedocs.io/en/latest/tensorboard.html
from tensorboardX import SummaryWriter
import utils.data_evaluation_utils as evaluation
plt.axis('scaled')
class LogWriter():
......@@ -91,7 +94,7 @@ class LogWriter():
self.logger.info(msg= message)
def loss_per_iteration(self, loss_per_iteration, batch_index, current_iteration):
def loss_per_iteration(self, loss_per_iteration, batch_index, iteration):
"""Log of loss / iteration
This function records the loss for every iteration.
......@@ -99,7 +102,7 @@ class LogWriter():
Args:
loss_per_iteration (torch.tensor): Value of loss for every iteration step
batch_index (int): Index of current batch
current_iteartion (int): Current iteration value
iteration (int): Current iteration value
Returns:
None
......@@ -109,7 +112,7 @@ class LogWriter():
"""
print("Loss for Iteration {} is: {}".format(batch_index, loss_per_iteration))
self.log_writer['train'].add_scalar(tag= 'loss / iteration', loss_per_iteration, current_iteration)
self.log_writer['train'].add_scalar(tag= 'loss / iteration', loss_per_iteration, iteration)
def loss_per_epoch(self, losses, phase, epoch):
"""Log function
......@@ -136,31 +139,67 @@ class LogWriter():
print("Loss for Epoch {} of {} is: {}".format(epoch, phase, loss))
self.log_writer[phase].add_scalar(tag= 'loss / iteration', loss, epoch)
def confusion_matrix_per_epoch(self):
"""Log function
# Currently, no confusion matrix is required
# TODO: add a confusion matrix per epoch and confusion matrix plot functions if required
This function logs a message in the logger.
def dice_score_per_epoch(self, phase, outputs, correct_labels, epoch):
"""Function calculating dice score for each epoch
This function computes the dice score for each epoch.
Args:
message (str): Message to be logged
phase (str): Current run mode or phase
outputs (torch.tensor): Tensor of all the network outputs (Y-hat)
correct_labels (torch.tensor): Output ground-truth labelled data (Y)
epoch (int): Current epoch value
Returns:
None
mean_dice_score (torch.tensor): Mean dice score value
Raises:
Raises
None
"""
print("Dice Score is being calculated...", end='', flush= True)
dice_score = evaluation.dice_score_calculator(outputs, correct_labels, self.number_of_classes)
mean_dice_score = torch.mean(dice_score)
self.plot_dice_score(dice_score, phase, plot_name='dice_score_per_epoch', title='Dice Score', epoch)
self.log_writer[phase].add_scalar(tag= 'loss / iteration', loss, epoch)
print("Dice Score calculated successfully")
return mean_dice_score
def plot_dice_score(self, dice_score, phase, plot_name, title, epochs):
"""Function plotting dice score for multiple epochs
def plot_confusion_matrix(self):
pass
This function plots the dice score for each epoch.
def dice_score_per_epoch(self):
pass
Args:
dice_score (torch.tensor): Dice score value for each class
phase (str): Current run mode or phase
plot_name (str): Caption name for later refference
title (str): Plot title
epoch (int): Current epoch value
def plot_dice_score(self):
pass
Returns:
None
Raises
None
"""
figure = matplotlib.figure.Figure() # Might add some arguments here later
ax = figure.add_subplot(1, 1, 1)
ax.set_xlabel(title)
ax.xaxis.set_label_position('top')
ax.bar(np.arange(self.number_of_classes), dice_score)
ax.set_xticks(np.arange(self.number_of_classes))
ax.set_xticklabels(self.labels)
ax.xaxis.tick_bottom()
if step:
self.log_writer[phase].add_figure(plot_name + '/' + phase, figure, global_step= epochs)
else:
self.log_writer[phase].add_figure(plot_name + '/' + phase, figure)
def plot_evaluation_box(self):
pass
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment