forked from JiahuiYu/generative_inpainting
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
112 lines (104 loc) · 4.04 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import glob
import socket
import logging
import tensorflow as tf
import neuralgym as ng
from inpaint_model import InpaintCAModel
logger = logging.getLogger()
def multigpu_graph_def(model, data, config, gpu_id=0, loss_type='g'):
with tf.device('/cpu:0'):
images = data.data_pipeline(config.BATCH_SIZE)
if gpu_id == 0 and loss_type == 'g':
_, _, losses = model.build_graph_with_losses(
images, config, summary=True, reuse=True)
else:
_, _, losses = model.build_graph_with_losses(
images, config, reuse=True)
if loss_type == 'g':
return losses['g_loss']
elif loss_type == 'd':
return losses['d_loss']
else:
raise ValueError('loss type is not supported.')
if __name__ == "__main__":
config = ng.Config('inpaint.yml')
if config.GPU_ID != -1:
ng.set_gpus(config.GPU_ID)
else:
ng.get_gpus(config.NUM_GPUS)
# training data
with open(config.DATA_FLIST[config.DATASET][0]) as f:
fnames = f.read().splitlines()
data = ng.data.DataFromFNames(
fnames, config.IMG_SHAPES, random_crop=config.RANDOM_CROP)
images = data.data_pipeline(config.BATCH_SIZE)
# main model
model = InpaintCAModel()
g_vars, d_vars, losses = model.build_graph_with_losses(
images, config=config)
# validation images
if config.VAL:
with open(config.DATA_FLIST[config.DATASET][1]) as f:
val_fnames = f.read().splitlines()
# progress monitor by visualizing static images
for i in range(config.STATIC_VIEW_SIZE):
static_fnames = val_fnames[i:i+1]
static_images = ng.data.DataFromFNames(
static_fnames, config.IMG_SHAPES, nthreads=1,
random_crop=config.RANDOM_CROP).data_pipeline(1)
static_inpainted_images = model.build_static_infer_graph(
static_images, config, name='static_view/%d' % i)
# training settings
lr = tf.get_variable(
'lr', shape=[], trainable=False,
initializer=tf.constant_initializer(1e-4))
d_optimizer = tf.train.AdamOptimizer(lr, beta1=0.5, beta2=0.9)
g_optimizer = d_optimizer
# gradient processor
if config.GRADIENT_CLIP:
gradient_processor = lambda grad_var: (
tf.clip_by_average_norm(grad_var[0], config.GRADIENT_CLIP_VALUE),
grad_var[1])
else:
gradient_processor = None
# log dir
log_prefix = 'model_logs/' + '_'.join([
ng.date_uid(), socket.gethostname(), config.DATASET,
'MASKED' if config.GAN_WITH_MASK else 'NORMAL',
config.GAN,config.LOG_DIR])
# train discriminator with secondary trainer, should initialize before
# primary trainer.
discriminator_training_callback = ng.callbacks.SecondaryTrainer(
pstep=1,
optimizer=d_optimizer,
var_list=d_vars,
max_iters=5,
graph_def=multigpu_graph_def,
graph_def_kwargs={
'model': model, 'data': data, 'config': config, 'loss_type': 'd'},
)
# train generator with primary trainer
trainer = ng.train.Trainer(
optimizer=g_optimizer,
var_list=g_vars,
max_iters=config.MAX_ITERS,
graph_def=multigpu_graph_def,
grads_summary=config.GRADS_SUMMARY,
gradient_processor=gradient_processor,
graph_def_kwargs={
'model': model, 'data': data, 'config': config, 'loss_type': 'g'},
spe=config.TRAIN_SPE,
log_dir=log_prefix,
)
# add all callbacks
if not config.PRETRAIN_COARSE_NETWORK:
trainer.add_callbacks(discriminator_training_callback)
trainer.add_callbacks([
ng.callbacks.WeightsViewer(),
ng.callbacks.ModelRestorer(trainer.context['saver'], dump_prefix='model_logs/'+config.MODEL_RESTORE+'/snap', optimistic=True),
ng.callbacks.ModelSaver(config.TRAIN_SPE, trainer.context['saver'], log_prefix+'/snap'),
ng.callbacks.SummaryWriter((config.VAL_PSTEPS//1), trainer.context['summary_writer'], tf.summary.merge_all()),
])
# launch training
trainer.train()