From 441697c0481f82b9c328d39f70e4b34fdc890758 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Tue, 2 Jan 2024 17:00:59 -0800 Subject: [PATCH] Re-structure the Python documentation (#2239) Summary: - Re-structure the Python documentation Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2239 Reviewed By: spcyppt Differential Revision: D52495567 Pulled By: q10 fbshipit-source-id: a46406c8755c61cee0dae6d6e06805f5f31f6afd --- .github/scripts/fbgemm_gpu_build.bash | 1 + CONTRIBUTING.md | 12 +- fbgemm_gpu/docs/src/conf.py | 1 + fbgemm_gpu/docs/src/general/ContactInfo.rst | 67 ----- fbgemm_gpu/docs/src/general/ContactUs.rst | 17 ++ fbgemm_gpu/docs/src/general/Contributing.rst | 2 + .../docs/src/general/DocsInstructions.rst | 7 +- fbgemm_gpu/docs/src/index.rst | 2 + .../jagged-tensor-ops/JaggedTensorOps.rst | 2 +- .../docs/src/python-api/jagged_tensor_ops.rst | 13 - fbgemm_gpu/fbgemm_gpu/__init__.py | 16 +- fbgemm_gpu/fbgemm_gpu/docs/__init__.py | 3 + fbgemm_gpu/fbgemm_gpu/docs/common.py | 9 + .../fbgemm_gpu/docs/jagged_tensor_ops.py | 256 +++++++++++++++++ .../table_batched_embedding_ops.py} | 266 +----------------- fbgemm_gpu/setup.py | 5 +- 16 files changed, 318 insertions(+), 361 deletions(-) delete mode 100644 fbgemm_gpu/docs/src/general/ContactInfo.rst create mode 100644 fbgemm_gpu/docs/src/general/ContactUs.rst create mode 100644 fbgemm_gpu/docs/src/general/Contributing.rst create mode 100644 fbgemm_gpu/fbgemm_gpu/docs/common.py create mode 100644 fbgemm_gpu/fbgemm_gpu/docs/jagged_tensor_ops.py rename fbgemm_gpu/fbgemm_gpu/{_fbgemm_gpu_docs.py => docs/table_batched_embedding_ops.py} (51%) diff --git a/.github/scripts/fbgemm_gpu_build.bash b/.github/scripts/fbgemm_gpu_build.bash index 51388037b..3d754c2f4 100644 --- a/.github/scripts/fbgemm_gpu_build.bash +++ b/.github/scripts/fbgemm_gpu_build.bash @@ -420,6 +420,7 @@ build_fbgemm_gpu_install () { # fbgemm_gpu/ subdirectory present cd - || return 1 (test_python_import_package "${env_name}" fbgemm_gpu) || return 1 + (test_python_import_symbol "${env_name}" fbgemm_gpu __version__) || return 1 cd - || return 1 echo "[BUILD] FBGEMM-GPU build + install completed" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2c431cdad..bcd4f835e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,27 +1,32 @@ -# Contributing to FBGEMM +# Contributing to FBGEMM / FBGEMM_GPU + We want to make contributing to this project as easy and transparent as possible. ## Code of Conduct + The code of conduct is described in [`CODE_OF_CONDUCT.md`](CODE_OF_CONDUCT.md). ## Pull Requests + We actively welcome your pull requests. -1. Fork the repo and create your branch from `main`. +1. **Fork** the repository and create your branch from `main`. 2. If you've added code that should be tested, add tests. -3. If you've changed APIs, update the documentation. +3. If you've added or changed APIs, update the documentation. 4. Ensure the test suite passes. 5. Make sure your code lints. 6. If you haven't already, complete the Contributor License Agreement ("CLA"). ## Contributor License Agreement ("CLA") + In order to accept your pull request, we need you to submit a CLA. You only need to do this once to work on any of Facebook's open source projects. Complete your CLA here: ## Issues + We use GitHub issues to track public bugs. Please ensure your description is clear and has sufficient instructions to be able to reproduce the issue. @@ -30,5 +35,6 @@ disclosure of security bugs. In those cases, please go through the process outlined on that page and do not file a public issue. ## License + By contributing to FBGEMM, you agree that your contributions will be licensed under the LICENSE file in the root directory of this source tree. diff --git a/fbgemm_gpu/docs/src/conf.py b/fbgemm_gpu/docs/src/conf.py index 79a19aa50..bbc09d1a9 100644 --- a/fbgemm_gpu/docs/src/conf.py +++ b/fbgemm_gpu/docs/src/conf.py @@ -52,6 +52,7 @@ "myst_parser", "sphinx.ext.autodoc", "sphinx.ext.autosectionlabel", + "sphinx.ext.autosummary", "sphinx.ext.intersphinx", "sphinx.ext.mathjax", "sphinx.ext.napoleon", diff --git a/fbgemm_gpu/docs/src/general/ContactInfo.rst b/fbgemm_gpu/docs/src/general/ContactInfo.rst deleted file mode 100644 index b36bb4b6e..000000000 --- a/fbgemm_gpu/docs/src/general/ContactInfo.rst +++ /dev/null @@ -1,67 +0,0 @@ -Testing FBGEMM_GPU ------------------- - -The tests (in the ``fbgemm_gpu/test/`` directory) and benchmarks (in the -``fbgemm_gpu/bench/`` directory) provide good examples on how to use FBGEMM_GPU. - -FBGEMM_GPU Tests -~~~~~~~~~~~~~~~~ - -To run the tests after building / installing the FBGEMM_GPU package: - -.. code:: sh - - # From the /fbgemm_gpu/ directory - cd test - - python -m pytest -v -rsx -s -W ignore::pytest.PytestCollectionWarning split_table_batched_embeddings_test.py - python -m pytest -v -rsx -s -W ignore::pytest.PytestCollectionWarning quantize_ops_test.py - python -m pytest -v -rsx -s -W ignore::pytest.PytestCollectionWarning sparse_ops_test.py - python -m pytest -v -rsx -s -W ignore::pytest.PytestCollectionWarning split_embedding_inference_converter_test.py - -Testing with the CUDA Variant -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -For the FBGEMM_GPU CUDA package, GPUs will be automatically detected and -used for testing. To run the tests and benchmarks on a GPU-capable -device in CPU-only mode, ``CUDA_VISIBLE_DEVICES=-1`` must be set in the -environment: - -.. code:: sh - - # Enable for running in CPU-only mode (when on a GPU-capable machine) - export CUDA_VISIBLE_DEVICES=-1 - - # Enable for debugging failed kernel executions - export CUDA_LAUNCH_BLOCKING=1 - - python -m pytest -v -rsx -s -W ignore::pytest.PytestCollectionWarning split_table_batched_embeddings_test.py - -Testing with the ROCm Variant -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -For ROCm machines, testing against a ROCm GPU needs to be enabled with -``FBGEMM_TEST_WITH_ROCM=1`` set in the environment: - -.. code:: sh - - # From the /fbgemm_gpu/ directory - cd test - - export FBGEMM_TEST_WITH_ROCM=1 - # Enable for debugging failed kernel executions - export HIP_LAUNCH_BLOCKING=1 - - python -m pytest -v -rsx -s -W ignore::pytest.PytestCollectionWarning split_table_batched_embeddings_test.py - -FBGEMM_GPU Benchmarks -~~~~~~~~~~~~~~~~~~~~~ - -To run the benchmarks: - -.. code:: sh - - # From the /fbgemm_gpu/ directory - cd bench - - python split_table_batched_embeddings_benchmark.py uvm diff --git a/fbgemm_gpu/docs/src/general/ContactUs.rst b/fbgemm_gpu/docs/src/general/ContactUs.rst new file mode 100644 index 000000000..494aeb9e3 --- /dev/null +++ b/fbgemm_gpu/docs/src/general/ContactUs.rst @@ -0,0 +1,17 @@ +Contact Us +========== + +GitHub +------ + +* `GitHub Issues `__: Use this to file + questions, issues, and feature requests concerning FBGEMM_GPU. + +* `GitHub Discussions `__: Use + avenuue to kick off longer discussions regarding FBGEMM_GPU. + +Slack +----- + +Feel free to reach out to us on the ``#fbgemm`` channel in +`Pytorch Slack `__. diff --git a/fbgemm_gpu/docs/src/general/Contributing.rst b/fbgemm_gpu/docs/src/general/Contributing.rst new file mode 100644 index 000000000..40d94c264 --- /dev/null +++ b/fbgemm_gpu/docs/src/general/Contributing.rst @@ -0,0 +1,2 @@ +.. include:: ../../../../CONTRIBUTING.md + :parser: myst_parser.sphinx_ diff --git a/fbgemm_gpu/docs/src/general/DocsInstructions.rst b/fbgemm_gpu/docs/src/general/DocsInstructions.rst index 9ee4a5198..4deb915ea 100644 --- a/fbgemm_gpu/docs/src/general/DocsInstructions.rst +++ b/fbgemm_gpu/docs/src/general/DocsInstructions.rst @@ -1,5 +1,5 @@ -Contributing Documentation -========================== +Building Documentation +====================== FBGEMM_GPU provides extensive comments in its source files, which provide the most authoritative and up-to-date documentation available for the package. @@ -170,7 +170,8 @@ When you add descriptionss to a function, make sure that the ``#ifndef`` and ``#endif`` are configured correctly. All functions are grouped by a specific group for better organization. -Make sure you add ``@defgroup`` to the code comments. +Make sure you add ``@defgroup`` to the code comments to define the group, and +``@ingroup`` in each docstring to associate the target method with the group. Follow these instructions to document, generate, and publish a new C++ description: diff --git a/fbgemm_gpu/docs/src/index.rst b/fbgemm_gpu/docs/src/index.rst index 057a1d670..a52e56681 100644 --- a/fbgemm_gpu/docs/src/index.rst +++ b/fbgemm_gpu/docs/src/index.rst @@ -19,6 +19,8 @@ library. general/InstallationInstructions.rst general/TestInstructions.rst general/DocsInstructions.rst + general/Contributing.rst + general/ContactUs.rst .. _fbgemm-gpu.docs.toc.overview: diff --git a/fbgemm_gpu/docs/src/overview/jagged-tensor-ops/JaggedTensorOps.rst b/fbgemm_gpu/docs/src/overview/jagged-tensor-ops/JaggedTensorOps.rst index 4ab34f509..a56de4fbe 100644 --- a/fbgemm_gpu/docs/src/overview/jagged-tensor-ops/JaggedTensorOps.rst +++ b/fbgemm_gpu/docs/src/overview/jagged-tensor-ops/JaggedTensorOps.rst @@ -24,7 +24,7 @@ Jagged Tensor Format ------------------- Jagged tensors are effectively represented in FBGEMm_GPU as a three-tensor -object. The three tensors are: **Values**, **Max Lengths**, and **Offsets**. +object. The three tensors are: **Values**, **MaxLengths**, and **Offsets**. Values ~~~~~~ diff --git a/fbgemm_gpu/docs/src/python-api/jagged_tensor_ops.rst b/fbgemm_gpu/docs/src/python-api/jagged_tensor_ops.rst index 9bdcd098d..ca1cbe522 100644 --- a/fbgemm_gpu/docs/src/python-api/jagged_tensor_ops.rst +++ b/fbgemm_gpu/docs/src/python-api/jagged_tensor_ops.rst @@ -1,37 +1,24 @@ Jagged Tensor Operators ======================= -.. automodule:: fbgemm_gpu - - .. autofunction:: torch.ops.fbgemm.jagged_2d_to_dense - .. autofunction:: torch.ops.fbgemm.jagged_1d_to_dense - .. autofunction:: torch.ops.fbgemm.dense_to_jagged - .. autofunction:: torch.ops.fbgemm.jagged_to_padded_dense - .. autofunction:: torch.ops.fbgemm.jagged_dense_elementwise_add - .. autofunction:: torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output - .. autofunction:: torch.ops.fbgemm.jagged_dense_dense_elementwise_add_jagged_output - .. autofunction:: torch.ops.fbgemm.jagged_dense_elementwise_mul - .. autofunction:: torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul - .. autofunction:: torch.ops.fbgemm.stacked_jagged_1d_to_dense - .. autofunction:: torch.ops.fbgemm.stacked_jagged_2d_to_dense diff --git a/fbgemm_gpu/fbgemm_gpu/__init__.py b/fbgemm_gpu/fbgemm_gpu/__init__.py index 3ca97566b..555c8574e 100644 --- a/fbgemm_gpu/fbgemm_gpu/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/__init__.py @@ -14,13 +14,15 @@ except Exception as e: print(e) -# __init__.py is only used in OSS -# Use existence to check if fbgemm_gpu_py.so has already been loaded +# Since __init__.py is only used in OSS context, we define `open_source` here +# and use its existence to determine whether or not we are in OSS context open_source: bool = True -# Re-export docs -# Trigger meta registrations -from . import _fbgemm_gpu_docs, sparse_ops # noqa: F401, E402 # noqa: F401, E402 +# Trigger the manual addition of docstrings to pybind11-generated operators +import fbgemm_gpu.docs # noqa: F401, E402 -# Re-export the version string from the auto-generated version file -from ._fbgemm_gpu_version import __version__ # noqa: F401, E402 +# Export the version string from the version file auto-generated by setup.py +from fbgemm_gpu.docs.version import __version__ # noqa: F401, E402 + +# Trigger meta operator registrations +from . import sparse_ops # noqa: F401, E402 diff --git a/fbgemm_gpu/fbgemm_gpu/docs/__init__.py b/fbgemm_gpu/fbgemm_gpu/docs/__init__.py index a9fdb3b99..22cf06544 100644 --- a/fbgemm_gpu/fbgemm_gpu/docs/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/docs/__init__.py @@ -4,3 +4,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +# Trigger the manual addition of docstrings to pybind11-generated operators +from . import jagged_tensor_ops, table_batched_embedding_ops # noqa: F401 # noqa: F401 diff --git a/fbgemm_gpu/fbgemm_gpu/docs/common.py b/fbgemm_gpu/fbgemm_gpu/docs/common.py new file mode 100644 index 000000000..cf283b1b1 --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/docs/common.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +def add_docs(method, docstr: str): + method.__doc__ = docstr diff --git a/fbgemm_gpu/fbgemm_gpu/docs/jagged_tensor_ops.py b/fbgemm_gpu/fbgemm_gpu/docs/jagged_tensor_ops.py new file mode 100644 index 000000000..a00b66079 --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/docs/jagged_tensor_ops.py @@ -0,0 +1,256 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from .common import add_docs + +add_docs( + torch.ops.fbgemm.jagged_2d_to_dense, + """ +jagged_2d_to_dense(values, x_offsets, max_sequence_length) -> Tensor + +Converts a jagged tensor, with a 2D values array into a dense tensor, padding with zeros. + +Args: + values (Tensor): 2D tensor containing the values of the jagged tensor. + + x_offsets (Tensor): 1D tensor containing the starting point of each jagged row in the values tensor. + + max_sequence_length (int): Maximum length of any row in the jagged dimension. + +Returns: + Tensor: The padded dense tensor + +Example: + >>> values = torch.tensor([[1,1],[2,2],[3,3],[4,4]]) + >>> x_offsets = torch.tensor([0, 1, 3]) + >>> torch.ops.fbgemm.jagged_2d_to_dense(values, x_offsets, 3) + tensor([[[1, 1], + [0, 0], + [0, 0]], + [[2, 2], + [3, 3], + [0, 0]]]) + +""", +) + +add_docs( + torch.ops.fbgemm.jagged_1d_to_dense, + """ +jagged_1d_to_dense(values, offsets, max_sequence_length, padding_value) -> Tensor) + +Converts a jagged tensor, with a 1D values array, into a dense tensor, padding with a specified padding value. + +Args: + values (Tensor): 1D tensor containing the values of the jagged tensor. + + offsets (Tensor): 1D tensor containing the starting point of each jagged row in the values tensor. + + max_sequence_length (int): Maximum length of any row in the jagged dimension. + + padding_value (int): Value to set in the empty areas of the dense output, outside of the jagged tensor coverage. + +Returns: + Tensor: the padded dense tensor + +Example: + >>> values = torch.tensor([1,2,3,4]) + >>> offsets = torch.tensor([0, 1, 3]) + >>> torch.ops.fbgemm.jagged_1d_to_dense(values, x_offsets, 3, 0) + tensor([[1, 0, 0], + [2, 3, 0]]) + +""", +) + +add_docs( + torch.ops.fbgemm.dense_to_jagged, + """ +dense_to_jagged(dense, x_offsets, total_L) -> (Tensor, Tensor[]) + +Converts a dense tensor into a jagged tensor, given the desired offsets of the resulting dense tensor. + +Args: + dense (Tensor): A dense input tensor to be converted + + x_offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension. + + total_L (int, Optional): Total number of values in the resulting jagged tensor. + +Returns: + (Tensor, Tensor[]): Values and offsets of the resulting jagged tensor. Offsets are identital to those that were input. + +Example: + >>> dense = torch.tensor([[[1, 1], [0, 0], [0, 0]], [[2, 2], [3, 3], [0, 0]]]) + >>> x_offsets = torch.tensor([0, 1, 3]) + >>> torch.ops.fbgemm.dense_to_jagged(dense, [x_offsets]) + (tensor([[1, 1], + [2, 2], + [3, 3]]), [tensor([0, 1, 3])]) + +""", +) + + +add_docs( + torch.ops.fbgemm.jagged_to_padded_dense, + """ +jagged_to_padded_dense(values, offsets, max_lengths, padding_value=0) -> Tensor + +Converts a jagged tensor into a dense tensor, padding with a specified padding value. + +Args: + values (Tensor): Jagged tensor values + + offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension. + + max_lengths (int[]): A list with max_length for each jagged dimension. + + padding_value (float): Value to set in the empty areas of the dense output, outside of the jagged tensor coverage. + +Returns: + Tensor: the padded dense tensor + +Example: + >>> values = torch.tensor([[1,1],[2,2],[3,3],[4,4]]) + >>> offsets = torch.tensor([0, 1, 3]) + >>> torch.ops.fbgemm.jagged_to_padded_dense(values, [offsets], [3], 7) + tensor([[[1, 1], + [7, 7], + [7, 7]], + [[2, 2], + [3, 3], + [7, 7]]]) +""", +) + + +add_docs( + torch.ops.fbgemm.jagged_dense_elementwise_add, + """ +jagged_dense_elementwise_add(x_values, x_offsets, y) -> Tensor + +Adds a jagged tensor to a dense tensor, resulting in dense tensor. Jagged +tensor input will be padded with zeros for the purposes of the addition. + +Args: + x_values (Tensor): Jagged tensor values + + offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension. + + y (Tensor): A dense tensor + +Returns: + Tensor: The sum of jagged input tensor + y + +""", +) + + +add_docs( + torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output, + """ +jagged_dense_elementwise_add_jagged_output(x_values, x_offsets, y) -> (Tensor, Tensor[]) + +Adds a jagged tensor to a dense tensor and, resulting in a jagged tensor with the same structure as the input jagged tensor. + +Args: + x_values (Tensor): Jagged tensor values + + x_offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension. + + y (Tensor): A dense tensor + +Returns: + (Tensor, Tensor[]): Values and offsets of the resulting jagged tensor. Offsets are identital to those that were input. + +""", +) + + +add_docs( + torch.ops.fbgemm.jagged_dense_dense_elementwise_add_jagged_output, + """ +jagged_dense_dense_elementwise_add_jagged_output(x_values, x_offsets, y_0, y_1) -> (Tensor, Tensor[]) + +Adds a jagged tensor to the sum of two dense tensors, resulting in a jagged tensor with the same structure as the input jagged tensor. + +Args: + x_values (Tensor): Jagged tensor values + + x_offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension. + + y_0 (Tensor): A dense tensor + + y_1 (Tensor): A dense tensor + +Returns: + (Tensor, Tensor[]): Values and offsets of the resulting jagged tensor. Offsets are identital to those that were input. + +""", +) + + +add_docs( + torch.ops.fbgemm.jagged_dense_elementwise_mul, + """ +jagged_dense_elementwise_mul(x_values, x_offsets, y) -> (Tensor, Tensor[]) + +Elementwise-multiplies a jagged tensor a dense tensor and, resulting in a jagged tensor with the same structure as the input jagged tensor. + +Args: + x_values (Tensor): Jagged tensor values + + x_offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension. + + y (Tensor): A dense tensor + +Returns: + (Tensor, Tensor[]): Values and offsets of the resulting jagged tensor. Offsets are identital to those that were input. + +""", +) + +add_docs( + torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul, + """ +batched_dense_vec_jagged_2d_mul(Tensor v, Tensor a_values, Tensor a_offsets) -> Tensor + +Batched vector matrix multiplication of a batched dense vector with a jagged tensor, dense vector is in +size (B * H, max_N) and jagged tensor is in size (B, max_N, H * D) where max_N is the maximum size of +jagged dimension. B * H is the batch size and each multiplies is max_N with [max_N, D] + +Args: + v (Tensor): dense vector tensor + + a_values (Tensor): Jagged tensor values + + a_offsets (Tensor []): A list of jagged offset tensors, one for each jagged dimension. + +Returns: + Tensor: output of batch matmul in size (B * H, D) + +""", +) + +# add_docs( +# torch.ops.fbgemm.stacked_jagged_1d_to_dense, +# """Args: +# {input} +# Keyword args: +# {out}""", +# ) +# +# +# add_docs( +# torch.ops.fbgemm.stacked_jagged_2d_to_dense, +# """Args: +# {input} +# Keyword args: +# {out}""", +# ) diff --git a/fbgemm_gpu/fbgemm_gpu/_fbgemm_gpu_docs.py b/fbgemm_gpu/fbgemm_gpu/docs/table_batched_embedding_ops.py similarity index 51% rename from fbgemm_gpu/fbgemm_gpu/_fbgemm_gpu_docs.py rename to fbgemm_gpu/fbgemm_gpu/docs/table_batched_embedding_ops.py index 7572010d1..47cda90e4 100644 --- a/fbgemm_gpu/fbgemm_gpu/_fbgemm_gpu_docs.py +++ b/fbgemm_gpu/fbgemm_gpu/docs/table_batched_embedding_ops.py @@ -4,230 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import fbgemm_gpu import fbgemm_gpu.split_table_batched_embeddings_ops_training -import torch # usort:skip -Tensor = torch.Tensor - - -def add_docs(method, docstr): - method.__doc__ = docstr - - -add_docs( - torch.ops.fbgemm.jagged_2d_to_dense, - """ -jagged_2d_to_dense(values, x_offsets, max_sequence_length) -> Tensor - -Converts a jagged tensor, with a 2D values array into a dense tensor, padding with zeros. - -Args: - values (Tensor): 2D tensor containing the values of the jagged tensor. - - x_offsets (Tensor): 1D tensor containing the starting point of each jagged row in the values tensor. - - max_sequence_length (int): Maximum length of any row in the jagged dimension. - -Returns: - Tensor: The padded dense tensor - -Example: - >>> values = torch.tensor([[1,1],[2,2],[3,3],[4,4]]) - >>> x_offsets = torch.tensor([0, 1, 3]) - >>> torch.ops.fbgemm.jagged_2d_to_dense(values, x_offsets, 3) - tensor([[[1, 1], - [0, 0], - [0, 0]], - [[2, 2], - [3, 3], - [0, 0]]]) - -""", -) - -# Example: -# -# >>> t = torch.arange(4) - - -add_docs( - torch.ops.fbgemm.jagged_1d_to_dense, - """ -jagged_1d_to_dense(values, offsets, max_sequence_length, padding_value) -> Tensor) - -Converts a jagged tensor, with a 1D values array, into a dense tensor, padding with a specified padding value. - -Args: - values (Tensor): 1D tensor containing the values of the jagged tensor. - - offsets (Tensor): 1D tensor containing the starting point of each jagged row in the values tensor. - - max_sequence_length (int): Maximum length of any row in the jagged dimension. - - padding_value (int): Value to set in the empty areas of the dense output, outside of the jagged tensor coverage. - -Returns: - Tensor: the padded dense tensor - -Example: - >>> values = torch.tensor([1,2,3,4]) - >>> offsets = torch.tensor([0, 1, 3]) - >>> torch.ops.fbgemm.jagged_1d_to_dense(values, x_offsets, 3, 0) - tensor([[1, 0, 0], - [2, 3, 0]]) - -""", -) - - -add_docs( - torch.ops.fbgemm.dense_to_jagged, - """ -dense_to_jagged(dense, x_offsets, total_L) -> (Tensor, Tensor[]) - -Converts a dense tensor into a jagged tensor, given the desired offsets of the resulting dense tensor. - -Args: - dense (Tensor): A dense input tensor to be converted - - x_offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension. - - total_L (int, Optional): Total number of values in the resulting jagged tensor. - -Returns: - (Tensor, Tensor[]): Values and offsets of the resulting jagged tensor. Offsets are identital to those that were input. - -Example: - >>> dense = torch.tensor([[[1, 1], [0, 0], [0, 0]], [[2, 2], [3, 3], [0, 0]]]) - >>> x_offsets = torch.tensor([0, 1, 3]) - >>> torch.ops.fbgemm.dense_to_jagged(dense, [x_offsets]) - (tensor([[1, 1], - [2, 2], - [3, 3]]), [tensor([0, 1, 3])]) - -""", -) - - -add_docs( - torch.ops.fbgemm.jagged_to_padded_dense, - """ -jagged_to_padded_dense(values, offsets, max_lengths, padding_value=0) -> Tensor - -Converts a jagged tensor into a dense tensor, padding with a specified padding value. - -Args: - values (Tensor): Jagged tensor values - - offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension. - - max_lengths (int[]): A list with max_length for each jagged dimension. - - padding_value (float): Value to set in the empty areas of the dense output, outside of the jagged tensor coverage. - -Returns: - Tensor: the padded dense tensor - -Example: - >>> values = torch.tensor([[1,1],[2,2],[3,3],[4,4]]) - >>> offsets = torch.tensor([0, 1, 3]) - >>> torch.ops.fbgemm.jagged_to_padded_dense(values, [offsets], [3], 7) - tensor([[[1, 1], - [7, 7], - [7, 7]], - [[2, 2], - [3, 3], - [7, 7]]]) -""", -) - - -add_docs( - torch.ops.fbgemm.jagged_dense_elementwise_add, - """ -jagged_dense_elementwise_add(x_values, x_offsets, y) -> Tensor - -Adds a jagged tensor to a dense tensor, resulting in dense tensor. Jagged -tensor input will be padded with zeros for the purposes of the addition. - -Args: - x_values (Tensor): Jagged tensor values - - offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension. - - y (Tensor): A dense tensor - -Returns: - Tensor: The sum of jagged input tensor + y - -""", -) - - -add_docs( - torch.ops.fbgemm.jagged_dense_elementwise_add_jagged_output, - """ -jagged_dense_elementwise_add_jagged_output(x_values, x_offsets, y) -> (Tensor, Tensor[]) - -Adds a jagged tensor to a dense tensor and, resulting in a jagged tensor with the same structure as the input jagged tensor. - -Args: - x_values (Tensor): Jagged tensor values - - x_offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension. - - y (Tensor): A dense tensor - -Returns: - (Tensor, Tensor[]): Values and offsets of the resulting jagged tensor. Offsets are identital to those that were input. - -""", -) - - -add_docs( - torch.ops.fbgemm.jagged_dense_dense_elementwise_add_jagged_output, - """ -jagged_dense_dense_elementwise_add_jagged_output(x_values, x_offsets, y_0, y_1) -> (Tensor, Tensor[]) - -Adds a jagged tensor to the sum of two dense tensors, resulting in a jagged tensor with the same structure as the input jagged tensor. - -Args: - x_values (Tensor): Jagged tensor values - - x_offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension. - - y_0 (Tensor): A dense tensor - - y_1 (Tensor): A dense tensor - -Returns: - (Tensor, Tensor[]): Values and offsets of the resulting jagged tensor. Offsets are identital to those that were input. - -""", -) - - -add_docs( - torch.ops.fbgemm.jagged_dense_elementwise_mul, - """ -jagged_dense_elementwise_mul(x_values, x_offsets, y) -> (Tensor, Tensor[]) - -Elementwise-multiplies a jagged tensor a dense tensor and, resulting in a jagged tensor with the same structure as the input jagged tensor. - -Args: - x_values (Tensor): Jagged tensor values - - x_offsets (Tensor[]): A list of jagged offset tensors, one for each jagged dimension. - - y (Tensor): A dense tensor - -Returns: - (Tensor, Tensor[]): Values and offsets of the resulting jagged tensor. Offsets are identital to those that were input. - -""", -) +from .common import add_docs add_docs( @@ -357,46 +136,3 @@ def add_docs(method, docstr): grad_fn=>) """, ) - - -add_docs( - torch.ops.fbgemm.batched_dense_vec_jagged_2d_mul, - """ -batched_dense_vec_jagged_2d_mul(Tensor v, Tensor a_values, Tensor a_offsets) -> Tensor - -Batched vector matrix multiplication of a batched dense vector with a jagged tensor, dense vector is in -size (B * H, max_N) and jagged tensor is in size (B, max_N, H * D) where max_N is the maximum size of -jagged dimension. B * H is the batch size and each multiplies is max_N with [max_N, D] - -Args: - v (Tensor): dense vector tensor - - a_values (Tensor): Jagged tensor values - - a_offsets (Tensor []): A list of jagged offset tensors, one for each jagged dimension. - -Returns: - Tensor: output of batch matmul in size (B * H, D) - -""", -) - - -# -# -# add_docs( -# torch.ops.fbgemm.stacked_jagged_1d_to_dense, -# """Args: -# {input} -# Keyword args: -# {out}""", -# ) -# -# -# add_docs( -# torch.ops.fbgemm.stacked_jagged_2d_to_dense, -# """Args: -# {input} -# Keyword args: -# {out}""", -# ) diff --git a/fbgemm_gpu/setup.py b/fbgemm_gpu/setup.py index deee82319..549757983 100644 --- a/fbgemm_gpu/setup.py +++ b/fbgemm_gpu/setup.py @@ -17,6 +17,7 @@ from datetime import date from typing import List, Optional +import setuptools import setuptools_git_versioning as gitversion import torch from setuptools.command.install import install as PipInstall @@ -266,7 +267,7 @@ def generate_package_version(cls, package_name: str, variant_version: str): @classmethod def generate_version_file(cls, package_version: str) -> None: - with open("fbgemm_gpu/_fbgemm_gpu_version.py", "w") as file: + with open("fbgemm_gpu/docs/version.py", "w") as file: print( f"[SETUP.PY] Generating version file at: {os.path.realpath(file.name)}" ) @@ -382,7 +383,7 @@ def main(argv: List[str]) -> None: "GPU", "CUDA", ], - packages=["fbgemm_gpu"], + packages=setuptools.find_packages(), install_requires=[ # Only specify numpy, as specifying torch will auto-install the # release version of torch, which is not what we want for the