Skip to content

Commit

Permalink
Add permute pooled embs docstrings (#3176)
Browse files Browse the repository at this point in the history
  • Loading branch information
sryap authored and facebook-github-bot committed Sep 26, 2024
1 parent b152339 commit c00a127
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Pooled Embedding Modules
========================

.. automodule:: fbgemm_gpu

.. autoclass:: fbgemm_gpu.permute_pooled_embedding_modules.PermutePooledEmbeddings
:members: __call__
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ Pooled Embedding Operators
.. automodule:: fbgemm_gpu

.. autofunction:: torch.ops.fbgemm.merge_pooled_embeddings

.. autofunction:: torch.ops.fbgemm.permute_pooled_embs
14 changes: 11 additions & 3 deletions fbgemm_gpu/docs/src/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,20 @@ Table of Contents
fbgemm_gpu-cpp-api/ssd_embedding_ops.rst
fbgemm_gpu-cpp-api/experimental_ops.rst

.. _fbgemm-gpu.toc.api.python:
.. _fbgemm-gpu.toc.api.python.ops:

.. toctree::
:maxdepth: 1
:caption: FBGEMM_GPU Python API
:caption: FBGEMM_GPU Python Operators API

fbgemm_gpu-python-api/table_batched_embedding_ops.rst
fbgemm_gpu-python-api/jagged_tensor_ops.rst
fbgemm_gpu-python-api/pooled_embedding_ops.rst

.. _fbgemm-gpu.toc.api.python.modules:

.. toctree::
:maxdepth: 1
:caption: FBGEMM_GPU Python Modules API

fbgemm_gpu-python-api/table_batched_embedding_ops.rst
fbgemm_gpu-python-api/pooled_embedding_modules.rst
6 changes: 5 additions & 1 deletion fbgemm_gpu/fbgemm_gpu/docs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@

# Trigger the manual addition of docstrings to pybind11-generated operators
try:
from . import jagged_tensor_ops, merge_pooled_embedding_ops # noqa: F401
from . import ( # noqa: F401
jagged_tensor_ops,
merge_pooled_embedding_ops,
permute_pooled_embedding_ops,
)
except Exception:
pass
108 changes: 108 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/docs/permute_pooled_embedding_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from .common import add_docs

add_docs(
torch.ops.fbgemm.permute_pooled_embs,
"""
permute_pooled_embs(pooled_embs, offset_dim_list, permute_list, inv_offset_dim_list, inv_permute_list) -> Tensor
Permute embedding outputs along the feature dimension.
The embedding output tensor `pooled_embs` contains the embedding outputs
for all features in a batch. It is represented in a 2D format, where the
rows are the batch size dimension and the columns are the feature *
embedding dimension. Permuting along the feature dimension is
essentially permuting along the second dimension (dim 1).
Args:
pooled_embs (Tensor): The embedding outputs to permute. Shape is
`(B_local, total_global_D)`, where `B_local` = a local batch size
and `total_global_D` is the total embedding dimension across all
features (global)
offset_dim_list (Tensor): The complete cumulative sum of embedding
dimensions of all features. Shape is `T + 1` where `T` is the
total number of features
permute_list (Tensor): A tensor that describes how each feature is
permuted. `permute_list[i]` indicates that the feature
`permute_list[i]` is permuted to position `i`
inv_offset_dim_list (Tensor): The complete cumulative sum of inverse
embedding dimensions, which are the permuted embedding dimensions.
`inv_offset_dim_list[i]` represents the starting embedding position of
feature `permute_list[i]`
inv_permute_list (Tensor): The inverse permute list, which contains the
permuted positions of each feature. `inv_permute_list[i]` represents
the permuted position of feature `i`
Returns:
Permuted embedding outputs (Tensor). Same shape as `pooled_embs`
**Example:**
>>> import torch
>>> from itertools import accumulate
>>>
>>> # Suppose batch size = 3 and there are 3 features
>>> batch_size = 3
>>>
>>> # Embedding dimensions for each feature
>>> embs_dims = torch.tensor([4, 4, 8], dtype=torch.int64, device="cuda")
>>>
>>> # Permute list, i.e., move feature 2 to position 0, move feature 0
>>> # to position 1, so on
>>> permute = torch.tensor([2, 0, 1], dtype=torch.int64, device="cuda")
>>>
>>> # Compute embedding dim offsets
>>> offset_dim_list = torch.tensor([0] + list(accumulate(embs_dims)), dtype=torch.int64, device="cuda")
>>> print(offset_dim_list)
>>>
tensor([ 0, 4, 8, 16], device='cuda:0')
>>>
>>> # Compute inverse embedding dims
>>> inv_embs_dims = [embs_dims[p] for p in permute]
>>> # Compute complete cumulative sum of inverse embedding dims
>>> inv_offset_dim_list = torch.tensor([0] + list(accumulate(inv_embs_dims)), dtype=torch.int64, device="cuda")
>>> print(inv_offset_dim_list)
>>>
tensor([ 0, 8, 12, 16], device='cuda:0')
>>>
>>> # Compute inverse permutes
>>> inv_permute = [0] * len(permute)
>>> for i, p in enumerate(permute):
>>> inv_permute[p] = i
>>> inv_permute_list = torch.tensor([inv_permute], dtype=torch.int64, device="cuda")
>>> print(inv_permute_list)
>>>
tensor([[1, 2, 0]], device='cuda:0')
>>>
>>> # Generate an example input
>>> pooled_embs = torch.arange(embs_dims.sum().item() * batch_size, dtype=torch.float32, device="cuda").reshape(batch_size, -1)
>>> print(pooled_embs)
>>>
tensor([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13.,
14., 15.],
[16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29.,
30., 31.],
[32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45.,
46., 47.]], device='cuda:0')
>>>
>>> torch.ops.fbgemm.permute_pooled_embs_auto_grad(pooled_embs, offset_dim_list, permute, inv_offset_dim_list, inv_permute_list)
>>>
tensor([[ 8., 9., 10., 11., 12., 13., 14., 15., 0., 1., 2., 3., 4., 5.,
6., 7.],
[24., 25., 26., 27., 28., 29., 30., 31., 16., 17., 18., 19., 20., 21.,
22., 23.],
[40., 41., 42., 43., 44., 45., 46., 47., 32., 33., 34., 35., 36., 37.,
38., 39.]], device='cuda:0')
""",
)
76 changes: 76 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,70 @@


class PermutePooledEmbeddings:
"""
A module for permuting embedding outputs along the feature dimension
An embedding output tensor contains the embedding outputs for all features
in a batch. It is represented in a 2D format, where the rows are the batch
size dimension and the columns are the feature * embedding dimension.
Permuting along the feature dimension is essentially permuting along the
second dimension (dim 1).
**Example:**
>>> import torch
>>> import fbgemm_gpu
>>> from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
>>>
>>> # Suppose batch size = 3 and there are 3 features
>>> batch_size = 3
>>>
>>> # Embedding dimensions for each feature
>>> embs_dims = torch.tensor([4, 4, 8], dtype=torch.int64, device="cuda")
>>>
>>> # Permute list, i.e., move feature 2 to position 0, move feature 0
>>> # to position 1, so on
>>> permute = [2, 0, 1]
>>>
>>> # Instantiate the module
>>> perm = PermutePooledEmbeddings(embs_dims, permute)
>>>
>>> # Generate an example input
>>> pooled_embs = torch.arange(
>>> embs_dims.sum().item() * batch_size,
>>> dtype=torch.float32, device="cuda"
>>> ).reshape(batch_size, -1)
>>> print(pooled_embs)
>>>
tensor([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13.,
14., 15.],
[16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29.,
30., 31.],
[32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42., 43., 44., 45.,
46., 47.]], device='cuda:0')
>>>
>>> # Invoke
>>> perm(pooled_embs)
>>>
tensor([[ 8., 9., 10., 11., 12., 13., 14., 15., 0., 1., 2., 3., 4., 5.,
6., 7.],
[24., 25., 26., 27., 28., 29., 30., 31., 16., 17., 18., 19., 20., 21.,
22., 23.],
[40., 41., 42., 43., 44., 45., 46., 47., 32., 33., 34., 35., 36., 37.,
38., 39.]], device='cuda:0')
Args:
embs_dims (List[int]): A list of embedding dimensions for all features.
Length = the number of features
permute (List[int]): A list that describes how each feature is
permuted. `permute[i]` is to permute feature `permute[i]` to
position `i`.
device (Optional[torch.device] = None): The device to run this module
on
"""

def __init__(
self,
embs_dims: List[int],
Expand Down Expand Up @@ -56,6 +120,18 @@ def __init__(
)

def __call__(self, pooled_embs: torch.Tensor) -> torch.Tensor:
"""
Performs pooled embedding output permutation along the feature dimension
Args:
pooled_embs (Tensor): The embedding outputs to permute. Shape is
`(B_local, total_global_D)`, where `B_local` = a local batch
size and `total_global_D` is the total embedding dimension
across all features (global)
Returns:
Permuted embedding outputs (Tensor). Same shape as `pooled_embs`
"""
result = torch.ops.fbgemm.permute_pooled_embs_auto_grad(
pooled_embs,
self._offset_dim_list.to(device=pooled_embs.device),
Expand Down

0 comments on commit c00a127

Please sign in to comment.