Skip to content

Commit

Permalink
Make batch_index_select_dim0 cuda pt2 autograd compatible (pytorch#2591)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2591

Making batch_index_select_dim0 with custom autograd functions PT2 traceable.

The main flow is described https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ptttacy8y1u9

Introducing batch_index_select_dim0_tensor operator that accepts Tensor instead of List[int].

Reviewed By: williamwen42

Differential Revision: D57232975

fbshipit-source-id: 5aa9dc67cc897611b29353502efd9cbe4f832c4e
  • Loading branch information
Ivan Kobzarev authored and facebook-github-bot committed May 16, 2024
1 parent 284ab83 commit 90ac420
Show file tree
Hide file tree
Showing 5 changed files with 1,311 additions and 318 deletions.
Loading

0 comments on commit 90ac420

Please sign in to comment.