-
Notifications
You must be signed in to change notification settings - Fork 173
/
Copy pathdomain_mix.py
81 lines (67 loc) · 2.67 KB
/
domain_mix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
from torch.nn import functional as F
from dassl.engine import TRAINER_REGISTRY, TrainerX
from dassl.metrics import compute_accuracy
__all__ = ["DomainMix"]
@TRAINER_REGISTRY.register()
class DomainMix(TrainerX):
"""DomainMix.
Dynamic Domain Generalization.
https://github.com/MetaVisionLab/DDG
"""
def __init__(self, cfg):
super(DomainMix, self).__init__(cfg)
self.mix_type = cfg.TRAINER.DOMAINMIX.TYPE
self.alpha = cfg.TRAINER.DOMAINMIX.ALPHA
self.beta = cfg.TRAINER.DOMAINMIX.BETA
self.dist_beta = torch.distributions.Beta(self.alpha, self.beta)
def forward_backward(self, batch):
images, label_a, label_b, lam = self.parse_batch_train(batch)
output = self.model(images)
loss = lam * F.cross_entropy(
output, label_a
) + (1-lam) * F.cross_entropy(output, label_b)
self.model_backward_and_update(loss)
loss_summary = {
"loss": loss.item(),
"acc": compute_accuracy(output, label_a)[0].item()
}
if (self.batch_idx + 1) == self.num_batches:
self.update_lr()
return loss_summary
def parse_batch_train(self, batch):
images = batch["img"]
target = batch["label"]
domain = batch["domain"]
images = images.to(self.device)
target = target.to(self.device)
domain = domain.to(self.device)
images, target_a, target_b, lam = self.domain_mix(
images, target, domain
)
return images, target_a, target_b, lam
def domain_mix(self, x, target, domain):
lam = (
self.dist_beta.rsample((1, ))
if self.alpha > 0 else torch.tensor(1)
).to(x.device)
# random shuffle
perm = torch.randperm(x.size(0), dtype=torch.int64, device=x.device)
if self.mix_type == "crossdomain":
domain_list = torch.unique(domain)
if len(domain_list) > 1:
for idx in domain_list:
cnt_a = torch.sum(domain == idx)
idx_b = (domain != idx).nonzero().squeeze(-1)
cnt_b = idx_b.shape[0]
perm_b = torch.ones(cnt_b).multinomial(
num_samples=cnt_a, replacement=bool(cnt_a > cnt_b)
)
perm[domain == idx] = idx_b[perm_b]
elif self.mix_type != "random":
raise NotImplementedError(
f"Chooses {'random', 'crossdomain'}, but got {self.mix_type}."
)
mixed_x = lam*x + (1-lam) * x[perm, :]
target_a, target_b = target, target[perm]
return mixed_x, target_a, target_b, lam