-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.py
116 lines (93 loc) · 4.05 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import numpy as np
import torch
import torch.nn.functional as F
class Poly1FocalLoss(torch.nn.Module):
def __init__(self,
epsilon: float = 1.0,
alpha: float = 0.25,
gamma: float = 2.0,
reduction: str = "mean",
weight: torch.Tensor = None,
pos_weight: torch.Tensor = None,
label_is_onehot: bool = False,
**kwargs
):
"""
Create instance of Poly1FocalLoss
:param num_classes: number of classes
:param epsilon: poly loss epsilon. the main one to finetune. larger values -> better performace in imagenet
:param alpha: focal loss alpha
:param gamma: focal loss gamma
:param reduction: one of none|sum|mean, apply reduction to final loss tensor
:param weight: manual rescaling weight for each class, passed to binary Cross-Entropy loss
:param label_is_onehot: set to True if labels are one-hot encoded
"""
super(Poly1FocalLoss, self).__init__()
self.epsilon = epsilon
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
self.weight = weight
self.pos_weight = pos_weight
self.label_is_onehot = label_is_onehot
return
def forward(self, logits, labels):
"""
Forward pass
:param logits: output of neural netwrok of shape [N, num_classes] or [N, num_classes, ...]
:param labels: ground truth tensor of shape [N] or [N, ...] with class ids if label_is_onehot was set to False, otherwise
one-hot encoded tensor of same shape as logits
:return: poly focal loss
"""
# focal loss implementation taken from
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/focal_loss.py
num_classes = logits.shape[1]
p = torch.sigmoid(logits)
if not self.label_is_onehot:
# if labels are of shape [N]
# convert to one-hot tensor of shape [N, num_classes]
if labels.ndim == 1:
labels = F.one_hot(labels, num_classes=num_classes)
# if labels are of shape [N, ...] e.g. segmentation task
# convert to one-hot tensor of shape [N, num_classes, ...]
else:
labels = F.one_hot(labels.unsqueeze(1), num_classes).transpose(1, -1).squeeze_(-1)
labels = labels.to(device=logits.device, dtype=logits.dtype)
ce_loss = F.binary_cross_entropy_with_logits(input=logits,
target=labels,
reduction="none",
weight=self.weight,
pos_weight=self.pos_weight)
pt = labels * p + (1 - labels) * (1 - p)
FL = ce_loss * ((1 - pt) ** self.gamma)
if self.alpha >= 0:
alpha_t = self.alpha * labels + (1 - self.alpha) * (1 - labels)
FL = alpha_t * FL
poly1 = FL + self.epsilon * torch.pow(1 - pt, self.gamma + 1)
if self.reduction == "mean":
poly1 = poly1.mean()
elif self.reduction == "sum":
poly1 = poly1.sum()
return 50*poly1
def cal_loss(pred, gold, smoothing=True):
''' Calculate cross entropy loss, apply label smoothing if needed. '''
gold = gold.contiguous().view(-1)
if smoothing:
eps = 0.3
n_class = pred.size(1)
one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, dim=1)
loss = -(one_hot * log_prb).sum(dim=1).mean()
else:
loss = F.cross_entropy(pred, gold, reduction='mean')
return loss
class IOStream():
def __init__(self, path):
self.f = open(path, 'a')
def cprint(self, text):
print(text)
self.f.write(text+'\n')
self.f.flush()
def close(self):
self.f.close()