Commit 76965ce7 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added evaluatio function based on dice score

parent fce7cb7c
......@@ -15,12 +15,16 @@ To use content from this folder, import the functions and instantiate them as yo
import os
import numpy as np
import torch
import logging
import utils.data_utils as data_utils
import nibabel as nib
log = logging.getLogger(__name__)
def dice_score_calculator(outputs, correct_labels, number_of_classes, number_of_samples=10, mode='train'):
"""Dice Score Calculator
This function calculates the confusion matrix for the given data.
The function returns the average dice score and the confusion matrix
The function returns the average dice score
Args:
outputs (torch.tensor): Tensor of all the network outputs (Y-hat)
......@@ -49,4 +53,143 @@ def dice_score_calculator(outputs, correct_labels, number_of_classes, number_of_
reunion = torch.sum(ground_truth) + torch.sum(predictions) + 1e-4
dice_score[i] = torch.div(intersection, reunion)
return dice_score
\ No newline at end of file
return dice_score
def evaluate_dice_score(trained_model_path,
number_of_classes,
data_directory,
targets_directory,
data_list,
orientation,
prediction_output_path,
device,
LogWriter,
mode):
"""Dice Score Evaluator
This function evaluates a given trained model by calculating the it's dice score prediction.
Args:
trained_model_path (str): Path to the location of the trained model
number_of_classes (int): Number of classes
data_directory (str): Path to input data directory
targets_directory (str): Path to labelled data (Y-equivalent)
data_list (str): Path to a .txt file containing the input files for consideration
orientation (str): String detailing the current view (COR, SAG, AXL)
prediction_output_path (str): Output prediction path
device (str/int): Device type used for training (int - GPU id, str- CPU)
LogWriter (class): Log Writer class for the BrainMapper U-net
mode (str): Current run mode or phase
Returns:
average_dice_score (int): Average dice score for the tested data
Raises:
None
"""
log.info("Started Evaluation. Check tensorboard for plots (if a LogWriter is provided)")
slice_size = 20 ## TODO: CHECK IF THIS IS ACCURATE!
with open(data_list) as data_list_file:
volumes_to_be_used = data_list_file.read().splitlines()
# Test if cuda is available and attempt to run on GPU
cuda_available = torch.cuda.is_available()
if type(device) == int:
if cuda_available:
model = torch.load(trained_model_path)
torch.cuda.empty_cache()
model.cuda(device)
else:
log.warning("CUDA not available. Switching to CPU. Investigate behaviour!")
device = 'cpu'
if (type(device) == str) or not cuda_available:
model = torch.load(trained_model_path, map_location=torch.device(device))
model.eval()
# Create the prediction path folder if this is not available
data_utils.create_folder(prediction_output_path)
# Initiate the evaluation
volume_dice_score_list = []
log.info("Dice Score Evaluation Started")
file_paths = data_utils.load_file_paths(data_directory, targets_directory, data_list)
with torch.no_grad():
for volume_index, file_path in enumerate(file_paths):
volume, label_map, header = data_utils.load_and_preprocess(file_paths, orientation)
if len(volume.shape) == 4:
volume = volume
else:
volume = [:, np.newaxis, :, :]
volume = torch.tensor(volume).type(torch.FloatTensor)
label_map = torch.tensor(label_map).type(torhc.LongTensor)
volume_predictions = []
for i in range (0, len(volume), slice_size):
input_slice = volume[i: i+slice_size]
if cuda_available and (type(device)==int):
input_slice = input_slice.cuda(device)
output = model(input_slice)
_, slice_output = torch.max(output, dim=1)\
# This needs to be checked - torch.max returns max values and locations.
# For segmentations, we are interested in the locations
# For the functional data, we might be interested in the actual values.
# The strength of the value represents the strength of the activation
# A threshold might also be required!
volume_predictions.append(slide_output)
volume_predictions = torch.cat(volume_predictions)
volume_dice_score = dice_score_calculator(volume_predictions, label_map.cuda(device), number_of_classes= number_of_classes, mode= mode)
volume_predictions = (volume_predictions.cpu().numpy()).astype('float32')
# We copy the header affines
header_affines = np.array([header['srow_x'], header['srow_y'], header['srow_z'], [0,0,0,1]])
# We apply the affine and save the image
output_nifti_image = nib.MGHImage(np.squeeze(volume_predictions), header_affines, header= header)
output_nifti_path = os.path.join(prediction_output_path, volumes_to_be_used[volume_index] + str('.mgz'))
nib.save(output_nifti_image, output_nifti_path)
if LogWriter:
LogWriter.plot_dice_score(volume_dice_score, phase= 'test', plot_name='Evaluation Dice Score', title=volumes_to_be_used[volume_index], epochs=volume_index)
# We convert the dice score to numpy arrays and save them in a list
volume_dice_score = volume_dice_score.cpu().numpy()
volume_dice_score_list.append(volume_dice_score)
log.info(volume_dice_score, np.mean(volume_dice_score))
dice_score_array = np.asarray(volumce_dice_score_list)
average_dice_score = dice_score_array.mean()
log.info("Mean dice score: {}".format(average_dice_score))
return average_dice_score
def evaluate_single_path():
pass
def evaluate_multiple_paths():
pass
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment