-
Notifications
You must be signed in to change notification settings - Fork 23
/
model.py
259 lines (210 loc) · 11 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
import metric
import tfutil as tfu
import numpy as np
import tensorflow as tf
class RCAN:
def __init__(self,
sess, # tensorflow session
batch_size=16, # batch size
n_channel=3, # number of image channel, 3 for RGB, 1 for gray-scale
img_scaling_factor=4, # image scale factor to up
lr_img_size=(48, 48), # input patch image size for LR
hr_img_size=(192, 192), # input patch image size for HR
n_res_blocks=20, # number of residual block
n_res_groups=10, # number of residual group
res_scale=1, # scaling factor of res block
n_filters=64, # number of conv2d filter size
kernel_size=3, # number of conv2d kernel size
activation='relu', # activation function
use_bn=False, # using batch_norm or not
reduction=16, # reduction rate at CA layer
# rgb_mean=(114.2430, 111.4502, 103.0450), # RGB mean, for DIV2K DataSet
# rgb_std=(69.6606, 66.0210, 72.1786), # RGB std, for DIV2K DataSet
rgb_mean=(0.4480, 0.4371, 0.4041), # RGB mean, for DIV2K DataSet
rgb_std=(0.2732, 0.2589, 0.2831), # RGB std, for DIV2K DataSet
optimizer='adam', # name of optimizer
lr=1e-4, # learning rate
lr_decay=.5, # learning rate decay factor
lr_decay_step=2e5, # learning rate decay step
momentum=.9, # SGD momentum value
beta1=.9, # Adam beta1 value
beta2=.999, # Adam beta2 value
opt_eps=1e-8, # Adam epsilon value
eps=1.1e-5, # epsilon
tf_log="./model/", # path saved tensor summary / model
n_gpu=1, # number of GPU
):
self.sess = sess
self.batch_size = batch_size
self.n_channel = n_channel
self.img_scale = img_scaling_factor
self.lr_img_size = lr_img_size + (self.n_channel,)
self.hr_img_size = hr_img_size + (self.n_channel,)
self.n_res_blocks = n_res_blocks
self.n_res_groups = n_res_groups
self.res_scale = res_scale
self.n_filters = n_filters
self.kernel_size = kernel_size
self.activation = activation
self.use_bn = use_bn
self.reduction = reduction
self.rgb_mean = tf.constant(rgb_mean, dtype=tf.float32)
self.rgb_std = tf.constant(rgb_std, dtype=tf.float32)
self.optimizer = optimizer
self.lr = lr
self.lr_decay = lr_decay
self.lr_decay_step = lr_decay_step
self.momentum = momentum
self.beta1 = beta1
self.beta2 = beta2
self.opt_eps = opt_eps
self._eps = eps
self.tf_log = tf_log
self.n_gpu = n_gpu
self.act = None
self.opt = None
self.train_op = None
self.loss = None
self.output = None
self.psnr = None
self.ssim = None
self.saver = None
self.best_saver = None
self.merged = None
self.writer = None
self.global_step = tf.Variable(0, trainable=False, name='global_step')
# tensor placeholder for input
self.x_lr = tf.placeholder(tf.float32, shape=(None,) + self.lr_img_size, name='x-lr-img')
self.x_hr = tf.placeholder(tf.float32, shape=(None,) + self.hr_img_size, name='x-hr-img')
self.lr = tf.placeholder(tf.float32, name='learning_rate')
# self.is_train = tf.placeholder(tf.bool, name='is_train')
# setting stuffs
self.setup()
# build a network
self.build_model()
def setup(self):
# Activation Function Setting
if self.activation == 'relu':
self.act = tf.nn.relu
elif self.activation == 'leaky_relu':
self.act = tf.nn.leaky_relu
elif self.activation == 'elu':
self.act = tf.nn.elu
else:
raise NotImplementedError("[-] Not supported activation function {}".format(self.activation))
# Optimizer
if self.optimizer == 'adam':
self.opt = tf.train.AdamOptimizer(learning_rate=self.lr,
beta1=self.beta1, beta2=self.beta2, epsilon=self.opt_eps)
elif self.optimizer == 'sgd': # sgd + m with nestrov
self.opt = tf.train.MomentumOptimizer(learning_rate=self.lr, momentum=self.momentum, use_nesterov=True)
else:
raise NotImplementedError("[-] Not supported optimizer {}".format(self.optimizer))
def image_processing(self, x, sign, name):
with tf.variable_scope(name):
r, g, b = tf.split(x, num_or_size_splits=3, axis=-1)
# normalize pixel with pre-calculated value based on DIV2K DataSet
rgb = tf.concat([(r + sign * self.rgb_mean[0]),
(g + sign * self.rgb_mean[1]),
(b + sign * self.rgb_mean[2])], axis=-1)
return rgb
def channel_attention(self, x, f, reduction, name):
"""
Channel Attention (CA) Layer
:param x: input layer
:param f: conv2d filter size
:param reduction: conv2d filter reduction rate
:param name: scope name
:return: output layer
"""
with tf.variable_scope("CA-%s" % name):
skip_conn = tf.identity(x, name='identity')
x = tfu.adaptive_global_average_pool_2d(x)
x = tfu.conv2d(x, f=f // reduction, k=1, name="conv2d-1")
x = self.act(x)
x = tfu.conv2d(x, f=f, k=1, name="conv2d-2")
x = tf.nn.sigmoid(x)
return tf.multiply(skip_conn, x)
def residual_channel_attention_block(self, x, f, kernel_size, reduction, use_bn, name):
with tf.variable_scope("RCAB-%s" % name):
skip_conn = tf.identity(x, name='identity')
x = tfu.conv2d(x, f=f, k=kernel_size, name="conv2d-1")
x = tf.layers.BatchNormalization(epsilon=self._eps, name="bn-1")(x) if use_bn else x
x = self.act(x)
x = tfu.conv2d(x, f=f, k=kernel_size, name="conv2d-2")
x = tf.layers.BatchNormalization(epsilon=self._eps, name="bn-2")(x) if use_bn else x
x = self.channel_attention(x, f, reduction, name="RCAB-%s" % name)
return self.res_scale * x + skip_conn # tf.math.add(self.res_scale * x, skip_conn)
def residual_group(self, x, f, kernel_size, reduction, use_bn, name):
with tf.variable_scope("RG-%s" % name):
skip_conn = tf.identity(x, name='identity')
for i in range(self.n_res_blocks):
x = self.residual_channel_attention_block(x, f, kernel_size, reduction, use_bn, name=str(i))
x = tfu.conv2d(x, f=f, k=kernel_size, name='rg-conv-1')
return x + skip_conn # tf.math.add(x, skip_conn)
def up_scaling(self, x, f, scale_factor, name):
"""
:param x: image
:param f: conv2d filter
:param scale_factor: scale factor
:param name: scope name
:return:
"""
with tf.variable_scope(name):
if scale_factor == 3:
x = tfu.conv2d(x, f * 9, k=1, name='conv2d-image_scaling-0')
x = tfu.pixel_shuffle(x, 3)
elif scale_factor & (scale_factor - 1) == 0: # is it 2^n?
log_scale_factor = int(np.log2(scale_factor))
for i in range(log_scale_factor):
x = tfu.conv2d(x, f * 4, k=1, name='conv2d-image_scaling-%d' % i)
x = tfu.pixel_shuffle(x, 2)
else:
raise NotImplementedError("[-] Not supported scaling factor (%d)" % scale_factor)
return x
def residual_channel_attention_network(self, x, f, kernel_size, reduction, use_bn, scale):
with tf.variable_scope("Residual_Channel_Attention_Network"):
x = self.image_processing(x, sign=-1, name='pre-processing')
# 1. head
head = tfu.conv2d(x, f=f, k=kernel_size, name="conv2d-head")
# 2. body
x = head
for i in range(self.n_res_groups):
x = self.residual_group(x, f, kernel_size, reduction, use_bn, name=str(i))
body = tfu.conv2d(x, f=f, k=kernel_size, name="conv2d-body")
body += head # tf.math.add(body, head)
# 3. tail
x = self.up_scaling(body, f, scale, name='up-scaling')
tail = tfu.conv2d(x, f=self.n_channel, k=kernel_size, name="conv2d-tail") # (-1, 384, 384, 3)
x = self.image_processing(tail, sign=1, name='post-processing')
return x
def build_model(self):
# RCAN model
self.output = self.residual_channel_attention_network(x=self.x_lr,
f=self.n_filters,
kernel_size=self.kernel_size,
reduction=self.reduction,
use_bn=self.use_bn,
scale=self.img_scale,
)
self.output = tf.clip_by_value(self.output * 255., 0., 255.)
# l1 loss
self.loss = tf.reduce_mean(tf.abs(self.output - self.x_hr))
self.train_op = self.opt.minimize(self.loss, global_step=self.global_step)
# metrics
self.psnr = tf.reduce_mean(metric.psnr(self.output, self.x_hr, m_val=1))
self.ssim = tf.reduce_mean(metric.ssim(self.output, self.x_hr, m_val=1))
# summaries
tf.summary.image('lr', self.x_lr, max_outputs=self.batch_size)
tf.summary.image('hr', self.x_hr, max_outputs=self.batch_size)
tf.summary.image('generated-hr', self.output, max_outputs=self.batch_size)
tf.summary.scalar("loss/l1_loss", self.loss)
tf.summary.scalar("metric/psnr", self.psnr)
tf.summary.scalar("metric/ssim", self.ssim)
tf.summary.scalar("misc/lr", self.lr)
# merge summary
self.merged = tf.summary.merge_all()
# model saver
self.saver = tf.train.Saver(max_to_keep=1)
self.best_saver = tf.train.Saver(max_to_keep=1)
self.writer = tf.summary.FileWriter(self.tf_log, self.sess.graph)