Commit 6ef6d440 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added dice score calculator function

parent a1a5c934
"""Data Evaluation Functions
Description:
-------------
This folder contains several functions which, either on their own or included in larger pieces of software, perform data evaluation tasks.
Usage
-------------
To use content from this folder, import the functions and instantiate them as you wish to use them:
from utils.data_evaluation_utils import function_name
"""
import os
import numpy as np
import torch
def dice_score_calculator(outputs, correct_labels, number_of_classes, number_of_samples=10, mode='train'):
"""Dice Score Calculator
This function calculates the confusion matrix for the given data.
The function returns the average dice score and the confusion matrix
Args:
outputs (torch.tensor): Tensor of all the network outputs (Y-hat)
correct_labels (torch.tensor): Output ground-truth labelled data (Y)
number_of_classes (int): Number of classes
number_of_samples (int): Output shape for randomly generated samples
mode (str): Current run mode or phase
Returns:
dice_score (torch.tensor): Dice score value for each class
Raises:
None
"""
dice_score = torch.zeros(number_of_classes)
if mode == 'train':
samples = np.random.choice(len(outputs), number_of_samples)
outputs, correct_labels = outputs[samples], correct_labels[samples]
for i in range(number_of_classes):
ground_truth = (correct_labels == i).float()
predictions = (outputs == i).float()
intersection = 2.0 * torch.sum(torch.mul(ground_truth, predictions))
reunion = torch.sum(ground_truth) + torch.sum(predictions) + 1e-4
dice_score[i] = torch.div(intersection, reunion)
return dice_score
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment