Skip to content

Commit

Permalink
add feature names per table to tbe module (#2508)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2508

add a new attribute `feature_names_per_table` that can be retrieved from predictor to record pooling factor stuff

Reviewed By: 842974287

Differential Revision: D56123412

fbshipit-source-id: 2dda5ac750826eb900ec822fb5c888d109b634db
  • Loading branch information
jiayisuse authored and facebook-github-bot committed May 8, 2024
1 parent 7e81391 commit 7d15c59
Showing 1 changed file with 8 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ def __init__( # noqa C901
cacheline_alignment: bool = True,
uvm_host_mapped: bool = False, # True to use cudaHostAlloc; False to use cudaMallocManaged.
reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row.
feature_names_per_table: Optional[List[List[str]]] = None,
) -> None: # noqa C901 # tuple of (rows, dims,)
super(IntNBitTableBatchedEmbeddingBagsCodegen, self).__init__()

Expand All @@ -214,6 +215,7 @@ def __init__( # noqa C901
self.embedding_specs = embedding_specs
self.output_dtype: int = output_dtype.as_int()
self.uvm_host_mapped = uvm_host_mapped
self.feature_names_per_table = feature_names_per_table
# (feature_names, rows, dims, weights_tys, locations) = zip(*embedding_specs)
# Pyre workaround
self.feature_names: List[str] = [e[0] for e in embedding_specs]
Expand Down Expand Up @@ -465,6 +467,12 @@ def get_table_wise_cache_miss(self) -> Tensor:
# table_wise_cache_miss contains all the cache miss count for each table in this embedding table object:
return self.table_wise_cache_miss

@torch.jit.export
def get_feature_num_per_table(self) -> List[int]:
if self.feature_names_per_table is None:
return []
return [len(feature_names) for feature_names in self.feature_names_per_table]

def reset_cache_miss_counter(self) -> None:
assert (
self.record_cache_metrics.record_cache_miss_counter
Expand Down

0 comments on commit 7d15c59

Please sign in to comment.