Skip to content

Commit

Permalink
Fix batch index select ops registration (pytorch#2598)
Browse files Browse the repository at this point in the history
Summary:
- Fix batch index select ops registration

Pull Request resolved: pytorch#2598

Reviewed By: sryap

Differential Revision: D57482140

Pulled By: q10

fbshipit-source-id: 50c6ca29f62a6a423383587404c4eb48ff9c4b15
  • Loading branch information
q10 authored and facebook-github-bot committed May 17, 2024
1 parent 578ab67 commit 1c0344f
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 138 deletions.
1 change: 1 addition & 0 deletions fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ set(fbgemm_gpu_sources_static_cpu
src/split_embeddings_cache/lru_cache_populate_byte.cpp
src/split_embeddings_cache/lxu_cache.cpp
src/split_embeddings_cache/split_embeddings_cache_ops.cpp
codegen/training/index_select/batch_index_select_dim0_ops.cpp
codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp)

if(NOT FBGEMM_CPU_ONLY)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,71 +489,11 @@ Tensor batch_index_select_dim0_tensor_cpu_autograd(

// Deprecated for fb namespace! Please use fbgemm namespace instead!
TORCH_LIBRARY_FRAGMENT(fb, m) {
m.def(
"batch_index_select_dim0("
" Tensor inputs,"
" Tensor indices,"
" SymInt[] input_num_indices,"
" SymInt[] input_rows,"
" SymInt[] input_columns,"
" bool permute_output_dim_0_1=False) -> Tensor");
DISPATCH_TO_CPU(
"batch_index_select_dim0", batch_index_select_dim0_cpu_autograd);
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.impl_abstract_pystub(
"fbgemm_gpu.sparse_ops",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py");

m.def(
"batch_index_select_dim0("
" Tensor inputs,"
" Tensor indices,"
" SymInt[] input_num_indices,"
" SymInt[] input_rows,"
" SymInt[] input_columns,"
" bool permute_output_dim_0_1=False) -> Tensor",
{PT2_COMPLIANT_TAG});

m.def(
"batch_index_select_dim0_forward_cpu_impl("
"Tensor inputs,"
"Tensor indices,"
"SymInt[] input_num_indices,"
"SymInt[] input_rows,"
"SymInt[] input_columns,"
"bool permute_output_dim_0_1) -> Tensor[]");

m.def(
"batch_index_select_dim0_backward_cpu_impl("
"Tensor grad_output,"
"Tensor indices,"
"Tensor indices_numels,"
"Tensor input_num_indices,"
"Tensor input_rows,"
"Tensor input_columns,"
"bool permute_output_dim_0_1,"
"Tensor saved_tensor) -> Tensor");

m.def(
"batch_index_select_dim0_tensor("
" Tensor inputs,"
" Tensor indices,"
" Tensor input_num_indices,"
" Tensor input_rows,"
" Tensor input_columns,"
" bool permute_output_dim_0_1=False) -> Tensor");

m.def(
"batch_index_select_dim0_tensor_forward_cpu_impl("
"Tensor inputs,"
"Tensor indices,"
"Tensor input_num_indices,"
"Tensor input_rows,"
"Tensor input_columns,"
"bool permute_output_dim_0_1) -> Tensor[]");

DISPATCH_TO_CPU(
"batch_index_select_dim0_forward_cpu_impl",
BatchIndexSelectDim0CPUOp::forward_impl);
Expand All @@ -565,14 +505,10 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"batch_index_select_dim0_backward_cpu_impl",
BatchIndexSelectDim0CPUOp::backward_impl);

m.impl(
"batch_index_select_dim0",
torch::dispatch(
c10::DispatchKey::AutogradCPU,
TORCH_FN(batch_index_select_dim0_cpu_autograd)));
m.impl(
DISPATCH_TO_AUTOGRAD_CPU(
"batch_index_select_dim0", batch_index_select_dim0_cpu_autograd);

DISPATCH_TO_AUTOGRAD_CPU(
"batch_index_select_dim0_tensor",
torch::dispatch(
c10::DispatchKey::AutogradCPU,
TORCH_FN(batch_index_select_dim0_tensor_cpu_autograd)));
batch_index_select_dim0_tensor_cpu_autograd);
}
Original file line number Diff line number Diff line change
Expand Up @@ -712,28 +712,6 @@ TORCH_LIBRARY_FRAGMENT(fb, m) {
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.set_python_module("fbgemm_gpu.sparse_ops");
m.def(
"batch_index_select_dim0_forward_cuda_impl("
"Tensor inputs,"
"Tensor indices,"
"SymInt[] input_num_indices,"
"SymInt[] input_rows,"
"SymInt[] input_columns,"
"bool permute_output_dim_0_1) -> Tensor[]");
m.def(
"batch_index_select_dim0_backward_cuda_impl("
"Tensor grad_output,"
"Tensor dev_weights,"
"Tensor weights_offsets,"
"Tensor D_offsets,"
"Tensor hash_size_cumsum,"
"Tensor indices,"
"int max_segment_length_per_warp,"
"Tensor grad_offsets,"
"Tensor total_L_offsets,"
"bool permute_output_dim_0_1,"
"Tensor saved_tensor) ->Tensor");
DISPATCH_TO_CUDA(
"batch_index_select_dim0_forward_cuda_impl",
BatchIndexSelectDim0GPUOp::forward_impl);
Expand All @@ -742,30 +720,6 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
BatchIndexSelectDim0GPUOp::backward_impl);
DISPATCH_TO_AUTOGRAD_CUDA(
"batch_index_select_dim0", batch_index_select_dim0_gpu);

// Tensor alternative
m.def(
"batch_index_select_dim0_tensor_forward_cuda_impl("
"Tensor inputs,"
"Tensor indices,"
"Tensor input_num_indices,"
"Tensor input_rows,"
"Tensor input_columns,"
"bool permute_output_dim_0_1) -> Tensor[]");

m.def(
"batch_index_select_dim0_tensor_backward_cuda_impl("
"Tensor grad_output,"
"Tensor dev_weights,"
"Tensor weights_offsets,"
"Tensor D_offsets,"
"Tensor hash_size_cumsum,"
"Tensor indices,"
"int max_segment_length_per_warp,"
"Tensor grad_offsets,"
"Tensor total_L_offsets,"
"bool permute_output_dim_0_1,"
"Tensor saved_tensor) -> Tensor");
DISPATCH_TO_CUDA(
"batch_index_select_dim0_tensor_forward_cuda_impl",
BatchIndexSelectDim0TensorGPUOp::forward_impl);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fbgemm_gpu/dispatch_macros.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

TORCH_LIBRARY_FRAGMENT(fb, m) {
m.def(
"batch_index_select_dim0("
" Tensor inputs,"
" Tensor indices,"
" SymInt[] input_num_indices,"
" SymInt[] input_rows,"
" SymInt[] input_columns,"
" bool permute_output_dim_0_1=False) -> Tensor");
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.set_python_module("fbgemm_gpu.sparse_ops");

m.impl_abstract_pystub(
"fbgemm_gpu.sparse_ops",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py");

m.def(
"batch_index_select_dim0("
" Tensor inputs,"
" Tensor indices,"
" SymInt[] input_num_indices,"
" SymInt[] input_rows,"
" SymInt[] input_columns,"
" bool permute_output_dim_0_1=False) -> Tensor",
{PT2_COMPLIANT_TAG});

m.def(
"batch_index_select_dim0_tensor("
" Tensor inputs,"
" Tensor indices,"
" Tensor input_num_indices,"
" Tensor input_rows,"
" Tensor input_columns,"
" bool permute_output_dim_0_1=False) -> Tensor");

m.def(
"batch_index_select_dim0_forward_cpu_impl("
" Tensor inputs,"
" Tensor indices,"
" SymInt[] input_num_indices,"
" SymInt[] input_rows,"
" SymInt[] input_columns,"
" bool permute_output_dim_0_1) -> Tensor[]");

m.def(
"batch_index_select_dim0_backward_cpu_impl("
" Tensor grad_output,"
" Tensor indices,"
" Tensor indices_numels,"
" Tensor input_num_indices,"
" Tensor input_rows,"
" Tensor input_columns,"
" bool permute_output_dim_0_1,"
" Tensor saved_tensor) -> Tensor");

m.def(
"batch_index_select_dim0_tensor_forward_cpu_impl("
" Tensor inputs,"
" Tensor indices,"
" Tensor input_num_indices,"
" Tensor input_rows,"
" Tensor input_columns,"
" bool permute_output_dim_0_1) -> Tensor[]");

// CUDA ops

m.def(
"batch_index_select_dim0_forward_cuda_impl("
" Tensor inputs,"
" Tensor indices,"
" SymInt[] input_num_indices,"
" SymInt[] input_rows,"
" SymInt[] input_columns,"
" bool permute_output_dim_0_1) -> Tensor[]");

m.def(
"batch_index_select_dim0_backward_cuda_impl("
" Tensor grad_output,"
" Tensor dev_weights,"
" Tensor weights_offsets,"
" Tensor D_offsets,"
" Tensor hash_size_cumsum,"
" Tensor indices,"
" int max_segment_length_per_warp,"
" Tensor grad_offsets,"
" Tensor total_L_offsets,"
" bool permute_output_dim_0_1,"
" Tensor saved_tensor) -> Tensor");

m.def(
"batch_index_select_dim0_tensor_forward_cuda_impl("
" Tensor inputs,"
" Tensor indices,"
" Tensor input_num_indices,"
" Tensor input_rows,"
" Tensor input_columns,"
" bool permute_output_dim_0_1) -> Tensor[]");

m.def(
"batch_index_select_dim0_tensor_backward_cuda_impl("
" Tensor grad_output,"
" Tensor dev_weights,"
" Tensor weights_offsets,"
" Tensor D_offsets,"
" Tensor hash_size_cumsum,"
" Tensor indices,"
" int max_segment_length_per_warp,"
" Tensor grad_offsets,"
" Tensor total_L_offsets,"
" bool permute_output_dim_0_1,"
" Tensor saved_tensor) -> Tensor");
}
5 changes: 3 additions & 2 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_cpu")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops")


import torch.utils._pytree as pytree
from torch import SymInt, Tensor
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
Expand Down Expand Up @@ -829,7 +830,7 @@ def group_index_select_dim0_gpu_backward_abstract(
return ret


@impl_abstract("fbgemm::keyed_jagged_index_select_dim1_forward_cuda_impl")
@impl_abstract("fbgemm::keyed_jagged_index_select_dim1_forward")
def keyed_jagged_index_select_dim1_forward_cuda_impl_abstract(
values: torch.Tensor,
lengths: torch.Tensor,
Expand Down Expand Up @@ -865,7 +866,7 @@ def keyed_jagged_index_select_dim1_forward_cuda_impl_abstract(
]


@impl_abstract("fbgemm::keyed_jagged_index_select_dim1_backward_cuda_impl")
@impl_abstract("fbgemm::keyed_jagged_index_select_dim1_backward")
def keyed_jagged_index_select_dim1_backward_cuda_impl_abstract(
grad: torch.Tensor,
indices: torch.Tensor,
Expand Down
2 changes: 2 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .split_embeddings_cache_ops import get_unique_indices # noqa: F401
5 changes: 5 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ inline bool torch_tensor_empty_or_on_cpu_check(
#define DISPATCH_TO_AUTOGRAD(name, function) \
m.impl(name, torch::dispatch(c10::DispatchKey::Autograd, TORCH_FN(function)))

#define DISPATCH_TO_AUTOGRAD_CPU(name, function) \
m.impl( \
name, \
torch::dispatch(c10::DispatchKey::AutogradCPU, TORCH_FN(function)))

#define DISPATCH_TO_AUTOGRAD_CUDA(name, function) \
m.impl( \
name, \
Expand Down
21 changes: 21 additions & 0 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1597,6 +1597,27 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"fbgemm_gpu.sparse_ops",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py");
#endif

m.set_python_module("fbgemm_gpu.sparse_ops");

m.def(
"keyed_jagged_index_select_dim1_forward("
" Tensor values,"
" Tensor lengths,"
" Tensor offsets,"
" Tensor indices,"
" SymInt batch_size,"
" Tensor? weights,"
" SymInt? selected_lengths_sum) -> Tensor[]");

m.def(
"keyed_jagged_index_select_dim1_backward("
" Tensor grad,"
" Tensor indices,"
" Tensor grad_offsets,"
" Tensor output_offsets,"
" Tensor saved_tensor) -> Tensor");

// (dense, offsets) -> jagged. Offsets output is same as input.
// SymInt is a new PyTorch 2.0 feature to support dynamic shape. See more
// details at https://pytorch.org/get-started/pytorch-2.0/#dynamic-shapes. If
Expand Down
Loading

0 comments on commit 1c0344f

Please sign in to comment.