Repository containing resources from our paper:
Kernel Mean Matching for Content Addressability of GANs
Wittawat Jitkrittum,*, Patsorn Sangkloy,* Muhammad Waleed Gondal, Amit Raj, James Hays, Bernhard Schölkopf
ICML 2019
(* Equal contribution)
https://arxiv.org/abs/1905.05882
- Full paper: main text + supplement on arXiv (file size: 36MB)
- Main text only here (file size: 7.3MB)
- Supplementary file only here (file size: 32MB)
We propose a novel procedure which adds content-addressability to any given unconditional implicit model e.g., a generative adversarial network (GAN). The procedure allows users to control the generative process by specifying a set (arbitrary size) of desired examples based on which similar samples are generated from the model. The proposed approach, based on kernel mean matching, is applicable to any generative models which transform latent vectors to samples, and does not require retraining of the model. Experiments on various high-dimensional image generation problems (CelebA-HQ, LSUN bedroom, bridge, tower) show that our approach is able to generate images which are consistent with the input set, while retaining the image quality of the original model. To our knowledge, this is the first work that attempts to construct, at test time, a content-addressable generative model from a trained marginal model.
We consider a GAN model from Mescheder et al., 2018 pretrained on CelebA-HQ. We run our proposed procedure using the three images (with border) at the corners as the input. All images in the triangle are the output from our procedure. Each of the output images is positioned such that the closeness to a corner (an input image) indicates the importance (weight) of the corresponding input image.
For a simple demo example on MNIST, check out this Colab notebook. No local installation is required.
-
Support Python 3.6+.
-
Require Pytorch 0.4.1. Require a GPU with ideally no less than 4GB of memory.
-
Automatic dependency resolution only works with a new version of pip. First upgrade you pip with
pip install --upgrade pip
. -
If you use Anaconda, consider creating a new environment before installing
cadgan
.conda create -n cadgan pytorch=0.4.1
where cadgan in the above command is an arbitrary name for the environment.
-
Activate the environment with
conda activate cadgan
. You might want to install Jupyter notebook withconda install jupyter
. -
Make you you activate the environment first. Then, install the
cadgan
package. This repo is set up so that once you clone, you can dopip install -e /path/to/the/folder/of/this/repo/
to install as a Python package. In Python, we can then do
import cadgan as cdg
, and all the code incadgan
folder is accessible throughcdg
.
You will need to change values in settings.ini
to your local path. This is
important since we will be using relative path in the script.
- Results will be saved in
expr_results_path
data_path
should point to where you store all your input dataproblem_model_path
will be used for storing various pre-trained models (warning: this can be quite large)- See comment in settings.ini for more details
We provide an example script to run CADGAN in ex/run_gkmm.py
For example, here is the command to run CADGAN for celebAHQ dataset on Mescheder et al., 2018's pre-trained model:
python3 run_gkmm.py \
--extractor_type vgg_face \
--extractor_layers 35 \
--texture 0\
--depth_process no \
--g_path celebAHQ_00/chkpts/model.pt \
--g_type celebAHQ.yaml \
--g_min -1.0 \
--g_max 1.0 \
--logdir log_celeba_face/ \
--device gpu \
--n_sample 1 \
--n_opt_iter 1000 \
--lr 5e-2 \
--seed 99 \
--img_log_steps 500 \
--cond_path celebaHQ/ \
--kernel imq \
--kparams -0.5 1e+2 \
--img_size 224
-
The above command will use all images in
[data_path]/celebaHQ/
as conditional images, with the generator from[problem_model_path]/celebAHQ_00/chkpts/model.pt
and then store results in[expr_results_path]/log_celeba_face/
. When this is run for the first time, the GAN model will be downloaded automatically. The required feature extractor (VGG face, in this case) will also be downloaded automatically. Downloading these models may take some time. The size of each model is roughly 300-600 MB. The results are written to a Tensorboard log folder. Simply use Tensorboard to see the result. This can be done by, for instance,./cadgan/ex/start_tensorboard.sh [expr_results_path]/log_celeba_face/
-
Note that possible value of
g_type
arelsun_bedroom.yaml
lsun_bridge.yaml
celebAHQ.yaml
lsun_tower.yaml
mnist_dcgan
colormnist_dcgan
. If the specified generator doesn't exist yet, the code will download the pre-trained model used in the paper into the specified location.
See run_lars_bedroom.sh
, run_lars_bridge.sh
, run_lars_tower.sh
, run_mnist.sh
and run_mnist_color.sh
for other model options.
We also provide 2 example images for each of the dataset in data/
that can be
used for testing.
In case you want to experiment with the parameters, we use ex/cmd_gkmm.py
to
generate commands for multiple combinations of parameters. This requires
cmdprod
package available here: https://github.com/wittawatj/cmdprod .
If you have questions or comments, please contact Wittawat and Patsorn.
- support running cadgan on celebaHQ
- support running cadgan on LSUN
- clean up code & readme
- test that all script can successfully run
- run_mnist.sh
- run_lars_bridge.sh
- run_lars_bedroom.sh
- run_lars_tower.sh
- run_lars_celeba.sh
- run_mnist_color.sh
- upload and share data/model files