-
Notifications
You must be signed in to change notification settings - Fork 4
/
models.py
95 lines (77 loc) · 3.12 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import tensorflow as tf
import tensorlayer as tl
from tensorlayer.layers import (BatchNorm2d, Conv2d, Dense, Flatten, Input, DeConv2d, Lambda, \
LocalResponseNorm, MaxPool2d, Elementwise, InstanceNorm2d)
from tensorlayer.models import Model
from data import flags
# def get_G(name=None):
# w_init = tf.random_normal_initializer(stddev=0.02)
#
# nx = Input((flags.batch_size, 256, 256, 3))
#
# n = Conv2d(64, (7, 7), (1, 1), W_init=w_init)(nx)
# n = InstanceNorm2d(act=tf.nn.relu)(n)
#
# n = Conv2d(128, (3, 3), (2, 2), W_init=w_init)(n)
# n = InstanceNorm2d(act=tf.nn.relu)(n)
#
# n = Conv2d(256, (3, 3), (2, 2), W_init=w_init)(n)
# n = InstanceNorm2d(act=tf.nn.relu)(n)
#
# for i in range(9):
# _n = Conv2d(256, (3, 3), (1, 1), W_init=w_init)(n)
# _n = InstanceNorm2d(act=tf.nn.relu)(_n)
# _n = Conv2d(256, (3, 3), (1, 1), W_init=w_init)(_n)
# _n = InstanceNorm2d()(_n)
# n = Elementwise(tf.add)([n, _n])
#
# n = DeConv2d(128, (3, 3), (2, 2), W_init=w_init)(n)
# n = InstanceNorm2d(act=tf.nn.relu)(n)
#
# n = DeConv2d(64, (3, 3), (2, 2), W_init=w_init)(n)
# n = InstanceNorm2d(act=tf.nn.relu)(n)
#
# n = Conv2d(3, (7, 7), (1, 1), act=tf.nn.tanh, W_init=w_init)(n)
#
# M = Model(inputs=nx, outputs=n, name=name)
# return M
def get_G(name=None): # follow the paper
w_init = tf.random_normal_initializer(stddev=0.02)
nx = Input((flags.batch_size, 256, 256, 3))
n = Conv2d(32, (7, 7), (1, 1), W_init=w_init)(nx)
n = InstanceNorm2d(act=tf.nn.relu)(n)
n = Conv2d(64, (3, 3), (2, 2), W_init=w_init)(n)
n = InstanceNorm2d(act=tf.nn.relu)(n)
n = Conv2d(128, (3, 3), (2, 2), W_init=w_init)(n)
n = InstanceNorm2d(act=tf.nn.relu)(n)
for i in range(9):
_n = Conv2d(128, (3, 3), (1, 1), W_init=w_init)(n)
_n = InstanceNorm2d(act=tf.nn.relu)(_n)
_n = Conv2d(128, (3, 3), (1, 1), W_init=w_init)(_n)
_n = InstanceNorm2d()(_n)
n = Elementwise(tf.add)([n, _n])
n = DeConv2d(64, (3, 3), (2, 2), W_init=w_init)(n)
n = InstanceNorm2d(act=tf.nn.relu)(n)
n = DeConv2d(32, (3, 3), (2, 2), W_init=w_init)(n)
n = InstanceNorm2d(act=tf.nn.relu)(n)
n = Conv2d(3, (7, 7), (1, 1), act=tf.nn.tanh, W_init=w_init)(n)
M = Model(inputs=nx, outputs=n, name=name)
return M
def get_D(name=None):
w_init = tf.random_normal_initializer(stddev=0.02)
lrelu = lambda x: tl.act.lrelu(x, 0.2)
nx = Input((flags.batch_size, 256, 256, 3))
# n = Lambda(lambda x: tf.image.random_crop(x, [flags.batch_size, 70, 70, 3]))(nx)
n = Conv2d(64, (4, 4), (2, 2), act=lrelu, W_init=w_init)(nx)
n = Conv2d(128, (4, 4), (2, 2), W_init=w_init)(n)
n = InstanceNorm2d(act=lrelu)(n)
n = Conv2d(256, (4, 4), (2, 2), W_init=w_init)(n)
n = InstanceNorm2d(act=lrelu)(n)
n = Conv2d(512, (4, 4), (1, 1), W_init=w_init)(n)
n = InstanceNorm2d(act=lrelu)(n)
n = Conv2d(1, (4, 4), (1, 1), padding='VALID', W_init=w_init)(n)
n = Flatten()(n)
M = Model(inputs=nx, outputs=n, name=name)
return M