forked from ChenglongChen/tensorflow-DeepFM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DeepFM.py
385 lines (329 loc) · 18.3 KB
/
DeepFM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
"""
Tensorflow implementation of DeepFM [1]
Reference:
[1] DeepFM: A Factorization-Machine based Neural Network for CTR Prediction,
Huifeng Guo, Ruiming Tang, Yunming Yey, Zhenguo Li, Xiuqiang He.
"""
import numpy as np
import tensorflow as tf
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.metrics import roc_auc_score
from time import time
from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm
from yellowfin import YFOptimizer
class DeepFM(BaseEstimator, TransformerMixin):
def __init__(self, feature_size, field_size,
embedding_size=8, dropout_fm=[1.0, 1.0],
deep_layers=[32, 32], dropout_deep=[0.5, 0.5, 0.5],
deep_layers_activation=tf.nn.relu,
epoch=10, batch_size=256,
learning_rate=0.001, optimizer_type="adam",
batch_norm=0, batch_norm_decay=0.995,
verbose=False, random_seed=2016,
use_fm=True, use_deep=True,
loss_type="logloss", eval_metric=roc_auc_score,
l2_reg=0.0, greater_is_better=True):
assert (use_fm or use_deep)
assert loss_type in ["logloss", "mse"], \
"loss_type can be either 'logloss' for classification task or 'mse' for regression task"
self.feature_size = feature_size # denote as M, size of the feature dictionary
self.field_size = field_size # denote as F, size of the feature fields
self.embedding_size = embedding_size # denote as K, size of the feature embedding
self.dropout_fm = dropout_fm
self.deep_layers = deep_layers
self.dropout_deep = dropout_deep
self.deep_layers_activation = deep_layers_activation
self.use_fm = use_fm
self.use_deep = use_deep
self.l2_reg = l2_reg
self.epoch = epoch
self.batch_size = batch_size
self.learning_rate = learning_rate
self.optimizer_type = optimizer_type
self.batch_norm = batch_norm
self.batch_norm_decay = batch_norm_decay
self.verbose = verbose
self.random_seed = random_seed
self.loss_type = loss_type
self.eval_metric = eval_metric
self.greater_is_better = greater_is_better
self.train_result, self.valid_result = [], []
self._init_graph()
def _init_graph(self):
self.graph = tf.Graph()
with self.graph.as_default():
tf.set_random_seed(self.random_seed)
self.feat_index = tf.placeholder(tf.int32, shape=[None, None],
name="feat_index") # None * F
self.feat_value = tf.placeholder(tf.float32, shape=[None, None],
name="feat_value") # None * F
self.label = tf.placeholder(tf.float32, shape=[None, 1], name="label") # None * 1
self.dropout_keep_fm = tf.placeholder(tf.float32, shape=[None], name="dropout_keep_fm")
self.dropout_keep_deep = tf.placeholder(tf.float32, shape=[None], name="dropout_keep_deep")
self.train_phase = tf.placeholder(tf.bool, name="train_phase")
self.weights = self._initialize_weights()
# model
self.embeddings = tf.nn.embedding_lookup(self.weights["feature_embeddings"],
self.feat_index) # None * F * K
feat_value = tf.reshape(self.feat_value, shape=[-1, self.field_size, 1])
self.embeddings = tf.multiply(self.embeddings, feat_value)
# ---------- first order term ----------
self.y_first_order = tf.nn.embedding_lookup(self.weights["feature_bias"], self.feat_index) # None * F * 1
self.y_first_order = tf.reduce_sum(tf.multiply(self.y_first_order, feat_value), 2) # None * F
self.y_first_order = tf.nn.dropout(self.y_first_order, self.dropout_keep_fm[0]) # None * F
# ---------- second order term ---------------
# sum_square part
self.summed_features_emb = tf.reduce_sum(self.embeddings, 1) # None * K
self.summed_features_emb_square = tf.square(self.summed_features_emb) # None * K
# square_sum part
self.squared_features_emb = tf.square(self.embeddings)
self.squared_sum_features_emb = tf.reduce_sum(self.squared_features_emb, 1) # None * K
# second order
self.y_second_order = 0.5 * tf.subtract(self.summed_features_emb_square, self.squared_sum_features_emb) # None * K
self.y_second_order = tf.nn.dropout(self.y_second_order, self.dropout_keep_fm[1]) # None * K
# ---------- Deep component ----------
self.y_deep = tf.reshape(self.embeddings, shape=[-1, self.field_size * self.embedding_size]) # None * (F*K)
self.y_deep = tf.nn.dropout(self.y_deep, self.dropout_keep_deep[0])
for i in range(0, len(self.deep_layers)):
self.y_deep = tf.add(tf.matmul(self.y_deep, self.weights["layer_%d" %i]), self.weights["bias_%d"%i]) # None * layer[i] * 1
if self.batch_norm:
self.y_deep = self.batch_norm_layer(self.y_deep, train_phase=self.train_phase, scope_bn="bn_%d" %i) # None * layer[i] * 1
self.y_deep = self.deep_layers_activation(self.y_deep)
self.y_deep = tf.nn.dropout(self.y_deep, self.dropout_keep_deep[1+i]) # dropout at each Deep layer
# ---------- DeepFM ----------
if self.use_fm and self.use_deep:
concat_input = tf.concat([self.y_first_order, self.y_second_order, self.y_deep], axis=1)
elif self.use_fm:
concat_input = tf.concat([self.y_first_order, self.y_second_order], axis=1)
elif self.use_deep:
concat_input = self.y_deep
self.out = tf.add(tf.matmul(concat_input, self.weights["concat_projection"]), self.weights["concat_bias"])
# loss
if self.loss_type == "logloss":
self.out = tf.nn.sigmoid(self.out)
self.loss = tf.losses.log_loss(self.label, self.out)
elif self.loss_type == "mse":
self.loss = tf.nn.l2_loss(tf.subtract(self.label, self.out))
# l2 regularization on weights
if self.l2_reg > 0:
self.loss += tf.contrib.layers.l2_regularizer(
self.l2_reg)(self.weights["concat_projection"])
if self.use_deep:
for i in range(len(self.deep_layers)):
self.loss += tf.contrib.layers.l2_regularizer(
self.l2_reg)(self.weights["layer_%d"%i])
# optimizer
if self.optimizer_type == "adam":
self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=0.9, beta2=0.999,
epsilon=1e-8).minimize(self.loss)
elif self.optimizer_type == "adagrad":
self.optimizer = tf.train.AdagradOptimizer(learning_rate=self.learning_rate,
initial_accumulator_value=1e-8).minimize(self.loss)
elif self.optimizer_type == "gd":
self.optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.learning_rate).minimize(self.loss)
elif self.optimizer_type == "momentum":
self.optimizer = tf.train.MomentumOptimizer(learning_rate=self.learning_rate, momentum=0.95).minimize(
self.loss)
elif self.optimizer_type == "yellowfin":
self.optimizer = YFOptimizer(learning_rate=self.learning_rate, momentum=0.0).minimize(
self.loss)
# init
self.saver = tf.train.Saver()
init = tf.global_variables_initializer()
self.sess = self._init_session()
self.sess.run(init)
# number of params
total_parameters = 0
for variable in self.weights.values():
shape = variable.get_shape()
variable_parameters = 1
for dim in shape:
variable_parameters *= dim.value
total_parameters += variable_parameters
if self.verbose > 0:
print("#params: %d" % total_parameters)
def _init_session(self):
config = tf.ConfigProto(device_count={"gpu": 0})
config.gpu_options.allow_growth = True
return tf.Session(config=config)
def _initialize_weights(self):
weights = dict()
# embeddings
weights["feature_embeddings"] = tf.Variable(
tf.random_normal([self.feature_size, self.embedding_size], 0.0, 0.01),
name="feature_embeddings") # feature_size * K
weights["feature_bias"] = tf.Variable(
tf.random_uniform([self.feature_size, 1], 0.0, 1.0), name="feature_bias") # feature_size * 1
# deep layers
num_layer = len(self.deep_layers)
input_size = self.field_size * self.embedding_size
glorot = np.sqrt(2.0 / (input_size + self.deep_layers[0]))
weights["layer_0"] = tf.Variable(
np.random.normal(loc=0, scale=glorot, size=(input_size, self.deep_layers[0])), dtype=np.float32)
weights["bias_0"] = tf.Variable(np.random.normal(loc=0, scale=glorot, size=(1, self.deep_layers[0])),
dtype=np.float32) # 1 * layers[0]
for i in range(1, num_layer):
glorot = np.sqrt(2.0 / (self.deep_layers[i-1] + self.deep_layers[i]))
weights["layer_%d" % i] = tf.Variable(
np.random.normal(loc=0, scale=glorot, size=(self.deep_layers[i-1], self.deep_layers[i])),
dtype=np.float32) # layers[i-1] * layers[i]
weights["bias_%d" % i] = tf.Variable(
np.random.normal(loc=0, scale=glorot, size=(1, self.deep_layers[i])),
dtype=np.float32) # 1 * layer[i]
# final concat projection layer
if self.use_fm and self.use_deep:
input_size = self.field_size + self.embedding_size + self.deep_layers[-1]
elif self.use_fm:
input_size = self.field_size + self.embedding_size
elif self.use_deep:
input_size = self.deep_layers[-1]
glorot = np.sqrt(2.0 / (input_size + 1))
weights["concat_projection"] = tf.Variable(
np.random.normal(loc=0, scale=glorot, size=(input_size, 1)),
dtype=np.float32) # layers[i-1]*layers[i]
weights["concat_bias"] = tf.Variable(tf.constant(0.01), dtype=np.float32)
return weights
def batch_norm_layer(self, x, train_phase, scope_bn):
bn_train = batch_norm(x, decay=self.batch_norm_decay, center=True, scale=True, updates_collections=None,
is_training=True, reuse=None, trainable=True, scope=scope_bn)
bn_inference = batch_norm(x, decay=self.batch_norm_decay, center=True, scale=True, updates_collections=None,
is_training=False, reuse=True, trainable=True, scope=scope_bn)
z = tf.cond(train_phase, lambda: bn_train, lambda: bn_inference)
return z
def get_batch(self, Xi, Xv, y, batch_size, index):
start = index * batch_size
end = (index+1) * batch_size
end = end if end < len(y) else len(y)
return Xi[start:end], Xv[start:end], [[y_] for y_ in y[start:end]]
# shuffle three lists simutaneously
def shuffle_in_unison_scary(self, a, b, c):
rng_state = np.random.get_state()
np.random.shuffle(a)
np.random.set_state(rng_state)
np.random.shuffle(b)
np.random.set_state(rng_state)
np.random.shuffle(c)
def fit_on_batch(self, Xi, Xv, y):
feed_dict = {self.feat_index: Xi,
self.feat_value: Xv,
self.label: y,
self.dropout_keep_fm: self.dropout_fm,
self.dropout_keep_deep: self.dropout_deep,
self.train_phase: True}
loss, opt = self.sess.run((self.loss, self.optimizer), feed_dict=feed_dict)
return loss
def fit(self, Xi_train, Xv_train, y_train,
Xi_valid=None, Xv_valid=None, y_valid=None,
early_stopping=False, refit=False):
"""
:param Xi_train: [[ind1_1, ind1_2, ...], [ind2_1, ind2_2, ...], ..., [indi_1, indi_2, ..., indi_j, ...], ...]
indi_j is the feature index of feature field j of sample i in the training set
:param Xv_train: [[val1_1, val1_2, ...], [val2_1, val2_2, ...], ..., [vali_1, vali_2, ..., vali_j, ...], ...]
vali_j is the feature value of feature field j of sample i in the training set
vali_j can be either binary (1/0, for binary/categorical features) or float (e.g., 10.24, for numerical features)
:param y_train: label of each sample in the training set
:param Xi_valid: list of list of feature indices of each sample in the validation set
:param Xv_valid: list of list of feature values of each sample in the validation set
:param y_valid: label of each sample in the validation set
:param early_stopping: perform early stopping or not
:param refit: refit the model on the train+valid dataset or not
:return: None
"""
has_valid = Xv_valid is not None
for epoch in range(self.epoch):
t1 = time()
self.shuffle_in_unison_scary(Xi_train, Xv_train, y_train)
total_batch = int(len(y_train) / self.batch_size)
for i in range(total_batch):
Xi_batch, Xv_batch, y_batch = self.get_batch(Xi_train, Xv_train, y_train, self.batch_size, i)
self.fit_on_batch(Xi_batch, Xv_batch, y_batch)
# evaluate training and validation datasets
train_result = self.evaluate(Xi_train, Xv_train, y_train)
self.train_result.append(train_result)
if has_valid:
valid_result = self.evaluate(Xi_valid, Xv_valid, y_valid)
self.valid_result.append(valid_result)
if self.verbose > 0 and epoch % self.verbose == 0:
if has_valid:
print("[%d] train-result=%.4f, valid-result=%.4f [%.1f s]"
% (epoch + 1, train_result, valid_result, time() - t1))
else:
print("[%d] train-result=%.4f [%.1f s]"
% (epoch + 1, train_result, time() - t1))
if has_valid and early_stopping and self.training_termination(self.valid_result):
break
# fit a few more epoch on train+valid until result reaches the best_train_score
if has_valid and refit:
if self.greater_is_better:
best_valid_score = max(self.valid_result)
else:
best_valid_score = min(self.valid_result)
best_epoch = self.valid_result.index(best_valid_score)
best_train_score = self.train_result[best_epoch]
Xi_train = Xi_train + Xi_valid
Xv_train = Xv_train + Xv_valid
y_train = y_train + y_valid
for epoch in range(100):
self.shuffle_in_unison_scary(Xi_train, Xv_train, y_train)
total_batch = int(len(y_train) / self.batch_size)
for i in range(total_batch):
Xi_batch, Xv_batch, y_batch = self.get_batch(Xi_train, Xv_train, y_train,
self.batch_size, i)
self.fit_on_batch(Xi_batch, Xv_batch, y_batch)
# check
train_result = self.evaluate(Xi_train, Xv_train, y_train)
if abs(train_result - best_train_score) < 0.001 or \
(self.greater_is_better and train_result > best_train_score) or \
((not self.greater_is_better) and train_result < best_train_score):
break
def training_termination(self, valid_result):
if len(valid_result) > 5:
if self.greater_is_better:
if valid_result[-1] < valid_result[-2] and \
valid_result[-2] < valid_result[-3] and \
valid_result[-3] < valid_result[-4] and \
valid_result[-4] < valid_result[-5]:
return True
else:
if valid_result[-1] > valid_result[-2] and \
valid_result[-2] > valid_result[-3] and \
valid_result[-3] > valid_result[-4] and \
valid_result[-4] > valid_result[-5]:
return True
return False
def predict(self, Xi, Xv):
"""
:param Xi: list of list of feature indices of each sample in the dataset
:param Xv: list of list of feature values of each sample in the dataset
:return: predicted probability of each sample
"""
# dummy y
dummy_y = [1] * len(Xi)
batch_index = 0
Xi_batch, Xv_batch, y_batch = self.get_batch(Xi, Xv, dummy_y, self.batch_size, batch_index)
y_pred = None
while len(Xi_batch) > 0:
num_batch = len(y_batch)
feed_dict = {self.feat_index: Xi_batch,
self.feat_value: Xv_batch,
self.label: y_batch,
self.dropout_keep_fm: [1.0] * len(self.dropout_fm),
self.dropout_keep_deep: [1.0] * len(self.dropout_deep),
self.train_phase: False}
batch_out = self.sess.run(self.out, feed_dict=feed_dict)
if batch_index == 0:
y_pred = np.reshape(batch_out, (num_batch,))
else:
y_pred = np.concatenate((y_pred, np.reshape(batch_out, (num_batch,))))
batch_index += 1
Xi_batch, Xv_batch, y_batch = self.get_batch(Xi, Xv, dummy_y, self.batch_size, batch_index)
return y_pred
def evaluate(self, Xi, Xv, y):
"""
:param Xi: list of list of feature indices of each sample in the dataset
:param Xv: list of list of feature values of each sample in the dataset
:param y: label of each sample in the dataset
:return: metric of the evaluation
"""
y_pred = self.predict(Xi, Xv)
return self.eval_metric(y, y_pred)