Skip to content

Commit

Permalink
Allow jagged_index_select backward to accept pre-computed output shape
Browse files Browse the repository at this point in the history
Summary: Save `num_dense_output_rows` computed during the forward pass and use it to avoid blocking `.item()` call during backward.

Reviewed By: sryap

Differential Revision: D54173841

fbshipit-source-id: 113c035d6462963d00df7545dd54ce4dd15ed753
  • Loading branch information
Shen Li authored and facebook-github-bot committed Feb 26, 2024
1 parent ad70943 commit 3995cc6
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
1 change: 1 addition & 0 deletions fbgemm_gpu/fbgemm_gpu/sparse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def jagged_index_add_2d_forward_v2_abstract(
input_offsets: Tensor,
output_offsets: Tensor,
num_output_rows: int,
num_dense_input_rows: Optional[int] = None,
) -> Tensor:
torch._check(values.device == indices.device)
torch._check(values.device == input_offsets.device)
Expand Down
14 changes: 12 additions & 2 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,7 @@ class JaggedIndexSelect2dOp

ctx->save_for_backward({indices, output_offsets, input_offsets});
ctx->saved_data["num_input_rows"] = values.sym_size(0);
ctx->saved_data["num_dense_output_rows"] = num_dense_output_rows;

static auto op =
c10::Dispatcher::singleton()
Expand Down Expand Up @@ -652,6 +653,8 @@ class JaggedIndexSelect2dOp
TENSORS_ON_SAME_DEVICE(grad, indices);

auto num_output_rows = ctx->saved_data["num_input_rows"].toSymInt();
auto num_dense_input_rows =
ctx->saved_data["num_dense_output_rows"].toOptional<int64_t>();

static auto op =
c10::Dispatcher::singleton()
Expand All @@ -661,10 +664,17 @@ class JaggedIndexSelect2dOp
const Tensor& indices,
const Tensor& input_offsets,
const Tensor& output_offsets,
c10::SymInt num_output_rows)>();
c10::SymInt num_output_rows,
const c10::optional<int64_t> optional_num_dense_input_rows)>();

return {
op.call(grad, indices, grad_offsets, output_offsets, num_output_rows),
op.call(
grad,
indices,
grad_offsets,
output_offsets,
num_output_rows,
num_dense_input_rows),
torch::autograd::Variable(), // lengths
torch::autograd::Variable(), // indices
torch::autograd::Variable() // num_dense_output_rows
Expand Down
15 changes: 10 additions & 5 deletions fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1189,9 +1189,14 @@ Tensor jagged_index_add_2d_forward_v2_impl(
const Tensor& indices,
const Tensor& input_offsets,
const Tensor& output_offsets,
const int64_t num_output_rows) {
int64_t num_dense_output_rows =
input_offsets[input_offsets.numel() - 1].item<int64_t>();
const int64_t num_output_rows,
const c10::optional<int64_t> optional_num_dense_input_rows) {
// Intentionally not using optional::value_or here to avoid materializing
// .item() call when possible.
int64_t num_dense_input_rows = optional_num_dense_input_rows.has_value()
? optional_num_dense_input_rows.value()
: input_offsets[input_offsets.numel() - 1].item<int64_t>();

static auto v1_op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("fbgemm::jagged_index_add_2d_forward", "")
Expand All @@ -1207,7 +1212,7 @@ Tensor jagged_index_add_2d_forward_v2_impl(
indices,
input_offsets,
output_offsets,
num_dense_output_rows,
num_dense_input_rows,
num_output_rows);
}

Expand Down Expand Up @@ -1730,7 +1735,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"jagged_index_add_2d_forward(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, int num_dense_input_rows, int num_output_rows) -> Tensor");
m.def(
"jagged_index_add_2d_forward_v2(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, SymInt num_output_rows) -> Tensor",
"jagged_index_add_2d_forward_v2(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, SymInt num_output_rows, int? num_dense_input_rows) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"jagged_1d_to_truncated_values(Tensor values, Tensor lengths, int max_truncated_length) -> Tensor");
Expand Down

0 comments on commit 3995cc6

Please sign in to comment.