Commit 6ef6d440 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added dice score calculator function

parent a1a5c934
"""Data Evaluation Functions
This folder contains several functions which, either on their own or included in larger pieces of software, perform data evaluation tasks.
To use content from this folder, import the functions and instantiate them as you wish to use them:
from utils.data_evaluation_utils import function_name
import os
import numpy as np
import torch
def dice_score_calculator(outputs, correct_labels, number_of_classes, number_of_samples=10, mode='train'):
"""Dice Score Calculator
This function calculates the confusion matrix for the given data.
The function returns the average dice score and the confusion matrix
outputs (torch.tensor): Tensor of all the network outputs (Y-hat)
correct_labels (torch.tensor): Output ground-truth labelled data (Y)
number_of_classes (int): Number of classes
number_of_samples (int): Output shape for randomly generated samples
mode (str): Current run mode or phase
dice_score (torch.tensor): Dice score value for each class
dice_score = torch.zeros(number_of_classes)
if mode == 'train':
samples = np.random.choice(len(outputs), number_of_samples)
outputs, correct_labels = outputs[samples], correct_labels[samples]
for i in range(number_of_classes):
ground_truth = (correct_labels == i).float()
predictions = (outputs == i).float()
intersection = 2.0 * torch.sum(torch.mul(ground_truth, predictions))
reunion = torch.sum(ground_truth) + torch.sum(predictions) + 1e-4
dice_score[i] = torch.div(intersection, reunion)
return dice_score
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment