-
Notifications
You must be signed in to change notification settings - Fork 339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add graph-based config and model-- ultragcn #251
base: master
Are you sure you want to change the base?
Conversation
CI PY3 Test Passed |
CI Test Failed |
self._nbr_weights = features.get('features')[5] | ||
self._neg_ids = features.get('features')[6] | ||
else: | ||
self._user_ids = features.get('id') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这边都是叫同名的id?
def get_outputs(self): | ||
# emb_1 = tf.reduce_join(tf.as_string(self._prediction_dict['user_embedding']), axis=-1, separator=',') | ||
# emb_2 = tf.reduce_join(tf.as_string(self._prediction_dict['item_embedding'] ), axis=-1, separator=',') | ||
return ['user_embedding','item_embedding'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议跟向量召回保持一致,user_emb, item_emb
def build_metric_graph(self, eval_config): | ||
metric_dict = {} | ||
for metric in eval_config.metrics_set: | ||
if metric.WhichOneof('metric') == 'recall_at_topk': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
metric会生效么?logits来自哪里?
@@ -25,3 +25,11 @@ message BinaryDataInput { | |||
repeated string dense_path = 2; | |||
repeated string label_path = 3; | |||
} | |||
|
|||
message GraphLearnInput { | |||
optional string user_node_input = 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些是可枚举的么?node_name, node_input,这种kv的形式是不是通用一些?
import json | ||
import logging | ||
|
||
from easy_rec.python.utils import pai_util |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
跟core/sampler.py中的graph init复用
output_types = [tf.int64, tf.float32, tf.int64, tf.float32, | ||
tf.int64, tf.float32, tf.int64] | ||
# user ids, user degrees, item ids, item degrees, nbr item ids, nbr item weight, neg item ids | ||
output_shapes = [tf.TensorShape([None]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output不建议是list,可读性较差,建议是dict
break | ||
epoch_id += 1 | ||
|
||
self._nbr_num = self._data_config.ultra_gcn_sampler.nbr_num |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不应该修改GraphInput类的成员变量,如果需要建成员变量的话,建议单独建个ultra gcn sampler的类,被GraphInput实例化
if self._sampler is not None and self._mode != tf.estimator.ModeKeys.PREDICT: | ||
if self._mode != tf.estimator.ModeKeys.TRAIN: | ||
self._sampler.set_eval_num_sample() | ||
sampler_type = self._data_config.WhichOneof('sampler') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
负采样和feature config还需要么?
No description provided.