From 6c578c1f1040eaf0c2a74632460255821142dd86 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Wed, 25 Sep 2024 08:07:46 -0700 Subject: [PATCH] Merge pooled emb docstring (#3172) Summary: As title Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3172 Differential Revision: D63391637 Pulled By: sryap --- .../pooled_embedding_ops.rst | 6 ++++ fbgemm_gpu/docs/src/index.rst | 1 + fbgemm_gpu/fbgemm_gpu/docs/__init__.py | 2 +- .../docs/merge_pooled_embedding_ops.py | 36 +++++++++++++++++++ 4 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_ops.rst create mode 100644 fbgemm_gpu/fbgemm_gpu/docs/merge_pooled_embedding_ops.py diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_ops.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_ops.rst new file mode 100644 index 000000000..519b74e6b --- /dev/null +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_ops.rst @@ -0,0 +1,6 @@ +Pooled Embedding Operators +========================== + +.. automodule:: fbgemm_gpu + +.. autofunction:: torch.ops.fbgemm.merge_pooled_embeddings diff --git a/fbgemm_gpu/docs/src/index.rst b/fbgemm_gpu/docs/src/index.rst index a71a58995..c4d98c720 100644 --- a/fbgemm_gpu/docs/src/index.rst +++ b/fbgemm_gpu/docs/src/index.rst @@ -91,3 +91,4 @@ Table of Contents fbgemm_gpu-python-api/table_batched_embedding_ops.rst fbgemm_gpu-python-api/jagged_tensor_ops.rst + fbgemm_gpu-python-api/pooled_embedding_ops.rst diff --git a/fbgemm_gpu/fbgemm_gpu/docs/__init__.py b/fbgemm_gpu/fbgemm_gpu/docs/__init__.py index 250f9d58e..5077a5ba3 100644 --- a/fbgemm_gpu/fbgemm_gpu/docs/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/docs/__init__.py @@ -7,6 +7,6 @@ # Trigger the manual addition of docstrings to pybind11-generated operators try: - from . import jagged_tensor_ops, table_batched_embedding_ops # noqa: F401 + from . import jagged_tensor_ops, merge_pooled_embedding_ops # noqa: F401 except Exception: pass diff --git a/fbgemm_gpu/fbgemm_gpu/docs/merge_pooled_embedding_ops.py b/fbgemm_gpu/fbgemm_gpu/docs/merge_pooled_embedding_ops.py new file mode 100644 index 000000000..6990946fb --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/docs/merge_pooled_embedding_ops.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from .common import add_docs + +add_docs( + torch.ops.fbgemm.merge_pooled_embeddings, + """ +merge_pooled_embeddings(pooled_embeddings, uncat_dim_size, target_device, cat_dim=1) -> Tensor + +Concatenate embedding outputs from different devices (on the same host) +on to the target device. + +Args: + pooled_embeddings (List[Tensor]): A list of embedding outputs from + different devices on the same host. Each output has 2 + dimensions. + + uncat_dim_size (int): The size of the dimension that is not + concatenated, i.e., if `cat_dim=0`, `uncat_dim_size` is the size + of dim 1 and vice versa. + + target_device (torch.device): The target device that aggregates all + the embedding outputs. + + cat_dim (int = 1): The dimension that the tensors are concatenated + +Returns: + The concatenated embedding output (2D) on the target device + """, +)