Skip to content

Commit

Permalink
Improve performance (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
DLPerf authored Aug 24, 2021
1 parent fe4393e commit 8cd4087
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions server/embedding_as_service/text/xlnet/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,10 @@ def get_cache_fn(mem_len):
def cache_fn(batch_size):
mems = []
if FLAGS.mem_len > 0:
for _ in range(FLAGS.n_layer):
zeros = tf.zeros(
zeros = tf.zeros(
[mem_len, batch_size, FLAGS.d_model],
dtype=tf_float)
for _ in range(FLAGS.n_layer):
mems.append(zeros)

return mems
Expand Down

0 comments on commit 8cd4087

Please sign in to comment.