Commit fc18699b authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added dice and combined crossentropy-dice losses

parent 999def81
......@@ -57,7 +57,7 @@ class L1Loss(_WeightedLoss):
return self.loss(X, y)
class MSELoss():
class MSELoss(_WeightedLoss):
"""MSE Loss
Standard PyTorch implementation of L2 norm error (Mean Squared Error)
......@@ -94,14 +94,14 @@ class MSELoss():
return self.loss(X, y)
class CrossEntropyLoss():
class CrossEntropyLoss(_WeightedLoss):
"""Cross Entropy Loss
Standard PyTorch implementation of Cross Entropy Loss.
The weighted component is utilized in case this loss is combined with another.
Args:
None
weights (torch.tensor): manual rescaling weight given to each class
Retruns:
float: scalar output of forward passing through the function
......@@ -132,9 +132,174 @@ class CrossEntropyLoss():
return self.loss(X, y)
class DiceLoss():
pass
class DiceLoss(_WeightedLoss):
"""Dice Loss
This represents an implementation of a binary and a multi-channel dice loss.
As this function has both a binary and a multi-channel form, no constructor is present.
The two methods are defined after the forward function, as StaticMethods.
Args:
None
Retruns:
torch.tensor: output of forward passing through the function
Raises:
None
"""
def forward(self, X, y, weights=None, ignore_index=None, binary=False):
"""Forward pass
Forward pass throught the loss function.
Args:
X (torch.tensor): input of size NxCxHxW
y (torch.tensor): output target of size (NxHxW)
weights (torch.tensor): manual rescaling weight given to each class
ignore_index (int): flag that specifies a target value that is ignored and does not contribute to the input gradient
binary (bool): flat that specified if the input is a one channel binarized
class CrossDiceLoss():
pass
Returns:
torch.tensor: output of forward passing throught the function
Raises:
None
"""
y = F.softmax(y, dim=1)
if binary:
return self._dice_loss_binary(X, y)
else:
return self.dice_loss_multichannel(X, y, weights, ignore_index)
@staticmethod
def _dice_loss_binary(X, y):
"""Dice loss for binarized input
Implementation of dice loss for one channel binarized input
Args:
X (torch.tensor): input of size Nx1xHxW
y (torch.tensor): output target of size (NxHxW)
Returns:
torch.tensor: output of forward passing throught the function
Raises:
None
"""
epsillon = 1e-4 # This is to prevent the denominator = 0
intersection = X * y
numerator = 2 * intersection.sum(0).sum(1).sum(1)
reunion = X + y
denominator = reunion.sum(0).sum(1).sum(1) + epsillon
loss = 1 - (numerator / denominator)
return loss.sum() / X.size(1)
@staticmethod
def _dice_loss_multichannel(X, y, weights, ignore_index):
"""Dice loss for binarized input
Implementation of dice loss for one channel binarized input
Args:
X (torch.tensor): input of size NxCxHxW
y (torch.tensor): output target of size (NxHxW)
weights (torch.tensor): manual rescaling weight given to each class
ignore_index (int): flag that specifies a target value that is ignored and does not contribute to the input gradient
Returns:
torch.tensor: output of forward passing throught the function
Raises:
None
"""
epsillon = 1e-4 # This is to prevent the denominator = 0
# First, we detach the tensor from the computational graph, so no gradient is backpropagated along these variables.
# We then initialize it to 0
y_encoded = X.detach() * 0
if ignore_index is not None:
# We mask the target elements to be ignored
mask = y == ignore_index
y = y.clone()
y[mask] = 0
# We now split the input & output into the composing channels
y_encoded.scatter_(1, y.unsqeeze(1), 1)
# Expand the mask to have the same dimensions
mask = mask.unsqueeze(1).exaxpand_as(y_encoded)
y_encoded[mask] = 0
else:
y_encoded.scatter_(1, y.unsqeeze(1), 1)
intersection = X * y_encoded
numerator = 2 * intersection.sum(0).sum(1).sum(1)
reunion = X + y_encoded
if ignore_index is not None:
# We are also masking the output elements to correspond
X[mask] = 0
denominator = reunion.sum(0).sum(1).sum(1) + epsillon
loss = 1 - (numerator / denominator)
return loss.sum() / X.size(1)
class CrossDiceLoss(_Loss):
"""Combination of cross entropy and dice loss
An implementation of a combined loss between cross_entropy and dice losses.
Previous work on segmentation suggests that a combination between the two losses could produce good results.
This function can also serve as a template for other combinations of loss functions
Args:
None
Retruns:
float: scalar output of forward passing through the function
Raises:
None
"""
def __init__(self):
super(CrossDiceLoss, self).__init__()
self.cross_entropy_loss = CrossEntropyLoss()
self.dice_loss = DiceLoss()
def forward(self, X, y, weight=None):
"""Forward pass
Forward pass throught the loss function.
Args:
X (torch.tensor): input of size NxC
y (torch.tensor): output of size (N)
weights (torch.tensor): manual rescaling weight given to each class
Returns:
float: scalar output of forward passing throught the function
Raises:
None
"""
if weight is None:
y_cross_entropy = torch.mean(self.cross_entropy_loss.forward(X, y))
else:
y_cross_entropy = torch.mean(torch.mul(self.cross_entropy_loss.forward(X, y)), weigth)
y_dice = torch.mean(self.dice_loss.forward(X, y))
return y_cross_entropy + y_dice
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment