Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ModelZoo] Support Co_Action Network #344

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions modelzoo/CAN/prepare_data.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
export PATH="~/anaconda4/bin:$PATH"
wget http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/reviews_Books.json.gz
wget http://snap.stanford.edu/data/amazon/productGraph/categoryFiles/meta_Books.json.gz
gunzip reviews_Books.json.gz
gunzip meta_Books.json.gz
python script/process_data.py meta_Books.json reviews_Books_5.json
python script/local_aggretor.py
python script/split_by_user.py
python script/generate_voc.py
35 changes: 35 additions & 0 deletions modelzoo/CAN/script/Dice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import tensorflow as tf

def dice(_x, axis=-1, epsilon=0.000000001, name=''):
with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
alphas = tf.get_variable('alpha'+name, _x.get_shape()[-1],
initializer=tf.constant_initializer(0.0),
dtype=tf.float32)
input_shape = list(_x.get_shape())

reduction_axes = list(range(len(input_shape)))
del reduction_axes[axis]
broadcast_shape = [1] * len(input_shape)
broadcast_shape[axis] = input_shape[axis]

# case: train mode (uses stats of the current batch)
mean = tf.reduce_mean(_x, axis=reduction_axes)
brodcast_mean = tf.reshape(mean, broadcast_shape)
std = tf.reduce_mean(tf.square(_x - brodcast_mean) + epsilon, axis=reduction_axes)
std = tf.sqrt(std)
brodcast_std = tf.reshape(std, broadcast_shape)
x_normed = (_x - brodcast_mean) / (brodcast_std + epsilon)
# x_normed = tf.layers.batch_normalization(_x, center=False, scale=False)
x_p = tf.sigmoid(x_normed)


return alphas * (1.0 - x_p) * _x + x_p * _x

def parametric_relu(_x):
alphas = tf.get_variable('alpha', _x.get_shape()[-1],
initializer=tf.constant_initializer(0.0),
dtype=tf.float32)
pos = tf.nn.relu(_x)
neg = alphas * (_x - abs(_x)) * 0.5

return pos + neg
14 changes: 14 additions & 0 deletions modelzoo/CAN/script/calc_ckpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

ckpt = tf.train.get_checkpoint_state("./ckpt_path/").model_checkpoint_path
saver = tf.train.import_meta_graph(ckpt+'.meta')
variables = tf.trainable_variables()
total_parameters = 0
for variable in variables:
shape = variable.get_shape()
variable_parameters = 1
for dim in shape:
# print(dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉无用的print或是注释

variable_parameters *= dim.value
# print(variable_parameters)
total_parameters += variable_parameters
print(total_parameters)
228 changes: 228 additions & 0 deletions modelzoo/CAN/script/data_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
import numpy
import json
#import cPickle as pkl
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

import _pickle as cPickle
import random

import gzip

import shuffle

def unicode_to_utf8(d):
return dict((key.encode("UTF-8"), value) for (key,value) in d.items())
def dict_unicode_to_utf8(d):
print('d={}'.format(d))
return dict(((key[0].encode("UTF-8"), key[1].encode("UTF-8")), value) for (key,value) in d.items())

def load_dict(filename):
try:
with open(filename, 'rb') as f:
return unicode_to_utf8(json.load(f))
except:
try:
with open(filename, 'rb') as f:
return unicode_to_utf8(cPickle.load(f))
except:
with open(filename, 'rb') as f:
return dict_unicode_to_utf8(cPickle.load(f))


def fopen(filename, mode='r'):
if filename.endswith('.gz'):
return gzip.open(filename, mode)
return open(filename, mode)


class DataIterator:

def __init__(self, source,
uid_voc,
mid_voc,
cat_voc,
batch_size=128,
maxlen=100,
skip_empty=False,
shuffle_each_epoch=False,
sort_by_length=True,
max_batch_size=20,
minlen=None,
label_type=1):
if shuffle_each_epoch:
self.source_orig = source
self.source = shuffle.main(self.source_orig, temporary=True)
else:
self.source = fopen(source, 'r')
self.source_dicts = []
#for source_dict in [uid_voc, mid_voc, cat_voc, cat_voc, cat_voc]:# 'item_carte_voc.pkl', 'cate_carte_voc.pkl']:
for source_dict in [uid_voc, mid_voc, cat_voc, '/home/test/modelzoo/CAN/data/item_carte_voc.pkl', '/home/test/modelzoo/CAN/data/cate_carte_voc.pkl']:
self.source_dicts.append(load_dict(source_dict))

