Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Andrei-Claudiu Roibu
BrainMapper
Commits
6ef6d440
Commit
6ef6d440
authored
Mar 30, 2020
by
Andrei-Claudiu Roibu
🖥
Browse files
added dice score calculator function
parent
a1a5c934
Changes
1
Hide whitespace changes
Inline
Side-by-side
utils/data_evaluation_utils.py
0 → 100644
View file @
6ef6d440
"""Data Evaluation Functions
Description:
-------------
This folder contains several functions which, either on their own or included in larger pieces of software, perform data evaluation tasks.
Usage
-------------
To use content from this folder, import the functions and instantiate them as you wish to use them:
from utils.data_evaluation_utils import function_name
"""
import
os
import
numpy
as
np
import
torch
def
dice_score_calculator
(
outputs
,
correct_labels
,
number_of_classes
,
number_of_samples
=
10
,
mode
=
'train'
):
"""Dice Score Calculator
This function calculates the confusion matrix for the given data.
The function returns the average dice score and the confusion matrix
Args:
outputs (torch.tensor): Tensor of all the network outputs (Y-hat)
correct_labels (torch.tensor): Output ground-truth labelled data (Y)
number_of_classes (int): Number of classes
number_of_samples (int): Output shape for randomly generated samples
mode (str): Current run mode or phase
Returns:
dice_score (torch.tensor): Dice score value for each class
Raises:
None
"""
dice_score
=
torch
.
zeros
(
number_of_classes
)
if
mode
==
'train'
:
samples
=
np
.
random
.
choice
(
len
(
outputs
),
number_of_samples
)
outputs
,
correct_labels
=
outputs
[
samples
],
correct_labels
[
samples
]
for
i
in
range
(
number_of_classes
):
ground_truth
=
(
correct_labels
==
i
).
float
()
predictions
=
(
outputs
==
i
).
float
()
intersection
=
2.0
*
torch
.
sum
(
torch
.
mul
(
ground_truth
,
predictions
))
reunion
=
torch
.
sum
(
ground_truth
)
+
torch
.
sum
(
predictions
)
+
1e-4
dice_score
[
i
]
=
torch
.
div
(
intersection
,
reunion
)
return
dice_score
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment