diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..d015226d --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +__pycache__ +GANs/*.pyc +checkpoint +logs +data/minist + +results diff --git a/ACGAN.py b/GANs/ACGAN.py similarity index 99% rename from ACGAN.py rename to GANs/ACGAN.py index 208cea54..c3b054b3 100644 --- a/ACGAN.py +++ b/GANs/ACGAN.py @@ -5,8 +5,8 @@ import tensorflow as tf import numpy as np -from ops import * -from utils import * +from GANs.ops import * +from GANs.utils import * class ACGAN(object): model_name = "ACGAN" # name for checkpoint diff --git a/BEGAN.py b/GANs/BEGAN.py similarity index 99% rename from BEGAN.py rename to GANs/BEGAN.py index e06de8bf..5baa9e06 100644 --- a/BEGAN.py +++ b/GANs/BEGAN.py @@ -5,8 +5,8 @@ import tensorflow as tf import numpy as np -from ops import * -from utils import * +from GANs.ops import * +from GANs.utils import * class BEGAN(object): model_name = "BEGAN" # name for checkpoint diff --git a/CGAN.py b/GANs/CGAN.py similarity index 99% rename from CGAN.py rename to GANs/CGAN.py index 646b8359..c8e15e4e 100644 --- a/CGAN.py +++ b/GANs/CGAN.py @@ -5,8 +5,8 @@ import tensorflow as tf import numpy as np -from ops import * -from utils import * +from GANs.ops import * +from GANs.utils import * class CGAN(object): model_name = "CGAN" # name for checkpoint diff --git a/CVAE.py b/GANs/CVAE.py similarity index 99% rename from CVAE.py rename to GANs/CVAE.py index e76e71de..3f7a47f8 100644 --- a/CVAE.py +++ b/GANs/CVAE.py @@ -5,10 +5,10 @@ import tensorflow as tf import numpy as np -from ops import * -from utils import * +from GANs.ops import * +from GANs.utils import * -import prior_factory as prior +import GANs.prior_factory as prior class CVAE(object): model_name = "CVAE" # name for checkpoint diff --git a/DRAGAN.py b/GANs/DRAGAN.py similarity index 99% rename from DRAGAN.py rename to GANs/DRAGAN.py index 8aa85605..22128ea9 100644 --- a/DRAGAN.py +++ b/GANs/DRAGAN.py @@ -5,8 +5,8 @@ import tensorflow as tf import numpy as np -from ops import * -from utils import * +from GANs.ops import * +from GANs.utils import * class DRAGAN(object): model_name = "DRAGAN" # name for checkpoint diff --git a/EBGAN.py b/GANs/EBGAN.py similarity index 99% rename from EBGAN.py rename to GANs/EBGAN.py index 279314dd..cad4ee2b 100644 --- a/EBGAN.py +++ b/GANs/EBGAN.py @@ -5,8 +5,8 @@ import tensorflow as tf import numpy as np -from ops import * -from utils import * +from GANs.ops import * +from GANs.utils import * class EBGAN(object): model_name = "EBGAN" # name for checkpoint diff --git a/GAN.py b/GANs/GAN.py similarity index 99% rename from GAN.py rename to GANs/GAN.py index 6d45c55e..1679f210 100644 --- a/GAN.py +++ b/GANs/GAN.py @@ -5,8 +5,8 @@ import tensorflow as tf import numpy as np -from ops import * -from utils import * +from GANs.ops import * +from GANs.utils import * class GAN(object): model_name = "GAN" # name for checkpoint diff --git a/LSGAN.py b/GANs/LSGAN.py similarity index 99% rename from LSGAN.py rename to GANs/LSGAN.py index 18beb85a..d38cf65b 100644 --- a/LSGAN.py +++ b/GANs/LSGAN.py @@ -5,8 +5,8 @@ import tensorflow as tf import numpy as np -from ops import * -from utils import * +from GANs.ops import * +from GANs.utils import * class LSGAN(object): model_name = "LSGAN" # name for checkpoint diff --git a/VAE.py b/GANs/VAE.py similarity index 98% rename from VAE.py rename to GANs/VAE.py index 65f20e78..823e7a95 100644 --- a/VAE.py +++ b/GANs/VAE.py @@ -5,10 +5,10 @@ import tensorflow as tf import numpy as np -from ops import * -from utils import * +from GANs.ops import * +from GANs.utils import * -import prior_factory as prior +import GANs.prior_factory as prior class VAE(object): model_name = "VAE" # name for checkpoint @@ -96,7 +96,7 @@ def build_model(self): """ Loss Function """ # encoding - self.mu, sigma = self.encoder(self.inputs, is_training=True, reuse=False) + self.mu, sigma = self.encoder(self.inputs, is_training=True, reuse=False) # sampling by re-parameterization technique z = self.mu + sigma * tf.random_normal(tf.shape(self.mu), 0, 1, dtype=tf.float32) @@ -241,7 +241,7 @@ def visualize_results(self, epoch): else: z_tot = np.concatenate((z_tot, z), axis=0) id_tot = np.concatenate((id_tot, batch_labels), axis=0) - + # in conda, py2.7 and matpltlib version not match save_scattered_image(z_tot, id_tot, -4, 4, name=check_folder( self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_learned_manifold.png') diff --git a/WGAN.py b/GANs/WGAN.py similarity index 99% rename from WGAN.py rename to GANs/WGAN.py index 148527c8..9b685ad9 100644 --- a/WGAN.py +++ b/GANs/WGAN.py @@ -5,8 +5,8 @@ import tensorflow as tf import numpy as np -from ops import * -from utils import * +from GANs.ops import * +from GANs.utils import * class WGAN(object): model_name = "WGAN" # name for checkpoint diff --git a/WGAN_GP.py b/GANs/WGAN_GP.py similarity index 99% rename from WGAN_GP.py rename to GANs/WGAN_GP.py index 347004c5..2f5653e6 100644 --- a/WGAN_GP.py +++ b/GANs/WGAN_GP.py @@ -5,8 +5,8 @@ import tensorflow as tf import numpy as np -from ops import * -from utils import * +from GANs.ops import * +from GANs.utils import * class WGAN_GP(object): model_name = "WGAN_GP" # name for checkpoint diff --git a/infoGAN.py b/GANs/infoGAN.py similarity index 99% rename from infoGAN.py rename to GANs/infoGAN.py index 84ff3bdc..45eb4576 100644 --- a/infoGAN.py +++ b/GANs/infoGAN.py @@ -5,8 +5,8 @@ import tensorflow as tf import numpy as np -from ops import * -from utils import * +from GANs.ops import * +from GANs.utils import * class infoGAN(object): model_name = "infoGAN" # name for checkpoint diff --git a/ops.py b/GANs/ops.py similarity index 98% rename from ops.py rename to GANs/ops.py index 53dccac1..c4fdf329 100644 --- a/ops.py +++ b/GANs/ops.py @@ -2,12 +2,12 @@ Most codes from https://github.com/carpedm20/DCGAN-tensorflow """ import math -import numpy as np +import numpy as np import tensorflow as tf from tensorflow.python.framework import ops -from utils import * +from GANs.utils import * if "concat_v2" in dir(tf): def concat(tensors, axis, *args, **kwargs): diff --git a/prior_factory.py b/GANs/prior_factory.py similarity index 100% rename from prior_factory.py rename to GANs/prior_factory.py diff --git a/utils.py b/GANs/utils.py similarity index 78% rename from utils.py rename to GANs/utils.py index 8d8e5ba3..7a3f586d 100644 --- a/utils.py +++ b/GANs/utils.py @@ -9,7 +9,7 @@ import numpy as np from time import gmtime, strftime from six.moves import xrange -import matplotlib.pyplot as plt +#import matplotlib.pyplot as plt import os, gzip import tensorflow as tf @@ -19,6 +19,7 @@ def load_mnist(dataset_name): data_dir = os.path.join("./data", dataset_name) def extract_data(filename, num_data, head_size, data_size): + print('reading gz data {}'.format(filename)) with gzip.open(filename) as bytestream: bytestream.read(head_size) buf = bytestream.read(data_size * num_data) @@ -123,27 +124,31 @@ def inverse_transform(images): return (images+1.)/2. """ Drawing Tools """ +# in conda, py2.7 and matpltlib version not match, del + # borrowed from https://github.com/ykwon0407/variational_autoencoder/blob/master/variational_bayes.ipynb -def save_scattered_image(z, id, z_range_x, z_range_y, name='scattered_image.jpg'): - N = 10 - plt.figure(figsize=(8, 6)) - plt.scatter(z[:, 0], z[:, 1], c=np.argmax(id, 1), marker='o', edgecolor='none', cmap=discrete_cmap(N, 'jet')) - plt.colorbar(ticks=range(N)) - axes = plt.gca() - axes.set_xlim([-z_range_x, z_range_x]) - axes.set_ylim([-z_range_y, z_range_y]) - plt.grid(True) - plt.savefig(name) - -# borrowed from https://gist.github.com/jakevdp/91077b0cae40f8f8244a -def discrete_cmap(N, base_cmap=None): - """Create an N-bin discrete colormap from the specified input map""" - - # Note that if base_cmap is a string or None, you can simply do - # return plt.cm.get_cmap(base_cmap, N) - # The following works for string, None, or a colormap instance: - - base = plt.cm.get_cmap(base_cmap) - color_list = base(np.linspace(0, 1, N)) - cmap_name = base.name + str(N) - return base.from_list(cmap_name, color_list, N) \ No newline at end of file +# def save_scattered_image(z, id, z_range_x, z_range_y, name='scattered_image.jpg'): +# N = 10 +# plt.figure(figsize=(8, 6)) +# plt.scatter(z[:, 0], z[:, 1], c=np.argmax(id, 1), marker='o', edgecolor='none', cmap=discrete_cmap(N, 'jet')) +# plt.colorbar(ticks=range(N)) +# axes = plt.gca() +# axes.set_xlim([-z_range_x, z_range_x]) +# axes.set_ylim([-z_range_y, z_range_y]) +# plt.grid(True) +# plt.savefig(name) + +# # borrowed from https://gist.github.com/jakevdp/91077b0cae40f8f8244a +# def discrete_cmap(N, base_cmap=None): +# """Create an N-bin discrete colormap from the specified input map""" + +# # Note that if base_cmap is a string or None, you can simply do +# # return plt.cm.get_cmap(base_cmap, N) +# # The following works for string, None, or a colormap instance: + +# base = plt.cm.get_cmap(base_cmap) +# color_list = base(np.linspace(0, 1, N)) +# cmap_name = base.name + str(N) +# return base.from_list(cmap_name, color_list, N) + + diff --git a/README.md b/README.md index bba08bfa..7b590605 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,17 @@ # tensorflow-generative-model-collections Tensorflow implementation of various GANs and VAEs. + ## Related Repositories ### Pytorch version Pytorch version of this repository is availabel at https://github.com/znxlwm/pytorch-generative-model-collections -### "Are GANs Created Equal? A Large-Scale Study" Paper -https://github.com/google/compare_gan is the code that was used in [the paper](https://arxiv.org/abs/1711.10337). -It provides IS/FID and rich experimental results for all gan-variants. +### "Are GANs Created Equal? A Large-Scale Study" Paper +https://github.com/google/compare_gan is the code that was used in [the paper](https://arxiv.org/abs/1711.10337). +It provides IS/FID and rich experimental results for all gan-variants. ## Generative Adversarial Networks (GANs) -### Lists +### Lists *Name* | *Paper Link* | *Value Function* :---: | :---: | :--- | @@ -23,16 +24,16 @@ It provides IS/FID and rich experimental results for all gan-variants. **infoGAN**| [Arxiv](https://arxiv.org/abs/1606.03657) | **ACGAN**| [Arxiv](https://arxiv.org/abs/1610.09585) | **EBGAN**| [Arxiv](https://arxiv.org/abs/1609.03126) | -**BEGAN**| [Arxiv](https://arxiv.org/abs/1702.08431) | +**BEGAN**| [Arxiv](https://arxiv.org/abs/1702.08431) | #### Variants of GAN structure ### Results for mnist -Network architecture of generator and discriminator is the exaclty sames as in [infoGAN paper](https://arxiv.org/abs/1606.03657). +Network architecture of generator and discriminator is the exaclty sames as in [infoGAN paper](https://arxiv.org/abs/1606.03657). For fair comparison of core ideas in all gan variants, all implementations for network architecture are kept same except EBGAN and BEGAN. Small modification is made for EBGAN/BEGAN, since those adopt auto-encoder strucutre for discriminator. But I tried to keep the capacity of discirminator. -The following results can be reproduced with command: +The following results can be reproduced with command: ``` python main.py --dataset mnist --gan_type --epoch 25 --batch_size 64 ``` @@ -68,10 +69,10 @@ infoGAN | | | infoGAN | | | -Without hyper-parameter tuning from mnist-version, ACGAN/infoGAN does not work well as compared with CGAN. -ACGAN tends to fall into mode-collapse. +Without hyper-parameter tuning from mnist-version, ACGAN/infoGAN does not work well as compared with CGAN. +ACGAN tends to fall into mode-collapse. infoGAN tends to ignore noise-vector. It results in that various style within the same class can not be represented. #### InfoGAN : Manipulating two continous codes @@ -122,7 +123,7 @@ infoGAN tends to ignore noise-vector. It results in that various style within th **VAE**| [Arxiv](https://arxiv.org/abs/1312.6114) | **CVAE**| [Arxiv](https://arxiv.org/abs/1406.5298) | **DVAE**| [Arxiv](https://arxiv.org/abs/1511.06406) | (to be added) -**AAE**| [Arxiv](https://arxiv.org/abs/1511.05644) | (to be added) +**AAE**| [Arxiv](https://arxiv.org/abs/1511.05644) | (to be added) #### Variants of VAE structure @@ -130,7 +131,7 @@ infoGAN tends to ignore noise-vector. It results in that various style within th ### Results for mnist Network architecture of decoder(generator) and encoder(discriminator) is the exaclty sames as in [infoGAN paper](https://arxiv.org/abs/1606.0365). The number of output nodes in encoder is different. (2x z_dim for VAE, 1 for GAN) -The following results can be reproduced with command: +The following results can be reproduced with command: ``` python main.py --dataset mnist --gan_type --epoch 25 --batch_size 64 ``` @@ -158,7 +159,7 @@ Results of CGAN is also given to compare images generated from CVAE and CGAN. #### Learned manifold -The following results can be reproduced with command: +The following results can be reproduced with command: ``` python main.py --dataset mnist --gan_type VAE --epoch 25 --batch_size 64 --dim_z 2 ``` @@ -169,9 +170,9 @@ Please notice that dimension of noise-vector z is 2. VAE | | | ### Results for fashion-mnist -Comments on network architecture in mnist are also applied to here. +Comments on network architecture in mnist are also applied to here. -The following results can be reproduced with command: +The following results can be reproduced with command: ``` python main.py --dataset fashion-mnist --gan_type --epoch 40 --batch_size 64 ``` @@ -198,7 +199,7 @@ Results of CGAN is also given to compare images generated from CVAE and CGAN. #### Learned manifold -The following results can be reproduced with command: +The following results can be reproduced with command: ``` python main.py --dataset fashion-mnist --gan_type VAE --epoch 25 --batch_size 64 --dim_z 2 ``` diff --git a/data/get_mnist_data.py b/data/get_mnist_data.py new file mode 100644 index 00000000..688e2017 --- /dev/null +++ b/data/get_mnist_data.py @@ -0,0 +1,6 @@ +from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets +import tensorflow as tf + +# run code in gan_env +mnist = read_data_sets("mnist", one_hot=True) +print(mnist) \ No newline at end of file diff --git a/data/mnist/t10k-images-idx3-ubyte.gz b/data/mnist/t10k-images-idx3-ubyte.gz new file mode 100644 index 00000000..5ace8ea9 Binary files /dev/null and b/data/mnist/t10k-images-idx3-ubyte.gz differ diff --git a/data/mnist/t10k-labels-idx1-ubyte.gz b/data/mnist/t10k-labels-idx1-ubyte.gz new file mode 100644 index 00000000..a7e14154 Binary files /dev/null and b/data/mnist/t10k-labels-idx1-ubyte.gz differ diff --git a/data/mnist/train-images-idx3-ubyte.gz b/data/mnist/train-images-idx3-ubyte.gz new file mode 100644 index 00000000..b50e4b6b Binary files /dev/null and b/data/mnist/train-images-idx3-ubyte.gz differ diff --git a/data/mnist/train-labels-idx1-ubyte.gz b/data/mnist/train-labels-idx1-ubyte.gz new file mode 100644 index 00000000..707a576b Binary files /dev/null and b/data/mnist/train-labels-idx1-ubyte.gz differ diff --git a/main.py b/main.py index 3decd107..3f67d19e 100644 --- a/main.py +++ b/main.py @@ -1,23 +1,23 @@ import os ## GAN Variants -from GAN import GAN -from CGAN import CGAN -from infoGAN import infoGAN -from ACGAN import ACGAN -from EBGAN import EBGAN -from WGAN import WGAN -from WGAN_GP import WGAN_GP -from DRAGAN import DRAGAN -from LSGAN import LSGAN -from BEGAN import BEGAN +from GANs.GAN import GAN +from GANs.CGAN import CGAN +from GANs.infoGAN import infoGAN +from GANs.ACGAN import ACGAN +from GANs.EBGAN import EBGAN +from GANs.WGAN import WGAN +from GANs.WGAN_GP import WGAN_GP +from GANs.DRAGAN import DRAGAN +from GANs.LSGAN import LSGAN +from GANs.BEGAN import BEGAN ## VAE Variants -from VAE import VAE -from CVAE import CVAE +from GANs.VAE import VAE +from GANs.CVAE import CVAE -from utils import show_all_variables -from utils import check_folder +from GANs.utils import show_all_variables +from GANs.utils import check_folder import tensorflow as tf import argparse @@ -29,11 +29,11 @@ def parse_args(): parser.add_argument('--gan_type', type=str, default='GAN', choices=['GAN', 'CGAN', 'infoGAN', 'ACGAN', 'EBGAN', 'BEGAN', 'WGAN', 'WGAN_GP', 'DRAGAN', 'LSGAN', 'VAE', 'CVAE'], - help='The type of GAN', required=True) + help='The type of GAN') parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', 'fashion-mnist', 'celebA'], help='The name of dataset') - parser.add_argument('--epoch', type=int, default=20, help='The number of epochs to run') - parser.add_argument('--batch_size', type=int, default=64, help='The size of batch') + parser.add_argument('--epoch', type=int, default=2, help='The number of epochs to run') + parser.add_argument('--batch_size', type=int, default=1024, help='The size of batch') parser.add_argument('--z_dim', type=int, default=62, help='Dimension of noise vector') parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', help='Directory name to save the checkpoints')