Commit a1a5c934 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added dice score calculation and plotting

parent 1a6a8939
...@@ -13,6 +13,7 @@ To use content from this folder, import the functions and instantiate them as yo ...@@ -13,6 +13,7 @@ To use content from this folder, import the functions and instantiate them as yo
""" """
import os import os
import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import shutil import shutil
import logging import logging
...@@ -22,6 +23,8 @@ import numpy as np ...@@ -22,6 +23,8 @@ import numpy as np
# More here: https://tensorboardx.readthedocs.io/en/latest/tensorboard.html # More here: https://tensorboardx.readthedocs.io/en/latest/tensorboard.html
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
import utils.data_evaluation_utils as evaluation
plt.axis('scaled') plt.axis('scaled')
class LogWriter(): class LogWriter():
...@@ -91,7 +94,7 @@ class LogWriter(): ...@@ -91,7 +94,7 @@ class LogWriter():
self.logger.info(msg= message) self.logger.info(msg= message)
def loss_per_iteration(self, loss_per_iteration, batch_index, current_iteration): def loss_per_iteration(self, loss_per_iteration, batch_index, iteration):
"""Log of loss / iteration """Log of loss / iteration
This function records the loss for every iteration. This function records the loss for every iteration.
...@@ -99,7 +102,7 @@ class LogWriter(): ...@@ -99,7 +102,7 @@ class LogWriter():
Args: Args:
loss_per_iteration (torch.tensor): Value of loss for every iteration step loss_per_iteration (torch.tensor): Value of loss for every iteration step
batch_index (int): Index of current batch batch_index (int): Index of current batch
current_iteartion (int): Current iteration value iteration (int): Current iteration value
Returns: Returns:
None None
...@@ -109,7 +112,7 @@ class LogWriter(): ...@@ -109,7 +112,7 @@ class LogWriter():
""" """
print("Loss for Iteration {} is: {}".format(batch_index, loss_per_iteration)) print("Loss for Iteration {} is: {}".format(batch_index, loss_per_iteration))
self.log_writer['train'].add_scalar(tag= 'loss / iteration', loss_per_iteration, current_iteration) self.log_writer['train'].add_scalar(tag= 'loss / iteration', loss_per_iteration, iteration)
def loss_per_epoch(self, losses, phase, epoch): def loss_per_epoch(self, losses, phase, epoch):
"""Log function """Log function
...@@ -136,31 +139,67 @@ class LogWriter(): ...@@ -136,31 +139,67 @@ class LogWriter():
print("Loss for Epoch {} of {} is: {}".format(epoch, phase, loss)) print("Loss for Epoch {} of {} is: {}".format(epoch, phase, loss))
self.log_writer[phase].add_scalar(tag= 'loss / iteration', loss, epoch) self.log_writer[phase].add_scalar(tag= 'loss / iteration', loss, epoch)
def confusion_matrix_per_epoch(self): # Currently, no confusion matrix is required
"""Log function # TODO: add a confusion matrix per epoch and confusion matrix plot functions if required
This function logs a message in the logger. def dice_score_per_epoch(self, phase, outputs, correct_labels, epoch):
"""Function calculating dice score for each epoch
This function computes the dice score for each epoch.
Args: Args:
message (str): Message to be logged phase (str): Current run mode or phase
outputs (torch.tensor): Tensor of all the network outputs (Y-hat)
correct_labels (torch.tensor): Output ground-truth labelled data (Y)
epoch (int): Current epoch value
Returns: Returns:
None mean_dice_score (torch.tensor): Mean dice score value
Raises: Raises
None None
""" """
print("Dice Score is being calculated...", end='', flush= True)
dice_score = evaluation.dice_score_calculator(outputs, correct_labels, self.number_of_classes)
mean_dice_score = torch.mean(dice_score)
self.plot_dice_score(dice_score, phase, plot_name='dice_score_per_epoch', title='Dice Score', epoch)
self.log_writer[phase].add_scalar(tag= 'loss / iteration', loss, epoch)
print("Dice Score calculated successfully")
return mean_dice_score
def plot_dice_score(self, dice_score, phase, plot_name, title, epochs):
def plot_confusion_matrix(self): """Function plotting dice score for multiple epochs
pass
def dice_score_per_epoch(self): This function plots the dice score for each epoch.
pass
def plot_dice_score(self): Args:
pass dice_score (torch.tensor): Dice score value for each class
phase (str): Current run mode or phase
plot_name (str): Caption name for later refference
title (str): Plot title
epoch (int): Current epoch value
Returns:
None
Raises
None
"""
figure = matplotlib.figure.Figure() # Might add some arguments here later
ax = figure.add_subplot(1, 1, 1)
ax.set_xlabel(title)
ax.xaxis.set_label_position('top')
ax.bar(np.arange(self.number_of_classes), dice_score)
ax.set_xticks(np.arange(self.number_of_classes))
ax.set_xticklabels(self.labels)
ax.xaxis.tick_bottom()
if step:
self.log_writer[phase].add_figure(plot_name + '/' + phase, figure, global_step= epochs)
else:
self.log_writer[phase].add_figure(plot_name + '/' + phase, figure)
def plot_evaluation_box(self): def plot_evaluation_box(self):
pass pass
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment