diff --git a/.gitignore b/.gitignore
new file mode 100644
index 00000000..d015226d
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,7 @@
+__pycache__
+GANs/*.pyc
+checkpoint
+logs
+data/minist
+
+results
diff --git a/ACGAN.py b/GANs/ACGAN.py
similarity index 99%
rename from ACGAN.py
rename to GANs/ACGAN.py
index 208cea54..c3b054b3 100644
--- a/ACGAN.py
+++ b/GANs/ACGAN.py
@@ -5,8 +5,8 @@
import tensorflow as tf
import numpy as np
-from ops import *
-from utils import *
+from GANs.ops import *
+from GANs.utils import *
class ACGAN(object):
model_name = "ACGAN" # name for checkpoint
diff --git a/BEGAN.py b/GANs/BEGAN.py
similarity index 99%
rename from BEGAN.py
rename to GANs/BEGAN.py
index e06de8bf..5baa9e06 100644
--- a/BEGAN.py
+++ b/GANs/BEGAN.py
@@ -5,8 +5,8 @@
import tensorflow as tf
import numpy as np
-from ops import *
-from utils import *
+from GANs.ops import *
+from GANs.utils import *
class BEGAN(object):
model_name = "BEGAN" # name for checkpoint
diff --git a/CGAN.py b/GANs/CGAN.py
similarity index 99%
rename from CGAN.py
rename to GANs/CGAN.py
index 646b8359..c8e15e4e 100644
--- a/CGAN.py
+++ b/GANs/CGAN.py
@@ -5,8 +5,8 @@
import tensorflow as tf
import numpy as np
-from ops import *
-from utils import *
+from GANs.ops import *
+from GANs.utils import *
class CGAN(object):
model_name = "CGAN" # name for checkpoint
diff --git a/CVAE.py b/GANs/CVAE.py
similarity index 99%
rename from CVAE.py
rename to GANs/CVAE.py
index e76e71de..3f7a47f8 100644
--- a/CVAE.py
+++ b/GANs/CVAE.py
@@ -5,10 +5,10 @@
import tensorflow as tf
import numpy as np
-from ops import *
-from utils import *
+from GANs.ops import *
+from GANs.utils import *
-import prior_factory as prior
+import GANs.prior_factory as prior
class CVAE(object):
model_name = "CVAE" # name for checkpoint
diff --git a/DRAGAN.py b/GANs/DRAGAN.py
similarity index 99%
rename from DRAGAN.py
rename to GANs/DRAGAN.py
index 8aa85605..22128ea9 100644
--- a/DRAGAN.py
+++ b/GANs/DRAGAN.py
@@ -5,8 +5,8 @@
import tensorflow as tf
import numpy as np
-from ops import *
-from utils import *
+from GANs.ops import *
+from GANs.utils import *
class DRAGAN(object):
model_name = "DRAGAN" # name for checkpoint
diff --git a/EBGAN.py b/GANs/EBGAN.py
similarity index 99%
rename from EBGAN.py
rename to GANs/EBGAN.py
index 279314dd..cad4ee2b 100644
--- a/EBGAN.py
+++ b/GANs/EBGAN.py
@@ -5,8 +5,8 @@
import tensorflow as tf
import numpy as np
-from ops import *
-from utils import *
+from GANs.ops import *
+from GANs.utils import *
class EBGAN(object):
model_name = "EBGAN" # name for checkpoint
diff --git a/GAN.py b/GANs/GAN.py
similarity index 99%
rename from GAN.py
rename to GANs/GAN.py
index 6d45c55e..1679f210 100644
--- a/GAN.py
+++ b/GANs/GAN.py
@@ -5,8 +5,8 @@
import tensorflow as tf
import numpy as np
-from ops import *
-from utils import *
+from GANs.ops import *
+from GANs.utils import *
class GAN(object):
model_name = "GAN" # name for checkpoint
diff --git a/LSGAN.py b/GANs/LSGAN.py
similarity index 99%
rename from LSGAN.py
rename to GANs/LSGAN.py
index 18beb85a..d38cf65b 100644
--- a/LSGAN.py
+++ b/GANs/LSGAN.py
@@ -5,8 +5,8 @@
import tensorflow as tf
import numpy as np
-from ops import *
-from utils import *
+from GANs.ops import *
+from GANs.utils import *
class LSGAN(object):
model_name = "LSGAN" # name for checkpoint
diff --git a/VAE.py b/GANs/VAE.py
similarity index 98%
rename from VAE.py
rename to GANs/VAE.py
index 65f20e78..823e7a95 100644
--- a/VAE.py
+++ b/GANs/VAE.py
@@ -5,10 +5,10 @@
import tensorflow as tf
import numpy as np
-from ops import *
-from utils import *
+from GANs.ops import *
+from GANs.utils import *
-import prior_factory as prior
+import GANs.prior_factory as prior
class VAE(object):
model_name = "VAE" # name for checkpoint
@@ -96,7 +96,7 @@ def build_model(self):
""" Loss Function """
# encoding
- self.mu, sigma = self.encoder(self.inputs, is_training=True, reuse=False)
+ self.mu, sigma = self.encoder(self.inputs, is_training=True, reuse=False)
# sampling by re-parameterization technique
z = self.mu + sigma * tf.random_normal(tf.shape(self.mu), 0, 1, dtype=tf.float32)
@@ -241,7 +241,7 @@ def visualize_results(self, epoch):
else:
z_tot = np.concatenate((z_tot, z), axis=0)
id_tot = np.concatenate((id_tot, batch_labels), axis=0)
-
+ # in conda, py2.7 and matpltlib version not match
save_scattered_image(z_tot, id_tot, -4, 4, name=check_folder(
self.result_dir + '/' + self.model_dir) + '/' + self.model_name + '_epoch%03d' % epoch + '_learned_manifold.png')
diff --git a/WGAN.py b/GANs/WGAN.py
similarity index 99%
rename from WGAN.py
rename to GANs/WGAN.py
index 148527c8..9b685ad9 100644
--- a/WGAN.py
+++ b/GANs/WGAN.py
@@ -5,8 +5,8 @@
import tensorflow as tf
import numpy as np
-from ops import *
-from utils import *
+from GANs.ops import *
+from GANs.utils import *
class WGAN(object):
model_name = "WGAN" # name for checkpoint
diff --git a/WGAN_GP.py b/GANs/WGAN_GP.py
similarity index 99%
rename from WGAN_GP.py
rename to GANs/WGAN_GP.py
index 347004c5..2f5653e6 100644
--- a/WGAN_GP.py
+++ b/GANs/WGAN_GP.py
@@ -5,8 +5,8 @@
import tensorflow as tf
import numpy as np
-from ops import *
-from utils import *
+from GANs.ops import *
+from GANs.utils import *
class WGAN_GP(object):
model_name = "WGAN_GP" # name for checkpoint
diff --git a/infoGAN.py b/GANs/infoGAN.py
similarity index 99%
rename from infoGAN.py
rename to GANs/infoGAN.py
index 84ff3bdc..45eb4576 100644
--- a/infoGAN.py
+++ b/GANs/infoGAN.py
@@ -5,8 +5,8 @@
import tensorflow as tf
import numpy as np
-from ops import *
-from utils import *
+from GANs.ops import *
+from GANs.utils import *
class infoGAN(object):
model_name = "infoGAN" # name for checkpoint
diff --git a/ops.py b/GANs/ops.py
similarity index 98%
rename from ops.py
rename to GANs/ops.py
index 53dccac1..c4fdf329 100644
--- a/ops.py
+++ b/GANs/ops.py
@@ -2,12 +2,12 @@
Most codes from https://github.com/carpedm20/DCGAN-tensorflow
"""
import math
-import numpy as np
+import numpy as np
import tensorflow as tf
from tensorflow.python.framework import ops
-from utils import *
+from GANs.utils import *
if "concat_v2" in dir(tf):
def concat(tensors, axis, *args, **kwargs):
diff --git a/prior_factory.py b/GANs/prior_factory.py
similarity index 100%
rename from prior_factory.py
rename to GANs/prior_factory.py
diff --git a/utils.py b/GANs/utils.py
similarity index 78%
rename from utils.py
rename to GANs/utils.py
index 8d8e5ba3..7a3f586d 100644
--- a/utils.py
+++ b/GANs/utils.py
@@ -9,7 +9,7 @@
import numpy as np
from time import gmtime, strftime
from six.moves import xrange
-import matplotlib.pyplot as plt
+#import matplotlib.pyplot as plt
import os, gzip
import tensorflow as tf
@@ -19,6 +19,7 @@ def load_mnist(dataset_name):
data_dir = os.path.join("./data", dataset_name)
def extract_data(filename, num_data, head_size, data_size):
+ print('reading gz data {}'.format(filename))
with gzip.open(filename) as bytestream:
bytestream.read(head_size)
buf = bytestream.read(data_size * num_data)
@@ -123,27 +124,31 @@ def inverse_transform(images):
return (images+1.)/2.
""" Drawing Tools """
+# in conda, py2.7 and matpltlib version not match, del
+
# borrowed from https://github.com/ykwon0407/variational_autoencoder/blob/master/variational_bayes.ipynb
-def save_scattered_image(z, id, z_range_x, z_range_y, name='scattered_image.jpg'):
- N = 10
- plt.figure(figsize=(8, 6))
- plt.scatter(z[:, 0], z[:, 1], c=np.argmax(id, 1), marker='o', edgecolor='none', cmap=discrete_cmap(N, 'jet'))
- plt.colorbar(ticks=range(N))
- axes = plt.gca()
- axes.set_xlim([-z_range_x, z_range_x])
- axes.set_ylim([-z_range_y, z_range_y])
- plt.grid(True)
- plt.savefig(name)
-
-# borrowed from https://gist.github.com/jakevdp/91077b0cae40f8f8244a
-def discrete_cmap(N, base_cmap=None):
- """Create an N-bin discrete colormap from the specified input map"""
-
- # Note that if base_cmap is a string or None, you can simply do
- # return plt.cm.get_cmap(base_cmap, N)
- # The following works for string, None, or a colormap instance:
-
- base = plt.cm.get_cmap(base_cmap)
- color_list = base(np.linspace(0, 1, N))
- cmap_name = base.name + str(N)
- return base.from_list(cmap_name, color_list, N)
\ No newline at end of file
+# def save_scattered_image(z, id, z_range_x, z_range_y, name='scattered_image.jpg'):
+# N = 10
+# plt.figure(figsize=(8, 6))
+# plt.scatter(z[:, 0], z[:, 1], c=np.argmax(id, 1), marker='o', edgecolor='none', cmap=discrete_cmap(N, 'jet'))
+# plt.colorbar(ticks=range(N))
+# axes = plt.gca()
+# axes.set_xlim([-z_range_x, z_range_x])
+# axes.set_ylim([-z_range_y, z_range_y])
+# plt.grid(True)
+# plt.savefig(name)
+
+# # borrowed from https://gist.github.com/jakevdp/91077b0cae40f8f8244a
+# def discrete_cmap(N, base_cmap=None):
+# """Create an N-bin discrete colormap from the specified input map"""
+
+# # Note that if base_cmap is a string or None, you can simply do
+# # return plt.cm.get_cmap(base_cmap, N)
+# # The following works for string, None, or a colormap instance:
+
+# base = plt.cm.get_cmap(base_cmap)
+# color_list = base(np.linspace(0, 1, N))
+# cmap_name = base.name + str(N)
+# return base.from_list(cmap_name, color_list, N)
+
+
diff --git a/README.md b/README.md
index bba08bfa..7b590605 100644
--- a/README.md
+++ b/README.md
@@ -1,16 +1,17 @@
# tensorflow-generative-model-collections
Tensorflow implementation of various GANs and VAEs.
+
## Related Repositories
### Pytorch version
Pytorch version of this repository is availabel at https://github.com/znxlwm/pytorch-generative-model-collections
-### "Are GANs Created Equal? A Large-Scale Study" Paper
-https://github.com/google/compare_gan is the code that was used in [the paper](https://arxiv.org/abs/1711.10337).
-It provides IS/FID and rich experimental results for all gan-variants.
+### "Are GANs Created Equal? A Large-Scale Study" Paper
+https://github.com/google/compare_gan is the code that was used in [the paper](https://arxiv.org/abs/1711.10337).
+It provides IS/FID and rich experimental results for all gan-variants.
## Generative Adversarial Networks (GANs)
-### Lists
+### Lists
*Name* | *Paper Link* | *Value Function*
:---: | :---: | :--- |
@@ -23,16 +24,16 @@ It provides IS/FID and rich experimental results for all gan-variants.
**infoGAN**| [Arxiv](https://arxiv.org/abs/1606.03657) |
**ACGAN**| [Arxiv](https://arxiv.org/abs/1610.09585) |
**EBGAN**| [Arxiv](https://arxiv.org/abs/1609.03126) |
-**BEGAN**| [Arxiv](https://arxiv.org/abs/1702.08431) |
+**BEGAN**| [Arxiv](https://arxiv.org/abs/1702.08431) |
#### Variants of GAN structure
### Results for mnist
-Network architecture of generator and discriminator is the exaclty sames as in [infoGAN paper](https://arxiv.org/abs/1606.03657).
+Network architecture of generator and discriminator is the exaclty sames as in [infoGAN paper](https://arxiv.org/abs/1606.03657).
For fair comparison of core ideas in all gan variants, all implementations for network architecture are kept same except EBGAN and BEGAN. Small modification is made for EBGAN/BEGAN, since those adopt auto-encoder strucutre for discriminator. But I tried to keep the capacity of discirminator.
-The following results can be reproduced with command:
+The following results can be reproduced with command:
```
python main.py --dataset mnist --gan_type --epoch 25 --batch_size 64
```
@@ -68,10 +69,10 @@ infoGAN | | |
infoGAN | | |
-Without hyper-parameter tuning from mnist-version, ACGAN/infoGAN does not work well as compared with CGAN.
-ACGAN tends to fall into mode-collapse.
+Without hyper-parameter tuning from mnist-version, ACGAN/infoGAN does not work well as compared with CGAN.
+ACGAN tends to fall into mode-collapse.
infoGAN tends to ignore noise-vector. It results in that various style within the same class can not be represented.
#### InfoGAN : Manipulating two continous codes
@@ -122,7 +123,7 @@ infoGAN tends to ignore noise-vector. It results in that various style within th
**VAE**| [Arxiv](https://arxiv.org/abs/1312.6114) |
**CVAE**| [Arxiv](https://arxiv.org/abs/1406.5298) |
**DVAE**| [Arxiv](https://arxiv.org/abs/1511.06406) | (to be added)
-**AAE**| [Arxiv](https://arxiv.org/abs/1511.05644) | (to be added)
+**AAE**| [Arxiv](https://arxiv.org/abs/1511.05644) | (to be added)
#### Variants of VAE structure
@@ -130,7 +131,7 @@ infoGAN tends to ignore noise-vector. It results in that various style within th
### Results for mnist
Network architecture of decoder(generator) and encoder(discriminator) is the exaclty sames as in [infoGAN paper](https://arxiv.org/abs/1606.0365). The number of output nodes in encoder is different. (2x z_dim for VAE, 1 for GAN)
-The following results can be reproduced with command:
+The following results can be reproduced with command:
```
python main.py --dataset mnist --gan_type --epoch 25 --batch_size 64
```
@@ -158,7 +159,7 @@ Results of CGAN is also given to compare images generated from CVAE and CGAN.
#### Learned manifold
-The following results can be reproduced with command:
+The following results can be reproduced with command:
```
python main.py --dataset mnist --gan_type VAE --epoch 25 --batch_size 64 --dim_z 2
```
@@ -169,9 +170,9 @@ Please notice that dimension of noise-vector z is 2.
VAE | | |
### Results for fashion-mnist
-Comments on network architecture in mnist are also applied to here.
+Comments on network architecture in mnist are also applied to here.
-The following results can be reproduced with command:
+The following results can be reproduced with command:
```
python main.py --dataset fashion-mnist --gan_type --epoch 40 --batch_size 64
```
@@ -198,7 +199,7 @@ Results of CGAN is also given to compare images generated from CVAE and CGAN.
#### Learned manifold
-The following results can be reproduced with command:
+The following results can be reproduced with command:
```
python main.py --dataset fashion-mnist --gan_type VAE --epoch 25 --batch_size 64 --dim_z 2
```
diff --git a/data/get_mnist_data.py b/data/get_mnist_data.py
new file mode 100644
index 00000000..688e2017
--- /dev/null
+++ b/data/get_mnist_data.py
@@ -0,0 +1,6 @@
+from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
+import tensorflow as tf
+
+# run code in gan_env
+mnist = read_data_sets("mnist", one_hot=True)
+print(mnist)
\ No newline at end of file
diff --git a/data/mnist/t10k-images-idx3-ubyte.gz b/data/mnist/t10k-images-idx3-ubyte.gz
new file mode 100644
index 00000000..5ace8ea9
Binary files /dev/null and b/data/mnist/t10k-images-idx3-ubyte.gz differ
diff --git a/data/mnist/t10k-labels-idx1-ubyte.gz b/data/mnist/t10k-labels-idx1-ubyte.gz
new file mode 100644
index 00000000..a7e14154
Binary files /dev/null and b/data/mnist/t10k-labels-idx1-ubyte.gz differ
diff --git a/data/mnist/train-images-idx3-ubyte.gz b/data/mnist/train-images-idx3-ubyte.gz
new file mode 100644
index 00000000..b50e4b6b
Binary files /dev/null and b/data/mnist/train-images-idx3-ubyte.gz differ
diff --git a/data/mnist/train-labels-idx1-ubyte.gz b/data/mnist/train-labels-idx1-ubyte.gz
new file mode 100644
index 00000000..707a576b
Binary files /dev/null and b/data/mnist/train-labels-idx1-ubyte.gz differ
diff --git a/main.py b/main.py
index 3decd107..3f67d19e 100644
--- a/main.py
+++ b/main.py
@@ -1,23 +1,23 @@
import os
## GAN Variants
-from GAN import GAN
-from CGAN import CGAN
-from infoGAN import infoGAN
-from ACGAN import ACGAN
-from EBGAN import EBGAN
-from WGAN import WGAN
-from WGAN_GP import WGAN_GP
-from DRAGAN import DRAGAN
-from LSGAN import LSGAN
-from BEGAN import BEGAN
+from GANs.GAN import GAN
+from GANs.CGAN import CGAN
+from GANs.infoGAN import infoGAN
+from GANs.ACGAN import ACGAN
+from GANs.EBGAN import EBGAN
+from GANs.WGAN import WGAN
+from GANs.WGAN_GP import WGAN_GP
+from GANs.DRAGAN import DRAGAN
+from GANs.LSGAN import LSGAN
+from GANs.BEGAN import BEGAN
## VAE Variants
-from VAE import VAE
-from CVAE import CVAE
+from GANs.VAE import VAE
+from GANs.CVAE import CVAE
-from utils import show_all_variables
-from utils import check_folder
+from GANs.utils import show_all_variables
+from GANs.utils import check_folder
import tensorflow as tf
import argparse
@@ -29,11 +29,11 @@ def parse_args():
parser.add_argument('--gan_type', type=str, default='GAN',
choices=['GAN', 'CGAN', 'infoGAN', 'ACGAN', 'EBGAN', 'BEGAN', 'WGAN', 'WGAN_GP', 'DRAGAN', 'LSGAN', 'VAE', 'CVAE'],
- help='The type of GAN', required=True)
+ help='The type of GAN')
parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', 'fashion-mnist', 'celebA'],
help='The name of dataset')
- parser.add_argument('--epoch', type=int, default=20, help='The number of epochs to run')
- parser.add_argument('--batch_size', type=int, default=64, help='The size of batch')
+ parser.add_argument('--epoch', type=int, default=2, help='The number of epochs to run')
+ parser.add_argument('--batch_size', type=int, default=1024, help='The size of batch')
parser.add_argument('--z_dim', type=int, default=62, help='Dimension of noise vector')
parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
help='Directory name to save the checkpoints')