Skip to content

Commit

Permalink
Always avoid CPU blocking D2H in JaggedIndexSelect2dOp backward (pyto…
Browse files Browse the repository at this point in the history
…rch#2510)

Summary:
Pull Request resolved: pytorch#2510

Since the value of `num_dense_output_rows` has been computed in forward. Directly save this number (not optional) in context and reuse it to avoid D2H syncs in backward.

Reviewed By: sryap

Differential Revision: D56212536

fbshipit-source-id: 3008745310f9056192dac417c775b553428f9bfc
  • Loading branch information
Shen Li authored and facebook-github-bot committed Apr 17, 2024
1 parent a2fdc90 commit f052e98
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 20 deletions.
26 changes: 14 additions & 12 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ class JaggedIndexSelect2dOp
const Tensor& values,
const Tensor& lengths,
const Tensor& indices,
const c10::optional<int64_t> num_dense_output_rows) {
const c10::optional<int64_t> optional_num_dense_output_rows) {
TORCH_CHECK(
values.dim() == 2, "jagged_index_select supports only 2D inputs")
TENSORS_ON_SAME_DEVICE(lengths, indices);
Expand All @@ -616,7 +616,6 @@ class JaggedIndexSelect2dOp

ctx->save_for_backward({indices, output_offsets, input_offsets});
ctx->saved_data["num_input_rows"] = values.sym_size(0);
ctx->saved_data["num_dense_output_rows"] = num_dense_output_rows;

static auto op =
c10::Dispatcher::singleton()
Expand All @@ -628,14 +627,17 @@ class JaggedIndexSelect2dOp
const Tensor& output_offsets,
const c10::optional<int64_t>)>();

return {
op.call(
values,
indices,
input_offsets,
output_offsets,
num_dense_output_rows),
output_lengths};
auto out = op.call(
values,
indices,
input_offsets,
output_offsets,
optional_num_dense_output_rows);

// Always save output size to avoid triggering D2H sync in backward
ctx->saved_data["num_dense_output_rows"] = out.sym_size(0);

return {out, output_lengths};
}

static torch::autograd::variable_list backward(
Expand All @@ -654,7 +656,7 @@ class JaggedIndexSelect2dOp

auto num_output_rows = ctx->saved_data["num_input_rows"].toSymInt();
auto num_dense_input_rows =
ctx->saved_data["num_dense_output_rows"].toOptional<int64_t>();
ctx->saved_data["num_dense_output_rows"].toSymInt();

static auto op =
c10::Dispatcher::singleton()
Expand All @@ -665,7 +667,7 @@ class JaggedIndexSelect2dOp
const Tensor& input_offsets,
const Tensor& output_offsets,
c10::SymInt num_output_rows,
const c10::optional<int64_t> optional_num_dense_input_rows)>();
c10::SymInt num_dense_input_rows)>();

return {
op.call(
Expand Down
10 changes: 2 additions & 8 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1161,13 +1161,7 @@ Tensor jagged_index_add_2d_forward_v2_impl(
const Tensor& input_offsets,
const Tensor& output_offsets,
const int64_t num_output_rows,
const c10::optional<int64_t> optional_num_dense_input_rows) {
// Intentionally not using optional::value_or here to avoid materializing
// .item() call when possible.
int64_t num_dense_input_rows = optional_num_dense_input_rows.has_value()
? optional_num_dense_input_rows.value()
: input_offsets[input_offsets.numel() - 1].item<int64_t>();

const int64_t num_dense_input_rows) {
static auto v1_op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::jagged_index_add_2d_forward", "")
Expand Down Expand Up @@ -1681,7 +1675,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"jagged_index_add_2d_forward(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, int num_dense_input_rows, int num_output_rows) -> Tensor");
m.def(
"jagged_index_add_2d_forward_v2(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, SymInt num_output_rows, int? num_dense_input_rows) -> Tensor",
"jagged_index_add_2d_forward_v2(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, SymInt num_output_rows, SymInt num_dense_input_rows) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"jagged_1d_to_truncated_values(Tensor values, Tensor lengths, int max_truncated_length) -> Tensor");
Expand Down

0 comments on commit f052e98

Please sign in to comment.