-
Notifications
You must be signed in to change notification settings - Fork 0
/
cross.py
36 lines (25 loc) · 1.02 KB
/
cross.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
class label_smooth_loss(torch.nn.Module):
def __init__(self, num_classes, smoothing=0.0):
super(label_smooth_loss, self).__init__()
eps = smoothing / num_classes
self.negative = eps
self.positive = (1 - smoothing) + eps
def forward(self, pred, target):
pred = pred.log_softmax(dim=1)
true_dist = torch.zeros_like(pred)
true_dist.fill_(self.negative)
true_dist.scatter_(1, target.data.unsqueeze(1), self.positive)
return torch.sum(-true_dist * pred, dim=1).mean()
prediction = torch.as_tensor(
[
[-1000, -1000, 1000],
[1000, -1000, -1000],
[-1000, 1000, -1000]
],
dtype=torch.float
)
target = torch.as_tensor([2, 0, 1])
loss1 = label_smooth_loss(num_classes=3, smoothing=0.1)
loss2 = torch.nn.CrossEntropyLoss(apply_softmax=True, label_smoothing=0.1)
print(loss1(prediction, target), loss2(prediction, target))