forked from ZFancy/TARF
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLS.py
More file actions
22 lines (19 loc) · 725 Bytes
/
LS.py
File metadata and controls
22 lines (19 loc) · 725 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch.nn as nn
import torch.nn.functional as F
class LabelSmoothingCrossEntropy(nn.Module):
def __init__(self, eps=0.1, reduction="mean"):
super(LabelSmoothingCrossEntropy, self).__init__()
self.eps = eps
self.reduction = reduction
def forward(self, output, target):
c = output.size()[-1]
log_preds = F.log_softmax(output, dim=-1)
if self.reduction == "sum":
loss = -log_preds.sum()
else:
loss = -log_preds.sum(dim=-1)
if self.reduction == "mean":
loss = loss.mean()
return loss * self.eps / c + (1 - self.eps) * F.nll_loss(
log_preds, target, reduction=self.reduction
)