diff --git a/fbgemm_gpu/bench/bench_utils.py b/fbgemm_gpu/bench/bench_utils.py index 02ffda22f..be5b06287 100644 --- a/fbgemm_gpu/bench/bench_utils.py +++ b/fbgemm_gpu/bench/bench_utils.py @@ -41,13 +41,13 @@ def benchmark_torch_function( # noqa: C901 copy_f_for_multi_thread_test: bool = False, ) -> Tuple[float, torch.Tensor]: logging.info(f"Start to benchmark {name}...") - if device != "" and device != "cuda": + if device != "cpu" and device != "" and device != "cuda": torch.cuda.set_device(device) for _ in range(num_warmups): output = f(*args) assert num_threads > 0 - if torch.cuda.is_available() and (num_threads == 1): + if device != "cpu" and torch.cuda.is_available() and (num_threads == 1): cache = torch.empty( int(flush_gpu_cache_size_mb * 1024 * 1024 // 4), dtype=torch.float, @@ -69,7 +69,7 @@ def benchmark_torch_function( # noqa: C901 [s.elapsed_time(e) for s, e in zip(start_event, end_event)] ) elapsed_time = torch.mean(times).item() * 1.0e-3 - elif torch.cuda.is_available() and (num_threads > 1): + elif device != "cpu" and torch.cuda.is_available() and (num_threads > 1): cache = torch.empty( int(flush_gpu_cache_size_mb * 1024 * 1024 // 4), dtype=torch.float, diff --git a/fbgemm_gpu/bench/sparse_ops_benchmark.py b/fbgemm_gpu/bench/sparse_ops_benchmark.py index bf0ad96c1..b06ace2f7 100644 --- a/fbgemm_gpu/bench/sparse_ops_benchmark.py +++ b/fbgemm_gpu/bench/sparse_ops_benchmark.py @@ -7,6 +7,7 @@ import contextlib import functools import logging +import math import random from typing import List @@ -868,5 +869,68 @@ def ben(fn, name, ad_indices, ad_lengths, batch_offsets, num_ads_in_batch): ben(pass_4, "pass_4", ad_indices, ad_lengths, batch_offsets, num_ads_in_batch) +@cli.command() +@click.option("--row-size", default=2560000) +@click.option("--batch-size", default=4096) +@click.option("--bucket-num", default=16) +@click.option("--input-precision", type=str, default="long") +def block_bucketize_sparse_features_bench( + row_size: int, batch_size: int, bucket_num: int, input_precision: str +) -> None: + + dtype = torch.int + if input_precision == "int": + dtype = torch.int + elif input_precision == "long": + dtype = torch.long + else: + raise RuntimeError(f"Does not support data type {input_precision}") + + indices = torch.randint(0, row_size, (batch_size,), dtype=dtype) + weights = torch.randint(0, row_size, (batch_size,), dtype=torch.float) + total = 0 + lengths = [] + for _ in range(batch_size): + length = random.randint(0, 10) + lengths.append(min(length, batch_size - total)) + total += length + if total > batch_size: + break + lengths = torch.tensor(lengths, dtype=dtype) + bucket_size = math.ceil(row_size / bucket_num) + block_sizes = torch.tensor([bucket_size] * lengths.numel(), dtype=dtype) + + bucket_pos = [j * bucket_size for j in range(bucket_num + 1)] + block_bucketize_pos = [torch.tensor(bucket_pos)] * lengths.numel() + test_param = {"uneven": block_bucketize_pos, "even": None} + for name, is_block_bucketize_pos in test_param.items(): + time, output = benchmark_torch_function( + torch.ops.fbgemm.block_bucketize_sparse_features, + ( + lengths, + indices, + False, + True, + block_sizes, + bucket_num, + weights, + None, + -1, # unused + is_block_bucketize_pos, + ), + iters=100, + device="cpu", + ) + + num_bytes = 0 + for tensor in [lengths, indices, weights, *block_bucketize_pos, *output]: + if isinstance(tensor, torch.Tensor): + num_bytes += (tensor.numel()) * tensor.element_size() + + logging.info( + f"{name}_block_bucketize_sparse_features forward: {dtype}, {num_bytes} bytes read/write, {time * 1e3} ms, {num_bytes / time / 1e9} GB/s" + ) + + if __name__ == "__main__": cli()