Skip to content

Commit

Permalink
Add docstrings for sparse ops (2)
Browse files Browse the repository at this point in the history
  • Loading branch information
sryap committed Sep 27, 2024
1 parent d056aa3 commit 8c6f493
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 0 deletions.
4 changes: 4 additions & 0 deletions fbgemm_gpu/docs/src/fbgemm_gpu-python-api/sparse_ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ Sparse Operators
.. autofunction:: torch.ops.fbgemm.asynchronous_complete_cumsum

.. autofunction:: torch.ops.fbgemm.offsets_range

.. autofunction:: torch.ops.fbgemm.segment_sum_csr

.. autofunction:: torch.ops.fbgemm.keyed_jagged_index_select_dim1
115 changes: 115 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/docs/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,118 @@
4, 5, 6], device='cuda:0')
""",
)

add_docs(
torch.ops.fbgemm.segment_sum_csr,
"""
segment_sum_csr(batch_size, csr_seg, values) -> Tensor
Sum values within each segment on the given CSR data where each row has the
same number of non-zero elements.
Args:
batch_size (int): The row stride (number of non-zero elements in each row)
csr_seg (Tensor): The complete cumulative sum of segment lengths. A segment
length is the number of rows within each segment. The shape of the
`csr_seg` tensor is `num_segments + 1` where `num_segments` is the
number of segments.
values (Tensor): The values tensor to be segment summed. The number of
elements in the tensor must be multiple of `batch_size`
Returns:
A tensor containing the segment sum results. Shape is the number of
segments.
**Example:**
>>> batch_size = 2
>>> # Randomize inputs
>>> lengths = torch.tensor([3, 4, 1], dtype=torch.int, device="cuda")
>>> offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
>>> print(offsets)
tensor([0, 3, 7, 8], device='cuda:0', dtype=torch.int32)
>>> values = torch.randn(lengths.sum().item() * batch_size, dtype=torch.float32, device="cuda")
>>> print(values)
tensor([-2.8642e-01, 1.6451e+00, 1.1322e-01, 1.7335e+00, -8.4700e-02,
-1.2756e+00, 1.1206e+00, 9.6385e-01, 6.2122e-02, 1.3104e-03,
2.2667e-01, 2.3113e+00, -1.1948e+00, -1.5463e-01, -1.0031e+00,
-3.5531e-01], device='cuda:0')
>>> # Invoke
>>> torch.ops.fbgemm.segment_sum_csr(batch_size, offsets, values)
tensor([ 1.8451, 3.3365, -1.3584], device='cuda:0')
""",
)

add_docs(
torch.ops.fbgemm.keyed_jagged_index_select_dim1,
"""
keyed_jagged_index_select_dim1(values, lengths, offsets, indices, batch_size, weights=None, selected_lengths_sum=None) -> List[Tensor]
Perform an index select operation on the batch dimension (dim 1) of the given
keyed jagged tensor (KJT) input. The same samples in the batch of every key
will be selected. Note that each KJT has 3 dimensions: (`num_keys`, `batch_size`,
jagged dim), where `num_keys` is the number of keys, and `batch_size` is the
batch size. This operator is similar to a permute operator.
Args:
values (Tensor): The KJT values tensor which contains concatenated data of
every key
lengths (Tensor): The KJT lengths tensor which contains the jagged shapes
of every key (dim 0) and sample (dim 1). Shape is `num_keys *
batch_size`
offsets (Tensor): The KJT offsets tensor which is the complete cumulative
sum of `lengths`. Shape is `num_keys * batch_size + 1`
indices (Tensor): The indices to select, i.e., samples in the batch to
select. The values of `indices` must be >= 0 and < `batch_size`
batch_size (int): The batch size (dim 1 of KJT)
weights (Optional[Tensor] = None): An optional float tensor which will be
selected the same way as `values`. Thus, it must have the same shape as
`values`
selected_lengths_sum (Optional[int] = None): An optional value that
represents the total number of elements in the index select data
(output shape). If not provided, the operator will compute this data
which may cause a device-host synchronization (if using GPU). Thus, it
is recommended to supply this value to avoid such the synchronization.
Returns:
The index-select KJT tensor (as a list of values, lengths, and weights if
`weights` is not None)
**Example:**
>>> num_keys = 2
>>> batch_size = 4
>>> output_size = 3
>>> # Randomize inputs
>>> lengths = torch.randint(low=0, high=10, size=(batch_size * num_keys,), dtype=torch.int64, device="cuda")
>>> print(lengths)
tensor([8, 5, 1, 4, 2, 7, 5, 9], device='cuda:0')
>>> offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
>>> print(offsets)
tensor([ 0, 8, 13, 14, 18, 20, 27, 32, 41], device='cuda:0')
>>> indices = torch.randint(low=0, high=batch_size, size=(output_size,), dtype=torch.int64, device="cuda")
>>> print(indices)
tensor([3, 3, 1], device='cuda:0')
>>> # Use torch.arange instead of torch.randn to simplify the example
>>> values = torch.arange(lengths.sum().item(), dtype=torch.float32, device="cuda")
>>> print(values)
tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13.,
14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27.,
28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40.],
device='cuda:0')
>>> # Invoke. Output = (output, lengths)
>>> torch.ops.fbgemm.keyed_jagged_index_select_dim1(values, lengths, offsets, indices, batch_size)
[tensor([14., 15., 16., 17., 14., 15., 16., 17., 8., 9., 10., 11., 12., 32.,
33., 34., 35., 36., 37., 38., 39., 40., 32., 33., 34., 35., 36., 37.,
38., 39., 40., 20., 21., 22., 23., 24., 25., 26.], device='cuda:0'),
tensor([4, 4, 5, 9, 9, 7], device='cuda:0')]
""",
)

0 comments on commit 8c6f493

Please sign in to comment.