From 4ea163f8ff13d3656aa60a8aff19538e6add3a71 Mon Sep 17 00:00:00 2001 From: sryap <17482891+sryap@users.noreply.github.com> Date: Sat, 28 Sep 2024 06:04:58 -0700 Subject: [PATCH] Add the block_bucketize_sparse_features docstring (#3191) Summary: A title Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3191 Differential Revision: D63583026 Pulled By: sryap --- .../src/fbgemm_gpu-python-api/sparse_ops.rst | 2 + fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py | 147 ++++++++++++++++++ 2 files changed, 149 insertions(+) diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst index 44d9e34ce..b95b6dda4 100644 --- a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst @@ -16,3 +16,5 @@ Sparse Operators .. autofunction:: torch.ops.fbgemm.segment_sum_csr .. autofunction:: torch.ops.fbgemm.keyed_jagged_index_select_dim1 + +.. autofunction:: torch.ops.fbgemm.block_bucketize_sparse_features \ No newline at end of file diff --git a/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py index 5dffc308f..76470e32b 100644 --- a/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py @@ -319,3 +319,150 @@ tensor([4, 4, 5, 9, 9, 7], device='cuda:0')] """, ) + +add_docs( + torch.ops.fbgemm.block_bucketize_sparse_features, + """ +block_bucketize_sparse_features(lengths, indices, bucketize_pos, sequence, block_sizes, my_size, weights=None, batch_size_per_feature=None, max_B= -1, block_bucketize_pos=None, keep_orig_idx=False) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor]] + +Preprocess sparse features by partitioning sparse features into multiple +buckets. Every feature is split into the same number of buckets, but the bucket +sizes (widths) for the different features can be different. Moreover, the +bucket sizes within each feature can be different. + +Args: + lengths (Tensor): The lengths of the sparse features. The tensor contains + the lengths of each sample in a batch and each feature. Shape is `B * + T` where `B` is the batch size and `T` is the number of features + + indices (Tensor): The sparse data. Only support integer types. Shape is the + sum of `lengths` + + bucketize_pos (bool): If True, return the original relative indices within + a sample. For example, `indices = [9, 8, 2, 1, 0, 8, 9]` and `lengths = + [3, 4]`. The original relative indices within a sample for the indices + are `[0, 1, 2, 0, 1, 2, 3]` + + sequence (bool): If True, return the new indices positions in the original + indices positions (the tensor is called `unbucketize_permute_data`). + + block_sizes (Tensor): This tensor is used for the case where the bucket + size within a feature is uniform (i.e., when + `block_bucketize_pos=None`). The tensor contains bucket sizes (i.e., + bucket widths) for each feature. `block_sizes[t]` represents the + bucket size of feature `t`. Shape is the number of features. + + my_size (int): The number of buckets for each feature. Note that every + feature has the same number of buckets. + + weights (Optional[Tensor] = None): An optional float tensor that will be + bucketized the same way as `indices`. This tensor must have the same + shape as `indices` + + batch_size_per_feature (Optional[Tensor] = None): An optional tensor that + contains batch sizes for different features. If not None, batch sizes + are not uniform among features. Otherwise, the operator will assume + that the batch size is uniform and infer it from the `lengths` and + `block_sizes` tensors + + max_B (int = -1): The max batch size. Must be set if + `batch_size_per_feature` is not None + + block_bucketize_pos (Optional[List[Tensor]] = None): The input is used for + non-uniform bucket sizes within a feature. `block_bucketize_pos` is a + list of tensors. Each tensor contains the range offsets of buckets for + each feature. These range offsets are equivalent to the complete + cumulative sum of the bucket sizes. For example, `[0, 4, 20]` represents + two buckets. The first bucket size is `(4 - 0) = 4`, and the second + bucket size is `(20 - 4) = 16`. The length of `block_bucketize_pos` + must be equal to the number of features. + + keep_orig_idx (bool = False): If True, return original indices instead of + the relative indices within each bucket + +Return: + A tuple of tensors containing + + (1) Bucketized lengths. Shape is `lengths.num() * my_size`. + + (2) Bucketized indices. Same shape as `indices`. + + (3) Bucketized weights or None if `weights` is None. Same shape as + `indices`. + + (4) Bucketized positions or None if `bucketize_pos=False`. Same shape as + `indices`. + + (5) `unbucketize_permute` or None if `sequence=False`. Same shape as + `indices` + +**Example**: + + >>> # Generate input example. Batch size = 2. Number of features = 4 + >>> lengths = torch.tensor([0, 2, 1, 3, 2, 3, 3, 1], dtype=torch.int, device="cuda") + >>> indices = torch.tensor([3, 4, 15, 11, 28, 29, 1, 10, 11, 12, 13, 11, 22, 20, 20], dtype=torch.int, device="cuda") + >>> block_sizes = torch.tensor([[5, 15, 10, 20]], dtype=torch.int, device="cuda") + >>> my_size = 2 # Number of buckets + >>> # Invoke with keep_orig_idx=False, bucketize_pos=False, and + >>> # sequence=False + >>> torch.ops.fbgemm.block_bucketize_sparse_features( + >>> lengths, + >>> indices, + >>> bucketize_pos=False, + >>> sequence=False, + >>> block_sizes=block_sizes, + >>> my_size=my_size, + >>> keep_orig_idx=False) + >>> # The first 8 values in the returned lengths are the lengths for bucket + >>> # 0 and the rests are the legths for bucket 1 + (tensor([0, 2, 0, 1, 1, 0, 1, 0, 0, 0, 1, 2, 1, 3, 2, 1], device='cuda:0', + dtype=torch.int32), + tensor([ 3, 4, 11, 1, 11, 0, 13, 14, 0, 1, 2, 3, 2, 0, 0], + device='cuda:0', dtype=torch.int32), + None, + None, + None) + >>> # Invoke with keep_orig_idx=True, bucketize_pos=True, and + >>> # sequence=True + >>> torch.ops.fbgemm.block_bucketize_sparse_features( + >>> lengths, + >>> indices, + >>> bucketize_pos=True, + >>> sequence=True, + >>> block_sizes=block_sizes, + >>> my_size=my_size, + >>> keep_orig_idx=True) + (tensor([0, 2, 0, 1, 1, 0, 1, 0, 0, 0, 1, 2, 1, 3, 2, 1], device='cuda:0', + dtype=torch.int32), + tensor([ 3, 4, 11, 1, 11, 15, 28, 29, 10, 11, 12, 13, 22, 20, 20], + device='cuda:0', dtype=torch.int32), + None, + tensor([0, 1, 0, 0, 0, 0, 1, 2, 1, 0, 1, 2, 1, 2, 0], device='cuda:0', + dtype=torch.int32), + tensor([ 0, 1, 5, 2, 6, 7, 3, 8, 9, 10, 11, 4, 12, 13, 14], + device='cuda:0', dtype=torch.int32)) + >>> # Invoke with block_bucketize_pos + >>> block_bucketize_pos = [ + >>> torch.tensor([0, 2, 8], dtype=torch.int), + >>> torch.tensor([0, 5, 10], dtype=torch.int), + >>> torch.tensor([0, 7, 12], dtype=torch.int), + >>> torch.tensor([0, 2, 16], dtype=torch.int), + >>> ] + >>> torch.ops.fbgemm.block_bucketize_sparse_features( + >>> lengths, + >>> indices, + >>> bucketize_pos=False, + >>> sequence=False, + >>> block_sizes=block_sizes, + >>> my_size=my_size, + >>> block_bucketize_pos=block_bucketize_pos, + >>> keep_orig_idx=False) + (tensor([0, 0, 0, 1, 1, 1, 2, 1, 0, 2, 1, 2, 1, 2, 1, 0], device='cuda:0', + dtype=torch.int32), + tensor([14, 1, 6, 11, 10, 10, 1, 2, 7, 5, 14, 3, 4, 6, 9], + device='cuda:0', dtype=torch.int32), + None, + None, + None) + """, +)