From f052e98af7bb7bfcee1803669d60e81064700025 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Wed, 17 Apr 2024 16:13:20 -0700 Subject: [PATCH] Always avoid CPU blocking D2H in JaggedIndexSelect2dOp backward (#2510) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2510 Since the value of `num_dense_output_rows` has been computed in forward. Directly save this number (not optional) in context and reuse it to avoid D2H syncs in backward. Reviewed By: sryap Differential Revision: D56212536 fbshipit-source-id: 3008745310f9056192dac417c775b553428f9bfc --- .../jagged_tensor_ops_autograd.cpp | 26 ++++++++++--------- .../jagged_tensor_ops_cpu.cpp | 10 ++----- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp index 57a0d4376..2e08efb4d 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp @@ -604,7 +604,7 @@ class JaggedIndexSelect2dOp const Tensor& values, const Tensor& lengths, const Tensor& indices, - const c10::optional num_dense_output_rows) { + const c10::optional optional_num_dense_output_rows) { TORCH_CHECK( values.dim() == 2, "jagged_index_select supports only 2D inputs") TENSORS_ON_SAME_DEVICE(lengths, indices); @@ -616,7 +616,6 @@ class JaggedIndexSelect2dOp ctx->save_for_backward({indices, output_offsets, input_offsets}); ctx->saved_data["num_input_rows"] = values.sym_size(0); - ctx->saved_data["num_dense_output_rows"] = num_dense_output_rows; static auto op = c10::Dispatcher::singleton() @@ -628,14 +627,17 @@ class JaggedIndexSelect2dOp const Tensor& output_offsets, const c10::optional)>(); - return { - op.call( - values, - indices, - input_offsets, - output_offsets, - num_dense_output_rows), - output_lengths}; + auto out = op.call( + values, + indices, + input_offsets, + output_offsets, + optional_num_dense_output_rows); + + // Always save output size to avoid triggering D2H sync in backward + ctx->saved_data["num_dense_output_rows"] = out.sym_size(0); + + return {out, output_lengths}; } static torch::autograd::variable_list backward( @@ -654,7 +656,7 @@ class JaggedIndexSelect2dOp auto num_output_rows = ctx->saved_data["num_input_rows"].toSymInt(); auto num_dense_input_rows = - ctx->saved_data["num_dense_output_rows"].toOptional(); + ctx->saved_data["num_dense_output_rows"].toSymInt(); static auto op = c10::Dispatcher::singleton() @@ -665,7 +667,7 @@ class JaggedIndexSelect2dOp const Tensor& input_offsets, const Tensor& output_offsets, c10::SymInt num_output_rows, - const c10::optional optional_num_dense_input_rows)>(); + c10::SymInt num_dense_input_rows)>(); return { op.call( diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp index 4ba75765a..3b0a41180 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp @@ -1161,13 +1161,7 @@ Tensor jagged_index_add_2d_forward_v2_impl( const Tensor& input_offsets, const Tensor& output_offsets, const int64_t num_output_rows, - const c10::optional optional_num_dense_input_rows) { - // Intentionally not using optional::value_or here to avoid materializing - // .item() call when possible. - int64_t num_dense_input_rows = optional_num_dense_input_rows.has_value() - ? optional_num_dense_input_rows.value() - : input_offsets[input_offsets.numel() - 1].item(); - + const int64_t num_dense_input_rows) { static auto v1_op = c10::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::jagged_index_add_2d_forward", "") @@ -1681,7 +1675,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "jagged_index_add_2d_forward(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, int num_dense_input_rows, int num_output_rows) -> Tensor"); m.def( - "jagged_index_add_2d_forward_v2(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, SymInt num_output_rows, int? num_dense_input_rows) -> Tensor", + "jagged_index_add_2d_forward_v2(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, SymInt num_output_rows, SymInt num_dense_input_rows) -> Tensor", {PT2_COMPLIANT_TAG}); m.def( "jagged_1d_to_truncated_values(Tensor values, Tensor lengths, int max_truncated_length) -> Tensor");