diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index dca726321..a6824b57c 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -417,6 +417,7 @@ set(fbgemm_gpu_sources_static_cpu src/split_embeddings_cache/lru_cache_populate_byte.cpp src/split_embeddings_cache/lxu_cache.cpp src/split_embeddings_cache/split_embeddings_cache_ops.cpp + codegen/training/index_select/batch_index_select_dim0_ops.cpp codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp) if(NOT FBGEMM_CPU_ONLY) diff --git a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp index 752a95ef5..4b5473f7e 100644 --- a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp @@ -489,71 +489,11 @@ Tensor batch_index_select_dim0_tensor_cpu_autograd( // Deprecated for fb namespace! Please use fbgemm namespace instead! TORCH_LIBRARY_FRAGMENT(fb, m) { - m.def( - "batch_index_select_dim0(" - " Tensor inputs," - " Tensor indices," - " SymInt[] input_num_indices," - " SymInt[] input_rows," - " SymInt[] input_columns," - " bool permute_output_dim_0_1=False) -> Tensor"); DISPATCH_TO_CPU( "batch_index_select_dim0", batch_index_select_dim0_cpu_autograd); } TORCH_LIBRARY_FRAGMENT(fbgemm, m) { - m.impl_abstract_pystub( - "fbgemm_gpu.sparse_ops", - "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py"); - - m.def( - "batch_index_select_dim0(" - " Tensor inputs," - " Tensor indices," - " SymInt[] input_num_indices," - " SymInt[] input_rows," - " SymInt[] input_columns," - " bool permute_output_dim_0_1=False) -> Tensor", - {PT2_COMPLIANT_TAG}); - - m.def( - "batch_index_select_dim0_forward_cpu_impl(" - "Tensor inputs," - "Tensor indices," - "SymInt[] input_num_indices," - "SymInt[] input_rows," - "SymInt[] input_columns," - "bool permute_output_dim_0_1) -> Tensor[]"); - - m.def( - "batch_index_select_dim0_backward_cpu_impl(" - "Tensor grad_output," - "Tensor indices," - "Tensor indices_numels," - "Tensor input_num_indices," - "Tensor input_rows," - "Tensor input_columns," - "bool permute_output_dim_0_1," - "Tensor saved_tensor) -> Tensor"); - - m.def( - "batch_index_select_dim0_tensor(" - " Tensor inputs," - " Tensor indices," - " Tensor input_num_indices," - " Tensor input_rows," - " Tensor input_columns," - " bool permute_output_dim_0_1=False) -> Tensor"); - - m.def( - "batch_index_select_dim0_tensor_forward_cpu_impl(" - "Tensor inputs," - "Tensor indices," - "Tensor input_num_indices," - "Tensor input_rows," - "Tensor input_columns," - "bool permute_output_dim_0_1) -> Tensor[]"); - DISPATCH_TO_CPU( "batch_index_select_dim0_forward_cpu_impl", BatchIndexSelectDim0CPUOp::forward_impl); @@ -565,14 +505,10 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "batch_index_select_dim0_backward_cpu_impl", BatchIndexSelectDim0CPUOp::backward_impl); - m.impl( - "batch_index_select_dim0", - torch::dispatch( - c10::DispatchKey::AutogradCPU, - TORCH_FN(batch_index_select_dim0_cpu_autograd))); - m.impl( + DISPATCH_TO_AUTOGRAD_CPU( + "batch_index_select_dim0", batch_index_select_dim0_cpu_autograd); + + DISPATCH_TO_AUTOGRAD_CPU( "batch_index_select_dim0_tensor", - torch::dispatch( - c10::DispatchKey::AutogradCPU, - TORCH_FN(batch_index_select_dim0_tensor_cpu_autograd))); + batch_index_select_dim0_tensor_cpu_autograd); } diff --git a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp index e2fd8f376..64e13ad8e 100644 --- a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp @@ -712,28 +712,6 @@ TORCH_LIBRARY_FRAGMENT(fb, m) { } TORCH_LIBRARY_FRAGMENT(fbgemm, m) { - m.set_python_module("fbgemm_gpu.sparse_ops"); - m.def( - "batch_index_select_dim0_forward_cuda_impl(" - "Tensor inputs," - "Tensor indices," - "SymInt[] input_num_indices," - "SymInt[] input_rows," - "SymInt[] input_columns," - "bool permute_output_dim_0_1) -> Tensor[]"); - m.def( - "batch_index_select_dim0_backward_cuda_impl(" - "Tensor grad_output," - "Tensor dev_weights," - "Tensor weights_offsets," - "Tensor D_offsets," - "Tensor hash_size_cumsum," - "Tensor indices," - "int max_segment_length_per_warp," - "Tensor grad_offsets," - "Tensor total_L_offsets," - "bool permute_output_dim_0_1," - "Tensor saved_tensor) ->Tensor"); DISPATCH_TO_CUDA( "batch_index_select_dim0_forward_cuda_impl", BatchIndexSelectDim0GPUOp::forward_impl); @@ -742,30 +720,6 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { BatchIndexSelectDim0GPUOp::backward_impl); DISPATCH_TO_AUTOGRAD_CUDA( "batch_index_select_dim0", batch_index_select_dim0_gpu); - - // Tensor alternative - m.def( - "batch_index_select_dim0_tensor_forward_cuda_impl(" - "Tensor inputs," - "Tensor indices," - "Tensor input_num_indices," - "Tensor input_rows," - "Tensor input_columns," - "bool permute_output_dim_0_1) -> Tensor[]"); - - m.def( - "batch_index_select_dim0_tensor_backward_cuda_impl(" - "Tensor grad_output," - "Tensor dev_weights," - "Tensor weights_offsets," - "Tensor D_offsets," - "Tensor hash_size_cumsum," - "Tensor indices," - "int max_segment_length_per_warp," - "Tensor grad_offsets," - "Tensor total_L_offsets," - "bool permute_output_dim_0_1," - "Tensor saved_tensor) -> Tensor"); DISPATCH_TO_CUDA( "batch_index_select_dim0_tensor_forward_cuda_impl", BatchIndexSelectDim0TensorGPUOp::forward_impl); diff --git a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_ops.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_ops.cpp new file mode 100644 index 000000000..4ff18498c --- /dev/null +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_ops.cpp @@ -0,0 +1,125 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fbgemm_gpu/dispatch_macros.h" +#include "fbgemm_gpu/sparse_ops_utils.h" + +TORCH_LIBRARY_FRAGMENT(fb, m) { + m.def( + "batch_index_select_dim0(" + " Tensor inputs," + " Tensor indices," + " SymInt[] input_num_indices," + " SymInt[] input_rows," + " SymInt[] input_columns," + " bool permute_output_dim_0_1=False) -> Tensor"); +} + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.set_python_module("fbgemm_gpu.sparse_ops"); + + m.impl_abstract_pystub( + "fbgemm_gpu.sparse_ops", + "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py"); + + m.def( + "batch_index_select_dim0(" + " Tensor inputs," + " Tensor indices," + " SymInt[] input_num_indices," + " SymInt[] input_rows," + " SymInt[] input_columns," + " bool permute_output_dim_0_1=False) -> Tensor", + {PT2_COMPLIANT_TAG}); + + m.def( + "batch_index_select_dim0_tensor(" + " Tensor inputs," + " Tensor indices," + " Tensor input_num_indices," + " Tensor input_rows," + " Tensor input_columns," + " bool permute_output_dim_0_1=False) -> Tensor"); + + m.def( + "batch_index_select_dim0_forward_cpu_impl(" + " Tensor inputs," + " Tensor indices," + " SymInt[] input_num_indices," + " SymInt[] input_rows," + " SymInt[] input_columns," + " bool permute_output_dim_0_1) -> Tensor[]"); + + m.def( + "batch_index_select_dim0_backward_cpu_impl(" + " Tensor grad_output," + " Tensor indices," + " Tensor indices_numels," + " Tensor input_num_indices," + " Tensor input_rows," + " Tensor input_columns," + " bool permute_output_dim_0_1," + " Tensor saved_tensor) -> Tensor"); + + m.def( + "batch_index_select_dim0_tensor_forward_cpu_impl(" + " Tensor inputs," + " Tensor indices," + " Tensor input_num_indices," + " Tensor input_rows," + " Tensor input_columns," + " bool permute_output_dim_0_1) -> Tensor[]"); + + // CUDA ops + + m.def( + "batch_index_select_dim0_forward_cuda_impl(" + " Tensor inputs," + " Tensor indices," + " SymInt[] input_num_indices," + " SymInt[] input_rows," + " SymInt[] input_columns," + " bool permute_output_dim_0_1) -> Tensor[]"); + + m.def( + "batch_index_select_dim0_backward_cuda_impl(" + " Tensor grad_output," + " Tensor dev_weights," + " Tensor weights_offsets," + " Tensor D_offsets," + " Tensor hash_size_cumsum," + " Tensor indices," + " int max_segment_length_per_warp," + " Tensor grad_offsets," + " Tensor total_L_offsets," + " bool permute_output_dim_0_1," + " Tensor saved_tensor) -> Tensor"); + + m.def( + "batch_index_select_dim0_tensor_forward_cuda_impl(" + " Tensor inputs," + " Tensor indices," + " Tensor input_num_indices," + " Tensor input_rows," + " Tensor input_columns," + " bool permute_output_dim_0_1) -> Tensor[]"); + + m.def( + "batch_index_select_dim0_tensor_backward_cuda_impl(" + " Tensor grad_output," + " Tensor dev_weights," + " Tensor weights_offsets," + " Tensor D_offsets," + " Tensor hash_size_cumsum," + " Tensor indices," + " int max_segment_length_per_warp," + " Tensor grad_offsets," + " Tensor total_L_offsets," + " bool permute_output_dim_0_1," + " Tensor saved_tensor) -> Tensor"); +} diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index c18106415..0ae295a6c 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -47,6 +47,7 @@ torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:input_combine_cpu") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu/codegen:index_select_ops") + import torch.utils._pytree as pytree from torch import SymInt, Tensor from torch.fx.experimental.symbolic_shapes import guard_size_oblivious @@ -829,7 +830,7 @@ def group_index_select_dim0_gpu_backward_abstract( return ret -@impl_abstract("fbgemm::keyed_jagged_index_select_dim1_forward_cuda_impl") +@impl_abstract("fbgemm::keyed_jagged_index_select_dim1_forward") def keyed_jagged_index_select_dim1_forward_cuda_impl_abstract( values: torch.Tensor, lengths: torch.Tensor, @@ -865,7 +866,7 @@ def keyed_jagged_index_select_dim1_forward_cuda_impl_abstract( ] -@impl_abstract("fbgemm::keyed_jagged_index_select_dim1_backward_cuda_impl") +@impl_abstract("fbgemm::keyed_jagged_index_select_dim1_backward") def keyed_jagged_index_select_dim1_backward_cuda_impl_abstract( grad: torch.Tensor, indices: torch.Tensor, diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/cache/__init__.py b/fbgemm_gpu/fbgemm_gpu/tbe/cache/__init__.py index a9fdb3b99..29316fda1 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/cache/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/cache/__init__.py @@ -4,3 +4,5 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +from .split_embeddings_cache_ops import get_unique_indices # noqa: F401 diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h index 54fd509dd..85a557623 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h @@ -111,6 +111,11 @@ inline bool torch_tensor_empty_or_on_cpu_check( #define DISPATCH_TO_AUTOGRAD(name, function) \ m.impl(name, torch::dispatch(c10::DispatchKey::Autograd, TORCH_FN(function))) +#define DISPATCH_TO_AUTOGRAD_CPU(name, function) \ + m.impl( \ + name, \ + torch::dispatch(c10::DispatchKey::AutogradCPU, TORCH_FN(function))) + #define DISPATCH_TO_AUTOGRAD_CUDA(name, function) \ m.impl( \ name, \ diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp index 3b0a41180..ecc646ac6 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp @@ -1597,6 +1597,27 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "fbgemm_gpu.sparse_ops", "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py"); #endif + + m.set_python_module("fbgemm_gpu.sparse_ops"); + + m.def( + "keyed_jagged_index_select_dim1_forward(" + " Tensor values," + " Tensor lengths," + " Tensor offsets," + " Tensor indices," + " SymInt batch_size," + " Tensor? weights," + " SymInt? selected_lengths_sum) -> Tensor[]"); + + m.def( + "keyed_jagged_index_select_dim1_backward(" + " Tensor grad," + " Tensor indices," + " Tensor grad_offsets," + " Tensor output_offsets," + " Tensor saved_tensor) -> Tensor"); + // (dense, offsets) -> jagged. Offsets output is same as input. // SymInt is a new PyTorch 2.0 feature to support dynamic shape. See more // details at https://pytorch.org/get-started/pytorch-2.0/#dynamic-shapes. If diff --git a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu index 2ed8473e1..bafc11111 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu @@ -380,7 +380,7 @@ class KeyedJaggedIndexSelectDim1GPUOp static auto forward_op_impl = c10::Dispatcher::singleton() .findSchemaOrThrow( - "fbgemm::keyed_jagged_index_select_dim1_forward_cuda_impl", "") + "fbgemm::keyed_jagged_index_select_dim1_forward", "") .typed(); auto res = forward_op_impl.call( @@ -501,7 +501,7 @@ class KeyedJaggedIndexSelectDim1GPUOp static auto backward_op = c10::Dispatcher::singleton() .findSchemaOrThrow( - "fbgemm::keyed_jagged_index_select_dim1_backward_cuda_impl", "") + "fbgemm::keyed_jagged_index_select_dim1_backward", "") .typed(); auto grad_input = backward_op.call( @@ -542,29 +542,13 @@ std::vector keyed_jagged_index_select_dim_1_gpu( TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.set_python_module("fbgemm_gpu.sparse_ops"); - m.def( - "keyed_jagged_index_select_dim1_forward_cuda_impl(" - "Tensor values," - "Tensor lengths," - "Tensor offsets," - "Tensor indices," - "SymInt batch_size," - "Tensor? weights," - "SymInt? selected_lengths_sum) -> Tensor[]"); - - m.def( - "keyed_jagged_index_select_dim1_backward_cuda_impl(" - "Tensor grad," - "Tensor indices," - "Tensor grad_offsets," - "Tensor output_offsets," - "Tensor saved_tensor) -> Tensor"); DISPATCH_TO_CUDA( - "keyed_jagged_index_select_dim1_forward_cuda_impl", + "keyed_jagged_index_select_dim1_forward", fbgemm_gpu::KeyedJaggedIndexSelectDim1GPUOp::forward_impl); + DISPATCH_TO_CUDA( - "keyed_jagged_index_select_dim1_backward_cuda_impl", + "keyed_jagged_index_select_dim1_backward", fbgemm_gpu::KeyedJaggedIndexSelectDim1GPUOp::backward_impl); DISPATCH_TO_AUTOGRAD_CUDA(