Skip to content

Commit

Permalink
Add a docstring for merge_pooled_embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
sryap committed Sep 25, 2024
1 parent fd2524a commit 2f8d76d
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Pooled Embedding Operators
==========================

.. automodule:: fbgemm_gpu

.. autofunction:: torch.ops.fbgemm.merge_pooled_embeddings
1 change: 1 addition & 0 deletions fbgemm_gpu/docs/src/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,4 @@ Table of Contents

fbgemm_gpu-python-api/table_batched_embedding_ops.rst
fbgemm_gpu-python-api/jagged_tensor_ops.rst
fbgemm_gpu-python-api/pooled_embedding_ops.rst
2 changes: 1 addition & 1 deletion fbgemm_gpu/fbgemm_gpu/docs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@

# Trigger the manual addition of docstrings to pybind11-generated operators
try:
from . import jagged_tensor_ops, table_batched_embedding_ops # noqa: F401
from . import jagged_tensor_ops, merge_pooled_embedding_ops # noqa: F401
except Exception:
pass
36 changes: 36 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/docs/merge_pooled_embedding_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from .common import add_docs

add_docs(
torch.ops.fbgemm.merge_pooled_embeddings,
"""
merge_pooled_embeddings(pooled_embeddings, uncat_dim_size, target_device, cat_dim=1) -> Tensor
Concatenate embedding outputs from different devices (on the same host)
on to the target device.
Args:
pooled_embeddings (List[Tensor]): A list of embedding outputs from
different devices on the same host. Each output has 2
dimensions.
uncat_dim_size (int): The size of the dimension that is not
concatenated, i.e., if `cat_dim=0`, `uncat_dim_size` is the size
of dim 1 and vice versa.
target_device (torch.device): The target device that aggregates all
the embedding outputs.
cat_dim (int = 1): The dimension that the tensors are concatenated
Returns:
The concatenated embedding output (2D) on the target device
""",
)

0 comments on commit 2f8d76d

Please sign in to comment.