From 90ac420dc1fde32911c7f9b432fc33d732177082 Mon Sep 17 00:00:00 2001 From: Ivan Kobzarev Date: Wed, 15 May 2024 23:55:08 -0700 Subject: [PATCH] Make batch_index_select_dim0 cuda pt2 autograd compatible (#2591) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2591 Making batch_index_select_dim0 with custom autograd functions PT2 traceable. The main flow is described https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9 Introducing batch_index_select_dim0_tensor operator that accepts Tensor instead of List[int]. Reviewed By: williamwen42 Differential Revision: D57232975 fbshipit-source-id: 5aa9dc67cc897611b29353502efd9cbe4f832c4e --- .../batch_index_select_dim0_cpu_host.cpp | 562 +++++++++--- .../batch_index_select_dim0_host.cpp | 832 ++++++++++++++---- fbgemm_gpu/fbgemm_gpu/sparse_ops.py | 186 ++++ fbgemm_gpu/test/sparse/failures_dict.json | 15 +- fbgemm_gpu/test/sparse/index_select_test.py | 34 +- 5 files changed, 1311 insertions(+), 318 deletions(-) diff --git a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp index 51470bd17..752a95ef5 100644 --- a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_cpu_host.cpp @@ -11,136 +11,282 @@ #include #include +#include +#include "fbgemm_gpu/dispatch_macros.h" #include "fbgemm_gpu/embedding_common.h" #include "fbgemm_gpu/sparse_ops.h" #include "fbgemm_gpu/sparse_ops_utils.h" +#include +#include + using Tensor = at::Tensor; using namespace fbgemm_gpu; +namespace { +Tensor tensor_from_vec(const std::vector& vec) { + auto tensor = at::empty( + {static_cast(vec.size())}, + at::TensorOptions().dtype(torch::kInt64)); + TORCH_CHECK(tensor.is_contiguous()); + std::memcpy( + tensor.data_ptr(), vec.data(), sizeof(int64_t) * vec.size()); + return tensor; +}; + +std::vector vecref_from_tensor(const Tensor& t) { + TORCH_CHECK(t.is_contiguous()); + const auto numel = static_cast(t.numel()); + const auto* ptr = t.data_ptr(); + return std::vector(ptr, ptr + numel); +}; + +} // namespace + class BatchIndexSelectDim0CPUOp : public torch::autograd::Function { public: - static torch::autograd::variable_list forward( - torch::autograd::AutogradContext* ctx, + static torch::autograd::variable_list forward_impl( const Tensor& inputs, const Tensor& indices, - const std::vector& input_num_indices, - const std::vector& input_rows, - const std::vector& input_columns, + const c10::SymIntArrayRef _input_num_indices, + const c10::SymIntArrayRef _input_rows, + const c10::SymIntArrayRef _input_columns, const bool permute_output_dim_0_1) { - const int64_t num_inputs = input_num_indices.size(); - ctx->save_for_backward({indices}); + const int64_t num_inputs = _input_num_indices.size(); + TORCH_CHECK( + num_inputs == static_cast(_input_rows.size()), + "[batch_index_select_dim0] input_rows must have the same length as " + "input_num_indices."); + TORCH_CHECK( + num_inputs == static_cast(_input_columns.size()), + "[batch_index_select_dim0] input_columns must have the same length as " + "input_num_indices."); + TORCH_CHECK( + reinterpret_cast(inputs.data_ptr()) % 16 == 0, + "Currently batch_index_select only supports 16-byte align input tensors"); - ctx->saved_data["input_numel"] = inputs.numel(); - ctx->saved_data["input_num_indices"] = input_num_indices; - ctx->saved_data["input_rows"] = input_rows; - ctx->saved_data["input_columns"] = input_columns; - ctx->saved_data["permute_output_dim_0_1"] = permute_output_dim_0_1; + static auto to_vec_int64 = + [](const c10::SymIntArrayRef& sym_vec) -> std::vector { + std::vector vec; + std::transform( + sym_vec.begin(), + sym_vec.end(), + std::back_inserter(vec), + [](const auto& symint) { + return symint.guard_int(__FILE__, __LINE__); + }); + return vec; + }; - // Early exit - if (inputs.numel() == 0) { - return {at::empty({0}, inputs.options())}; - } + Tensor ret; + Tensor indices_numels_tensor; + std::vector input_num_indices; + std::vector input_rows; + std::vector input_columns; - // Compute section sizes for splitting tensors - std::vector input_numels; - std::vector indices_numels; - input_numels.reserve(num_inputs); - indices_numels.reserve(num_inputs); - for (auto i = 0; i < num_inputs; i++) { - input_numels.push_back(input_rows[i] * input_columns[i]); - indices_numels.push_back(input_num_indices[i]); - } + Tensor input_num_indices_tensor; + Tensor input_columns_tensor; + Tensor input_rows_tensor; - ctx->saved_data["indices_numels"] = indices_numels; + // Early exit + if (inputs.numel() == 0) { + ret = at::empty({0}, inputs.options()); + } else { + input_num_indices = to_vec_int64(_input_num_indices); + input_num_indices_tensor = tensor_from_vec(input_num_indices); + input_rows = to_vec_int64(_input_rows); + input_columns = to_vec_int64(_input_columns); + input_columns_tensor = tensor_from_vec(input_columns); + input_rows_tensor = tensor_from_vec(input_rows); - // Split tensors into vectors - const auto inputs_ = at::split_with_sizes(inputs, input_numels, 0); - const auto indices_ = at::split_with_sizes(indices, indices_numels, 0); + TORCH_CHECK( + torch::all(torch::gt(input_columns_tensor, 0)).item(), + "[batch_index_select_dim0] All input_columns must be the same."); + TORCH_CHECK( + torch::all(torch::gt(input_rows_tensor, 0)).item(), + "[batch_index_select_dim0] All input_rows must be the same."); - std::vector outputs; - outputs.reserve(num_inputs); - for (auto i = 0; i < num_inputs; i++) { - const auto input = inputs_[i].view({input_rows[i], input_columns[i]}); - const auto index = indices_[i]; - const auto output = at::index_select(input, 0, index); if (permute_output_dim_0_1) { - outputs.push_back(output); + // All output rows must be the same + TORCH_CHECK(input_num_indices[0] > 0); + TORCH_CHECK( + torch::all( + torch::eq(input_num_indices_tensor, input_num_indices[0])) + .item(), + "[batch_index_select_dim0] All input_num_indices must be the same if " + "permute_output_dim_0_1 is true."); } else { - outputs.push_back(output.flatten()); + TORCH_CHECK( + torch::all(torch::gt(input_num_indices_tensor, 0)).item(), + "[batch_index_select_dim0] All input_num_indices must be greater than zero."); + } + + // Compute section sizes for splitting tensors + std::vector input_numels; + std::vector indices_numels; + input_numels.reserve(num_inputs); + indices_numels.reserve(num_inputs); + for (auto i = 0; i < num_inputs; i++) { + input_numels.push_back(input_rows[i] * input_columns[i]); + indices_numels.push_back(input_num_indices[i]); + } + indices_numels_tensor = tensor_from_vec(indices_numels); + + // Split tensors into vectors + const auto inputs_ = at::split_with_sizes(inputs, input_numels, 0); + const auto indices_ = at::split_with_sizes(indices, indices_numels, 0); + + std::vector outputs; + outputs.reserve(num_inputs); + for (auto i = 0; i < num_inputs; i++) { + const auto input = inputs_[i].view({input_rows[i], input_columns[i]}); + const auto index = indices_[i]; + const auto output = at::index_select(input, 0, index); + if (permute_output_dim_0_1) { + outputs.push_back(output); + } else { + outputs.push_back(output.flatten()); + } } + + // permute_output_dim_0_1 = true shape: (batch_size, num_inputs, cols) + // permute_output_dim_0_1 = false shape: (num_inputs, batch_size cols) + ret = at::concat(outputs, permute_output_dim_0_1 ? 1 : 0).flatten(); } - // permute_output_dim_0_1 = true shape: (batch_size, num_inputs, cols) - // permute_output_dim_0_1 = false shape: (num_inputs, batch_size cols) - return {at::concat(outputs, permute_output_dim_0_1 ? 1 : 0).flatten()}; + auto saved_data_tensor = tensor_from_vec({inputs.numel()}); + + return { + ret, + + input_num_indices_tensor, + input_rows_tensor, + input_columns_tensor, + + indices_numels_tensor, + saved_data_tensor}; } - static torch::autograd::variable_list backward( + static torch::autograd::variable_list forward( torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_outputs) { - using torch::autograd::Variable; + const Tensor& inputs, + const Tensor& indices, + const c10::SymIntArrayRef input_num_indices, + const c10::SymIntArrayRef input_rows, + const c10::SymIntArrayRef input_columns, + const bool permute_output_dim_0_1) { + at::AutoDispatchBelowADInplaceOrView guard; + static auto forward_op_impl = + torch::Dispatcher::singleton() + .findSchemaOrThrow( + "fbgemm::batch_index_select_dim0_forward_cpu_impl", "") + .typed(); - TORCH_CHECK_EQ(grad_outputs.size(), 1); + auto res = forward_op_impl.call( + inputs, + indices, + input_num_indices, + input_rows, + input_columns, + permute_output_dim_0_1); + ctx->saved_data["permute_output_dim_0_1"] = permute_output_dim_0_1; + ctx->save_for_backward( + std::vector{indices, res[1], res[2], res[3], res[4], res[5]}); + res.resize(1); + return res; + } - const auto grad_output = grad_outputs[0]; - const auto input_numel = ctx->saved_data["input_numel"].toInt(); + static Tensor backward_impl( + const Tensor& grad_output, + const Tensor& indices, + const Tensor& indices_numels, + const Tensor& input_num_indices, + const Tensor& input_rows, + const Tensor& input_columns, + const bool permute_output_dim_0_1, + const Tensor& saved_tensor) { + const int64_t input_numel = saved_tensor[0].item(); // Early exit if (input_numel == 0) { - return { - at::empty({0}, grad_output.options()), - Variable(), // indices - Variable(), // input_num_indices - Variable(), // input_rows - Variable(), // input_columns - Variable() // permute_output_dim_0_1 - }; + return at::empty({0}, grad_output.options()); } + const int64_t num_inputs = input_num_indices.size(0); - const auto saved = ctx->get_saved_variables(); - auto indices = *std::begin(saved); - - const auto input_num_indices = - ctx->saved_data["input_num_indices"].toIntVector(); - const auto input_rows = ctx->saved_data["input_rows"].toIntVector(); - const auto input_cols = ctx->saved_data["input_columns"].toIntVector(); - const auto permute_output_dim_0_1 = - ctx->saved_data["permute_output_dim_0_1"].toBool(); - const auto indices_numels = ctx->saved_data["indices_numels"].toIntVector(); - - const int64_t num_inputs = input_num_indices.size(); + auto input_num_indices_vec = vecref_from_tensor(input_num_indices); + auto input_rows_vec = vecref_from_tensor(input_rows); + auto input_columns_vec = vecref_from_tensor(input_columns); std::vector grads; if (permute_output_dim_0_1) { grads = at::split_with_sizes( - grad_output.view({input_num_indices[0], -1}), input_cols, 1); + grad_output.view({input_num_indices_vec[0], -1}), + input_columns_vec, + 1); } else { std::vector grad_numels; grad_numels.reserve(num_inputs); for (auto i = 0; i < num_inputs; i++) { - grad_numels.push_back(input_num_indices[i] * input_cols[i]); + grad_numels.push_back(input_num_indices_vec[i] * input_columns_vec[i]); } grads = at::split_with_sizes(grad_output, grad_numels, 0); } - const auto indices_ = at::split_with_sizes(indices, indices_numels, 0); + const auto indices_ = + at::split_with_sizes(indices, vecref_from_tensor(indices_numels), 0); std::vector grad_inputs; grad_inputs.reserve(num_inputs); for (auto i = 0; i < num_inputs; i++) { - const auto num_indices = input_num_indices[i]; - const auto grad_input = - at::zeros({input_rows[i], input_cols[i]}, grad_output.options()); + const auto num_indices = input_num_indices_vec[i]; + const auto grad_input = at::zeros( + {input_rows_vec[i], input_columns_vec[i]}, grad_output.options()); const auto grad = permute_output_dim_0_1 ? grads[i] : grads[i].view({num_indices, -1}); grad_inputs.push_back( at::index_add(grad_input, 0, indices_[i], grad).flatten()); } + return at::concat(grad_inputs, 0); + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_outputs) { + using torch::autograd::Variable; + TORCH_CHECK_EQ(grad_outputs.size(), 1); + + const auto grad_output = grad_outputs[0]; + const auto permute_output_dim_0_1 = + ctx->saved_data["permute_output_dim_0_1"].toBool(); + const auto saved = ctx->get_saved_variables(); + + auto savedItr = std::begin(saved); + auto indices = *savedItr++; + + auto input_num_indices = *savedItr++; + auto input_rows = *savedItr++; + auto input_columns = *savedItr++; + + auto indices_numels = *savedItr++; + auto saved_tensor = *savedItr++; + static auto backward_op = + torch::Dispatcher::singleton() + .findSchemaOrThrow( + "fbgemm::batch_index_select_dim0_backward_cpu_impl", "") + .typed(); + auto ret = backward_op.call( + grad_output, + indices, + indices_numels, + input_num_indices, + input_rows, + input_columns, + permute_output_dim_0_1, + saved_tensor); return { - at::concat(grad_inputs, 0), + ret, Variable(), // indices Variable(), // input_num_indices Variable(), // input_rows @@ -150,58 +296,171 @@ class BatchIndexSelectDim0CPUOp } }; -Tensor batch_index_select_dim0_cpu( - Tensor inputs, - Tensor indices, - std::vector input_num_indices, - std::vector input_rows, - std::vector input_columns, - // Permute dim 0 and 1 of the output tensor - const bool permute_output_dim_0_1) { - const int64_t num_inputs = input_num_indices.size(); - TORCH_CHECK( - num_inputs == static_cast(input_rows.size()), - "[batch_index_select_dim0] input_rows must have the same length as " - "input_num_indices."); - TORCH_CHECK( - num_inputs == static_cast(input_columns.size()), - "[batch_index_select_dim0] input_columns must have the same length as " - "input_num_indices."); - - TORCH_CHECK( - reinterpret_cast(inputs.data_ptr()) % 16 == 0, - "Currently batch_index_select only supports 16-byte align input tensors"); - - const auto int_opts = torch::TensorOptions().dtype(torch::kInt64); - const auto num_cols = - torch::from_blob(input_columns.data(), {num_inputs}, int_opts); - const auto input_num_rows = - torch::from_blob(input_rows.data(), {num_inputs}, int_opts); - const auto output_num_rows = - torch::from_blob(input_num_indices.data(), {num_inputs}, int_opts); - - if (num_inputs > 0) { +class BatchIndexSelectDim0TensorCPUOp + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward_impl( + const Tensor& inputs, + const Tensor& indices, + const Tensor& input_num_indices, + const Tensor& input_rows, + const Tensor& input_columns, + const bool permute_output_dim_0_1) { + const int64_t num_inputs = input_num_indices.size(0); + TORCH_CHECK( + num_inputs == input_rows.size(0), + "[batch_index_select_dim0] input_rows must have the same length as " + "input_num_indices."); + TORCH_CHECK( + num_inputs == input_columns.size(0), + "[batch_index_select_dim0] input_columns must have the same length as " + "input_num_indices."); + TORCH_CHECK( + reinterpret_cast(inputs.data_ptr()) % 16 == 0, + "Currently batch_index_select only supports 16-byte align input tensors"); + + auto saved_data_tensor = tensor_from_vec({inputs.numel()}); + + // Early exit + if (inputs.numel() == 0) { + return {at::empty({0}, inputs.options()), saved_data_tensor}; + } + TORCH_CHECK( - torch::all(torch::gt(num_cols, 0)).item(), + torch::all(torch::gt(input_columns, 0)).item(), "[batch_index_select_dim0] All input_columns must be the same."); TORCH_CHECK( - torch::all(torch::gt(input_num_rows, 0)).item(), + torch::all(torch::gt(input_rows, 0)).item(), "[batch_index_select_dim0] All input_rows must be the same."); + if (permute_output_dim_0_1) { // All output rows must be the same - TORCH_CHECK(input_num_indices[0] > 0); + const auto item0 = input_num_indices[0].item(); + TORCH_CHECK(item0 > 0); TORCH_CHECK( - torch::all(torch::eq(output_num_rows, input_num_indices[0])) - .item(), + torch::all(torch::eq(input_num_indices, item0)).item(), "[batch_index_select_dim0] All input_num_indices must be the same if " "permute_output_dim_0_1 is true."); } else { TORCH_CHECK( - torch::all(torch::gt(output_num_rows, 0)).item(), + torch::all(torch::gt(input_num_indices, 0)).item(), "[batch_index_select_dim0] All input_num_indices must be greater than zero."); } + + const auto input_numels = at::mul(input_rows, input_columns); + + // Split tensors into vectors + const auto inputs_ = + at::split_with_sizes(inputs, vecref_from_tensor(input_numels), 0); + const auto indices_ = + at::split_with_sizes(indices, vecref_from_tensor(input_num_indices), 0); + + const auto input_rows_vec = vecref_from_tensor(input_rows); + const auto input_columns_vec = vecref_from_tensor(input_columns); + + std::vector outputs; + outputs.reserve(num_inputs); + for (auto i = 0; i < num_inputs; i++) { + const auto input = + inputs_[i].view({input_rows_vec[i], input_columns_vec[i]}); + const auto index = indices_[i]; + const auto output = at::index_select(input, 0, index); + if (permute_output_dim_0_1) { + outputs.push_back(output); + } else { + outputs.push_back(output.flatten()); + } + } + + // permute_output_dim_0_1 = true shape: (batch_size, num_inputs, cols) + // permute_output_dim_0_1 = false shape: (num_inputs, batch_size cols) + + return { + at::concat(outputs, permute_output_dim_0_1 ? 1 : 0).flatten(), + saved_data_tensor}; + } + + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const Tensor& inputs, + const Tensor& indices, + const Tensor& input_num_indices, + const Tensor& input_rows, + const Tensor& input_columns, + const bool permute_output_dim_0_1) { + at::AutoDispatchBelowADInplaceOrView guard; + static auto forward_op_impl = + torch::Dispatcher::singleton() + .findSchemaOrThrow( + "fbgemm::batch_index_select_dim0_tensor_forward_cpu_impl", "") + .typed(); + + auto res = forward_op_impl.call( + inputs, + indices, + input_num_indices, + input_rows, + input_columns, + permute_output_dim_0_1); + ctx->saved_data["permute_output_dim_0_1"] = permute_output_dim_0_1; + ctx->save_for_backward(std::vector{ + indices, input_num_indices, input_rows, input_columns, res[1]}); + res.resize(1); + return res; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_outputs) { + using torch::autograd::Variable; + TORCH_CHECK_EQ(grad_outputs.size(), 1); + + const auto grad_output = grad_outputs[0]; + const auto permute_output_dim_0_1 = + ctx->saved_data["permute_output_dim_0_1"].toBool(); + const auto saved = ctx->get_saved_variables(); + + auto savedItr = std::begin(saved); + + auto indices = *savedItr++; + auto input_num_indices = *savedItr++; + auto input_rows = *savedItr++; + auto input_columns = *savedItr++; + auto saved_tensor = *savedItr++; + + static auto backward_op = + torch::Dispatcher::singleton() + .findSchemaOrThrow( + "fbgemm::batch_index_select_dim0_backward_cpu_impl", "") + .typed(); + auto ret = backward_op.call( + grad_output, + indices, + input_num_indices, + input_num_indices, + input_rows, + input_columns, + permute_output_dim_0_1, + saved_tensor); + return { + ret, + Variable(), // indices + Variable(), // input_num_indices + Variable(), // input_rows + Variable(), // input_columns + Variable() // permute_output_dim_0_1 + }; } +}; +Tensor batch_index_select_dim0_cpu_autograd( + Tensor inputs, + Tensor indices, + const c10::SymIntArrayRef input_num_indices, + const c10::SymIntArrayRef input_rows, + const c10::SymIntArrayRef input_columns, + // Permute dim 0 and 1 of the output tensor + const bool permute_output_dim_0_1) { return BatchIndexSelectDim0CPUOp::apply( inputs, indices, @@ -211,6 +470,23 @@ Tensor batch_index_select_dim0_cpu( permute_output_dim_0_1)[0]; } +Tensor batch_index_select_dim0_tensor_cpu_autograd( + const Tensor& inputs, + const Tensor& indices, + const Tensor& input_num_indices, + const Tensor& input_rows, + const Tensor& input_columns, + // Permute dim 0 and 1 of the output tensor + const bool permute_output_dim_0_1) { + return BatchIndexSelectDim0TensorCPUOp::apply( + inputs, + indices, + input_num_indices, + input_rows, + input_columns, + permute_output_dim_0_1)[0]; +} + // Deprecated for fb namespace! Please use fbgemm namespace instead! TORCH_LIBRARY_FRAGMENT(fb, m) { m.def( @@ -221,13 +497,15 @@ TORCH_LIBRARY_FRAGMENT(fb, m) { " SymInt[] input_rows," " SymInt[] input_columns," " bool permute_output_dim_0_1=False) -> Tensor"); - DISPATCH_TO_CPU("batch_index_select_dim0", batch_index_select_dim0_cpu); + DISPATCH_TO_CPU( + "batch_index_select_dim0", batch_index_select_dim0_cpu_autograd); } TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.impl_abstract_pystub( "fbgemm_gpu.sparse_ops", "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py"); + m.def( "batch_index_select_dim0(" " Tensor inputs," @@ -235,6 +513,66 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " SymInt[] input_num_indices," " SymInt[] input_rows," " SymInt[] input_columns," + " bool permute_output_dim_0_1=False) -> Tensor", + {PT2_COMPLIANT_TAG}); + + m.def( + "batch_index_select_dim0_forward_cpu_impl(" + "Tensor inputs," + "Tensor indices," + "SymInt[] input_num_indices," + "SymInt[] input_rows," + "SymInt[] input_columns," + "bool permute_output_dim_0_1) -> Tensor[]"); + + m.def( + "batch_index_select_dim0_backward_cpu_impl(" + "Tensor grad_output," + "Tensor indices," + "Tensor indices_numels," + "Tensor input_num_indices," + "Tensor input_rows," + "Tensor input_columns," + "bool permute_output_dim_0_1," + "Tensor saved_tensor) -> Tensor"); + + m.def( + "batch_index_select_dim0_tensor(" + " Tensor inputs," + " Tensor indices," + " Tensor input_num_indices," + " Tensor input_rows," + " Tensor input_columns," " bool permute_output_dim_0_1=False) -> Tensor"); - DISPATCH_TO_CPU("batch_index_select_dim0", batch_index_select_dim0_cpu); + + m.def( + "batch_index_select_dim0_tensor_forward_cpu_impl(" + "Tensor inputs," + "Tensor indices," + "Tensor input_num_indices," + "Tensor input_rows," + "Tensor input_columns," + "bool permute_output_dim_0_1) -> Tensor[]"); + + DISPATCH_TO_CPU( + "batch_index_select_dim0_forward_cpu_impl", + BatchIndexSelectDim0CPUOp::forward_impl); + DISPATCH_TO_CPU( + "batch_index_select_dim0_tensor_forward_cpu_impl", + BatchIndexSelectDim0TensorCPUOp::forward_impl); + + DISPATCH_TO_CPU( + "batch_index_select_dim0_backward_cpu_impl", + BatchIndexSelectDim0CPUOp::backward_impl); + + m.impl( + "batch_index_select_dim0", + torch::dispatch( + c10::DispatchKey::AutogradCPU, + TORCH_FN(batch_index_select_dim0_cpu_autograd))); + m.impl( + "batch_index_select_dim0_tensor", + torch::dispatch( + c10::DispatchKey::AutogradCPU, + TORCH_FN(batch_index_select_dim0_tensor_cpu_autograd))); } diff --git a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp index f0d10e682..e2fd8f376 100644 --- a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp @@ -7,6 +7,7 @@ */ #include +#include #include #include #include @@ -51,55 +52,267 @@ Tensor batch_index_select_dim0_codegen_backward_cuda( class BatchIndexSelectDim0GPUOp : public torch::autograd::Function { public: + static torch::autograd::variable_list forward_impl( + Tensor inputs, + Tensor indices, + c10::SymIntArrayRef _input_num_indices, + c10::SymIntArrayRef _input_rows, + c10::SymIntArrayRef _input_columns, + // Permute dim 0 and 1 of the output tensor + const bool permute_output_dim_0_1) { + auto to_vec_int64 = + [](const c10::SymIntArrayRef& sym_vec) -> std::vector { + std::vector vec; + std::transform( + sym_vec.begin(), + sym_vec.end(), + std::back_inserter(vec), + [](const auto& symint) { + return symint.guard_int(__FILE__, __LINE__); + }); + return vec; + }; + auto input_num_indices = to_vec_int64(_input_num_indices); + auto input_rows = to_vec_int64(_input_rows); + auto input_columns = to_vec_int64(_input_columns); + TORCH_CHECK(input_num_indices.size() == _input_num_indices.size()); + + // From the empirical study, this value provides the best perf + constexpr int64_t ROWS_PER_WARP = 1; + const int64_t num_inputs = input_num_indices.size(); + + TORCH_CHECK( + num_inputs == static_cast(input_rows.size()), + "[batch_index_select_dim0] input_rows must have the same length as " + "input_num_indices."); + TORCH_CHECK( + num_inputs == static_cast(input_columns.size()), + "[batch_index_select_dim0] input_columns must have the same length as " + "input_num_indices."); + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(inputs, indices); + + TORCH_CHECK( + reinterpret_cast(inputs.data_ptr()) % 16 == 0, + "Currently batch_index_select only supports 16-byte align input tensors"); + + const auto int_opts = torch::TensorOptions().dtype(torch::kInt64); + const auto num_cols = + torch::from_blob(input_columns.data(), {num_inputs}, int_opts); + const auto max_col = num_inputs > 0 ? num_cols.max().item() : 0; + const auto input_num_rows = + torch::from_blob(input_rows.data(), {num_inputs}, int_opts); + const auto output_num_rows = + torch::from_blob(input_num_indices.data(), {num_inputs}, int_opts); + + if (num_inputs > 0) { + TORCH_CHECK( + torch::all(torch::gt(num_cols, 0)).item(), + "[batch_index_select_dim0] All input_columns must be the same."); + TORCH_CHECK( + torch::all(torch::gt(input_num_rows, 0)).item(), + "[batch_index_select_dim0] All input_rows must be the same."); + if (permute_output_dim_0_1) { + // All output rows must be the same + TORCH_CHECK(input_num_indices[0] > 0); + TORCH_CHECK( + torch::all(torch::eq(output_num_rows, input_num_indices[0])) + .item(), + "[batch_index_select_dim0] All input_num_indices must be the same if " + "permute_output_dim_0_1 is true."); + } else { + TORCH_CHECK( + torch::all(torch::gt(output_num_rows, 0)).item(), + "[batch_index_select_dim0] All input_num_indices must be greater than zero."); + } + } + + const auto max_output_num_rows = + num_inputs > 0 ? output_num_rows.max().item() : 0; + + const auto input_numels = input_num_rows * num_cols; + const auto output_numels = + permute_output_dim_0_1 ? Tensor() : (output_num_rows * num_cols); + + // Takes ~1.2 ms for num_inputs = 1024 on CPU + auto D_offsets = fbgemm_gpu::asynchronous_complete_cumsum_cpu(num_cols).to( + torch::kInt32); + auto input_offsets = + fbgemm_gpu::asynchronous_complete_cumsum_cpu(input_numels); + auto input_row_offsets = + fbgemm_gpu::asynchronous_complete_cumsum_cpu(input_num_rows); + auto total_L_offsets = + fbgemm_gpu::asynchronous_complete_cumsum_cpu(output_num_rows); + int64_t total_hash_size_bits = + std::log2(static_cast(input_row_offsets[-1].item())) + + 1; + input_offsets = + torch::narrow(input_offsets, 0, 0, input_offsets.numel() - 1); + const int64_t num_warps_per_input = + (max_output_num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + + // Transfer helper tensors to GPU + const auto device = inputs.device(); + constexpr bool non_blocking = true; + + int64_t output_size; + Tensor output_offsets; + if (permute_output_dim_0_1) { + // output_offsets is not required because the output tensor is not jagged + output_offsets = at::empty({0}, inputs.options().dtype(at::kLong)); + output_size = num_inputs > 0 + ? (input_num_indices[0] * D_offsets[-1].item()) + : 0; + } else { + output_offsets = + fbgemm_gpu::asynchronous_complete_cumsum_cpu(output_numels); + output_size = output_offsets[-1].item(); + output_offsets = output_offsets.to(device, non_blocking); + } + + D_offsets = D_offsets.to(device, non_blocking); + input_offsets = input_offsets.to(device, non_blocking); + input_row_offsets = input_row_offsets.to(device, non_blocking); + total_L_offsets = total_L_offsets.to(device, non_blocking); + + const auto sparse_type = fbgemm_gpu::getSparseType(inputs.scalar_type()); + TORCH_CHECK( + sparse_type == SparseType::FP32 || sparse_type == SparseType::FP16, + "batch_index_select_dim0 supports only either float or half") + + const auto output_dtype = + static_cast(fbgemm_gpu::getSparseType(inputs.scalar_type())); + + const auto max_D = max_col; + const auto fixed_L_per_warp = ROWS_PER_WARP; + + auto output = inputs.numel() > 0 + ? batch_index_select_dim0_codegen_forward_cuda( + inputs, // dev_weights + input_offsets, // weights_offsets + D_offsets, + max_D, + indices, + output_dtype, + output_offsets, + total_L_offsets, + output_size, + fixed_L_per_warp, + num_warps_per_input, // num_warps_per_feature + permute_output_dim_0_1) + : at::empty({0}, inputs.options()); + + int64_t saved_data[] = { + max_D, + total_hash_size_bits, + fixed_L_per_warp, + num_warps_per_input, + }; + + auto saved_data_tensor = at::empty( + {sizeof(saved_data) / sizeof(int64_t)}, + at::TensorOptions().dtype(at::kLong)); + TORCH_CHECK(saved_data_tensor.is_contiguous()); + memcpy( + saved_data_tensor.data_ptr(), saved_data, sizeof(saved_data)); + + return { + output, // 0:op_output + input_offsets, // 1:weights_offsets + input_row_offsets, // 2:hash_size_cumsum, + D_offsets, // 3:D_offsets, + output_offsets, // 4:output_offsets, + total_L_offsets, // 5:total_L_offsets + saved_data_tensor, // 6:saved_data_tensor + }; + } + + // make scheme the same as main op static torch::autograd::variable_list forward( torch::autograd::AutogradContext* ctx, - const int64_t output_dtype, + Tensor inputs, + Tensor indices, + c10::SymIntArrayRef input_num_indices, + c10::SymIntArrayRef input_rows, + c10::SymIntArrayRef input_columns, + // Permute dim 0 and 1 of the output tensor + const bool permute_output_dim_0_1) { + at::AutoDispatchBelowADInplaceOrView guard; + static auto forward_op_impl = + torch::Dispatcher::singleton() + .findSchemaOrThrow( + "fbgemm::batch_index_select_dim0_forward_cuda_impl", "") + .typed(); + + auto res = forward_op_impl.call( + inputs, + indices, + input_num_indices, + input_rows, + input_columns, + permute_output_dim_0_1); + + // 0:op_output + // 1:weights_offsets, + // 2:hash_size_cumsum, + // 3:D_offsets, + // 4:output_offsets, + // 5:total_L_offsets + // 6:saved_data_tensor = [max_D, total_hash_size_bits, fixed_L_per_warp, + // num_warps_per_input] + + ctx->saved_data["permute_output_dim_0_1"] = permute_output_dim_0_1; + + ctx->save_for_backward(std::vector{ + inputs, indices, res[1], res[2], res[3], res[4], res[5], res[6]}); + + res.resize(1); + return res; + } + + static Tensor backward_impl( + const Tensor& grad_output, const Tensor& dev_weights, const Tensor& weights_offsets, + const Tensor& D_offsets, const Tensor& hash_size_cumsum, - const int64_t total_hash_size_bits, const Tensor& indices, - const Tensor& D_offsets, - const c10::SymInt max_D, - const Tensor& output_offsets, + const int64_t max_segment_length_per_warp, + const Tensor& grad_offsets, const Tensor& total_L_offsets, - const int64_t output_size, - const int64_t fixed_L_per_warp, - const int64_t num_warps_per_feature, - const bool permute_output_dim_0_1) { - ctx->save_for_backward( - {dev_weights, - weights_offsets, - hash_size_cumsum, - indices, - D_offsets, - output_offsets, - total_L_offsets}); - - ctx->saved_data["max_D"] = max_D; - ctx->saved_data["total_hash_size_bits"] = total_hash_size_bits; - ctx->saved_data["fixed_L_per_warp"] = fixed_L_per_warp; - ctx->saved_data["num_warps_per_feature"] = num_warps_per_feature; - ctx->saved_data["permute_output_dim_0_1"] = permute_output_dim_0_1; - - // Early exit + const bool permute_output_dim_0_1, + const Tensor& saved_tensor) { if (dev_weights.numel() == 0) { - return {at::empty({0}, dev_weights.options())}; + return at::empty({0}, dev_weights.options()); } - return {batch_index_select_dim0_codegen_forward_cuda( + auto _grad_output = grad_output; + // FIXME: to support aligned memory access in Vec4T load/store function + // 16 for FP32 and 8 for FP16 + if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0 || + at::has_internal_overlap(grad_output) != at::MemOverlap::No) { + _grad_output = at::empty_like(grad_output).copy_(grad_output); + } + + const auto max_D = saved_tensor[0].item(); + const auto total_hash_size_bits = saved_tensor[1].item(); + const auto fixed_L_per_warp = saved_tensor[2].item(); + const auto num_warps_per_feature = saved_tensor[3].item(); + + return batch_index_select_dim0_codegen_backward_cuda( + _grad_output, dev_weights, weights_offsets, D_offsets, max_D, + hash_size_cumsum, + total_hash_size_bits, indices, - output_dtype, - output_offsets, + max_segment_length_per_warp, + grad_offsets, total_L_offsets, - output_size, fixed_L_per_warp, num_warps_per_feature, - permute_output_dim_0_1)}; + permute_output_dim_0_1); } static torch::autograd::variable_list backward( @@ -107,71 +320,54 @@ class BatchIndexSelectDim0GPUOp torch::autograd::variable_list grad_outputs) { const auto saved = ctx->get_saved_variables(); auto savedItr = std::begin(saved); - auto dev_weights = *savedItr++; + auto dev_weights = *savedItr++; // inputs + auto indices = *savedItr++; // indices + auto weights_offsets = *savedItr++; auto hash_size_cumsum = *savedItr++; - auto indices = *savedItr++; auto D_offsets = *savedItr++; auto grad_offsets = *savedItr++; auto total_L_offsets = *savedItr++; - const auto max_D = ctx->saved_data["max_D"].toSymInt(); - const auto total_hash_size_bits = - ctx->saved_data["total_hash_size_bits"].toInt(); - const auto fixed_L_per_warp = ctx->saved_data["fixed_L_per_warp"].toInt(); - const auto num_warps_per_feature = - ctx->saved_data["num_warps_per_feature"].toInt(); + auto saved_tensor = *savedItr++; + const auto permute_output_dim_0_1 = ctx->saved_data["permute_output_dim_0_1"].toBool(); using torch::autograd::Variable; Tensor grad_dev_weights; - if (dev_weights.numel() == 0) { - grad_dev_weights = at::empty({0}, dev_weights.options()); - } else { - TORCH_CHECK_EQ(grad_outputs.size(), 1); + TORCH_CHECK_EQ(grad_outputs.size(), 1); - constexpr int32_t max_segment_length_per_warp = 32; + constexpr int32_t max_segment_length_per_warp = 32; - auto grad_output = grad_outputs[0]; - // FIXME: to support aligned memory access in Vec4T load/store function - // 16 for FP32 and 8 for FP16 - if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0) { - grad_output = at::empty_like(grad_output).copy_(grad_output); - } + auto grad_output = grad_outputs[0]; - grad_dev_weights = batch_index_select_dim0_codegen_backward_cuda( - grad_output, - dev_weights, - weights_offsets, - D_offsets, - max_D, - hash_size_cumsum, - total_hash_size_bits, - indices, - max_segment_length_per_warp, - grad_offsets, - total_L_offsets, - fixed_L_per_warp, - num_warps_per_feature, - permute_output_dim_0_1); - } + static auto backward_op = + torch::Dispatcher::singleton() + .findSchemaOrThrow( + "fbgemm::batch_index_select_dim0_backward_cuda_impl", "") + .typed(); + + auto res = backward_op.call( + grad_output, + dev_weights, + weights_offsets, + D_offsets, + hash_size_cumsum, + indices, + max_segment_length_per_warp, + grad_offsets, + total_L_offsets, + permute_output_dim_0_1, + saved_tensor); return { - Variable(), // output_dtype - grad_dev_weights, // grad_dev_weights - Variable(), // weights_offsets - Variable(), // hash_size_cumsum - Variable(), // total_hash_size_bits + res, // inputs Variable(), // indices - Variable(), // D_offsets - Variable(), // max_D - Variable(), // output_offsets - Variable(), // total_L_offsets - Variable(), // output_size - Variable(), // fixed_L_per_warp - Variable(), // num_warps_per_feature + Variable(), // input_num_indices + Variable(), // input_rows + Variable(), // input_columns Variable(), // permute_output_dim_0_1 }; } @@ -180,126 +376,333 @@ class BatchIndexSelectDim0GPUOp Tensor batch_index_select_dim0_gpu( Tensor inputs, Tensor indices, - std::vector input_num_indices, - std::vector input_rows, - std::vector input_columns, + c10::SymIntArrayRef input_num_indices, + c10::SymIntArrayRef input_rows, + c10::SymIntArrayRef input_columns, // Permute dim 0 and 1 of the output tensor const bool permute_output_dim_0_1) { - // From the empirical study, this value provides the best perf - constexpr int64_t ROWS_PER_WARP = 1; - const int64_t num_inputs = input_num_indices.size(); - TORCH_CHECK( - num_inputs == static_cast(input_rows.size()), - "[batch_index_select_dim0] input_rows must have the same length as " - "input_num_indices."); - TORCH_CHECK( - num_inputs == static_cast(input_columns.size()), - "[batch_index_select_dim0] input_columns must have the same length as " - "input_num_indices."); - TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(inputs, indices); - - TORCH_CHECK( - reinterpret_cast(inputs.data_ptr()) % 16 == 0, - "Currently batch_index_select only supports 16-byte align input tensors"); - - const auto int_opts = torch::TensorOptions().dtype(torch::kInt64); - const auto num_cols = - torch::from_blob(input_columns.data(), {num_inputs}, int_opts); - const auto max_col = num_inputs > 0 ? num_cols.max().item() : 0; - const auto input_num_rows = - torch::from_blob(input_rows.data(), {num_inputs}, int_opts); - const auto output_num_rows = - torch::from_blob(input_num_indices.data(), {num_inputs}, int_opts); - - if (num_inputs > 0) { + return BatchIndexSelectDim0GPUOp::apply( + inputs, + indices, + input_num_indices, + input_rows, + input_columns, + permute_output_dim_0_1)[0]; +} + +class BatchIndexSelectDim0TensorGPUOp + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward_impl( + const Tensor& inputs, + const Tensor& indices, + const Tensor& input_num_indices, + const Tensor& input_rows, + const Tensor& input_columns, + // Permute dim 0 and 1 of the output tensor + const bool permute_output_dim_0_1) { + // From the empirical study, this value provides the best perf + constexpr int64_t ROWS_PER_WARP = 1; + const int64_t num_inputs = input_num_indices.size(0); + TORCH_CHECK( - torch::all(torch::gt(num_cols, 0)).item(), - "[batch_index_select_dim0] All input_columns must be the same."); + num_inputs == input_rows.size(0), + "[batch_index_select_dim0] input_rows must have the same length as " + "input_num_indices."); TORCH_CHECK( - torch::all(torch::gt(input_num_rows, 0)).item(), - "[batch_index_select_dim0] All input_rows must be the same."); - if (permute_output_dim_0_1) { - // All output rows must be the same - TORCH_CHECK(input_num_indices[0] > 0); + num_inputs == input_columns.size(0), + "[batch_index_select_dim0] input_columns must have the same length as " + "input_num_indices."); + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(inputs, indices); + + TORCH_CHECK( + reinterpret_cast(inputs.data_ptr()) % 16 == 0, + "Currently batch_index_select only supports 16-byte align input tensors"); + + const auto num_cols = input_columns; + const auto max_col = + num_inputs > 0 ? input_columns.max().item() : 0; + const auto input_num_rows = input_rows; + const auto output_num_rows = input_num_indices; + + if (num_inputs > 0) { TORCH_CHECK( - torch::all(torch::eq(output_num_rows, input_num_indices[0])) - .item(), - "[batch_index_select_dim0] All input_num_indices must be the same if " - "permute_output_dim_0_1 is true."); - } else { + torch::all(torch::gt(input_columns, 0)).item(), + "[batch_index_select_dim0] All input_columns must be the same."); TORCH_CHECK( - torch::all(torch::gt(output_num_rows, 0)).item(), - "[batch_index_select_dim0] All input_num_indices must be greater than zero."); + torch::all(torch::gt(input_num_rows, 0)).item(), + "[batch_index_select_dim0] All input_rows must be the same."); + if (permute_output_dim_0_1) { + // All output rows must be the same + TORCH_CHECK(input_num_indices[0].item() > 0); + TORCH_CHECK( + torch::all(torch::eq(output_num_rows, input_num_indices[0])) + .item(), + "[batch_index_select_dim0] All input_num_indices must be the same if " + "permute_output_dim_0_1 is true."); + } else { + TORCH_CHECK( + torch::all(torch::gt(output_num_rows, 0)).item(), + "[batch_index_select_dim0] All input_num_indices must be greater than zero."); + } + } + + const auto max_output_num_rows = + num_inputs > 0 ? output_num_rows.max().item() : 0; + + const auto input_numels = input_num_rows * num_cols; + const auto output_numels = + permute_output_dim_0_1 ? Tensor() : (output_num_rows * num_cols); + + // Takes ~1.2 ms for num_inputs = 1024 on CPU + auto D_offsets = fbgemm_gpu::asynchronous_complete_cumsum_cpu(num_cols).to( + torch::kInt32); + auto input_offsets = + fbgemm_gpu::asynchronous_complete_cumsum_cpu(input_numels); + auto input_row_offsets = + fbgemm_gpu::asynchronous_complete_cumsum_cpu(input_num_rows); + auto total_L_offsets = + fbgemm_gpu::asynchronous_complete_cumsum_cpu(output_num_rows); + int64_t total_hash_size_bits = + std::log2(static_cast(input_row_offsets[-1].item())) + + 1; + input_offsets = + torch::narrow(input_offsets, 0, 0, input_offsets.numel() - 1); + const int64_t num_warps_per_input = + (max_output_num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + + // Transfer helper tensors to GPU + const auto device = inputs.device(); + constexpr bool non_blocking = true; + + int64_t output_size; + Tensor output_offsets; + if (permute_output_dim_0_1) { + // output_offsets is not required because the output tensor is not jagged + output_offsets = at::empty({0}, inputs.options().dtype(at::kLong)); + output_size = num_inputs > 0 ? (input_num_indices[0].item() * + D_offsets[-1].item()) + : 0; + } else { + output_offsets = + fbgemm_gpu::asynchronous_complete_cumsum_cpu(output_numels); + output_size = output_offsets[-1].item(); + output_offsets = output_offsets.to(device, non_blocking); } + + D_offsets = D_offsets.to(device, non_blocking); + input_offsets = input_offsets.to(device, non_blocking); + input_row_offsets = input_row_offsets.to(device, non_blocking); + total_L_offsets = total_L_offsets.to(device, non_blocking); + + const auto sparse_type = fbgemm_gpu::getSparseType(inputs.scalar_type()); + TORCH_CHECK( + sparse_type == SparseType::FP32 || sparse_type == SparseType::FP16, + "batch_index_select_dim0 supports only either float or half") + + const auto output_dtype = + static_cast(fbgemm_gpu::getSparseType(inputs.scalar_type())); + const auto max_D = max_col; + const auto fixed_L_per_warp = ROWS_PER_WARP; + + auto output = inputs.numel() > 0 + ? batch_index_select_dim0_codegen_forward_cuda( + inputs, // dev_weights + input_offsets, // weights_offsets + D_offsets, + max_D, + indices, + output_dtype, + output_offsets, + total_L_offsets, + output_size, + fixed_L_per_warp, + num_warps_per_input, // num_warps_per_feature + permute_output_dim_0_1) + : at::empty({0}, inputs.options()); + + int64_t saved_data[] = { + max_D, + total_hash_size_bits, + fixed_L_per_warp, + num_warps_per_input, + }; + + auto saved_data_tensor = at::empty( + {sizeof(saved_data) / sizeof(int64_t)}, + at::TensorOptions().dtype(at::kLong)); + TORCH_CHECK(saved_data_tensor.is_contiguous()); + memcpy( + saved_data_tensor.data_ptr(), saved_data, sizeof(saved_data)); + + return { + output, // 0:op_output + input_offsets, // 1:weights_offsets + input_row_offsets, // 2:hash_size_cumsum, + D_offsets, // 3:D_offsets, + output_offsets, // 4:output_offsets, + total_L_offsets, // 5:total_L_offsets + saved_data_tensor, // 6:saved_data_tensor + }; } - const auto max_output_num_rows = - num_inputs > 0 ? output_num_rows.max().item() : 0; - - const auto input_numels = input_num_rows * num_cols; - const auto output_numels = - permute_output_dim_0_1 ? Tensor() : (output_num_rows * num_cols); - - // Takes ~1.2 ms for num_inputs = 1024 on CPU - auto D_offsets = - fbgemm_gpu::asynchronous_complete_cumsum_cpu(num_cols).to(torch::kInt32); - auto input_offsets = - fbgemm_gpu::asynchronous_complete_cumsum_cpu(input_numels); - auto input_row_offsets = - fbgemm_gpu::asynchronous_complete_cumsum_cpu(input_num_rows); - auto total_L_offsets = - fbgemm_gpu::asynchronous_complete_cumsum_cpu(output_num_rows); - int64_t total_hash_size_bits = - std::log2(static_cast(input_row_offsets[-1].item())) + 1; - input_offsets = torch::narrow(input_offsets, 0, 0, input_offsets.numel() - 1); - - const int64_t num_warps_per_input = - (max_output_num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; - - // Transfer helper tensors to GPU - const auto device = inputs.device(); - constexpr bool non_blocking = true; - - int64_t output_size; - Tensor output_offsets; - if (permute_output_dim_0_1) { - // output_offsets is not required because the output tensor is not jagged - output_offsets = at::empty({0}, inputs.options().dtype(at::kLong)); - output_size = num_inputs > 0 - ? (input_num_indices[0] * D_offsets[-1].item()) - : 0; - } else { - output_offsets = - fbgemm_gpu::asynchronous_complete_cumsum_cpu(output_numels); - output_size = output_offsets[-1].item(); - output_offsets = output_offsets.to(device, non_blocking); + // make scheme the same as main op + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const Tensor& inputs, + const Tensor& indices, + const Tensor& input_num_indices, + const Tensor& input_rows, + const Tensor& input_columns, + // Permute dim 0 and 1 of the output tensor + const bool permute_output_dim_0_1) { + at::AutoDispatchBelowADInplaceOrView guard; + static auto forward_op_impl = + torch::Dispatcher::singleton() + .findSchemaOrThrow( + "fbgemm::batch_index_select_dim0_tensor_forward_cuda_impl", "") + .typed(); + + auto res = forward_op_impl.call( + inputs, + indices, + input_num_indices, + input_rows, + input_columns, + permute_output_dim_0_1); + + // 0:op_output + // 1:weights_offsets, + // 2:hash_size_cumsum, + // 3:D_offsets, + // 4:output_offsets, + // 5:total_L_offsets + // 6:saved_data_tensor = [max_D, total_hash_size_bits, fixed_L_per_warp, + // num_warps_per_input] + + ctx->saved_data["permute_output_dim_0_1"] = permute_output_dim_0_1; + + ctx->save_for_backward(std::vector{ + inputs, indices, res[1], res[2], res[3], res[4], res[5], res[6]}); + + // res.resize(1); + return res; } - D_offsets = D_offsets.to(device, non_blocking); - input_offsets = input_offsets.to(device, non_blocking); - input_row_offsets = input_row_offsets.to(device, non_blocking); - total_L_offsets = total_L_offsets.to(device, non_blocking); + static Tensor backward_impl( + const Tensor& grad_output, + const Tensor& dev_weights, + const Tensor& weights_offsets, + const Tensor& D_offsets, + const Tensor& hash_size_cumsum, + const Tensor& indices, + const int64_t max_segment_length_per_warp, + const Tensor& grad_offsets, + const Tensor& total_L_offsets, + const bool permute_output_dim_0_1, + const Tensor& saved_tensor) { + if (dev_weights.numel() == 0) { + return at::empty({0}, dev_weights.options()); + } - const auto sparse_type = fbgemm_gpu::getSparseType(inputs.scalar_type()); - TORCH_CHECK( - sparse_type == SparseType::FP32 || sparse_type == SparseType::FP16, - "batch_index_select_dim0 supports only either float or half") + auto _grad_output = grad_output; + // FIXME: to support aligned memory access in Vec4T load/store function + // 16 for FP32 and 8 for FP16 + if (reinterpret_cast(grad_output.data_ptr()) % 16 != 0 || + at::has_internal_overlap(grad_output) != at::MemOverlap::No) { + _grad_output = at::empty_like(grad_output).copy_(grad_output); + } - // Call TBE - return BatchIndexSelectDim0GPUOp::apply( - static_cast(fbgemm_gpu::getSparseType(inputs.scalar_type())), + const auto max_D = saved_tensor[0].item(); + const auto total_hash_size_bits = saved_tensor[1].item(); + const auto fixed_L_per_warp = saved_tensor[2].item(); + const auto num_warps_per_feature = saved_tensor[3].item(); + + return batch_index_select_dim0_codegen_backward_cuda( + _grad_output, + dev_weights, + weights_offsets, + D_offsets, + max_D, + hash_size_cumsum, + total_hash_size_bits, + indices, + max_segment_length_per_warp, + grad_offsets, + total_L_offsets, + fixed_L_per_warp, + num_warps_per_feature, + permute_output_dim_0_1); + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_outputs) { + const auto saved = ctx->get_saved_variables(); + auto savedItr = std::begin(saved); + auto dev_weights = *savedItr++; // inputs + auto indices = *savedItr++; // indices + + auto weights_offsets = *savedItr++; + auto hash_size_cumsum = *savedItr++; + auto D_offsets = *savedItr++; + auto grad_offsets = *savedItr++; + auto total_L_offsets = *savedItr++; + + auto saved_tensor = *savedItr++; + + const auto permute_output_dim_0_1 = + ctx->saved_data["permute_output_dim_0_1"].toBool(); + + constexpr int32_t max_segment_length_per_warp = 32; + + auto grad_output = grad_outputs[0]; + + static auto backward_op = + torch::Dispatcher::singleton() + .findSchemaOrThrow( + "fbgemm::batch_index_select_dim0_tensor_backward_cuda_impl", "") + .typed(); + + auto res = backward_op.call( + grad_output, + dev_weights, + weights_offsets, + D_offsets, + hash_size_cumsum, + indices, + max_segment_length_per_warp, + grad_offsets, + total_L_offsets, + permute_output_dim_0_1, + saved_tensor); + + using torch::autograd::Variable; + return { + std::move(res), // inputs + Variable(), // indices + Variable(), // input_num_indices + Variable(), // input_rows + Variable(), // input_columns + Variable(), // permute_output_dim_0_1 + }; + } +}; + +Tensor batch_index_select_dim0_tensor_gpu( + const Tensor& inputs, + const Tensor& indices, + const Tensor& input_num_indices, + const Tensor& input_rows, + const Tensor& input_columns, + // Permute dim 0 and 1 of the output tensor + const bool permute_output_dim_0_1) { + return BatchIndexSelectDim0TensorGPUOp::apply( inputs, - input_offsets, - input_row_offsets, - total_hash_size_bits, indices, - D_offsets, - max_col, - output_offsets, - total_L_offsets, - output_size, - ROWS_PER_WARP, // fixed_L_per_warp - num_warps_per_input, + input_num_indices, + input_rows, + input_columns, permute_output_dim_0_1)[0]; } @@ -309,5 +712,66 @@ TORCH_LIBRARY_FRAGMENT(fb, m) { } TORCH_LIBRARY_FRAGMENT(fbgemm, m) { - DISPATCH_TO_CUDA("batch_index_select_dim0", batch_index_select_dim0_gpu); + m.set_python_module("fbgemm_gpu.sparse_ops"); + m.def( + "batch_index_select_dim0_forward_cuda_impl(" + "Tensor inputs," + "Tensor indices," + "SymInt[] input_num_indices," + "SymInt[] input_rows," + "SymInt[] input_columns," + "bool permute_output_dim_0_1) -> Tensor[]"); + m.def( + "batch_index_select_dim0_backward_cuda_impl(" + "Tensor grad_output," + "Tensor dev_weights," + "Tensor weights_offsets," + "Tensor D_offsets," + "Tensor hash_size_cumsum," + "Tensor indices," + "int max_segment_length_per_warp," + "Tensor grad_offsets," + "Tensor total_L_offsets," + "bool permute_output_dim_0_1," + "Tensor saved_tensor) ->Tensor"); + DISPATCH_TO_CUDA( + "batch_index_select_dim0_forward_cuda_impl", + BatchIndexSelectDim0GPUOp::forward_impl); + DISPATCH_TO_CUDA( + "batch_index_select_dim0_backward_cuda_impl", + BatchIndexSelectDim0GPUOp::backward_impl); + DISPATCH_TO_AUTOGRAD_CUDA( + "batch_index_select_dim0", batch_index_select_dim0_gpu); + + // Tensor alternative + m.def( + "batch_index_select_dim0_tensor_forward_cuda_impl(" + "Tensor inputs," + "Tensor indices," + "Tensor input_num_indices," + "Tensor input_rows," + "Tensor input_columns," + "bool permute_output_dim_0_1) -> Tensor[]"); + + m.def( + "batch_index_select_dim0_tensor_backward_cuda_impl(" + "Tensor grad_output," + "Tensor dev_weights," + "Tensor weights_offsets," + "Tensor D_offsets," + "Tensor hash_size_cumsum," + "Tensor indices," + "int max_segment_length_per_warp," + "Tensor grad_offsets," + "Tensor total_L_offsets," + "bool permute_output_dim_0_1," + "Tensor saved_tensor) -> Tensor"); + DISPATCH_TO_CUDA( + "batch_index_select_dim0_tensor_forward_cuda_impl", + BatchIndexSelectDim0TensorGPUOp::forward_impl); + DISPATCH_TO_CUDA( + "batch_index_select_dim0_tensor_backward_cuda_impl", + BatchIndexSelectDim0TensorGPUOp::backward_impl); + DISPATCH_TO_AUTOGRAD_CUDA( + "batch_index_select_dim0_tensor", batch_index_select_dim0_tensor_gpu); } diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index 49b000ec7..c9b0c7f37 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -13,6 +13,7 @@ from fbgemm_gpu.split_embedding_configs import SparseType from fbgemm_gpu.split_table_batched_embeddings_ops_common import PoolingMode +from torch.fx.experimental.symbolic_shapes import guard_size_oblivious try: # pyre-ignore @@ -517,6 +518,107 @@ def batch_index_select_dim0_abstract( return inputs.new_empty([output_numel]) +@torch.library.impl_abstract("fbgemm::batch_index_select_dim0_tensor") +def batch_index_select_dim0_tensor_abstract( + inputs: torch.Tensor, + indices: torch.Tensor, + input_num_indices: torch.Tensor, + input_rows: torch.Tensor, + input_columns: torch.Tensor, + permute_output_dim_0_1: bool, +) -> torch.Tensor: + torch._check(input_num_indices.size(0) == input_rows.size(0)) + torch._check(input_num_indices.size(0) == input_columns.size(0)) + output_numel = torch.library.get_ctx().new_dynamic_size() + return inputs.new_empty([output_numel]) + + +@impl_abstract("fbgemm::batch_index_select_dim0_forward_cuda_impl") +def batch_index_select_dim0_forward_cuda_impl_abstract( + inputs: torch.Tensor, + indices: torch.Tensor, + input_num_indices: List[int], + input_rows: List[int], + input_columns: List[int], + permute_output_dim_0_1: bool, +) -> List[torch.Tensor]: + num_inputs = len(input_rows) + torch._check(len(input_num_indices) == len(input_rows)) + torch._check(len(input_num_indices) == len(input_columns)) + + output_numel = 0 + for i, cols in enumerate(input_columns): + output_numel += input_num_indices[i] * cols + + output_offsets = ( + inputs.new_empty([0], dtype=torch.int64) + if permute_output_dim_0_1 + else inputs.new_empty([num_inputs + 1], dtype=torch.int64) + ) + + if permute_output_dim_0_1: + for i in range(num_inputs): + torch._check(input_num_indices[0] == input_num_indices[i]) + + return [ + inputs.new_empty([output_numel]), + inputs.new_empty([num_inputs], dtype=torch.int64), + inputs.new_empty([num_inputs + 1], dtype=torch.int64), + inputs.new_empty([num_inputs + 1], dtype=torch.int32), # D_offsets + output_offsets, + inputs.new_empty([num_inputs + 1], dtype=torch.int64), + inputs.new_empty([4], dtype=torch.int64, device="cpu"), + ] + + +@torch.library.impl_abstract("fbgemm::batch_index_select_dim0_tensor_forward_cuda_impl") +def batch_index_select_dim0_tensor_forward_cuda_impl_abstract( + inputs: torch.Tensor, + indices: torch.Tensor, + input_num_indices: torch.Tensor, + input_rows: torch.Tensor, + input_columns: torch.Tensor, + permute_output_dim_0_1: bool, +) -> List[torch.Tensor]: + num_inputs: int = input_rows.size(0) + torch._check(input_num_indices.size(0) == input_rows.size(0)) + torch._check(input_num_indices.size(0) == input_columns.size(0)) + output_numel = torch.library.get_ctx().new_dynamic_size() + if permute_output_dim_0_1: + output_offsets = inputs.new_empty([0], dtype=torch.int64) + else: + output_offsets = inputs.new_empty([num_inputs + 1], dtype=torch.int64) + + return [ + inputs.new_empty([output_numel]), + inputs.new_empty([num_inputs], dtype=torch.int64), + inputs.new_empty([num_inputs + 1], dtype=torch.int64), + inputs.new_empty([num_inputs + 1], dtype=torch.int32), # D_offsets + output_offsets, + inputs.new_empty([num_inputs + 1], dtype=torch.int64), # total_L_offsets + inputs.new_empty([4], dtype=torch.int64, device="cpu"), + ] + + +@torch.library.impl_abstract( + "fbgemm::batch_index_select_dim0_tensor_backward_cuda_impl" +) +def batch_index_select_dim0_tensor_backward_cuda_impl_abstract( + grad_output: torch.Tensor, + dev_weights: torch.Tensor, + weights_offsets: torch.Tensor, + D_offsets: torch.Tensor, + hash_size_cumsum: torch.Tensor, + indices: torch.Tensor, + max_segment_length_per_warp: int, + grad_offsets: torch.Tensor, + total_L_offsets: torch.Tensor, + permute_output_dim_0_1: bool, + saved_tensor: torch.Tensor, +) -> torch.Tensor: + return grad_output.new_empty(dev_weights.shape) + + @impl_abstract("fbgemm::keyed_jagged_index_select_dim1") def keyed_jagged_index_select_dim1_abstract( values: torch.Tensor, @@ -564,6 +666,90 @@ def keyed_jagged_index_select_dim1_abstract( return ret +@torch.library.impl_abstract("fbgemm::batch_index_select_dim0_backward_cuda_impl") +def batch_index_select_dim0_backward_cuda_impl_abstract( + grad_output: torch.Tensor, + dev_weights: torch.Tensor, + weights_offsets: torch.Tensor, + D_offsets: torch.Tensor, + hash_size_cumsum: torch.Tensor, + indices: torch.Tensor, + max_segment_length_per_warp: int, + grad_offsets: torch.Tensor, + total_L_offsets: torch.Tensor, + permute_output_dim_0_1: bool, + saved_tensor: torch.Tensor, +) -> torch.Tensor: + return grad_output.new_empty(dev_weights.shape) + + +@impl_abstract("fbgemm::batch_index_select_dim0_forward_cpu_impl") +def batch_index_select_dim0_forward_cpu_impl_abstract( + inputs: torch.Tensor, + indices: torch.Tensor, + input_num_indices: List[int], + input_rows: List[int], + input_columns: List[int], + permute_output_dim_0_1: bool, +) -> List[torch.Tensor]: + # input lists must have the same length + num_inputs = len(input_num_indices) + torch._check(num_inputs == len(input_rows)) + torch._check(num_inputs == len(input_columns)) + + if permute_output_dim_0_1 and guard_size_oblivious(len(input_num_indices) > 0): + # All num_indices must be the same if permute_output_dim_0_1 is True + for x in input_num_indices: + torch._check(x == input_num_indices[0]) + + output_numel: int = sum([i * c for i, c in zip(input_num_indices, input_columns)]) + + return [ + inputs.new_empty([output_numel]), + inputs.new_empty([len(input_num_indices)], dtype=torch.int64), + inputs.new_empty([len(input_rows)], dtype=torch.int64), + inputs.new_empty([len(input_columns)], dtype=torch.int64), + inputs.new_empty([num_inputs], dtype=torch.int64), # indices_numels + inputs.new_empty([1], dtype=torch.int64), # saved_tensor + ] + + +@impl_abstract("fbgemm::batch_index_select_dim0_tensor_forward_cpu_impl") +def batch_index_select_dim0_tensor_forward_cpu_impl_abstract( + inputs: torch.Tensor, + indices: torch.Tensor, + input_num_indices: torch.Tensor, + input_rows: torch.Tensor, + input_columns: torch.Tensor, + permute_output_dim_0_1: bool, +) -> List[torch.Tensor]: + # input lists must have the same length + num_inputs = len(input_num_indices) + torch._check(num_inputs == len(input_rows)) + torch._check(num_inputs == len(input_columns)) + + output_numel = torch.library.get_ctx().new_dynamic_size() + + return [ + inputs.new_empty([output_numel]), + inputs.new_empty([1], dtype=torch.int64), + ] + + +@impl_abstract("fbgemm::batch_index_select_dim0_backward_cpu_impl") +def batch_index_select_dim0_backward_cpu_impl_abstract( + grad_output: torch.Tensor, + indices: torch.Tensor, + indices_numels: torch.Tensor, + input_num_indices: torch.Tensor, + input_rows: torch.Tensor, + input_columns: torch.Tensor, + permute_output_dim_0_1: bool, + saved_tensor: torch.Tensor, +) -> torch.Tensor: + return grad_output.new_empty([torch.library.get_ctx().new_dynamic_size()]) + + @impl_abstract("fbgemm::bounds_check_indices") def bounds_check_indices_abstract( rows_per_table: torch.Tensor, diff --git a/fbgemm_gpu/test/sparse/failures_dict.json b/fbgemm_gpu/test/sparse/failures_dict.json index fefa85fa0..7e8aa175e 100644 --- a/fbgemm_gpu/test/sparse/failures_dict.json +++ b/fbgemm_gpu/test/sparse/failures_dict.json @@ -5,20 +5,7 @@ "fbgemm::asynchronous_complete_cumsum": {}, "fbgemm::asynchronous_exclusive_cumsum": {}, "fbgemm::asynchronous_inclusive_cumsum": {}, - "fbgemm::batch_index_select_dim0": { - "IndexSelectTest.test_aot_dispatch_dynamic__test_batch_index_select_dim0": { - "comment": "", - "status": "xfail" - }, - "IndexSelectTest.test_autograd_registration__test_batch_index_select_dim0": { - "comment": "", - "status": "xfail" - }, - "IndexSelectTest.test_faketensor__test_batch_index_select_dim0": { - "comment": "", - "status": "xfail" - } - }, + "fbgemm::batch_index_select_dim0": {}, "fbgemm::block_bucketize_sparse_features": { "BlockBucketizeTest.test_aot_dispatch_dynamic__test_block_bucketize_sparse_features": { "comment": "", diff --git a/fbgemm_gpu/test/sparse/index_select_test.py b/fbgemm_gpu/test/sparse/index_select_test.py index fd1e54c59..d23f89337 100644 --- a/fbgemm_gpu/test/sparse/index_select_test.py +++ b/fbgemm_gpu/test/sparse/index_select_test.py @@ -247,6 +247,12 @@ def compare_tensor_groups( permute_output_dim_0_1=st.booleans(), dtype=st.sampled_from([torch.float, torch.half]), use_cpu=st.booleans() if gpu_available else st.just(True), + op=st.sampled_from( + [ + torch.ops.fbgemm.batch_index_select_dim0, + torch.ops.fbgemm.batch_index_select_dim0_tensor, + ] + ), ) @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) def test_batch_index_select_dim0( # noqa: C901 @@ -258,6 +264,8 @@ def test_batch_index_select_dim0( # noqa: C901 permute_output_dim_0_1: bool, dtype: torch.dtype, use_cpu: bool, + # pyre-ignore + op, ) -> None: device = "cpu" if use_cpu else "cuda" input_rows = torch.randint( @@ -331,14 +339,24 @@ def validate( concat_inputs.requires_grad = True - output_test = torch.ops.fbgemm.batch_index_select_dim0( - concat_inputs, - concat_indices, - input_num_indices, - input_rows, - input_columns, - permute_output_dim_0_1, - ) + if op == torch.ops.fbgemm.batch_index_select_dim0_tensor: + output_test = op( + concat_inputs, + concat_indices, + torch.tensor(input_num_indices, dtype=torch.int64, device="cpu"), + torch.tensor(input_rows, dtype=torch.int64, device="cpu"), + torch.tensor(input_columns, dtype=torch.int64, device="cpu"), + permute_output_dim_0_1, + ) + else: + output_test = op( + concat_inputs, + concat_indices, + input_num_indices, + input_rows, + input_columns, + permute_output_dim_0_1, + ) if permute_output_dim_0_1 and num_inputs > 0: output_list = output_test.view(input_num_indices[0], -1).split(