Commit 6ef6d440 by Andrei-Claudiu Roibu 🖥

### added dice score calculator function

parent a1a5c934
 """Data Evaluation Functions Description: ------------- This folder contains several functions which, either on their own or included in larger pieces of software, perform data evaluation tasks. Usage ------------- To use content from this folder, import the functions and instantiate them as you wish to use them: from utils.data_evaluation_utils import function_name """ import os import numpy as np import torch def dice_score_calculator(outputs, correct_labels, number_of_classes, number_of_samples=10, mode='train'): """Dice Score Calculator This function calculates the confusion matrix for the given data. The function returns the average dice score and the confusion matrix Args: outputs (torch.tensor): Tensor of all the network outputs (Y-hat) correct_labels (torch.tensor): Output ground-truth labelled data (Y) number_of_classes (int): Number of classes number_of_samples (int): Output shape for randomly generated samples mode (str): Current run mode or phase Returns: dice_score (torch.tensor): Dice score value for each class Raises: None """ dice_score = torch.zeros(number_of_classes) if mode == 'train': samples = np.random.choice(len(outputs), number_of_samples) outputs, correct_labels = outputs[samples], correct_labels[samples] for i in range(number_of_classes): ground_truth = (correct_labels == i).float() predictions = (outputs == i).float() intersection = 2.0 * torch.sum(torch.mul(ground_truth, predictions)) reunion = torch.sum(ground_truth) + torch.sum(predictions) + 1e-4 dice_score[i] = torch.div(intersection, reunion) return dice_score \ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!