From d3f072c182f8543944d740af4daa497ff0ff4ef1 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Thu, 19 Sep 2024 17:35:02 -0700 Subject: [PATCH] Jagged tensor micro-benchmarks (#3156) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3156 X-link: https://github.com/facebookresearch/FBGEMM/pull/250 - Add jagged tensor micro-benchmarks ``` (foo) bash-5.1$ python -W ignore jagged_tensor_benchmark.py device --embedding-dim 512 INFO:root:######## Jagged (2D) to Dense ######## INFO:root:FBGEMM JaggedTensor: 5.746198445558548e-05 sec 438.11657809101143 GB/s INFO:root:PyTorch NestedTensor: 6.370197981595993e-05 sec 395.1842010676863 GB/s INFO:root: INFO:root:######## Dense to Jagged (2D) ######## INFO:root:FBGEMM JaggedTensor: 3.12004815787077e-05 sec 806.880109734599 GB/s INFO:root:PyTorch NestedTensor: 0.0014418727159500122 sec 17.459249850229323 GB/s INFO:root: INFO:root:######## Jagged (x) Dense -> Jagged ######## INFO:root:(+) FBGEMM JaggedTensor: 4.031049832701683e-05 sec 624.9347699856205 GB/s INFO:root:(+) PyTorch NestedTensor: 0.001540895700454712 sec 16.348564015439923 GB/s INFO:root:(*) FBGEMM JaggedTensor: 4.03628796339035e-05 sec 624.1237550068162 GB/s INFO:root:(*) PyTorch NestedTensor: 0.0015746270418167114 sec 15.998348390445281 GB/s INFO:root: INFO:root:######## Jagged + Dense + Dense -> Jagged ######## INFO:root:FBGEMM JaggedTensor: 5.2013471722602845e-05 sec 645.7602403302756 GB/s INFO:root:PyTorch NestedTensor: 0.0028932960033416747 sec 11.608985724656774 GB/s INFO:root: INFO:root:######## Jagged (1D) to Dense ######## INFO:root:FBGEMM JaggedTensor: 1.526080071926117e-05 sec 6.511322821651443 GB/s INFO:root:PyTorch NestedTensor: 3.976528346538544e-05 sec 2.4729108264901147 GB/s INFO:root: INFO:root:######## Dense to Jagged (1D) ######## INFO:root:FBGEMM JaggedTensor: 1.5250975266098977e-05 sec 6.51551774665078 GB/s INFO:root:PyTorch NestedTensor: 0.0014563246965408326 sec 0.06752340342340878 GB/s INFO:root: (foo) bash-5.1$ ``` Differential Revision: D59973955 --- fbgemm_gpu/bench/bench_utils.py | 6 +- fbgemm_gpu/bench/jagged_tensor_benchmark.py | 359 ++++++++++++++++---- 2 files changed, 305 insertions(+), 60 deletions(-) diff --git a/fbgemm_gpu/bench/bench_utils.py b/fbgemm_gpu/bench/bench_utils.py index fd1eaaaede..a7f43a0ead 100644 --- a/fbgemm_gpu/bench/bench_utils.py +++ b/fbgemm_gpu/bench/bench_utils.py @@ -35,6 +35,8 @@ def benchmark_torch_function( # noqa: C901 f, # pyre-fixme[2]: Parameter must be annotated. args, + # pyre-fixme[2]: Parameter must be annotated. + kwargs={}, flush_gpu_cache_size_mb: int = 40, iters: int = 10, num_warmups: int = 2, @@ -43,11 +45,11 @@ def benchmark_torch_function( # noqa: C901 num_threads: int = 1, copy_f_for_multi_thread_test: bool = False, ) -> Tuple[float, torch.Tensor]: - logging.info(f"Start to benchmark {name}...") + logging.debug(f"Start to benchmark {name}...") if device != "cpu" and device != "" and device != "cuda": torch.cuda.set_device(device) for _ in range(num_warmups): - output = f(*args) + output = f(*args, **kwargs) assert num_threads > 0 if device != "cpu" and torch.cuda.is_available() and (num_threads == 1): diff --git a/fbgemm_gpu/bench/jagged_tensor_benchmark.py b/fbgemm_gpu/bench/jagged_tensor_benchmark.py index acbe22fb2c..46337701e3 100644 --- a/fbgemm_gpu/bench/jagged_tensor_benchmark.py +++ b/fbgemm_gpu/bench/jagged_tensor_benchmark.py @@ -6,9 +6,12 @@ # pyre-strict +from __future__ import annotations + import functools import logging import random +from dataclasses import dataclass from typing import List, Tuple import click @@ -17,7 +20,7 @@ from torch.profiler import profile logger: logging.Logger = logging.getLogger() -logger.setLevel(logging.DEBUG) +logger.setLevel(logging.INFO) # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. open_source: bool = getattr(fbgemm_gpu, "open_source", False) @@ -52,111 +55,351 @@ def cli() -> None: pass -@cli.command() -@click.option("--batch-size", type=int, default=128) -@click.option("--embedding-dim", type=int, default=128) -@click.option("--max-len", type=int, default=128) -@click.option("--elem-type", type=str, default="half") -def device( - batch_size: int, - embedding_dim: int, - max_len: int, - elem_type: str, -) -> None: - lengths = torch.randint(max_len, size=(batch_size,)) - total_lengths = lengths.sum().item() - offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) +@dataclass +class JaggedTensor: + """ + A simple wrapper class around jagged tensors for benchmarking purposes. + Jagged tensors are a tensor of variable length vectors. They are + represented as a tuple of (values, lengths, offsets) where values is a 2D + tensor of shape (total_lengths, embedding_dim) and lengths is a 1D tensor + of shape (batch_size,) containing the length of each row in the batch. + Offsets is a 1D tensor of shape (batch_size + 1,) containing the offset of + each row. + """ + + values: torch.Tensor + lengths: torch.Tensor + offsets: torch.Tensor + batch_size: int + embedding_dim: int + max_len: int + + @property + def total_lengths(self) -> int: + return int(self.lengths.sum().item()) + + @staticmethod + def rand_2d( + batch_size: int, embedding_dim: int, max_len: int, elem_type: str + ) -> JaggedTensor: + """ + Generate a random JaggedTensor with 2D values. + """ + # Each row in the batch has different length + lengths = torch.randint(max_len, size=(batch_size,)) + total_lengths = lengths.sum().item() + + # Compute the offsets + offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + + # Set dtype + dtype = ( + torch.float16 + if elem_type == "half" or elem_type == "float16" + else torch.float32 + ) - dtype = ( - torch.float16 - if elem_type == "half" or elem_type == "float16" - else torch.float32 + # pyre-fixme[6]: For 1st param expected `int` but got `Union[bool, float, int]`. + values_2d = torch.rand(total_lengths, embedding_dim, dtype=dtype) + + if torch.cuda.is_available(): + values_2d = values_2d.cuda() + offsets = offsets.cuda() + + return JaggedTensor( + values_2d, lengths, offsets, batch_size, embedding_dim, max_len + ) + + def to_dense(self) -> torch.Tensor: + """ + Convert the JaggedTensor into a dense tensor. + """ + if self.values.dim() == 2: + return torch.ops.fbgemm.jagged_2d_to_dense( + self.values, self.offsets, self.max_len + ) + elif self.values.dim() == 1: + return torch.ops.fbgemm.jagged_1d_to_dense( + self.values, self.offsets, self.max_len, padding_value=0 + ) + else: + raise RuntimeError(f"Unsupported JaggedTensor dim {self.values.dim()}") + + def as_nested(self) -> torch.Tensor: + """ + Convert the JaggedTensor into a PyTorch NestedTensor. + """ + tensors = [] + + for i in range(1, len(self.offsets)): + tensors.append(self.values[self.offsets[i - 1] : self.offsets[i],]) + + return torch.nested.nested_tensor(tensors) + + def nbytes(self) -> int: + """ + Return the number of bytes used by the JaggedTensor. + """ + offsets_nbytes = self.offsets.numel() * self.offsets.element_size() + values_nbytes = self.values.numel() * self.values.element_size() + return offsets_nbytes + values_nbytes + + +def dense_to_nested(values: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor: + """ + Convert a dense tensor into a PyTorch NestedTensor. + """ + return torch.nested.nested_tensor( + [values[i][: lengths[i],] for i in range(len(lengths))] ) - # pyre-fixme[6]: For 1st param expected `int` but got `Union[bool, float, int]`. - values_2d = torch.rand(total_lengths, embedding_dim, dtype=dtype) - if torch.cuda.is_available(): - offsets = offsets.cuda() - values_2d = values_2d.cuda() +def bench_jagged_2d_to_dense(jten: JaggedTensor) -> None: + logging.info("######## Jagged (2D) to Dense ########") time, output = benchmark_torch_function( - torch.ops.fbgemm.jagged_2d_to_dense, (values_2d, offsets, max_len), iters=1000 + jten.to_dense, + (), + iters=1000, ) - offsets_nbytes = offsets.numel() * offsets.element_size() - values_nbytes = values_2d.numel() * values_2d.element_size() dense_nbytes = output.numel() * output.element_size() + num_bytes = jten.nbytes() + dense_nbytes + logging.info(f"FBGEMM JaggedTensor: {time} sec {num_bytes / time / 1e9} GB/s") - num_bytes = offsets_nbytes + values_nbytes + dense_nbytes - logging.info(f"jagged_2d_to_dense {time} sec {num_bytes / time / 1e9} GB/s") + nten = jten.as_nested() + time, output = benchmark_torch_function( + torch.nested.to_padded_tensor, + (nten, 0.0, (jten.batch_size, jten.max_len, jten.embedding_dim)), + iters=1000, + ) - total_L = values_2d.size(0) - time, jagged_output = benchmark_torch_function( - torch.ops.fbgemm.dense_to_jagged, (output, [offsets], total_L), iters=1000 + nten_bytes = nten.numel() * nten.element_size() + dense_nbytes = output.numel() * output.element_size() + num_bytes = nten_bytes + dense_nbytes + logging.info(f"PyTorch NestedTensor: {time} sec {num_bytes / time / 1e9} GB/s") + logging.info("") + + +def bench_dense_to_jagged_2d(jten: JaggedTensor) -> None: + logging.info("######## Dense to Jagged (2D) ########") + + dense_values = jten.to_dense() + + time, output = benchmark_torch_function( + torch.ops.fbgemm.dense_to_jagged, + (dense_values, [jten.offsets], jten.total_lengths), + iters=1000, + ) + + dense_nbytes = dense_values.numel() * dense_values.element_size() + output_nbytes = output[0].numel() * output[0].element_size() + offsets_nbytes = jten.offsets.numel() * jten.offsets.element_size() + num_bytes = dense_nbytes + output_nbytes + offsets_nbytes + logging.info(f"FBGEMM JaggedTensor: {time} sec {num_bytes / time / 1e9} GB/s") + + time, output = benchmark_torch_function( + dense_to_nested, + (dense_values, jten.lengths), + iters=1000, ) - num_bytes = offsets_nbytes + 2 * values_nbytes - logging.info(f"dense_to_jagged (2d) {time} sec {num_bytes / time / 1e9} GB/s") + output_nbytes = output.numel() * output.element_size() + num_bytes = dense_nbytes + output_nbytes + logging.info(f"PyTorch NestedTensor: {time} sec {num_bytes / time / 1e9} GB/s") + logging.info("") + + +def bench_jagged_dense_elementwise_op_jagged_output(jten: JaggedTensor) -> None: + logging.info("######## Jagged (x) Dense -> Jagged ########") + + def nested_tensor_add( + jagged_x: JaggedTensor, nested_x: torch.Tensor, dense_y: torch.Tensor + ) -> torch.Tensor: + return nested_x + dense_to_nested( + dense_y, + jagged_x.lengths, + ) + + def nested_tensor_mul( + jagged_x: JaggedTensor, nested_x: torch.Tensor, dense_y: torch.Tensor + ) -> torch.Tensor: + return nested_x * dense_to_nested( + dense_y, + jagged_x.lengths, + ) + + offsets_nbytes = jten.offsets.numel() * jten.offsets.element_size() + values_nbytes = jten.values.numel() * jten.values.element_size() + num_bytes = offsets_nbytes + 3 * values_nbytes time, jagged_output = benchmark_torch_function( torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output, - (values_2d, [offsets], output), + (jten.values, [jten.offsets], jten.to_dense()), + iters=1000, + ) + logging.info(f"(Add) FBGEMM JaggedTensor: {time} sec {num_bytes / time / 1e9} GB/s") + + time, nested_output = benchmark_torch_function( + nested_tensor_add, + (jten, jten.as_nested(), jten.to_dense()), iters=1000, ) - num_bytes = offsets_nbytes + 3 * values_nbytes logging.info( - f"jagged_dense_elementwise_add_jagged_output {time} sec {num_bytes / time / 1e9} GB/s" + f"(Add) PyTorch NestedTensor: {time} sec {num_bytes / time / 1e9} GB/s" ) time, jagged_output = benchmark_torch_function( torch.ops.fbgemm.jagged_dense_elementwise_mul, - (values_2d, [offsets], output), + (jten.values, [jten.offsets], jten.to_dense()), + iters=1000, + ) + logging.info(f"(Mul) FBGEMM JaggedTensor: {time} sec {num_bytes / time / 1e9} GB/s") + + time, nested_output = benchmark_torch_function( + nested_tensor_mul, + (jten, jten.as_nested(), jten.to_dense()), iters=1000, ) - num_bytes = offsets_nbytes + 3 * values_nbytes logging.info( - f"jagged_dense_elementwise_mul {time} sec {num_bytes / time / 1e9} GB/s" + f"(Mul) PyTorch NestedTensor: {time} sec {num_bytes / time / 1e9} GB/s" ) + logging.info("") + + +def bench_jagged_dense_dense_elementwise_add_jagged_output(jten: JaggedTensor) -> None: + logging.info("######## Jagged + Dense + Dense -> Jagged ########") + + def nested_tensor_add( + jagged_x: JaggedTensor, + nested_x: torch.Tensor, + dense_y0: torch.Tensor, + dense_y1: torch.Tensor, + ) -> torch.Tensor: + return ( + nested_x + + dense_to_nested( + dense_y0, + jagged_x.lengths, + ) + + dense_to_nested( + dense_y1, + jagged_x.lengths, + ) + ) - output_sq = output * output + offsets_nbytes = jten.offsets.numel() * jten.offsets.element_size() + values_nbytes = jten.values.numel() * jten.values.element_size() + num_bytes = offsets_nbytes + 4 * values_nbytes + + output = jten.to_dense() time, jagged_output = benchmark_torch_function( torch.ops.fbgemm.jagged_dense_dense_elementwise_add_jagged_output, - (values_2d, [offsets], output, output_sq), + (jten.values, [jten.offsets], output, output * output), iters=1000, ) - num_bytes = offsets_nbytes + 4 * values_nbytes - logging.info( - f"jagged_dense_dense_elementwise_add_jagged_output {time} sec {num_bytes / time / 1e9} GB/s" + logging.info(f"FBGEMM JaggedTensor: {time} sec {num_bytes / time / 1e9} GB/s") + + time, nested_output = benchmark_torch_function( + nested_tensor_add, + (jten, jten.as_nested(), output, output * output), + iters=1000, ) + logging.info(f"PyTorch NestedTensor: {time} sec {num_bytes / time / 1e9} GB/s") + logging.info("") + + +def bench_jagged_1d_to_dense(jten: JaggedTensor) -> None: + logging.info("######## Jagged (1D) to Dense ########") # pyre-fixme[6]: For 1st param expected `Union[List[int], Size, # typing.Tuple[int, ...]]` but got `Union[bool, float, int]`. - values_1d = torch.rand(total_lengths) + jten.values = torch.rand(jten.total_lengths) if torch.cuda.is_available(): - values_1d = values_1d.cuda() - values_nbytes = values_1d.numel() * values_1d.element_size() + jten.values = jten.values.cuda() time, output = benchmark_torch_function( - lambda: torch.ops.fbgemm.jagged_1d_to_dense( - values_1d, offsets, max_len, padding_value=0 - ), + jten.to_dense, (), iters=1000, ) + dense_nbytes = output.numel() * output.element_size() + num_bytes = jten.nbytes() + dense_nbytes + logging.info(f"FBGEMM JaggedTensor: {time} sec {num_bytes / time / 1e9} GB/s") + + nten = jten.as_nested() + time, output = benchmark_torch_function( + torch.nested.to_padded_tensor, + (nten, 0.0, (jten.batch_size, jten.embedding_dim)), + iters=1000, + ) + + nten_bytes = nten.numel() * nten.element_size() + num_bytes = nten_bytes + dense_nbytes + logging.info(f"PyTorch NestedTensor: {time} sec {num_bytes / time / 1e9} GB/s") + logging.info("") - num_bytes = offsets_nbytes + values_nbytes + dense_nbytes - logging.info(f"jagged_1d_to_dense {time} sec {num_bytes / time / 1e9} GB/s") - total_L = values_1d.size(0) - output_1d = torch.unsqueeze(output, -1) +def bench_dense_to_jagged_1d(jten: JaggedTensor) -> None: + logging.info("######## Dense to Jagged (1D) ########") + + # pyre-fixme[6]: For 1st param expected `Union[List[int], Size, + # typing.Tuple[int, ...]]` but got `Union[bool, float, int]`. + jten.values = torch.rand(jten.total_lengths) + if torch.cuda.is_available(): + jten.values = jten.values.cuda() + dense_values = jten.to_dense() + + dense_1d = torch.unsqueeze(dense_values, -1) time, jagged_output = benchmark_torch_function( - torch.ops.fbgemm.dense_to_jagged, (output_1d, [offsets], total_L), iters=1000 + torch.ops.fbgemm.dense_to_jagged, + (dense_1d, [jten.offsets], jten.total_lengths), + iters=1000, ) - num_bytes = offsets_nbytes + 2 * values_nbytes - logging.info(f"dense_to_jagged (1d) {time} sec {num_bytes / time / 1e9} GB/s") + dense_1d_nbytes = dense_1d.numel() * dense_1d.element_size() + offsets_nbytes = jten.offsets.numel() * jten.offsets.element_size() + jagged_output_bytes = jagged_output[0].numel() * jagged_output[0].element_size() + num_bytes = offsets_nbytes + dense_1d_nbytes + jagged_output_bytes + logging.info(f"FBGEMM JaggedTensor: {time} sec {num_bytes / time / 1e9} GB/s") + + time, output = benchmark_torch_function( + dense_to_nested, + (dense_1d, jten.lengths), + iters=1000, + ) + + nten_nbytes = output.numel() * output.element_size() + num_bytes = dense_1d_nbytes + nten_nbytes + logging.info(f"PyTorch NestedTensor: {time} sec {num_bytes / time / 1e9} GB/s") + logging.info("") + + +@cli.command() +@click.option("--batch-size", type=int, default=128) +@click.option("--embedding-dim", type=int, default=128) +@click.option("--max-len", type=int, default=128) +@click.option("--elem-type", type=str, default="half") +def device( + batch_size: int, + embedding_dim: int, + max_len: int, + elem_type: str, +) -> None: + jtensor = JaggedTensor.rand_2d(batch_size, embedding_dim, max_len, elem_type) + + bench_jagged_2d_to_dense(jtensor) + + bench_dense_to_jagged_2d(jtensor) + + bench_jagged_dense_elementwise_op_jagged_output(jtensor) + + bench_jagged_dense_dense_elementwise_add_jagged_output(jtensor) + + bench_jagged_1d_to_dense(jtensor) + + bench_dense_to_jagged_1d(jtensor) @cli.command()