f_meta = open("/home/test/modelzoo/CAN/data/item-info", "r")
meta_map = {}
for line in f_meta:
arr = line.strip().split("\t")
if arr[0] not in meta_map:
meta_map[arr[0]] = arr[1]
self.meta_id_map ={}
for key in meta_map:
val = meta_map[key]
if key in self.source_dicts[1]:
mid_idx = self.source_dicts[1][key]
else:
mid_idx = 0
if val in self.source_dicts[2]:
cat_idx = self.source_dicts[2][val]
else:
cat_idx = 0
self.meta_id_map[mid_idx] = cat_idx

f_review = open("/home/test/modelzoo/CAN/data/reviews-info", "r")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个路径不要用绝对路径

self.mid_list_for_random = []
for line in f_review:
arr = line.strip().split("\t")
tmp_idx = 0
if arr[1] in self.source_dicts[1]:
tmp_idx = self.source_dicts[1][arr[1]]
self.mid_list_for_random.append(tmp_idx)

self.batch_size = batch_size
self.maxlen = maxlen
self.minlen = minlen
self.skip_empty = skip_empty

self.n_uid = len(self.source_dicts[0])
self.n_mid = len(self.source_dicts[1])
self.n_cat = len(self.source_dicts[2])
self.n_carte = [len(self.source_dicts[3]), len(self.source_dicts[4])]
print("n_uid=%d, n_mid=%d, n_cat=%d" % (self.n_uid, self.n_mid, self.n_cat))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多余的print


self.shuffle = shuffle_each_epoch
self.sort_by_length = sort_by_length

self.source_buffer = []
self.k = batch_size * max_batch_size

self.end_of_data = False
self.label_type = label_type

def get_n(self):
return self.n_uid, self.n_mid, self.n_cat, self.n_carte

def __iter__(self):
return self

def reset(self):
if self.shuffle:
self.source= shuffle.main(self.source_orig, temporary=True)
else:
self.source.seek(0)

def __next__(self):
if self.end_of_data:
self.end_of_data = False
self.reset()
raise StopIteration

source = []
target = []

if len(self.source_buffer) == 0:
for k_ in range(self.k):
ss = self.source.readline()
if ss == "":
break
self.source_buffer.append(ss.strip("\n").split("\t"))

# sort by history behavior length
if self.sort_by_length:
his_length = numpy.array([len(s[4].split("")) for s in self.source_buffer])
tidx = his_length.argsort()

_sbuf = [self.source_buffer[i] for i in tidx]
self.source_buffer = _sbuf
else:
self.source_buffer.reverse()

if len(self.source_buffer) == 0:
self.end_of_data = False
self.reset()
raise StopIteration

try:

# actual work here
while True:

# read from source file and map to word index
try:
ss = self.source_buffer.pop()
except IndexError:
break

uid = self.source_dicts[0][ss[1]] if ss[1] in self.source_dicts[0] else 0
mid = self.source_dicts[1][ss[2]] if ss[2] in self.source_dicts[1] else 0
cat = self.source_dicts[2][ss[3]] if ss[3] in self.source_dicts[2] else 0

tmp = []
item_carte = []
for fea in ss[4].split(""):
m = self.source_dicts[1][fea] if fea in self.source_dicts[1] else 0
tmp.append(m)
i_c = self.source_dicts[3][(ss[2], fea)] if (ss[2], fea) in self.source_dicts[3] else 0
item_carte.append(i_c)
mid_list = tmp

tmp1 = []
cate_carte = []
for fea in ss[5].split(""):
c = self.source_dicts[2][fea] if fea in self.source_dicts[2] else 0
tmp1.append(c)
c_c = self.source_dicts[4][(ss[3], fea)] if (ss[3], fea) in self.source_dicts[4] else 0
cate_carte.append(c_c)
cat_list = tmp1

# read from source file and map to word index

if self.minlen != None:
if len(mid_list) <= self.minlen:
continue
if self.skip_empty and (not mid_list):
continue

