-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3f442fb
commit 2cb1b2b
Showing
9 changed files
with
1,641 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import os | ||
GPUID = 1 | ||
os.environ["CUDA_VISIBLE_DEVICES"] = str(GPUID) | ||
import numpy as np | ||
import tensorflow as tf | ||
import scipy.ndimage.interpolation | ||
import matplotlib.pyplot as plt | ||
import matplotlib.gridspec as gridspec | ||
from celeb_model import encoder1, encoder2, discriminator | ||
# from model_fancyCelebA_utils import encoder1, encoder2, discriminator | ||
import scipy.io as sio | ||
import pdb | ||
import h5py | ||
import json | ||
import time | ||
|
||
""" parameters """ | ||
n_epochs = 16 | ||
mb_size = 64 | ||
# X_dim = () | ||
lr = 1e-4 | ||
Z_dim = 64 | ||
pa, pb = 0., 0. | ||
lamb_list = [1e-6, 1e-4, 1e-2, 1e-1, 0.5, 50, 100, 1000] | ||
|
||
##################################### | ||
|
||
def log(x): | ||
return tf.log(x + 1e-8) | ||
|
||
""" data pre-process """ | ||
hdf5_root = '/home/lqchen/work/pixel-cnn-3/data/CelebA/' | ||
Images = h5py.File('%sceleba_64.hdf5' % hdf5_root)['features'] | ||
|
||
# Images_all = np.transpose(Images, [0, 2, 3, 1]) | ||
# Images = Images_all[:162770] | ||
# Images_val = Images_all[162770: 182637] | ||
# Images_test = Images_all[182637:] | ||
# del Images_all | ||
|
||
num_train = 192000 | ||
|
||
""" function utilities """ | ||
def sample_XY(X, Y, size): | ||
start_idx = np.random.randint(0, X.shape[0] - size) | ||
return X[start_idx:start_idx + size], Y[start_idx:start_idx + size] | ||
|
||
|
||
def sample_X(X, size, num_train=192000): | ||
start_idx = np.random.randint(0, num_train - size) | ||
return X[start_idx:start_idx + size] | ||
|
||
|
||
def sample_Y(Y, size): | ||
start_idx = np.random.randint(0, Y.shape[0] - size) | ||
return Y[start_idx:start_idx + size] | ||
|
||
def sample_Z(m, n): | ||
return np.random.uniform(-1., 1., size=[m, n]) | ||
|
||
|
||
def plot(samples): | ||
fig = plt.figure(figsize=(8, 8)) | ||
gs = gridspec.GridSpec(8, 8) | ||
gs.update(wspace=0.05, hspace=0.05) | ||
|
||
for i, sample in enumerate(samples): | ||
ax = plt.subplot(gs[i]) | ||
plt.axis('off') | ||
ax.set_xticklabels([]) | ||
ax.set_yticklabels([]) | ||
ax.set_aspect('equal') | ||
plt.imshow(sample) | ||
return fig | ||
|
||
""" Networks """ | ||
def generative_Y2X(z, reuse=None): | ||
with tf.variable_scope("Y2X", reuse=reuse): | ||
h = encoder2(z) | ||
return h | ||
def generative_X2Y(x, reuse=None): | ||
with tf.variable_scope("X2Y", reuse=reuse): | ||
h = encoder1(x) | ||
return h | ||
|
||
def data_network(x, y, reuse=None): | ||
with tf.variable_scope('D', reuse=reuse): | ||
f, d = discriminator(x, y) | ||
return tf.squeeze(f, squeeze_dims=[1]), tf.squeeze(d, squeeze_dims=[1]) | ||
|
||
# def data_network_2(x, y, reuse=None): | ||
# """Approximate z log data density.""" | ||
# with tf.variable_scope('D2', reuse=reuse) as scope: | ||
# d = discriminator(x, y) | ||
# | ||
# return tf.squeeze(d, squeeze_dims=[1]) | ||
|
||
""" Construct model and training ops """ | ||
tf.reset_default_graph() | ||
|
||
X = tf.placeholder(tf.float32, shape=[mb_size, 64, 64, 3]) | ||
z = tf.placeholder(tf.float32, shape=[mb_size, Z_dim]) | ||
lamb = tf.placeholder(tf.float32) | ||
# Generator | ||
z_gen = generative_X2Y(X) | ||
X_gen = generative_Y2X(z) | ||
# Discriminator | ||
fxz, Dxz = data_network(X, z_gen) | ||
fzx, Dzx = data_network(X_gen, z, reuse=True) | ||
|
||
# Discriminator loss | ||
|
||
D_loss = -tf.reduce_mean(log(Dzx) + log(1 - Dxz)) | ||
|
||
# Generator loss | ||
L_x = -tf.reduce_mean(Dxz) | ||
L_z = tf.reduce_mean(Dzx) | ||
G_loss = L_x + L_z | ||
|
||
## reconstruct | ||
X_rec = generative_Y2X(z_gen, reuse=True) | ||
z_rec = generative_X2Y(X_gen, reuse=True) | ||
Xrec_loss = tf.reduce_mean(tf.abs(X - X_rec)) | ||
Zrec_loss = tf.reduce_mean(tf.abs(z - z_rec)) | ||
|
||
""" Solvers """ | ||
gvar1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "Y2X") | ||
gvar2 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "X2Y") | ||
dvars1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "D") | ||
|
||
opt = tf.train.AdamOptimizer(lr, beta1=0.5) | ||
|
||
# G_loss = G_loss + pa * Xrec_loss + pb * Zrec_loss | ||
G_loss = G_loss + lamb * Xrec_loss + lamb * Zrec_loss | ||
# G_loss = G_loss + lamb * X_ce_loss + lamb * Z_ce_loss | ||
D_solver = opt.minimize(D_loss, var_list = dvars1) | ||
G_solver = opt.minimize(G_loss, var_list = gvar1 + gvar2) | ||
|
||
# Call this after declaring all tf.Variables. | ||
saver = tf.train.Saver() | ||
|
||
""" Training """ | ||
config = tf.ConfigProto() | ||
config.gpu_options.per_process_gpu_memory_fraction = 0.3 | ||
sess = tf.Session(config=config) | ||
|
||
# Load pretrained Model | ||
# try: | ||
# saver.restore(sess=sess, save_path="../model/celeba_model_R.ckpt") | ||
# print("\n--------model restored--------\n") | ||
# except: | ||
# print("\n--------model Not restored--------\n") | ||
# pass | ||
|
||
disc_steps = 2 | ||
gen_steps = 1 | ||
for num_test in range(len(lamb_list)): | ||
init = tf.global_variables_initializer() | ||
sess.run(init) | ||
for it in range(n_epochs): | ||
|
||
# TODO: dynamic control of the steps | ||
# if it >= 4: | ||
# disc_steps = 1 | ||
# gen_steps = 1 | ||
for idx in range(0, num_train // mb_size): | ||
|
||
# _x = Images[idx * mb_size: (idx + 1) * mb_size] | ||
_x = sample_X(Images, mb_size) | ||
_x = np.transpose(_x, [0,2,3,1]) / 127.5 - 1 | ||
z_sample = sample_Z(mb_size, Z_dim) | ||
for k in range(disc_steps): | ||
_, D_loss_curr = sess.run([D_solver, D_loss], | ||
feed_dict={X: _x, z: z_sample, lamb: lamb_list[num_test]}) | ||
for j in range(gen_steps): | ||
_, G_loss_curr = sess.run([G_solver, G_loss], | ||
feed_dict={X: _x, z: z_sample, lamb: lamb_list[num_test]}) | ||
|
||
if idx % 200 == 0: | ||
saver.save(sess, './model/celeba_model_ali_%d.ckpt' % num_test) | ||
print(num_test) | ||
print('epoch: {}; iter: {}; D_loss: {:.4}; G_loss: {:.4}'.format( | ||
it, idx, D_loss_curr, G_loss_curr)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import tensorflow as tf | ||
from tensorflow.contrib import layers | ||
import pdb | ||
import numpy as np | ||
|
||
|
||
def conv_cond_concat(x, y): | ||
"""Concatenate conditioning vector on feature map axis.""" | ||
x_shapes = x.get_shape() | ||
y_shapes = y.get_shape() | ||
return tf.concat([ | ||
x, y * tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3) | ||
|
||
# initializer = tf.truncated_normal_initializer(stddev=0.02) | ||
initializer = tf.contrib.layers.xavier_initializer() | ||
|
||
Z_dim = 64 | ||
mb_size = 64 | ||
noise_dim = 20 | ||
|
||
def lrelu(x, leak=0.2, name="lrelu"): | ||
with tf.variable_scope(name): | ||
f1 = 0.5 * (1 + leak) | ||
f2 = 0.5 * (1 - leak) | ||
return f1 * x + f2 * abs(x) | ||
|
||
def discriminator(x, y): | ||
|
||
h = tf.reshape(x, [-1, 64, 64, 3]) | ||
# noise = tf.random_normal([mb_size, 64, 64, 1]) | ||
noise = tf.random_uniform([mb_size, 64, 64, 1], -1, 1) | ||
h = tf.concat([h, noise], axis=3) | ||
|
||
h = layers.conv2d(h, 64, 5, stride=2, padding='SAME', | ||
activation_fn=None, weights_initializer=initializer) | ||
h = layers.batch_norm(h, activation_fn=lrelu) | ||
|
||
h = layers.conv2d(h, 64 * 2, 5, stride=2, padding='SAME', | ||
activation_fn=None, weights_initializer=initializer) | ||
h = layers.batch_norm(h, activation_fn=lrelu) | ||
|
||
h = layers.conv2d(h, 64 * 4, 5, stride=2, padding='SAME', | ||
activation_fn=None, weights_initializer=initializer) | ||
h = layers.batch_norm(h, activation_fn=lrelu) | ||
|
||
h = layers.conv2d(h, 64 * 8, 5, stride=2, padding='SAME', | ||
activation_fn=None, weights_initializer=initializer) | ||
h = layers.batch_norm(h, activation_fn=lrelu) | ||
# average pooling | ||
h = layers.avg_pool2d(h, 2, stride=2) | ||
h = layers.flatten(h) | ||
|
||
noise_z = tf.random_uniform([mb_size, noise_dim], -1, 1) | ||
y = tf.concat([y, noise_z], axis=1) | ||
zh = layers.fully_connected(y, 2*2*512, activation_fn=lrelu) | ||
|
||
h = tf.concat([h, zh], axis=1) | ||
h = layers.fully_connected(h, 1, activation_fn=None) | ||
|
||
return h, tf.sigmoid(h) | ||
|
||
# def discriminator2(x, y): | ||
# # with tf.variable_scope('Discriminator', reuse=None) as scope: | ||
# yb = tf.reshape(y, [-1, 1, 1, Z_dim]) | ||
# h = tf.reshape(x, [-1, 64, 64, 3]) | ||
|
||
# h = conv_cond_concat(h, yb) | ||
# h = layers.conv2d(h, 64, 5, stride=2, padding='SAME', activation_fn=None, weights_initializer=initializer) | ||
# h = layers.batch_norm(h, activation_fn=tf.nn.relu) | ||
|
||
# # h = conv_cond_concat(h, yb) | ||
# h = layers.conv2d(h, 64*2, 5, stride=2, padding='SAME', activation_fn=None, weights_initializer=initializer) | ||
# h = layers.batch_norm(h, activation_fn=tf.nn.relu) | ||
|
||
# # h = conv_cond_concat(h, yb) | ||
# h = layers.conv2d(h, 64*4, 5, stride=2, padding='SAME', activation_fn=None, weights_initializer=initializer) | ||
# h = layers.batch_norm(h, activation_fn=tf.nn.relu) | ||
|
||
# # h = conv_cond_concat(h, yb) | ||
# h = layers.conv2d(h, 64*8, 5, stride=2, padding='SAME', activation_fn=None, weights_initializer=initializer) | ||
# h = layers.batch_norm(h, activation_fn=tf.nn.relu) | ||
# h = layers.flatten(h) | ||
|
||
# h = layers.fully_connected(h, 1, activation_fn=None) | ||
|
||
# return h, tf.sigmoid(h) | ||
|
||
def encoder1(tensor): | ||
# noise = tf.random_normal([mb_size, 64, 64, 1]) | ||
noise = tf.random_uniform([mb_size, 64, 64, 1], -1, 1) | ||
tensor = tf.concat([tensor, noise], axis=3) | ||
conv1 = layers.conv2d(tensor, 32, 5, stride=2, | ||
activation_fn=None, weights_initializer=initializer) | ||
conv1 = layers.batch_norm(conv1, activation_fn=tf.nn.relu) | ||
|
||
conv2 = layers.conv2d(conv1, 64, 5, stride=2, activation_fn=None, | ||
normalizer_fn=layers.batch_norm, weights_initializer=initializer) | ||
conv2 = layers.batch_norm(conv2, activation_fn=tf.nn.relu) | ||
|
||
conv3 = layers.conv2d(conv2, 128, 5, stride=2, activation_fn=None, normalizer_fn=layers.batch_norm, | ||
weights_initializer=initializer) | ||
conv3 = layers.batch_norm(conv3, activation_fn=tf.nn.relu) | ||
|
||
conv4 = layers.conv2d(conv3, 256, 5, stride=2, activation_fn=None, normalizer_fn=layers.batch_norm, | ||
weights_initializer=initializer) | ||
conv4 = layers.batch_norm(conv4, activation_fn=tf.nn.relu) | ||
|
||
conv5 = layers.conv2d(conv4, 512, 5, stride=2, activation_fn=None, normalizer_fn=layers.batch_norm, | ||
weights_initializer=initializer) | ||
conv5 = layers.batch_norm(conv5, activation_fn=tf.nn.relu) | ||
|
||
# fc1 = tf.reshape(conv4, shape=[-1, 2 * 2 * 512]) | ||
fc1 = layers.flatten(conv5) | ||
fc1 = layers.fully_connected( | ||
inputs=fc1, num_outputs=512, activation_fn=None, weights_initializer=initializer) | ||
fc1 = layers.batch_norm(fc1, activation_fn=lrelu) | ||
|
||
fc2 = layers.fully_connected(inputs=fc1, num_outputs=Z_dim, | ||
activation_fn=tf.nn.tanh, weights_initializer=initializer) | ||
|
||
return fc2 | ||
|
||
|
||
def encoder2(y): | ||
# noise = tf.random_normal([mb_size, 10]) | ||
noise = tf.random_uniform([mb_size, noise_dim], -1, 1) | ||
h = tf.concat([y, noise], axis=1) | ||
|
||
h = layers.fully_connected(h, 1024, weights_initializer=initializer) | ||
h = layers.batch_norm(h, activation_fn=lrelu) | ||
|
||
h = layers.fully_connected( | ||
h, 64 * 8 * 4 * 4, activation_fn=None, weights_initializer=initializer) | ||
h = tf.reshape(h, [-1, 4, 4, 64 * 8]) | ||
h = layers.batch_norm(h, activation_fn=lrelu) | ||
|
||
h = layers.conv2d_transpose(h, 64 * 4, 5, stride=2, padding='SAME', | ||
activation_fn=None, weights_initializer=initializer) | ||
h = layers.batch_norm(h, activation_fn=lrelu) | ||
|
||
h = layers.conv2d_transpose(h, 64 * 2, 5, stride=2, padding='SAME', | ||
activation_fn=None, weights_initializer=initializer) | ||
h = layers.batch_norm(h, activation_fn=lrelu) | ||
|
||
h = layers.conv2d_transpose(h, 64 * 1, 5, stride=2, padding='SAME', | ||
activation_fn=None, weights_initializer=initializer) | ||
h = layers.batch_norm(h, activation_fn=lrelu) | ||
|
||
h = layers.conv2d_transpose(h, 3, 5, stride=2, padding='SAME', | ||
activation_fn=tf.nn.tanh, weights_initializer=initializer) | ||
return h |
Oops, something went wrong.