Skip to content

Commit

Permalink
Add back logic to remove d2h sync from `keyed_jagged_index_select_dim…
Browse files Browse the repository at this point in the history
…1` (pytorch#2663)

Summary:
Pull Request resolved: pytorch#2663

functionality to remove d2h sync by providing output size through `selected_lengths_sum` was removed in pytorch#2590

this diff adds it back

Reviewed By: IvanKobzarev

Differential Revision: D58072649

fbshipit-source-id: 2d091639f032a3fe88a3993c8c5ee13eb33a1402
  • Loading branch information
joshuadeng authored and facebook-github-bot committed Jun 3, 2024
1 parent 293a30b commit a50b09f
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ class KeyedJaggedIndexSelectDim1GPUOp
const Tensor& indices,
const c10::SymInt _batch_size,
const std::optional<Tensor>& weights,
const std::optional<c10::SymInt>& selected_lengths_sum) {
const std::optional<c10::SymInt> selected_lengths_sum) {
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(values.get_device());

Expand Down Expand Up @@ -249,9 +249,9 @@ class KeyedJaggedIndexSelectDim1GPUOp
});
});
// TODO: Try to not do D->H transfer
const int64_t num_outputs =
output_offsets[output_offsets.numel() - 1].item<int64_t>();
const int64_t num_outputs = (selected_lengths_sum.has_value())
? selected_lengths_sum.value().guard_int(__FILE__, __LINE__)
: output_offsets[output_offsets.numel() - 1].item<int64_t>();
Tensor output = at::empty({num_outputs}, values.options());
Tensor output_weights;
if (weights.has_value()) {
Expand Down

0 comments on commit a50b09f

Please sign in to comment.