From 3972d030dd8326d8ad4d7a0c57080857475e61df Mon Sep 17 00:00:00 2001 From: Flavio Sales Truzzi <590773+flaviotruzzi@users.noreply.github.com> Date: Tue, 24 Sep 2024 10:11:34 -0700 Subject: [PATCH] - Required changes for kernels (#3165) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3165 X-link: https://github.com/facebookresearch/FBGEMM/pull/259 Adding small changes to kernels for CompiledAutograd support. Adding `static constexpr bool is_traceable = true;` on kernels, making some kernels to use tensors instead of double and unrolling input shapes on GroupIndexSelectDim0GPUOp from vector into the ctx dict to help enablement of CompiledAutograd. Reviewed By: Microve Differential Revision: D63151913 --- .../batch_index_select_dim0_host.cpp | 1 + .../permute_multi_embedding_function.h | 1 + .../permute_pooled_embs_function_split.h | 1 + fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp | 1 + fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp | 23 ++++++++++++++++--- 5 files changed, 24 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp index f5a70eb3b0..06cd53b16b 100644 --- a/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp +++ b/fbgemm_gpu/codegen/training/index_select/batch_index_select_dim0_host.cpp @@ -53,6 +53,7 @@ Tensor batch_index_select_dim0_codegen_backward_cuda( class BatchIndexSelectDim0GPUOp : public torch::autograd::Function { public: + static constexpr bool is_traceable = true; static torch::autograd::variable_list forward_impl( Tensor inputs, Tensor indices, diff --git a/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h b/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h index 1cfc1a987c..b1c5386b0e 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h +++ b/fbgemm_gpu/include/fbgemm_gpu/permute_multi_embedding_function.h @@ -30,6 +30,7 @@ using torch::autograd::variable_list; class PermuteMultiEmbeddingOp : public torch::autograd::Function { public: + static constexpr bool is_traceable = true; static variable_list forward( AutogradContext* ctx, const at::TensorList& pooled_embs, diff --git a/fbgemm_gpu/include/fbgemm_gpu/permute_pooled_embs_function_split.h b/fbgemm_gpu/include/fbgemm_gpu/permute_pooled_embs_function_split.h index 9d4a49179d..8bed1ac9d1 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/permute_pooled_embs_function_split.h +++ b/fbgemm_gpu/include/fbgemm_gpu/permute_pooled_embs_function_split.h @@ -27,6 +27,7 @@ class PermutePooledEmbsFunctionSplit : public torch::autograd::Function< PermutePooledEmbsFunctionSplit> { public: + static constexpr bool is_traceable = true; static Variable forward( AutogradContext* ctx, const at::Tensor& pooled_embs, // [B_local][Sum_T_global(D)] diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index a80eea05e4..ced1610f79 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -66,6 +66,7 @@ namespace fbgemm_gpu { // Needed this to support backward pass. class PackSegments : public torch::autograd::Function { public: + static constexpr bool is_traceable = true; static torch::autograd::variable_list forward( torch::autograd::AutogradContext* ctx, const Tensor& t_in, diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp index 6325017e89..703015770d 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_gpu.cpp @@ -216,6 +216,7 @@ group_index_select_dim0_unpack( class GroupIndexSelectDim0GPUOp : public torch::autograd::Function { public: + static constexpr bool is_traceable = true; // need to combine input_group and indices_group into one tensor list // to get this working with autograd. static torch::autograd::variable_list forward_impl( @@ -435,13 +436,17 @@ class GroupIndexSelectDim0GPUOp input_shape_group.end(), input_shape.begin(), input_shape.end()); } + for (unsigned long i = 0; i < input_shape_group.size(); i++) { + ctx->saved_data["input_shape_group_" + std::to_string(i)] = + input_shape_group[i]; + } + // save indices, args_tensor, saved_data auto saved_tensors = std::vector(indices_group); saved_tensors.insert( saved_tensors.end(), result.cbegin() + group_size, result.cend()); saved_tensors.push_back(input_group[0]); ctx->save_for_backward(saved_tensors); - ctx->saved_data["input_shape_group"] = input_shape_group; return result; } @@ -609,8 +614,20 @@ class GroupIndexSelectDim0GPUOp auto saved_tensors = ctx->get_saved_variables(); TORCH_CHECK(saved_tensors.size() == group_size + 3); - auto output_shape_group = - ctx->saved_data["input_shape_group"].toSymIntVector(); + + std::vector output_shape_group; + int i = 0; + while (true) { + auto el = ctx->saved_data.find("input_shape_group_" + std::to_string(i)); + + if (el == ctx->saved_data.end()) { + break; + } + + output_shape_group.emplace_back(el->second.toSymInt()); + i++; + } + grad_output_group.insert( grad_output_group.end(), saved_tensors.begin(), saved_tensors.end()); static auto backward_op =