Skip to content

Commit

Permalink
Update TBE training benchmark (pytorch#3112)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#202

Pull Request resolved: pytorch#3112

As title

Reviewed By: spcyppt

Differential Revision: D62484938

fbshipit-source-id: 27327ff50d0ba31b616b0ed426dd03e21edbaccc
  • Loading branch information
sryap authored and facebook-github-bot committed Sep 11, 2024
1 parent 08b1965 commit 49b5e55
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def cli() -> None:
@click.option(
"--ssd-prefix", type=str, default="/tmp/ssd_benchmark", help="SSD directory prefix"
)
@click.option("--cache-load-factor", default=0.2)
def device( # noqa C901
alpha: float,
bag_size: int,
Expand Down Expand Up @@ -163,6 +164,7 @@ def device( # noqa C901
uvm_host_mapped: bool,
ssd: bool,
ssd_prefix: str,
cache_load_factor: float,
) -> None:
assert not ssd or not dense, "--ssd cannot be used together with --dense"
np.random.seed(42)
Expand Down Expand Up @@ -278,6 +280,7 @@ def device( # noqa C901
],
cache_precision=weights_precision,
cache_algorithm=CacheAlgorithm.LRU,
cache_load_factor=cache_load_factor,
**common_split_args,
)
emb = emb.to(get_device())
Expand Down Expand Up @@ -736,6 +739,12 @@ def run_bench(indices: Tensor, offsets: Tensor, per_sample_weights: Tensor) -> N
@click.option("--flush-gpu-cache-size-mb", default=0)
@click.option("--requests_data_file", type=str, default=None)
@click.option("--tables", type=str, default=None)
@click.option(
"--uvm-host-mapped",
is_flag=True,
default=False,
help="Use host mapped UVM buffers in SSD-TBE (malloc+cudaHostRegister)",
)
def cache( # noqa C901
alpha: float,
bag_size: int,
Expand All @@ -756,6 +765,7 @@ def cache( # noqa C901
flush_gpu_cache_size_mb: int,
requests_data_file: Optional[str],
tables: Optional[str],
uvm_host_mapped: bool,
) -> None:
np.random.seed(42)
torch.manual_seed(42)
Expand Down Expand Up @@ -788,6 +798,7 @@ def cache( # noqa C901
optimizer=optimizer,
weights_precision=weights_precision,
stochastic_rounding=stoc,
uvm_host_mapped=uvm_host_mapped,
).cuda()

if weights_precision == SparseType.INT8:
Expand All @@ -808,6 +819,7 @@ def cache( # noqa C901
stochastic_rounding=stoc,
cache_load_factor=cache_load_factor,
cache_algorithm=cache_alg,
uvm_host_mapped=uvm_host_mapped,
).cuda()

if weights_precision == SparseType.INT8:
Expand Down

0 comments on commit 49b5e55

Please sign in to comment.