-
Notifications
You must be signed in to change notification settings - Fork 82
/
Copy pathmdn.py
186 lines (158 loc) · 8.31 KB
/
mdn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""
A Mixture Density Layer for Keras
cpmpercussion: Charles Martin (University of Oslo) 2018
https://github.com/cpmpercussion/keras-mdn-layer
Hat tip to [Omimo's Keras MDN layer](https://github.com/omimo/Keras-MDN) for a starting point for this code.
"""
import keras
from keras import backend as K
from keras.layers import Dense
from keras.engine.topology import Layer
import numpy as np
from tensorflow.contrib.distributions import Categorical, Mixture, MultivariateNormalDiag
import tensorflow as tf
def elu_plus_one_plus_epsilon(x):
"""ELU activation with a very small addition to help prevent NaN in loss."""
return (K.elu(x) + 1 + 1e-8)
class MDN(Layer):
"""A Mixture Density Network Layer for Keras.
This layer has a few tricks to avoid NaNs in the loss function when training:
- Activation for variances is ELU + 1 + 1e-8 (to avoid very small values)
- Mixture weights (pi) are trained in as logits, not in the softmax space.
A loss function needs to be constructed with the same output dimension and number of mixtures.
A sampling function is also provided to sample from distribution parametrised by the MDN outputs.
"""
def __init__(self, output_dimension, num_mixtures, **kwargs):
self.output_dim = output_dimension
self.num_mix = num_mixtures
with tf.name_scope('MDN'):
self.mdn_mus = Dense(self.num_mix * self.output_dim, name='mdn_mus') # mix*output vals, no activation
self.mdn_sigmas = Dense(self.num_mix * self.output_dim, activation=elu_plus_one_plus_epsilon, name='mdn_sigmas') # mix*output vals exp activation
self.mdn_pi = Dense(self.num_mix, name='mdn_pi') # mix vals, logits
super(MDN, self).__init__(**kwargs)
def build(self, input_shape):
self.mdn_mus.build(input_shape)
self.mdn_sigmas.build(input_shape)
self.mdn_pi.build(input_shape)
self.trainable_weights = self.mdn_mus.trainable_weights + self.mdn_sigmas.trainable_weights + self.mdn_pi.trainable_weights
self.non_trainable_weights = self.mdn_mus.non_trainable_weights + self.mdn_sigmas.non_trainable_weights + self.mdn_pi.non_trainable_weights
super(MDN, self).build(input_shape)
def call(self, x, mask=None):
with tf.name_scope('MDN'):
mdn_out = keras.layers.concatenate([self.mdn_mus(x),
self.mdn_sigmas(x),
self.mdn_pi(x)],
name='mdn_outputs')
return mdn_out
def compute_output_shape(self, input_shape):
return (input_shape[0], self.output_dim)
def get_config(self):
config = {
"output_dimension": self.output_dim,
"num_mixtures": self.num_mix
}
base_config = super(MDN, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def get_mixture_loss_func(output_dim, num_mixes):
"""Construct a loss functions for the MDN layer parametrised by number of mixtures."""
# Construct a loss function with the right number of mixtures and outputs
def loss_func(y_true, y_pred):
out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[num_mixes * output_dim,
num_mixes * output_dim,
num_mixes],
axis=1, name='mdn_coef_split')
cat = Categorical(logits=out_pi)
component_splits = [output_dim] * num_mixes
mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1)
sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1)
coll = [MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale
in zip(mus, sigs)]
mixture = Mixture(cat=cat, components=coll)
loss = mixture.log_prob(y_true)
loss = tf.negative(loss)
loss = tf.reduce_mean(loss)
return loss
# Actually return the loss_func
with tf.name_scope('MDN'):
return loss_func
def get_mixture_sampling_fun(output_dim, num_mixes):
"""Construct a sampling function for the MDN layer parametrised by mixtures and output dimension."""
# Construct a loss function with the right number of mixtures and outputs
def sampling_func(y_pred):
out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[num_mixes * output_dim,
num_mixes * output_dim,
num_mixes],
axis=1, name='mdn_coef_split')
cat = Categorical(logits=out_pi)
component_splits = [output_dim] * num_mixes
mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1)
sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1)
coll = [MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale
in zip(mus, sigs)]
mixture = Mixture(cat=cat, components=coll)
samp = mixture.sample()
# Todo: temperature adjustment for sampling function.
return samp
# Actually return the loss_func
with tf.name_scope('MDNLayer'):
return sampling_func
def get_mixture_mse_accuracy(output_dim, num_mixes):
"""Construct an MSE accuracy function for the MDN layer
that takes one sample and compares to the true value."""
# Construct a loss function with the right number of mixtures and outputs
def mse_func(y_true, y_pred):
out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[num_mixes * output_dim,
num_mixes * output_dim,
num_mixes],
axis=1, name='mdn_coef_split')
cat = Categorical(logits=out_pi)
component_splits = [output_dim] * num_mixes
mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1)
sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1)
coll = [MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale
in zip(mus, sigs)]
mixture = Mixture(cat=cat, components=coll)
samp = mixture.sample()
mse = tf.reduce_mean(tf.square(samp - y_true), axis=-1)
# Todo: temperature adjustment for sampling functon.
return mse
# Actually return the loss_func
with tf.name_scope('MDNLayer'):
return mse_func
def split_mixture_params(params, output_dim, num_mixes):
"""Splits up an array of mixture parameters into mus, sigmas, and pis
depending on the number of mixtures and output dimension."""
mus = params[:num_mixes*output_dim]
sigs = params[num_mixes*output_dim:2*num_mixes*output_dim]
pi_logits = params[-num_mixes:]
return mus, sigs, pi_logits
def softmax(w, t=1.0):
"""Softmax function for a list or numpy array of logits. Also adjusts temperature."""
e = np.array(w) / t # adjust temperature
e -= e.max() # subtract max to protect from exploding exp values.
e = np.exp(e)
dist = e / np.sum(e)
return dist
def sample_from_categorical(dist):
"""Samples from a categorical model PDF."""
r = np.random.rand(1) # uniform random number in [0,1]
accumulate = 0
for i in range(0, dist.size):
accumulate += dist[i]
if accumulate >= r:
return i
tf.logging.info('Error sampling mixture model.')
return -1
def sample_from_output(params, output_dim, num_mixes, temp=1.0):
"""Sample from an MDN output with temperature adjustment."""
mus = params[:num_mixes*output_dim]
sigs = params[num_mixes*output_dim:2*num_mixes*output_dim]
pis = softmax(params[-num_mixes:], t=temp)
m = sample_from_categorical(pis)
# Alternative way to sample from categorical:
# m = np.random.choice(range(len(pis)), p=pis)
mus_vector = mus[m*output_dim:(m+1)*output_dim]
sig_vector = sigs[m*output_dim:(m+1)*output_dim] * temp # adjust for temperature
cov_matrix = np.identity(output_dim) * sig_vector
sample = np.random.multivariate_normal(mus_vector, cov_matrix, 1)
return sample