-
Notifications
You must be signed in to change notification settings - Fork 495
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
278 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
""" | ||
Created on Apr 26, 2022 | ||
train MIND demo | ||
@author: Ziyao Geng([email protected]) | ||
""" | ||
import os | ||
from absl import flags, app | ||
from time import time | ||
from tensorflow.keras.optimizers import Adam | ||
|
||
from reclearn.models.matching import MIND | ||
from reclearn.data.datasets import movielens as ml | ||
from reclearn.evaluator import eval_pos_neg | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' | ||
os.environ['CUDA_VISIBLE_DEVICES'] = '6' | ||
|
||
# Setting training parameters | ||
flags.DEFINE_string("file_path", "data/ml-1m/ratings.dat", "file path.") | ||
flags.DEFINE_string("train_path", "data/ml-1m/ml_seq_train.txt", "train path. If set to None, the program will split the dataset.") | ||
flags.DEFINE_string("val_path", "data/ml-1m/ml_seq_val.txt", "val path.") | ||
flags.DEFINE_string("test_path", "data/ml-1m/ml_seq_test.txt", "test path.") | ||
flags.DEFINE_string("meta_path", "data/ml-1m/ml_seq_meta.txt", "meta path.") | ||
flags.DEFINE_integer("embed_dim", 64, "The size of embedding dimension.") | ||
flags.DEFINE_float("embed_reg", 0.0, "The value of embedding regularization.") | ||
flags.DEFINE_integer("num_interest", 1, "The number of user interests.") | ||
flags.DEFINE_bool("stop_grad", True, "The weights in the capsule network are updated without gradient descent.") | ||
flags.DEFINE_bool("label_attention", True, "Whether using label-aware attention or not.") | ||
flags.DEFINE_float("learning_rate", 0.001, "Learning rate.") | ||
flags.DEFINE_integer("neg_num", 2, "The number of negative sample for each positive sample.") | ||
flags.DEFINE_integer("seq_len", 100, "The length of user's behavior sequence.") | ||
flags.DEFINE_integer("epochs", 20, "train steps.") | ||
flags.DEFINE_integer("batch_size", 512, "Batch Size.") | ||
flags.DEFINE_integer("test_neg_num", 100, "The number of test negative samples.") | ||
flags.DEFINE_integer("k", 10, "recall k items at test stage.") | ||
|
||
|
||
def main(argv): | ||
# TODO: 1. Split Data | ||
if FLAGS.train_path == "None": | ||
train_path, val_path, test_path, meta_path = ml.split_seq_data(file_path=FLAGS.file_path) | ||
else: | ||
train_path, val_path, test_path, meta_path = FLAGS.train_path, FLAGS.val_path, FLAGS.test_path, FLAGS.meta_path | ||
with open(meta_path) as f: | ||
_, max_item_num = [int(x) for x in f.readline().strip('\n').split('\t')] | ||
# TODO: 2. Load Sequence Data | ||
train_data = ml.load_seq_data(train_path, "train", FLAGS.seq_len, FLAGS.neg_num, max_item_num) | ||
val_data = ml.load_seq_data(val_path, "val", FLAGS.seq_len, FLAGS.neg_num, max_item_num) | ||
test_data = ml.load_seq_data(test_path, "test", FLAGS.seq_len, FLAGS.test_neg_num, max_item_num) | ||
# TODO: 3. Set Model Hyper Parameters. | ||
model_params = { | ||
'item_num': max_item_num + 1, | ||
'embed_dim': FLAGS.embed_dim, | ||
'seq_len': FLAGS.seq_len, | ||
'num_interest': FLAGS.num_interest, | ||
'stop_grad': FLAGS.stop_grad, | ||
'label_attention': FLAGS.label_attention, | ||
'neg_num': FLAGS.neg_num, | ||
'batch_size': FLAGS.batch_size, | ||
'embed_reg': FLAGS.embed_reg | ||
} | ||
# TODO: 4. Build Model | ||
model = MIND(**model_params) | ||
model.compile(optimizer=Adam(learning_rate=FLAGS.learning_rate)) | ||
# TODO: 5. Fit Model | ||
for epoch in range(1, FLAGS.epochs + 1): | ||
t1 = time() | ||
model.fit( | ||
x=train_data, | ||
epochs=1, | ||
validation_data=val_data, | ||
batch_size=FLAGS.batch_size | ||
) | ||
t2 = time() | ||
eval_dict = eval_pos_neg(model, test_data, ['hr', 'mrr', 'ndcg'], FLAGS.k, FLAGS.batch_size) | ||
print('Iteration %d Fit [%.1f s], Evaluate [%.1f s]: HR = %.4f, MRR = %.4f, NDCG = %.4f' | ||
% (epoch, t2 - t1, time() - t2, eval_dict['hr'], eval_dict['mrr'], eval_dict['ndcg'])) | ||
|
||
|
||
if __name__ == '__main__': | ||
app.run(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
""" | ||
Created on Apr 25, 2022 | ||
Reference: "Multi-Interest Network with Dynamic Routing for Recommendation at Tmall", CIKM, 2019 | ||
@author: Ziyao Geng([email protected]) | ||
""" | ||
import tensorflow as tf | ||
from tensorflow.keras import Model | ||
from tensorflow.keras.layers import Input | ||
from tensorflow.keras.regularizers import l2 | ||
|
||
from reclearn.layers.core import CapsuleNetwork | ||
|
||
|
||
class MIND(Model): | ||
def __init__(self, item_num, embed_dim, seq_len=100, num_interest=4, stop_grad=True, label_attention=True, | ||
neg_num=4, batch_size=512, embed_reg=0., seed=None): | ||
"""MIND | ||
Args: | ||
:param item_num: An integer type. The largest item index + 1. | ||
:param embed_dim: An integer type. Embedding dimension of item vector. | ||
:param seq_len: An integer type. The length of the input sequence. | ||
:param bilinear_type: An integer type. The number of user interests. | ||
:param num_interest: An integer type. The number of user interests. | ||
:param stop_grad: A boolean type. The weights in the capsule network are updated without gradient descent. | ||
:param label_attention: A boolean type. Whether using label-aware attention or not. | ||
:param neg_num: A integer type. The number of negative samples for each positive sample. | ||
:param batch_size: A integer type. The number of samples per batch. | ||
:param embed_reg: A float type. The regularizer of embedding. | ||
:param seed: A Python integer to use as random seed. | ||
:return | ||
""" | ||
super(MIND, self).__init__() | ||
with tf.name_scope("Embedding_layer"): | ||
# item embedding | ||
self.item_embedding_table = self.add_weight(name='item_embedding_table', | ||
shape=(item_num, embed_dim), | ||
initializer='random_normal', | ||
regularizer=l2(embed_reg), | ||
trainable=True) | ||
# embedding bias | ||
self.embedding_bias = self.add_weight(name='embedding_bias', | ||
shape=(item_num,), | ||
initializer=tf.zeros_initializer(), | ||
trainable=False) | ||
self.capsule_network = CapsuleNetwork(embed_dim, seq_len, 0, num_interest, stop_grad) | ||
self.seq_len = seq_len | ||
self.num_interest = num_interest | ||
self.label_attention = label_attention | ||
self.item_num = item_num | ||
self.embed_dim = embed_dim | ||
self.neg_num = neg_num | ||
self.batch_size = batch_size | ||
# seed | ||
tf.random.set_seed(seed) | ||
|
||
def call(self, inputs, training=False): | ||
user_hist_emb = tf.nn.embedding_lookup(self.item_embedding_table, inputs['click_seq']) | ||
mask = tf.cast(tf.not_equal(inputs['click_seq'], 0), dtype=tf.float32) # (None, seq_len) | ||
user_hist_emb = tf.multiply(user_hist_emb, tf.expand_dims(mask, axis=-1)) # (None, seq_len, embed_dim) | ||
# capsule network | ||
interest_capsule = self.capsule_network(user_hist_emb, mask) # (None, num_inter, embed_dim) | ||
|
||
if training: | ||
if self.label_attention: | ||
item_embed = tf.nn.embedding_lookup(self.item_embedding_table, tf.reshape(inputs['pos_item'], [-1, ])) | ||
inter_att = tf.matmul(interest_capsule, tf.reshape(item_embed, [-1, self.embed_dim, 1])) # (None, num_inter, 1) | ||
inter_att = tf.nn.softmax(tf.pow(tf.reshape(inter_att, [-1, self.num_interest]), 1)) | ||
|
||
user_info = tf.matmul(tf.reshape(inter_att, [-1, 1, self.num_interest]), interest_capsule) # (None, 1, embed_dim) | ||
user_info = tf.reshape(user_info, [-1, self.embed_dim]) | ||
else: | ||
user_info = tf.reduce_max(interest_capsule, axis=1) # (None, embed_dim) | ||
# train, sample softmax loss | ||
loss = tf.reduce_mean(tf.nn.sampled_softmax_loss( | ||
weights=self.item_embedding_table, | ||
biases=self.embedding_bias, | ||
labels=tf.reshape(inputs['pos_item'], shape=[-1, 1]), | ||
inputs=user_info, | ||
num_sampled=self.neg_num * self.batch_size, | ||
num_classes=self.item_num | ||
)) | ||
# add loss | ||
self.add_loss(loss) | ||
return loss | ||
else: | ||
# predict/eval | ||
pos_info = tf.nn.embedding_lookup(self.item_embedding_table, inputs['pos_item']) # (None, embed_dim) | ||
neg_info = tf.nn.embedding_lookup(self.item_embedding_table, inputs['neg_item']) # (None, neg_num, embed_dim) | ||
|
||
if self.label_attention: | ||
user_info = tf.reduce_max(interest_capsule, axis=1) # (None, embed_dim) | ||
else: | ||
user_info = tf.reduce_max(interest_capsule, axis=1) # (None, embed_dim) | ||
|
||
# calculate similar scores. | ||
pos_scores = tf.reduce_sum(tf.multiply(user_info, pos_info), axis=-1, keepdims=True) # (None, 1) | ||
neg_scores = tf.reduce_sum(tf.multiply(tf.expand_dims(user_info, axis=1), neg_info), | ||
axis=-1) # (None, neg_num) | ||
logits = tf.concat([pos_scores, neg_scores], axis=-1) | ||
return logits | ||
|
||
def summary(self): | ||
inputs = { | ||
'click_seq': Input(shape=(self.seq_len,), dtype=tf.int32), | ||
'pos_item': Input(shape=(), dtype=tf.int32), | ||
'neg_item': Input(shape=(1,), dtype=tf.int32) # suppose neg_num=1 | ||
} | ||
Model(inputs=inputs, outputs=self.call(inputs)).summary() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters