forked from tolstikhin/wae
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
91 lines (80 loc) · 3.04 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import sys
import logging
import argparse
import configs
from wae import WAE
from datahandler import DataHandler
import utils
parser = argparse.ArgumentParser()
parser.add_argument("--exp", default='mnist_small',
help='dataset [mnist/celebA/dsprites]')
parser.add_argument("--zdim",
help='dimensionality of the latent space',
type=int)
parser.add_argument("--lr",
help='ae learning rate',
type=float)
parser.add_argument("--z_test",
help='method of choice for verifying Pz=Qz [mmd/gan]')
parser.add_argument("--wae_lambda", help='WAE regularizer', type=int)
parser.add_argument("--work_dir")
parser.add_argument("--lambda_schedule",
help='constant or adaptive')
parser.add_argument("--enc_noise",
help="type of encoder noise:"\
" 'deterministic': no noise whatsoever,"\
" 'gaussian': gaussian encoder,"\
" 'implicit': implicit encoder,"\
" 'add_noise': add noise before feeding "\
"to deterministic encoder")
FLAGS = parser.parse_args()
def main():
if FLAGS.exp == 'celebA':
opts = configs.config_celebA
elif FLAGS.exp == 'celebA_small':
opts = configs.config_celebA_small
elif FLAGS.exp == 'mnist':
opts = configs.config_mnist
elif FLAGS.exp == 'mnist_small':
opts = configs.config_mnist_small
elif FLAGS.exp == 'dsprites':
opts = configs.config_dsprites
elif FLAGS.exp == 'grassli':
opts = configs.config_grassli
elif FLAGS.exp == 'grassli_small':
opts = configs.config_grassli_small
else:
assert False, 'Unknown experiment configuration'
if FLAGS.zdim is not None:
opts['zdim'] = FLAGS.zdim
if FLAGS.lr is not None:
opts['lr'] = FLAGS.lr
if FLAGS.z_test is not None:
opts['z_test'] = FLAGS.z_test
if FLAGS.lambda_schedule is not None:
opts['lambda_schedule'] = FLAGS.lambda_schedule
if FLAGS.work_dir is not None:
opts['work_dir'] = FLAGS.work_dir
if FLAGS.wae_lambda is not None:
opts['lambda'] = FLAGS.wae_lambda
if FLAGS.enc_noise is not None:
opts['e_noise'] = FLAGS.enc_noise
if opts['verbose']:
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(message)s')
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
utils.create_dir(opts['work_dir'])
utils.create_dir(os.path.join(opts['work_dir'],
'checkpoints'))
# Dumping all the configs to the text file
with utils.o_gfile((opts['work_dir'], 'params.txt'), 'w') as text:
text.write('Parameters:\n')
for key in opts:
text.write('%s : %s\n' % (key, opts[key]))
# Loading the dataset
data = DataHandler(opts)
assert data.num_points >= opts['batch_size'], 'Training set too small'
# Training WAE
wae = WAE(opts)
wae.train(data)
main()