-
Notifications
You must be signed in to change notification settings - Fork 479
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix batch index select ops registration (#2598)
Summary: - Fix batch index select ops registration Pull Request resolved: #2598 Reviewed By: sryap Differential Revision: D57482140 Pulled By: q10 fbshipit-source-id: 50c6ca29f62a6a423383587404c4eb48ff9c4b15
- Loading branch information
1 parent
578ab67
commit 1c0344f
Showing
9 changed files
with
167 additions
and
138 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
125 changes: 125 additions & 0 deletions
125
fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_ops.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include "fbgemm_gpu/dispatch_macros.h" | ||
#include "fbgemm_gpu/sparse_ops_utils.h" | ||
|
||
TORCH_LIBRARY_FRAGMENT(fb, m) { | ||
m.def( | ||
"batch_index_select_dim0(" | ||
" Tensor inputs," | ||
" Tensor indices," | ||
" SymInt[] input_num_indices," | ||
" SymInt[] input_rows," | ||
" SymInt[] input_columns," | ||
" bool permute_output_dim_0_1=False) -> Tensor"); | ||
} | ||
|
||
TORCH_LIBRARY_FRAGMENT(fbgemm, m) { | ||
m.set_python_module("fbgemm_gpu.sparse_ops"); | ||
|
||
m.impl_abstract_pystub( | ||
"fbgemm_gpu.sparse_ops", | ||
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py"); | ||
|
||
m.def( | ||
"batch_index_select_dim0(" | ||
" Tensor inputs," | ||
" Tensor indices," | ||
" SymInt[] input_num_indices," | ||
" SymInt[] input_rows," | ||
" SymInt[] input_columns," | ||
" bool permute_output_dim_0_1=False) -> Tensor", | ||
{PT2_COMPLIANT_TAG}); | ||
|
||
m.def( | ||
"batch_index_select_dim0_tensor(" | ||
" Tensor inputs," | ||
" Tensor indices," | ||
" Tensor input_num_indices," | ||
" Tensor input_rows," | ||
" Tensor input_columns," | ||
" bool permute_output_dim_0_1=False) -> Tensor"); | ||
|
||
m.def( | ||
"batch_index_select_dim0_forward_cpu_impl(" | ||
" Tensor inputs," | ||
" Tensor indices," | ||
" SymInt[] input_num_indices," | ||
" SymInt[] input_rows," | ||
" SymInt[] input_columns," | ||
" bool permute_output_dim_0_1) -> Tensor[]"); | ||
|
||
m.def( | ||
"batch_index_select_dim0_backward_cpu_impl(" | ||
" Tensor grad_output," | ||
" Tensor indices," | ||
" Tensor indices_numels," | ||
" Tensor input_num_indices," | ||
" Tensor input_rows," | ||
" Tensor input_columns," | ||
" bool permute_output_dim_0_1," | ||
" Tensor saved_tensor) -> Tensor"); | ||
|
||
m.def( | ||
"batch_index_select_dim0_tensor_forward_cpu_impl(" | ||
" Tensor inputs," | ||
" Tensor indices," | ||
" Tensor input_num_indices," | ||
" Tensor input_rows," | ||
" Tensor input_columns," | ||
" bool permute_output_dim_0_1) -> Tensor[]"); | ||
|
||
// CUDA ops | ||
|
||
m.def( | ||
"batch_index_select_dim0_forward_cuda_impl(" | ||
" Tensor inputs," | ||
" Tensor indices," | ||
" SymInt[] input_num_indices," | ||
" SymInt[] input_rows," | ||
" SymInt[] input_columns," | ||
" bool permute_output_dim_0_1) -> Tensor[]"); | ||
|
||
m.def( | ||
"batch_index_select_dim0_backward_cuda_impl(" | ||
" Tensor grad_output," | ||
" Tensor dev_weights," | ||
" Tensor weights_offsets," | ||
" Tensor D_offsets," | ||
" Tensor hash_size_cumsum," | ||
" Tensor indices," | ||
" int max_segment_length_per_warp," | ||
" Tensor grad_offsets," | ||
" Tensor total_L_offsets," | ||
" bool permute_output_dim_0_1," | ||
" Tensor saved_tensor) -> Tensor"); | ||
|
||
m.def( | ||
"batch_index_select_dim0_tensor_forward_cuda_impl(" | ||
" Tensor inputs," | ||
" Tensor indices," | ||
" Tensor input_num_indices," | ||
" Tensor input_rows," | ||
" Tensor input_columns," | ||
" bool permute_output_dim_0_1) -> Tensor[]"); | ||
|
||
m.def( | ||
"batch_index_select_dim0_tensor_backward_cuda_impl(" | ||
" Tensor grad_output," | ||
" Tensor dev_weights," | ||
" Tensor weights_offsets," | ||
" Tensor D_offsets," | ||
" Tensor hash_size_cumsum," | ||
" Tensor indices," | ||
" int max_segment_length_per_warp," | ||
" Tensor grad_offsets," | ||
" Tensor total_L_offsets," | ||
" bool permute_output_dim_0_1," | ||
" Tensor saved_tensor) -> Tensor"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.