Skip to content

Commit

Permalink
Benchmark block_bucketize_sparse_features uneven sharding (pytorch#2140)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2140

Benchmark block_bucketize_sparse_features

Reviewed By: jiayisuse

Differential Revision: D51288847

fbshipit-source-id: dbb8dc705f32bc90fbdb316bdba8923c89d4f606
  • Loading branch information
tissue3 authored and facebook-github-bot committed Nov 20, 2023
1 parent 4c0fad5 commit f65d7e2
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 3 deletions.
6 changes: 3 additions & 3 deletions fbgemm_gpu/bench/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ def benchmark_torch_function( # noqa: C901
copy_f_for_multi_thread_test: bool = False,
) -> Tuple[float, torch.Tensor]:
logging.info(f"Start to benchmark {name}...")
if device != "" and device != "cuda":
if device != "cpu" and device != "" and device != "cuda":
torch.cuda.set_device(device)
for _ in range(num_warmups):
output = f(*args)

assert num_threads > 0
if torch.cuda.is_available() and (num_threads == 1):
if device != "cpu" and torch.cuda.is_available() and (num_threads == 1):
cache = torch.empty(
int(flush_gpu_cache_size_mb * 1024 * 1024 // 4),
dtype=torch.float,
Expand All @@ -69,7 +69,7 @@ def benchmark_torch_function( # noqa: C901
[s.elapsed_time(e) for s, e in zip(start_event, end_event)]
)
elapsed_time = torch.mean(times).item() * 1.0e-3
elif torch.cuda.is_available() and (num_threads > 1):
elif device != "cpu" and torch.cuda.is_available() and (num_threads > 1):
cache = torch.empty(
int(flush_gpu_cache_size_mb * 1024 * 1024 // 4),
dtype=torch.float,
Expand Down
64 changes: 64 additions & 0 deletions fbgemm_gpu/bench/sparse_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import contextlib
import functools
import logging
import math
import random
from typing import List

Expand Down Expand Up @@ -868,5 +869,68 @@ def ben(fn, name, ad_indices, ad_lengths, batch_offsets, num_ads_in_batch):
ben(pass_4, "pass_4", ad_indices, ad_lengths, batch_offsets, num_ads_in_batch)


@cli.command()
@click.option("--row-size", default=2560000)
@click.option("--batch-size", default=4096)
@click.option("--bucket-num", default=16)
@click.option("--input-precision", type=str, default="long")
def block_bucketize_sparse_features_bench(
row_size: int, batch_size: int, bucket_num: int, input_precision: str
) -> None:

dtype = torch.int
if input_precision == "int":
dtype = torch.int
elif input_precision == "long":
dtype = torch.long
else:
raise RuntimeError(f"Does not support data type {input_precision}")

indices = torch.randint(0, row_size, (batch_size,), dtype=dtype)
weights = torch.randint(0, row_size, (batch_size,), dtype=torch.float)
total = 0
lengths = []
for _ in range(batch_size):
length = random.randint(0, 10)
lengths.append(min(length, batch_size - total))
total += length
if total > batch_size:
break
lengths = torch.tensor(lengths, dtype=dtype)
bucket_size = math.ceil(row_size / bucket_num)
block_sizes = torch.tensor([bucket_size] * lengths.numel(), dtype=dtype)

bucket_pos = [j * bucket_size for j in range(bucket_num + 1)]
block_bucketize_pos = [torch.tensor(bucket_pos)] * lengths.numel()
test_param = {"uneven": block_bucketize_pos, "even": None}
for name, is_block_bucketize_pos in test_param.items():
time, output = benchmark_torch_function(
torch.ops.fbgemm.block_bucketize_sparse_features,
(
lengths,
indices,
False,
True,
block_sizes,
bucket_num,
weights,
None,
-1, # unused
is_block_bucketize_pos,
),
iters=100,
device="cpu",
)

num_bytes = 0
for tensor in [lengths, indices, weights, *block_bucketize_pos, *output]:
if isinstance(tensor, torch.Tensor):
num_bytes += (tensor.numel()) * tensor.element_size()

logging.info(
f"{name}_block_bucketize_sparse_features forward: {dtype}, {num_bytes} bytes read/write, {time * 1e3} ms, {num_bytes / time / 1e9} GB/s"
)


if __name__ == "__main__":
cli()

0 comments on commit f65d7e2

Please sign in to comment.