From 40d96f00605a854fdedb1ea8b95fb907ea2b91a9 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Wed, 5 Jul 2023 17:44:55 -0700 Subject: [PATCH] Fix the hash_size_cumsum cond in TBE's init (#1862) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1862 This diff fixes `ValueError: math domain error` introduced by D46618714 Reviewed By: q10, jianyuh, yuguo68 Differential Revision: D47236684 fbshipit-source-id: d5375928ad5ee6bad58c183355b52fac71b4b211 --- .../fbgemm_gpu/split_table_batched_embeddings_ops_training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index 021a7b3a3..17db584bb 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -432,7 +432,7 @@ def __init__( # noqa C901 ) hash_size_cumsum = [0] + list(accumulate(rows)) self.total_hash_size: int = int(hash_size_cumsum[-1]) - if hash_size_cumsum == 0: + if self.total_hash_size == 0: self.total_hash_size_bits: int = 0 else: self.total_hash_size_bits: int = int(log2(float(self.total_hash_size)) + 1)