diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index dae8e5017..76ba2fe57 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -226,6 +226,7 @@ def jagged_index_add_2d_forward_v2_abstract( input_offsets: Tensor, output_offsets: Tensor, num_output_rows: int, + num_dense_input_rows: Optional[int] = None, ) -> Tensor: torch._check(values.device == indices.device) torch._check(values.device == input_offsets.device) diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp index 48b7f85e2..57a0d4376 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_autograd.cpp @@ -616,6 +616,7 @@ class JaggedIndexSelect2dOp ctx->save_for_backward({indices, output_offsets, input_offsets}); ctx->saved_data["num_input_rows"] = values.sym_size(0); + ctx->saved_data["num_dense_output_rows"] = num_dense_output_rows; static auto op = c10::Dispatcher::singleton() @@ -652,6 +653,8 @@ class JaggedIndexSelect2dOp TENSORS_ON_SAME_DEVICE(grad, indices); auto num_output_rows = ctx->saved_data["num_input_rows"].toSymInt(); + auto num_dense_input_rows = + ctx->saved_data["num_dense_output_rows"].toOptional(); static auto op = c10::Dispatcher::singleton() @@ -661,10 +664,17 @@ class JaggedIndexSelect2dOp const Tensor& indices, const Tensor& input_offsets, const Tensor& output_offsets, - c10::SymInt num_output_rows)>(); + c10::SymInt num_output_rows, + const c10::optional optional_num_dense_input_rows)>(); return { - op.call(grad, indices, grad_offsets, output_offsets, num_output_rows), + op.call( + grad, + indices, + grad_offsets, + output_offsets, + num_output_rows, + num_dense_input_rows), torch::autograd::Variable(), // lengths torch::autograd::Variable(), // indices torch::autograd::Variable() // num_dense_output_rows diff --git a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp index aab5142b8..22415edef 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops/jagged_tensor_ops_cpu.cpp @@ -1189,9 +1189,14 @@ Tensor jagged_index_add_2d_forward_v2_impl( const Tensor& indices, const Tensor& input_offsets, const Tensor& output_offsets, - const int64_t num_output_rows) { - int64_t num_dense_output_rows = - input_offsets[input_offsets.numel() - 1].item(); + const int64_t num_output_rows, + const c10::optional optional_num_dense_input_rows) { + // Intentionally not using optional::value_or here to avoid materializing + // .item() call when possible. + int64_t num_dense_input_rows = optional_num_dense_input_rows.has_value() + ? optional_num_dense_input_rows.value() + : input_offsets[input_offsets.numel() - 1].item(); + static auto v1_op = c10::Dispatcher::singleton() .findSchemaOrThrow("fbgemm::jagged_index_add_2d_forward", "") @@ -1207,7 +1212,7 @@ Tensor jagged_index_add_2d_forward_v2_impl( indices, input_offsets, output_offsets, - num_dense_output_rows, + num_dense_input_rows, num_output_rows); } @@ -1730,7 +1735,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "jagged_index_add_2d_forward(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, int num_dense_input_rows, int num_output_rows) -> Tensor"); m.def( - "jagged_index_add_2d_forward_v2(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, SymInt num_output_rows) -> Tensor", + "jagged_index_add_2d_forward_v2(Tensor values, Tensor indices, Tensor input_offsets, Tensor output_offsets, SymInt num_output_rows, int? num_dense_input_rows) -> Tensor", {PT2_COMPLIANT_TAG}); m.def( "jagged_1d_to_truncated_values(Tensor values, Tensor lengths, int max_truncated_length) -> Tensor");