Skip to content

Commit

Permalink
Fix the Autograd with fbgemm:: split_embedding_codegen_lookup_sgd_fun…
Browse files Browse the repository at this point in the history
…ction_cpu (pytorch#1899)

Summary:
Pull Request resolved: pytorch#1899

According to the discussion in https://fb.workplace.com/groups/1405155842844877/posts/7259624410731295/?comment_id=7260170757343327.
> autograd.Function that uses non-PyTorch code in the backward (like direct .data_ptr manipulation or directly invoking a CUDA kernel) will lose the backwards when you attempt to trace it. Unfortunately the only way to resolve this is to go through your autograd.Function and refactor the backward to call a custom operator that is understood by the PyTorch dispatcher.

This diff addresses the problem.

Reviewed By: jiyuanzFB

Differential Revision: D47852595

fbshipit-source-id: 11b945bf63486baa40943065ece0360b46cddcfa
  • Loading branch information
egienvalue authored and facebook-github-bot committed Jul 31, 2023
1 parent 26cc4df commit f5d37c3
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 11 deletions.
13 changes: 13 additions & 0 deletions fbgemm_gpu/codegen/embedding_backward_split_cpu_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@

#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/script.h>

#include "codegen/embedding_forward_split_cpu.h"
#include "fbgemm/FbgemmEmbedding.h"
#include "fbgemm/Types.h"
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/cpu_utils.h"
#include "fbgemm_gpu/sparse_ops_utils.h"

using Tensor = at::Tensor;
using namespace fbgemm_gpu;
Expand Down Expand Up @@ -396,4 +399,14 @@ for (const auto d : c10::irange(D)) {
return grad;
{% endif %}
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
{% if not dense %}
m.def("split_embedding_backward_codegen_{{ optimizer }}_cpu(Tensor grad_output, Tensor host_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets,int pooling_mode, Tensor indice_weights, bool stochastic_rounding, {{ (args.split_function_args | join(", ")).replace("double", "float").replace("int64_t", "int")}}, int output_dtype = 0) -> ()");
{% else %}
m.def("split_embedding_backward_codegen_{{ optimizer }}_cpu(Tensor grad_output, Tensor host_weights, Tensor weights_offsets, Tensor D_offsets, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets,int pooling_mode, Tensor indice_weights, {{ (args.split_function_args | join(", ")).replace("double", "float").replace("int64_t", "int")}}) -> Tensor");
{% endif %}
DISPATCH_TO_CPU("split_embedding_backward_codegen_{{ optimizer }}_cpu", split_embedding_backward_codegen_{{ optimizer }}_cpu);
}

// clang-format on
36 changes: 25 additions & 11 deletions fbgemm_gpu/codegen/embedding_backward_split_host_cpu_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,11 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function<
{% for (var, _) in args.saved_data %}
ctx->saved_data["{{ var }}"] = {{ var }};
{% endfor %}

return {split_embedding_codegen_forward_cpu(
static auto op =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::split_embedding_codegen_forward_cpu", "")
.typed<decltype(split_embedding_codegen_forward_cpu)>();
return {op.call(
host_weights,
weights_offsets,
D_offsets,
Expand Down Expand Up @@ -132,9 +135,12 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function<
TORCH_CHECK_EQ(grad_outputs.size(), 1);

using torch::autograd::Variable;

static auto op1 =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::split_embedding_backward_codegen_{{ optimizer }}_cpu", "")
.typed<decltype(split_embedding_backward_codegen_{{ optimizer }}_cpu)>();
auto grad_output = gradient_clipping ? clamp(grad_outputs[0], -max_gradient, max_gradient) : grad_outputs[0];
split_embedding_backward_codegen_{{ optimizer }}_cpu(
op1.call(
grad_output,
host_weights,
weights_placements,
Expand All @@ -150,9 +156,13 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function<
stochastic_rounding,
{{ args.split_function_arg_names | join(", ") }},
output_dtype);
static auto op2 =
torch::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::split_embedding_codegen_grad_indice_weights_cpu", "")
.typed<decltype(split_embedding_codegen_grad_indice_weights_cpu)>();
// NOTE: MEAN pooling will not work with indice_weights!
auto grad_indice_weights = indice_weights.defined()
? split_embedding_codegen_grad_indice_weights_cpu(
? op2.call(
grad_outputs[0],
host_weights,
weights_offsets,
Expand Down Expand Up @@ -234,16 +244,20 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function_cpu(
// Deprecated for fb namespace! Please use fbgemm namespace instead!
TORCH_LIBRARY_FRAGMENT(fb, m) {
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function_cpu(Tensor host_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0) -> Tensor");
DISPATCH_TO_CPU(
"split_embedding_codegen_lookup_{{ optimizer }}_function_cpu",
split_embedding_codegen_lookup_{{ optimizer }}_function_cpu);
m.impl(
"split_embedding_codegen_lookup_{{ optimizer }}_function_cpu",
torch::dispatch(
c10::DispatchKey::Autograd,
TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function_cpu)));
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function_cpu(Tensor host_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0) -> Tensor");
DISPATCH_TO_CPU(
"split_embedding_codegen_lookup_{{ optimizer }}_function_cpu",
split_embedding_codegen_lookup_{{ optimizer }}_function_cpu);
m.impl(
"split_embedding_codegen_lookup_{{ optimizer }}_function_cpu",
torch::dispatch(
c10::DispatchKey::Autograd,
TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function_cpu)));
}

} // namespace
Expand Down
24 changes: 24 additions & 0 deletions fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,18 @@
#include "fbgemm/Utils.h"
#include "fbgemm_gpu/cpu_utils.h"
#include "fbgemm_gpu/embedding_common.h"
#include "fbgemm_gpu/sparse_ops_utils.h"
#ifdef FBCODE_CAFFE2
#include <libdivide.h>
#include "folly/container/F14Map.h"
#else
#include <omp.h>
#endif

#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/core/op_registration/op_registration.h>
#include <torch/script.h>

using Tensor = at::Tensor;
using namespace fbgemm_gpu;
Expand Down Expand Up @@ -608,3 +612,23 @@ template void csr2csc<double>(
int64_t num_embeddings);

} // namespace internal

namespace {

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"split_embedding_codegen_grad_indice_weights_cpu(Tensor grad_output, Tensor weights, Tensor weights_offsets, Tensor D_offsets, Tensor indices, Tensor offsets, Tensor feature_requires_grad) -> Tensor");
DISPATCH_TO_CPU(
"split_embedding_codegen_grad_indice_weights_cpu",
split_embedding_codegen_grad_indice_weights_cpu);
}

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"split_embedding_codegen_forward_cpu(Tensor weights, Tensor weights_offsets, Tensor D_offsets, int total_D, Tensor hash_size_cumsum, Tensor indices, Tensor offsets, int pooling_mode, Tensor indice_weights, int output_dtype) -> Tensor");
DISPATCH_TO_CPU(
"split_embedding_codegen_forward_cpu",
split_embedding_codegen_forward_cpu);
}

} // namespace

0 comments on commit f5d37c3

Please sign in to comment.