From f5d37c3f993c969f4688a8d7f1955687f6d4b149 Mon Sep 17 00:00:00 2001 From: Jun Luo Date: Mon, 31 Jul 2023 11:59:25 -0700 Subject: [PATCH] Fix the Autograd with fbgemm:: split_embedding_codegen_lookup_sgd_function_cpu (#1899) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/1899 According to the discussion in https://fb.workplace.com/groups/1405155842844877/posts/7259624410731295/?comment_id=7260170757343327. > autograd.Function that uses non-PyTorch code in the backward (like direct .data_ptr manipulation or directly invoking a CUDA kernel) will lose the backwards when you attempt to trace it. Unfortunately the only way to resolve this is to go through your autograd.Function and refactor the backward to call a custom operator that is understood by the PyTorch dispatcher. This diff addresses the problem. Reviewed By: jiyuanzFB Differential Revision: D47852595 fbshipit-source-id: 11b945bf63486baa40943065ece0360b46cddcfa --- .../embedding_backward_split_cpu_template.cpp | 13 +++++++ ...dding_backward_split_host_cpu_template.cpp | 36 +++++++++++++------ .../codegen/embedding_forward_split_cpu.cpp | 24 +++++++++++++ 3 files changed, 62 insertions(+), 11 deletions(-) diff --git a/fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp index ad5cbee86..e9c8dc4fc 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp @@ -13,12 +13,15 @@ #include #include +#include +#include #include "codegen/embedding_forward_split_cpu.h" #include "fbgemm/FbgemmEmbedding.h" #include "fbgemm/Types.h" #include "fbgemm_gpu/embedding_common.h" #include "fbgemm_gpu/cpu_utils.h" +#include "fbgemm_gpu/sparse_ops_utils.h" using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -396,4 +399,14 @@ for (const auto d : c10::irange(D)) { return grad; {% endif %} } + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + {% if not dense %} + m.def("split_embedding_backward_codegen_{{ optimizer }}_cpu(Tensor grad_output, Tensor host_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets,int pooling_mode, Tensor indice_weights, bool stochastic_rounding, {{ (args.split_function_args | join(", ")).replace("double", "float").replace("int64_t", "int")}}, int output_dtype = 0) -> ()"); + {% else %} + m.def("split_embedding_backward_codegen_{{ optimizer }}_cpu(Tensor grad_output, Tensor host_weights, Tensor weights_offsets, Tensor D_offsets, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets,int pooling_mode, Tensor indice_weights, {{ (args.split_function_args | join(", ")).replace("double", "float").replace("int64_t", "int")}}) -> Tensor"); + {% endif %} + DISPATCH_TO_CPU("split_embedding_backward_codegen_{{ optimizer }}_cpu", split_embedding_backward_codegen_{{ optimizer }}_cpu); +} + // clang-format on diff --git a/fbgemm_gpu/codegen/embedding_backward_split_host_cpu_template.cpp b/fbgemm_gpu/codegen/embedding_backward_split_host_cpu_template.cpp index afe7cca44..d052fdebd 100644 --- a/fbgemm_gpu/codegen/embedding_backward_split_host_cpu_template.cpp +++ b/fbgemm_gpu/codegen/embedding_backward_split_host_cpu_template.cpp @@ -83,8 +83,11 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function< {% for (var, _) in args.saved_data %} ctx->saved_data["{{ var }}"] = {{ var }}; {% endfor %} - - return {split_embedding_codegen_forward_cpu( + static auto op = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::split_embedding_codegen_forward_cpu", "") + .typed(); + return {op.call( host_weights, weights_offsets, D_offsets, @@ -132,9 +135,12 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function< TORCH_CHECK_EQ(grad_outputs.size(), 1); using torch::autograd::Variable; - + static auto op1 = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::split_embedding_backward_codegen_{{ optimizer }}_cpu", "") + .typed(); auto grad_output = gradient_clipping ? clamp(grad_outputs[0], -max_gradient, max_gradient) : grad_outputs[0]; - split_embedding_backward_codegen_{{ optimizer }}_cpu( + op1.call( grad_output, host_weights, weights_placements, @@ -150,9 +156,13 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function< stochastic_rounding, {{ args.split_function_arg_names | join(", ") }}, output_dtype); + static auto op2 = + torch::Dispatcher::singleton() + .findSchemaOrThrow("fbgemm::split_embedding_codegen_grad_indice_weights_cpu", "") + .typed(); // NOTE: MEAN pooling will not work with indice_weights! auto grad_indice_weights = indice_weights.defined() - ? split_embedding_codegen_grad_indice_weights_cpu( + ? op2.call( grad_outputs[0], host_weights, weights_offsets, @@ -234,16 +244,20 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function_cpu( // Deprecated for fb namespace! Please use fbgemm namespace instead! TORCH_LIBRARY_FRAGMENT(fb, m) { m.def("split_embedding_codegen_lookup_{{ optimizer }}_function_cpu(Tensor host_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0) -> Tensor"); - DISPATCH_TO_CPU( - "split_embedding_codegen_lookup_{{ optimizer }}_function_cpu", - split_embedding_codegen_lookup_{{ optimizer }}_function_cpu); + m.impl( + "split_embedding_codegen_lookup_{{ optimizer }}_function_cpu", + torch::dispatch( + c10::DispatchKey::Autograd, + TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function_cpu))); } TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def("split_embedding_codegen_lookup_{{ optimizer }}_function_cpu(Tensor host_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0) -> Tensor"); - DISPATCH_TO_CPU( - "split_embedding_codegen_lookup_{{ optimizer }}_function_cpu", - split_embedding_codegen_lookup_{{ optimizer }}_function_cpu); + m.impl( + "split_embedding_codegen_lookup_{{ optimizer }}_function_cpu", + torch::dispatch( + c10::DispatchKey::Autograd, + TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function_cpu))); } } // namespace diff --git a/fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp b/fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp index 8ef970ba7..fc879b525 100644 --- a/fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp +++ b/fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp @@ -12,6 +12,7 @@ #include "fbgemm/Utils.h" #include "fbgemm_gpu/cpu_utils.h" #include "fbgemm_gpu/embedding_common.h" +#include "fbgemm_gpu/sparse_ops_utils.h" #ifdef FBCODE_CAFFE2 #include #include "folly/container/F14Map.h" @@ -19,7 +20,10 @@ #include #endif +#include #include +#include +#include using Tensor = at::Tensor; using namespace fbgemm_gpu; @@ -608,3 +612,23 @@ template void csr2csc( int64_t num_embeddings); } // namespace internal + +namespace { + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.def( + "split_embedding_codegen_grad_indice_weights_cpu(Tensor grad_output, Tensor weights, Tensor weights_offsets, Tensor D_offsets, Tensor indices, Tensor offsets, Tensor feature_requires_grad) -> Tensor"); + DISPATCH_TO_CPU( + "split_embedding_codegen_grad_indice_weights_cpu", + split_embedding_codegen_grad_indice_weights_cpu); +} + +TORCH_LIBRARY_FRAGMENT(fbgemm, m) { + m.def( + "split_embedding_codegen_forward_cpu(Tensor weights, Tensor weights_offsets, Tensor D_offsets, int total_D, Tensor hash_size_cumsum, Tensor indices, Tensor offsets, int pooling_mode, Tensor indice_weights, int output_dtype) -> Tensor"); + DISPATCH_TO_CPU( + "split_embedding_codegen_forward_cpu", + split_embedding_codegen_forward_cpu); +} + +} // namespace