-
Notifications
You must be signed in to change notification settings - Fork 562
/
Copy pathuda.py
41 lines (31 loc) · 1.34 KB
/
uda.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""
@author: Baixu Chen
@contact: [email protected]
"""
import torch.nn as nn
import torch.nn.functional as F
class StrongWeakConsistencyLoss(nn.Module):
"""
Consistency loss between strong and weak augmented samples from `Unsupervised Data Augmentation for
Consistency Training (NIPS 2020) <https://arxiv.org/pdf/1904.12848v4.pdf>`_.
Args:
threshold (float): Confidence threshold.
temperature (float): Temperature.
Inputs:
- y_strong: unnormalized classifier predictions on strong augmented samples.
- y: unnormalized classifier predictions on weak augmented samples.
Shape:
- y, y_strong: :math:`(minibatch, C)` where C means the number of classes.
- Output: scalar.
"""
def __init__(self, threshold: float, temperature: float):
super(StrongWeakConsistencyLoss, self).__init__()
self.threshold = threshold
self.temperature = temperature
def forward(self, y_strong, y):
confidence, _ = F.softmax(y.detach(), dim=1).max(dim=1)
mask = (confidence > self.threshold).float()
log_prob = F.log_softmax(y_strong / self.temperature, dim=1)
con_loss = (F.kl_div(log_prob, F.softmax(y.detach(), dim=1), reduction='none').sum(dim=1))
con_loss = (con_loss * mask).sum() / max(mask.sum(), 1)
return con_loss