Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
AndreiClaudiu Roibu
BrainMapper
Commits
fc18699b
Commit
fc18699b
authored
Mar 28, 2020
by
AndreiClaudiu Roibu
🖥
Browse files
added dice and combined crossentropydice losses
parent
999def81
Changes
1
Hide whitespace changes
Inline
Sidebyside
utils/losses.py
View file @
fc18699b
...
...
@@ 57,7 +57,7 @@ class L1Loss(_WeightedLoss):
return
self
.
loss
(
X
,
y
)
class
MSELoss
():
class
MSELoss
(
_WeightedLoss
):
"""MSE Loss
Standard PyTorch implementation of L2 norm error (Mean Squared Error)
...
...
@@ 94,14 +94,14 @@ class MSELoss():
return
self
.
loss
(
X
,
y
)
class
CrossEntropyLoss
():
class
CrossEntropyLoss
(
_WeightedLoss
):
"""Cross Entropy Loss
Standard PyTorch implementation of Cross Entropy Loss.
The weighted component is utilized in case this loss is combined with another.
Args:
None
weights (torch.tensor): manual rescaling weight given to each class
Retruns:
float: scalar output of forward passing through the function
...
...
@@ 132,9 +132,174 @@ class CrossEntropyLoss():
return
self
.
loss
(
X
,
y
)
class
DiceLoss
():
pass
class
DiceLoss
(
_WeightedLoss
):
"""Dice Loss
This represents an implementation of a binary and a multichannel dice loss.
As this function has both a binary and a multichannel form, no constructor is present.
The two methods are defined after the forward function, as StaticMethods.
Args:
None
Retruns:
torch.tensor: output of forward passing through the function
Raises:
None
"""
def
forward
(
self
,
X
,
y
,
weights
=
None
,
ignore_index
=
None
,
binary
=
False
):
"""Forward pass
Forward pass throught the loss function.
Args:
X (torch.tensor): input of size NxCxHxW
y (torch.tensor): output target of size (NxHxW)
weights (torch.tensor): manual rescaling weight given to each class
ignore_index (int): flag that specifies a target value that is ignored and does not contribute to the input gradient
binary (bool): flat that specified if the input is a one channel binarized
class
CrossDiceLoss
():
pass
Returns:
torch.tensor: output of forward passing throught the function
Raises:
None
"""
y
=
F
.
softmax
(
y
,
dim
=
1
)
if
binary
:
return
self
.
_dice_loss_binary
(
X
,
y
)
else
:
return
self
.
dice_loss_multichannel
(
X
,
y
,
weights
,
ignore_index
)
@
staticmethod
def
_dice_loss_binary
(
X
,
y
):
"""Dice loss for binarized input
Implementation of dice loss for one channel binarized input
Args:
X (torch.tensor): input of size Nx1xHxW
y (torch.tensor): output target of size (NxHxW)
Returns:
torch.tensor: output of forward passing throught the function
Raises:
None
"""
epsillon
=
1e4
# This is to prevent the denominator = 0
intersection
=
X
*
y
numerator
=
2
*
intersection
.
sum
(
0
).
sum
(
1
).
sum
(
1
)
reunion
=
X
+
y
denominator
=
reunion
.
sum
(
0
).
sum
(
1
).
sum
(
1
)
+
epsillon
loss
=
1

(
numerator
/
denominator
)
return
loss
.
sum
()
/
X
.
size
(
1
)
@
staticmethod
def
_dice_loss_multichannel
(
X
,
y
,
weights
,
ignore_index
):
"""Dice loss for binarized input
Implementation of dice loss for one channel binarized input
Args:
X (torch.tensor): input of size NxCxHxW
y (torch.tensor): output target of size (NxHxW)
weights (torch.tensor): manual rescaling weight given to each class
ignore_index (int): flag that specifies a target value that is ignored and does not contribute to the input gradient
Returns:
torch.tensor: output of forward passing throught the function
Raises:
None
"""
epsillon
=
1e4
# This is to prevent the denominator = 0
# First, we detach the tensor from the computational graph, so no gradient is backpropagated along these variables.
# We then initialize it to 0
y_encoded
=
X
.
detach
()
*
0
if
ignore_index
is
not
None
:
# We mask the target elements to be ignored
mask
=
y
==
ignore_index
y
=
y
.
clone
()
y
[
mask
]
=
0
# We now split the input & output into the composing channels
y_encoded
.
scatter_
(
1
,
y
.
unsqeeze
(
1
),
1
)
# Expand the mask to have the same dimensions
mask
=
mask
.
unsqueeze
(
1
).
exaxpand_as
(
y_encoded
)
y_encoded
[
mask
]
=
0
else
:
y_encoded
.
scatter_
(
1
,
y
.
unsqeeze
(
1
),
1
)
intersection
=
X
*
y_encoded
numerator
=
2
*
intersection
.
sum
(
0
).
sum
(
1
).
sum
(
1
)
reunion
=
X
+
y_encoded
if
ignore_index
is
not
None
:
# We are also masking the output elements to correspond
X
[
mask
]
=
0
denominator
=
reunion
.
sum
(
0
).
sum
(
1
).
sum
(
1
)
+
epsillon
loss
=
1

(
numerator
/
denominator
)
return
loss
.
sum
()
/
X
.
size
(
1
)
class
CrossDiceLoss
(
_Loss
):
"""Combination of cross entropy and dice loss
An implementation of a combined loss between cross_entropy and dice losses.
Previous work on segmentation suggests that a combination between the two losses could produce good results.
This function can also serve as a template for other combinations of loss functions
Args:
None
Retruns:
float: scalar output of forward passing through the function
Raises:
None
"""
def
__init__
(
self
):
super
(
CrossDiceLoss
,
self
).
__init__
()
self
.
cross_entropy_loss
=
CrossEntropyLoss
()
self
.
dice_loss
=
DiceLoss
()
def
forward
(
self
,
X
,
y
,
weight
=
None
):
"""Forward pass
Forward pass throught the loss function.
Args:
X (torch.tensor): input of size NxC
y (torch.tensor): output of size (N)
weights (torch.tensor): manual rescaling weight given to each class
Returns:
float: scalar output of forward passing throught the function
Raises:
None
"""
if
weight
is
None
:
y_cross_entropy
=
torch
.
mean
(
self
.
cross_entropy_loss
.
forward
(
X
,
y
))
else
:
y_cross_entropy
=
torch
.
mean
(
torch
.
mul
(
self
.
cross_entropy_loss
.
forward
(
X
,
y
)),
weigth
)
y_dice
=
torch
.
mean
(
self
.
dice_loss
.
forward
(
X
,
y
))
return
y_cross_entropy
+
y_dice
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment