diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 4d29b01a5..19ee84f66 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -191,6 +191,7 @@ def __init__( # noqa C901 cacheline_alignment: bool = True, uvm_host_mapped: bool = False, # True to use cudaHostAlloc; False to use cudaMallocManaged. reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row. + feature_names_per_table: Optional[List[List[str]]] = None, ) -> None: # noqa C901 # tuple of (rows, dims,) super(IntNBitTableBatchedEmbeddingBagsCodegen, self).__init__() @@ -214,6 +215,7 @@ def __init__( # noqa C901 self.embedding_specs = embedding_specs self.output_dtype: int = output_dtype.as_int() self.uvm_host_mapped = uvm_host_mapped + self.feature_names_per_table = feature_names_per_table # (feature_names, rows, dims, weights_tys, locations) = zip(*embedding_specs) # Pyre workaround self.feature_names: List[str] = [e[0] for e in embedding_specs] @@ -465,6 +467,12 @@ def get_table_wise_cache_miss(self) -> Tensor: # table_wise_cache_miss contains all the cache miss count for each table in this embedding table object: return self.table_wise_cache_miss + @torch.jit.export + def get_feature_num_per_table(self) -> List[int]: + if self.feature_names_per_table is None: + return [] + return [len(feature_names) for feature_names in self.feature_names_per_table] + def reset_cache_miss_counter(self) -> None: assert ( self.record_cache_metrics.record_cache_miss_counter