-
Notifications
You must be signed in to change notification settings - Fork 91
/
Copy pathrefinement_net.py
38 lines (33 loc) · 1.43 KB
/
refinement_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import tensorflow as tf
class Refinement_net(object):
def __init__(self, is_train=True):
self.is_train = is_train
def inference(self, input_images):
"""Inference on a set of input_images.
Args:
"""
return self._build_model(input_images)
def down(self, x, outChannels, filterSize):
x = tf.layers.average_pooling2d(x, 2, 2)
x = tf.nn.leaky_relu(tf.layers.conv2d(x, outChannels, filterSize, 1, 'same'), 0.1)
x = tf.nn.leaky_relu(tf.layers.conv2d(x, outChannels, filterSize, 1, 'same'), 0.1)
return x
def up(self, x, outChannels, skpCn):
x = tf.image.resize_bilinear(x, 2*tf.shape(x)[1:3])
x = tf.nn.leaky_relu(tf.layers.conv2d(x, outChannels, 3, 1, 'same'), 0.1)
x = tf.nn.leaky_relu(tf.layers.conv2d(tf.concat([x, skpCn], -1), outChannels, 3, 1, 'same'), 0.1)
return x
def _build_model(self, input_images):
x = tf.nn.leaky_relu(tf.layers.conv2d(input_images, 16, 7, 1, 'same'), 0.1)
s1 = tf.nn.leaky_relu(tf.layers.conv2d(x, 16, 7, 1, 'same'), 0.1)
s2 = self.down(s1, 32, 5)
s3 = self.down(s2, 64, 3)
s4 = self.down(s3, 128, 3)
x = self.down(s4, 128, 3)
x = self.up(x, 128, s4)
x = self.up(x, 64, s3)
x = self.up(x, 32, s2)
x = self.up(x, 16, s1)
x = tf.layers.conv2d(x, 3, 3, 1, 'same')
output = input_images[..., 0:3] + x
return output