-
Notifications
You must be signed in to change notification settings - Fork 207
/
solver_encoder.py
127 lines (85 loc) · 4.29 KB
/
solver_encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from model_vc import Generator
import torch
import torch.nn.functional as F
import time
import datetime
class Solver(object):
def __init__(self, vcc_loader, config):
"""Initialize configurations."""
# Data loader.
self.vcc_loader = vcc_loader
# Model configurations.
self.lambda_cd = config.lambda_cd
self.dim_neck = config.dim_neck
self.dim_emb = config.dim_emb
self.dim_pre = config.dim_pre
self.freq = config.freq
# Training configurations.
self.batch_size = config.batch_size
self.num_iters = config.num_iters
# Miscellaneous.
self.use_cuda = torch.cuda.is_available()
self.device = torch.device('cuda:0' if self.use_cuda else 'cpu')
self.log_step = config.log_step
# Build the model and tensorboard.
self.build_model()
def build_model(self):
self.G = Generator(self.dim_neck, self.dim_emb, self.dim_pre, self.freq)
self.g_optimizer = torch.optim.Adam(self.G.parameters(), 0.0001)
self.G.to(self.device)
def reset_grad(self):
"""Reset the gradient buffers."""
self.g_optimizer.zero_grad()
#=====================================================================================================================================#
def train(self):
# Set data loader.
data_loader = self.vcc_loader
# Print logs in specified order
keys = ['G/loss_id','G/loss_id_psnt','G/loss_cd']
# Start training.
print('Start training...')
start_time = time.time()
for i in range(self.num_iters):
# =================================================================================== #
# 1. Preprocess input data #
# =================================================================================== #
# Fetch data.
try:
x_real, emb_org = next(data_iter)
except:
data_iter = iter(data_loader)
x_real, emb_org = next(data_iter)
x_real = x_real.to(self.device)
emb_org = emb_org.to(self.device)
# =================================================================================== #
# 2. Train the generator #
# =================================================================================== #
self.G = self.G.train()
# Identity mapping loss
x_identic, x_identic_psnt, code_real = self.G(x_real, emb_org, emb_org)
g_loss_id = F.mse_loss(x_real, x_identic)
g_loss_id_psnt = F.mse_loss(x_real, x_identic_psnt)
# Code semantic loss.
code_reconst = self.G(x_identic_psnt, emb_org, None)
g_loss_cd = F.l1_loss(code_real, code_reconst)
# Backward and optimize.
g_loss = g_loss_id + g_loss_id_psnt + self.lambda_cd * g_loss_cd
self.reset_grad()
g_loss.backward()
self.g_optimizer.step()
# Logging.
loss = {}
loss['G/loss_id'] = g_loss_id.item()
loss['G/loss_id_psnt'] = g_loss_id_psnt.item()
loss['G/loss_cd'] = g_loss_cd.item()
# =================================================================================== #
# 4. Miscellaneous #
# =================================================================================== #
# Print out training information.
if (i+1) % self.log_step == 0:
et = time.time() - start_time
et = str(datetime.timedelta(seconds=et))[:-7]
log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
for tag in keys:
log += ", {}: {:.4f}".format(tag, loss[tag])
print(log)