From 768776607b5cfe5c6c0ae0f8a31f669855f66180 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Thu, 26 Sep 2024 02:31:18 -0700 Subject: [PATCH] Add permute pooled embs docstrings (#3176) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/272 As title Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3176 Test Plan: See examples Module https://deploy-preview-3176--pytorch-fbgemm-docs.netlify.app/fbgemm_gpu-python-api/pooled_embedding_modules Op https://deploy-preview-3176--pytorch-fbgemm-docs.netlify.app/fbgemm_gpu-python-api/pooled_embedding_ops Differential Revision: D63440824 Pulled By: sryap --- .../pooled_embedding_modules.rst | 7 ++ .../pooled_embedding_ops.rst | 2 + fbgemm_gpu/docs/src/index.rst | 14 ++- fbgemm_gpu/fbgemm_gpu/docs/__init__.py | 6 +- .../docs/permute_pooled_embedding_ops.py | 108 ++++++++++++++++++ .../permute_pooled_embedding_modules.py | 76 ++++++++++++ 6 files changed, 209 insertions(+), 4 deletions(-) create mode 100644 fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_modules.rst create mode 100644 fbgemm_gpu/fbgemm_gpu/docs/permute_pooled_embedding_ops.py diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_modules.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_modules.rst new file mode 100644 index 000000000..654373f40 --- /dev/null +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_modules.rst @@ -0,0 +1,7 @@ +Pooled Embedding Modules +======================== + +.. automodule:: fbgemm_gpu + +.. autoclass:: fbgemm_gpu.permute_pooled_embedding_modules.PermutePooledEmbeddings + :members: __call__ diff --git a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_ops.rst b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_ops.rst index 519b74e6b..52e2fd47d 100644 --- a/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_ops.rst +++ b/fbgemm_gpu/docs/src/fbgemm_gpu-python-api/pooled_embedding_ops.rst @@ -4,3 +4,5 @@ Pooled Embedding Operators .. automodule:: fbgemm_gpu .. autofunction:: torch.ops.fbgemm.merge_pooled_embeddings + +.. autofunction:: torch.ops.fbgemm.permute_pooled_embs diff --git a/fbgemm_gpu/docs/src/index.rst b/fbgemm_gpu/docs/src/index.rst index c4d98c720..ba0d8ba6b 100644 --- a/fbgemm_gpu/docs/src/index.rst +++ b/fbgemm_gpu/docs/src/index.rst @@ -83,12 +83,20 @@ Table of Contents fbgemm_gpu-cpp-api/ssd_embedding_ops.rst fbgemm_gpu-cpp-api/experimental_ops.rst -.. _fbgemm-gpu.toc.api.python: +.. _fbgemm-gpu.toc.api.python.ops: .. toctree:: :maxdepth: 1 - :caption: FBGEMM_GPU Python API + :caption: FBGEMM_GPU Python Operators API - fbgemm_gpu-python-api/table_batched_embedding_ops.rst fbgemm_gpu-python-api/jagged_tensor_ops.rst fbgemm_gpu-python-api/pooled_embedding_ops.rst + +.. _fbgemm-gpu.toc.api.python.modules: + +.. toctree:: + :maxdepth: 1 + :caption: FBGEMM_GPU Python Modules API + + fbgemm_gpu-python-api/table_batched_embedding_ops.rst + fbgemm_gpu-python-api/pooled_embedding_modules.rst diff --git a/fbgemm_gpu/fbgemm_gpu/docs/__init__.py b/fbgemm_gpu/fbgemm_gpu/docs/__init__.py index 5077a5ba3..4b621cbe3 100644 --- a/fbgemm_gpu/fbgemm_gpu/docs/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/docs/__init__.py @@ -7,6 +7,10 @@ # Trigger the manual addition of docstrings to pybind11-generated operators try: - from . import jagged_tensor_ops, merge_pooled_embedding_ops # noqa: F401 + from . import ( # noqa: F401 + jagged_tensor_ops, + merge_pooled_embedding_ops, + permute_pooled_embedding_ops, + ) except Exception: pass diff --git a/fbgemm_gpu/fbgemm_gpu/docs/permute_pooled_embedding_ops.py b/fbgemm_gpu/fbgemm_gpu/docs/permute_pooled_embedding_ops.py new file mode 100644 index 000000000..825686002 --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/docs/permute_pooled_embedding_ops.py @@ -0,0 +1,108 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from .common import add_docs + +add_docs( + torch.ops.fbgemm.permute_pooled_embs, + """ +permute_pooled_embs(pooled_embs, offset_dim_list, permute_list, inv_offset_dim_list, inv_permute_list) -> Tensor + +Permute embedding outputs along the feature dimension. + +The embedding output tensor `pooled_embs` contains the embedding outputs +for all features in a batch. It is represented in a 2D format, where the +rows are the batch size dimension and the columns are the feature * +embedding dimension. Permuting along the feature dimension is +essentially permuting along the second dimension (dim 1). + +Args: + pooled_embs (Tensor): The embedding outputs to permute. Shape is + `(B_local, total_global_D)`, where `B_local` = a local batch size + and `total_global_D` is the total embedding dimension across all + features (global) + + offset_dim_list (Tensor): The complete cumulative sum of embedding + dimensions of all features. Shape is `T + 1` where `T` is the + total number of features + + permute_list (Tensor): A tensor that describes how each feature is + permuted. `permute_list[i]` indicates that the feature + `permute_list[i]` is permuted to position `i` + + inv_offset_dim_list (Tensor): The complete cumulative sum of inverse + embedding dimensions, which are the permuted embedding dimensions. + `inv_offset_dim_list[i]` represents the starting embedding position of + feature `permute_list[i]` + + inv_permute_list (Tensor): The inverse permute list, which contains the + permuted positions of each feature. `inv_permute_list[i]` represents + the permuted position of feature `i` + +Returns: + Permuted embedding outputs (Tensor). Same shape as `pooled_embs` + +**Example:** + + >>> import torch + >>> from itertools import accumulate + >>> + >>> # Suppose batch size = 3 and there are 3 features + >>> batch_size = 3 + >>> + >>> # Embedding dimensions for each feature + >>> embs_dims = torch.tensor([4, 4, 8], dtype=torch.int64, device="cuda") + >>> + >>> # Permute list, i.e., move feature 2 to position 0, move feature 0 + >>> # to position 1, so on + >>> permute = torch.tensor([2, 0, 1], dtype=torch.int64, device="cuda") + >>> + >>> # Compute embedding dim offsets + >>> offset_dim_list = torch.tensor([0] + list(accumulate(embs_dims)), dtype=torch.int64, device="cuda") + >>> print(offset_dim_list) + >>> + tensor([ 0, 4, 8, 16], device='cuda:0') + >>> + >>> # Compute inverse embedding dims + >>> inv_embs_dims = [embs_dims[p] for p in permute] + >>> # Compute complete cumulative sum of inverse embedding dims + >>> inv_offset_dim_list = torch.tensor([0] + list(accumulate(inv_embs_dims)), dtype=torch.int64, device="cuda") + >>> print(inv_offset_dim_list) + >>> + tensor([ 0, 8, 12, 16], device='cuda:0') + >>> + >>> # Compute inverse permutes + >>> inv_permute = [0] * len(permute) + >>> for i, p in enumerate(permute): + >>> inv_permute[p] = i + >>> inv_permute_list = torch.tensor([inv_permute], dtype=torch.int64, device="cuda") + >>> print(inv_permute_list) + >>> + tensor([[1, 2, 0]], device='cuda:0') + >>> + >>> # Generate an example input + >>> pooled_embs = torch.arange(embs_dims.sum().item() * batch_size, dtype=torch.float32, device="cuda").reshape(batch_size, -1) + >>> print(pooled_embs) + >>> + tensor([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., + 14., 15.], + [16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., + 30., 31.], + [32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., + 46., 47.]], device='cuda:0') + >>> + >>> torch.ops.fbgemm.permute_pooled_embs_auto_grad(pooled_embs, offset_dim_list, permute, inv_offset_dim_list, inv_permute_list) + >>> + tensor([[ 8., 9., 10., 11., 12., 13., 14., 15., 0., 1., 2., 3., 4., 5., + 6., 7.], + [24., 25., 26., 27., 28., 29., 30., 31., 16., 17., 18., 19., 20., 21., + 22., 23.], + [40., 41., 42., 43., 44., 45., 46., 47., 32., 33., 34., 35., 36., 37., + 38., 39.]], device='cuda:0') + """, +) diff --git a/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py b/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py index cbdf1b0bc..2f26c3547 100644 --- a/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py +++ b/fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py @@ -27,6 +27,70 @@ class PermutePooledEmbeddings: + """ + A module for permuting embedding outputs along the feature dimension + + An embedding output tensor contains the embedding outputs for all features + in a batch. It is represented in a 2D format, where the rows are the batch + size dimension and the columns are the feature * embedding dimension. + Permuting along the feature dimension is essentially permuting along the + second dimension (dim 1). + + **Example:** + + >>> import torch + >>> import fbgemm_gpu + >>> from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings + >>> + >>> # Suppose batch size = 3 and there are 3 features + >>> batch_size = 3 + >>> + >>> # Embedding dimensions for each feature + >>> embs_dims = torch.tensor([4, 4, 8], dtype=torch.int64, device="cuda") + >>> + >>> # Permute list, i.e., move feature 2 to position 0, move feature 0 + >>> # to position 1, so on + >>> permute = [2, 0, 1] + >>> + >>> # Instantiate the module + >>> perm = PermutePooledEmbeddings(embs_dims, permute) + >>> + >>> # Generate an example input + >>> pooled_embs = torch.arange( + >>> embs_dims.sum().item() * batch_size, + >>> dtype=torch.float32, device="cuda" + >>> ).reshape(batch_size, -1) + >>> print(pooled_embs) + >>> + tensor([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., + 14., 15.], + [16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., + 30., 31.], + [32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45., + 46., 47.]], device='cuda:0') + >>> + >>> # Invoke + >>> perm(pooled_embs) + >>> + tensor([[ 8., 9., 10., 11., 12., 13., 14., 15., 0., 1., 2., 3., 4., 5., + 6., 7.], + [24., 25., 26., 27., 28., 29., 30., 31., 16., 17., 18., 19., 20., 21., + 22., 23.], + [40., 41., 42., 43., 44., 45., 46., 47., 32., 33., 34., 35., 36., 37., + 38., 39.]], device='cuda:0') + + Args: + embs_dims (List[int]): A list of embedding dimensions for all features. + Length = the number of features + + permute (List[int]): A list that describes how each feature is + permuted. `permute[i]` is to permute feature `permute[i]` to + position `i`. + + device (Optional[torch.device] = None): The device to run this module + on + """ + def __init__( self, embs_dims: List[int], @@ -56,6 +120,18 @@ def __init__( ) def __call__(self, pooled_embs: torch.Tensor) -> torch.Tensor: + """ + Performs pooled embedding output permutation along the feature dimension + + Args: + pooled_embs (Tensor): The embedding outputs to permute. Shape is + `(B_local, total_global_D)`, where `B_local` = a local batch + size and `total_global_D` is the total embedding dimension + across all features (global) + + Returns: + Permuted embedding outputs (Tensor). Same shape as `pooled_embs` + """ result = torch.ops.fbgemm.permute_pooled_embs_auto_grad( pooled_embs, self._offset_dim_list.to(device=pooled_embs.device),