Commit 999def81 authored by Andrei-Claudiu Roibu's avatar Andrei-Claudiu Roibu 🖥
Browse files

added L1, L2 and CrossEntropy losses

parent 32db8f36
......@@ -17,17 +17,122 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss, _WeightedLoss
import numpy as np
class L1Loss():
pass
class DiceLoss():
pass
class L1Loss(_WeightedLoss):
"""L1 Loss
Standard PyTorch implementation of wighted L1 loss (Mean Absolut Error)
Args:
None
Retruns:
float: scalar output of forward passing through the function
Raises:
None
"""
def __init__(self):
super(L1Loss, self).__init__()
self.loss = nn.L1Loss()
def forward(self, X, y):
"""Forward pass
Forward pass throught the loss function.
Args:
X (torch.tensor): input of size NxC
y (torch.tensor): output of size (N)
Returns:
float: scalar output of forward passing throught the function
Raises:
None
"""
return self.loss(X, y)
class MSELoss():
pass
"""MSE Loss
Standard PyTorch implementation of L2 norm error (Mean Squared Error)
Args:
None
Retruns:
float: scalar output of forward passing through the function
Raises:
None
"""
def __init__(self):
super(MSELoss, self).__init__()
self.loss = nn.MSELoss()
def forward(self, X, y):
"""Forward pass
Forward pass throught the loss function.
Args:
X (torch.tensor): input of size NxC
y (torch.tensor): output of size (N)
Returns:
float: scalar output of forward passing throught the function
Raises:
None
"""
return self.loss(X, y)
class CrossEntropyLoss():
"""Cross Entropy Loss
Standard PyTorch implementation of Cross Entropy Loss.
The weighted component is utilized in case this loss is combined with another.
Args:
None
Retruns:
float: scalar output of forward passing through the function
Raises:
None
"""
def __init__(self, weight= None):
super(CrossEntropyLoss, self).__init__()
self.loss = nn.CrossEntropyLoss(weight= weight)
def forward(self, X, y):
"""Forward pass
Forward pass throught the loss function.
Args:
X (torch.tensor): input of size NxC
y (torch.tensor): output of size (N)
Returns:
float: scalar output of forward passing throught the function
Raises:
None
"""
return self.loss(X, y)
class DiceLoss():
pass
class CrossDiceLoss():
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment