diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst index afc38a450..44d9e34ce 100644 --- a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst @@ -12,3 +12,7 @@ Sparse Operators .. autofunction:: torch.ops.fbgemm.asynchronous_complete_cumsum .. autofunction:: torch.ops.fbgemm.offsets_range + +.. autofunction:: torch.ops.fbgemm.segment_sum_csr + +.. autofunction:: torch.ops.fbgemm.keyed_jagged_index_select_dim1 diff --git a/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py index ae307dc8f..5dffc308f 100644 --- a/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py @@ -204,3 +204,118 @@ 4, 5, 6], device='cuda:0') """, ) + +add_docs( + torch.ops.fbgemm.segment_sum_csr, + """ +segment_sum_csr(batch_size, csr_seg, values) -> Tensor + +Sum values within each segment on the given CSR data where each row has the +same number of non-zero elements. + +Args: + batch_size (int): The row stride (number of non-zero elements in each row) + + csr_seg (Tensor): The complete cumulative sum of segment lengths. A segment + length is the number of rows within each segment. The shape of the + `csr_seg` tensor is `num_segments + 1` where `num_segments` is the + number of segments. + + values (Tensor): The values tensor to be segment summed. The number of + elements in the tensor must be multiple of `batch_size` + +Returns: + A tensor containing the segment sum results. Shape is the number of + segments. + +**Example:** + + >>> batch_size = 2 + >>> # Randomize inputs + >>> lengths = torch.tensor([3, 4, 1], dtype=torch.int, device="cuda") + >>> offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + >>> print(offsets) + tensor([0, 3, 7, 8], device='cuda:0', dtype=torch.int32) + >>> values = torch.randn(lengths.sum().item() * batch_size, dtype=torch.float32, device="cuda") + >>> print(values) + tensor([-2.8642e-01, 1.6451e+00, 1.1322e-01, 1.7335e+00, -8.4700e-02, + -1.2756e+00, 1.1206e+00, 9.6385e-01, 6.2122e-02, 1.3104e-03, + 2.2667e-01, 2.3113e+00, -1.1948e+00, -1.5463e-01, -1.0031e+00, + -3.5531e-01], device='cuda:0') + >>> # Invoke + >>> torch.ops.fbgemm.segment_sum_csr(batch_size, offsets, values) + tensor([ 1.8451, 3.3365, -1.3584], device='cuda:0') + """, +) + +add_docs( + torch.ops.fbgemm.keyed_jagged_index_select_dim1, + """ +keyed_jagged_index_select_dim1(values, lengths, offsets, indices, batch_size, weights=None, selected_lengths_sum=None) -> List[Tensor] + +Perform an index select operation on the batch dimension (dim 1) of the given +keyed jagged tensor (KJT) input. The same samples in the batch of every key +will be selected. Note that each KJT has 3 dimensions: (`num_keys`, `batch_size`, +jagged dim), where `num_keys` is the number of keys, and `batch_size` is the +batch size. This operator is similar to a permute operator. + +Args: + values (Tensor): The KJT values tensor which contains concatenated data of + every key + + lengths (Tensor): The KJT lengths tensor which contains the jagged shapes + of every key (dim 0) and sample (dim 1). Shape is `num_keys * + batch_size` + + offsets (Tensor): The KJT offsets tensor which is the complete cumulative + sum of `lengths`. Shape is `num_keys * batch_size + 1` + + indices (Tensor): The indices to select, i.e., samples in the batch to + select. The values of `indices` must be >= 0 and < `batch_size` + + batch_size (int): The batch size (dim 1 of KJT) + + weights (Optional[Tensor] = None): An optional float tensor which will be + selected the same way as `values`. Thus, it must have the same shape as + `values` + + selected_lengths_sum (Optional[int] = None): An optional value that + represents the total number of elements in the index select data + (output shape). If not provided, the operator will compute this data + which may cause a device-host synchronization (if using GPU). Thus, it + is recommended to supply this value to avoid such the synchronization. + +Returns: + The index-select KJT tensor (as a list of values, lengths, and weights if + `weights` is not None) + +**Example:** + + >>> num_keys = 2 + >>> batch_size = 4 + >>> output_size = 3 + >>> # Randomize inputs + >>> lengths = torch.randint(low=0, high=10, size=(batch_size * num_keys,), dtype=torch.int64, device="cuda") + >>> print(lengths) + tensor([8, 5, 1, 4, 2, 7, 5, 9], device='cuda:0') + >>> offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + >>> print(offsets) + tensor([ 0, 8, 13, 14, 18, 20, 27, 32, 41], device='cuda:0') + >>> indices = torch.randint(low=0, high=batch_size, size=(output_size,), dtype=torch.int64, device="cuda") + >>> print(indices) + tensor([3, 3, 1], device='cuda:0') + >>> # Use torch.arange instead of torch.randn to simplify the example + >>> values = torch.arange(lengths.sum().item(), dtype=torch.float32, device="cuda") + >>> print(values) + tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., + 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., + 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40.], + device='cuda:0') + >>> # Invoke. Output = (output, lengths) + >>> torch.ops.fbgemm.keyed_jagged_index_select_dim1(values, lengths, offsets, indices, batch_size) + [tensor([14., 15., 16., 17., 14., 15., 16., 17., 8., 9., 10., 11., 12., 32., + 33., 34., 35., 36., 37., 38., 39., 40., 32., 33., 34., 35., 36., 37., + 38., 39., 40., 20., 21., 22., 23., 24., 25., 26.], device='cuda:0'), + tensor([4, 4, 5, 9, 9, 7], device='cuda:0')] + """, +)