noclk_mid_list = []
noclk_cat_list = []
for pos_mid in mid_list:
noclk_tmp_mid = []
noclk_tmp_cat = []
noclk_index = 0
while True:
noclk_mid_indx = random.randint(0, len(self.mid_list_for_random)-1)
noclk_mid = self.mid_list_for_random[noclk_mid_indx]
if noclk_mid == pos_mid:
continue
noclk_tmp_mid.append(noclk_mid)
noclk_tmp_cat.append(self.meta_id_map[noclk_mid])
noclk_index += 1
if noclk_index >= 5:
break
noclk_mid_list.append(noclk_tmp_mid)
noclk_cat_list.append(noclk_tmp_cat)
carte_list = [item_carte, cate_carte]
source.append([uid, mid, cat, mid_list, cat_list, noclk_mid_list, noclk_cat_list, carte_list])
if self.label_type == 1:
target.append([float(ss[0])])
else:
target.append([float(ss[0]), 1-float(ss[0])])

if len(source) >= self.batch_size or len(target) >= self.batch_size:
break
except IOError:
self.end_of_data = True

# all sentence pairs in maxibatch filtered out because of length
if len(source) == 0 or len(target) == 0:
source, target = self.next()

return source, target


91 changes: 91 additions & 0 deletions modelzoo/CAN/script/generate_voc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import pickle as pk

f_train = open("/home/test/modelzoo/DIEN/data/local_train_splitByUser", "r")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,不要使用绝对路径

uid_dict = {}
mid_dict = {}
cat_dict = {}
item_carte_dict = {}
cate_carte_dict = {}

iddd = 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个iddd是有用的变量吗?

for line in f_train:
arr = line.strip("\n").split("\t")
clk = arr[0]
uid = arr[1]
mid = arr[2]
cat = arr[3]
mid_list = arr[4]
cat_list = arr[5]
if uid not in uid_dict:
uid_dict[uid] = 0
uid_dict[uid] += 1
if mid not in mid_dict:
mid_dict[mid] = 0
mid_dict[mid] += 1
if cat not in cat_dict:
cat_dict[cat] = 0
cat_dict[cat] += 1
if len(mid_list) == 0:
continue
for m in mid_list.split(""):
if m not in mid_dict:
mid_dict[m] = 0
mid_dict[m] += 1
if (mid, m) not in item_carte_dict:
item_carte_dict[(mid, m)] = 0
item_carte_dict[(mid, m)] += 1
#print iddd
iddd+=1
for c in cat_list.split(""):
if c not in cat_dict:
cat_dict[c] = 0
cat_dict[c] += 1
if (cat, c) not in cate_carte_dict:
cate_carte_dict[(cat, c)] = 0
cate_carte_dict[(cat, c)] += 1

sorted_uid_dict = sorted(uid_dict.items(), key=lambda x:x[1], reverse=True)
sorted_mid_dict = sorted(mid_dict.items(), key=lambda x:x[1], reverse=True)
sorted_cat_dict = sorted(cat_dict.items(), key=lambda x:x[1], reverse=True)
sorted_item_carte_dict = sorted(item_carte_dict.items(), key=lambda x:x[1], reverse=True)
sorted_cate_carte_dict = sorted(cate_carte_dict.items(), key=lambda x:x[1], reverse=True)

uid_voc = {}
index = 0
for key, value in sorted_uid_dict:
uid_voc[key] = index
index += 1

mid_voc = {}
mid_voc["default_mid"] = 0
index = 1
for key, value in sorted_mid_dict:
mid_voc[key] = index
index += 1

cat_voc = {}
cat_voc["default_cat"] = 0
index = 1
for key, value in sorted_cat_dict:
cat_voc[key] = index
index += 1

item_carte_voc = {}
item_carte_voc["default_item_carte"] = 0
index = 1
for key, value in sorted_item_carte_dict:
item_carte_voc[key] = index
index += 1

cate_carte_voc = {}
cate_carte_voc["default_cate_carte"] = 0
index = 1
for key, value in sorted_cate_carte_dict:
cate_carte_voc[key] = index
index += 1

pk.dump(uid_voc, open("uid_voc.pkl", "wb"))
pk.dump(mid_voc, open("mid_voc.pkl", "wb"))
pk.dump(cat_voc, open("cat_voc.pkl", "wb"))
pk.dump(item_carte_voc, open("item_carte_voc.pkl", "wb"))
pk.dump(cate_carte_voc, open("cate_carte_voc.pkl", "wb"))
Loading