From 4bbfb492b872c5a3290a2bce1ed5c160162558a3 Mon Sep 17 00:00:00 2001 From: ZiyaoGeng <593947521@qq.com> Date: Fri, 29 Apr 2022 14:10:30 +0800 Subject: [PATCH] update deepfm --- example/train_small_criteo_demo.py | 2 +- reclearn/models/ranking/deepfm.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/example/train_small_criteo_demo.py b/example/train_small_criteo_demo.py index e28106a..de619f2 100644 --- a/example/train_small_criteo_demo.py +++ b/example/train_small_criteo_demo.py @@ -24,7 +24,7 @@ # TODO: Hyper Parameters file = 'data/criteo/train.txt' read_part = True - sample_num = 5000000 + sample_num = 50000 test_size = 0.2 model_params = { diff --git a/reclearn/models/ranking/deepfm.py b/reclearn/models/ranking/deepfm.py index dea85eb..2e91d92 100644 --- a/reclearn/models/ranking/deepfm.py +++ b/reclearn/models/ranking/deepfm.py @@ -54,10 +54,10 @@ def call(self, inputs): sparse_inputs = index_mapping(inputs, self.map_dict) wide_inputs = {'sparse_inputs': sparse_inputs, 'embed_inputs': tf.reshape(sparse_embed, shape=(-1, self.field_num, self.embed_dim))} - wide_outputs = self.fm(wide_inputs) # (batch_size, 1) + wide_outputs = tf.reshape(self.fm(wide_inputs), [-1, 1]) # (batch_size, 1) # deep deep_outputs = self.mlp(sparse_embed) - deep_outputs = self.dense(deep_outputs) # (batch_size, 1) + deep_outputs = tf.reshape(self.dense(deep_outputs), [-1, 1]) # (batch_size, 1) # outputs outputs = tf.nn.sigmoid(tf.add(wide_outputs, deep_outputs)) return outputs