From a50b09f74bb83c9c28e4756d1287d67bf6720402 Mon Sep 17 00:00:00 2001 From: Joshua Deng Date: Mon, 3 Jun 2024 10:24:35 -0700 Subject: [PATCH] Add back logic to remove d2h sync from `keyed_jagged_index_select_dim1` (#2663) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2663 functionality to remove d2h sync by providing output size through `selected_lengths_sum` was removed in https://github.com/pytorch/FBGEMM/pull/2590 this diff adds it back Reviewed By: IvanKobzarev Differential Revision: D58072649 fbshipit-source-id: 2d091639f032a3fe88a3993c8c5ee13eb33a1402 --- .../jagged_tensor_ops/keyed_jagged_index_select_dim1.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu index 514a4f6df..034aae174 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu @@ -175,7 +175,7 @@ class KeyedJaggedIndexSelectDim1GPUOp const Tensor& indices, const c10::SymInt _batch_size, const std::optional& weights, - const std::optional& selected_lengths_sum) { + const std::optional selected_lengths_sum) { at::cuda::OptionalCUDAGuard device_guard; device_guard.set_index(values.get_device()); @@ -249,9 +249,9 @@ class KeyedJaggedIndexSelectDim1GPUOp }); }); - // TODO: Try to not do D->H transfer - const int64_t num_outputs = - output_offsets[output_offsets.numel() - 1].item(); + const int64_t num_outputs = (selected_lengths_sum.has_value()) + ? selected_lengths_sum.value().guard_int(__FILE__, __LINE__) + : output_offsets[output_offsets.numel() - 1].item(); Tensor output = at::empty({num_outputs}, values.options()); Tensor output_weights; if (weights.has_value()) {