-
Notifications
You must be signed in to change notification settings - Fork 14
/
conditional_gan_mnist.py
74 lines (62 loc) · 2.36 KB
/
conditional_gan_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import numpy as np
import tensorflow as tf
from easydict import EasyDict as edict
from gans.callbacks import saver
from gans.datasets import mnist
from gans.models.discriminators import discriminator
from gans.models.generators.latent_to_image import latent_to_image
from gans.trainers import conditional_gan_trainer
from gans.trainers import optimizers
model_parameters = edict({
'img_height': 28,
'img_width': 28,
'num_channels': 1,
'batch_size': 16,
'num_epochs': 10,
'buffer_size': 1000,
'latent_size': 100,
'num_classes': 10,
'learning_rate_generator': 0.0001,
'learning_rate_discriminator': 0.0001,
'save_images_every_n_steps': 10
})
dataset = mnist.MnistDataset(model_parameters, with_labels=True)
def validation_dataset():
test_batch_size = model_parameters.num_classes ** 2
labels = np.repeat(list(range(model_parameters.num_classes)), model_parameters.num_classes)
validation_samples = [tf.random.normal([test_batch_size, model_parameters.latent_size]), np.array(labels)]
return validation_samples
validation_dataset = validation_dataset()
generator = latent_to_image.LatentToImageGenerator(model_parameters)
discriminator = discriminator.Discriminator(model_parameters)
generator_optimizer = optimizers.Adam(
learning_rate=model_parameters.learning_rate_generator,
beta_1=0.5,
)
discriminator_optimizer = optimizers.Adam(
learning_rate=model_parameters.learning_rate_discriminator,
beta_1=0.5,
)
callbacks = [
saver.ImageProblemSaver(
save_images_every_n_steps=model_parameters.save_images_every_n_steps,
)
]
gan_trainer = conditional_gan_trainer.ConditionalGANTrainer(
batch_size=model_parameters.batch_size,
generator=generator,
discriminator=discriminator,
training_name='CONDITIONAL_GAN_MNIST',
generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
latent_size=model_parameters.latent_size,
num_classes=model_parameters.num_classes,
continue_training=False,
save_images_every_n_steps=model_parameters.save_images_every_n_steps,
validation_dataset=validation_dataset,
callbacks=callbacks,
)
gan_trainer.train(
dataset=dataset,
num_epochs=model_parameters.num_epochs,
)