[PROTOTYPE] generated batching rules for custom dispatcher ops #578
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Note: this PR has a bunch of issues, but it shows one potential way of getting "automatically generated" batching rules for custom ops registered to the dispatcher through python.
Let's say you have a custom operator (
foo
) in python that you've defined a derivative formula for (foo_vjp
), and you want to vmap over it. If I run this:Then the chain of calls in the dispatcher will look something like this:
That will work, but it has a downside: it requires you (the user) to write a custom batching rule for your custom op. In theory, we should be able to get the batching rule for free.
One way to "get the batching rule for free" is by running
foo()
, letting it decompose into whateveraten
ops it eventually calls, and running the batching rules on each of those aten ops. There's a problem with that though. If we decomposefoo
when we run the batching rule, then theres no way to "undo" the decomposition below. Any kernels that we redispatch to will see the "base" ops, instead of the original op:How do we get around that? We can't really "undo" the decomposition inside of the call stack... But we could just run "foo" twice: once for the forward pass where we do decompose into the base ops, and run the batching rule on each, and once for the backward pass where we dont decompose, taking care so that:
(1) When we run the forward, we skip autograd
(2) when we run the autograd kernel (to setup the autograd graph), we dont redispatch and run the backend again.
Known issues:
(1) I'm not sure how composable this is. I haven't thought too hard yet about what would happen if you the logic together with another functionality (e.g.
amp
orfunctionalization
)(2) It interacts poorly with
DynamicLayer
- I left a comment explaining why, but I'm not sure what the best solution is.(3) I'm hardcoding that the custom ops accept and return a single TensorList argument (but so does the existing code 😛)
(4) I got one very basic test to pass, but there are probably other problems I just haven't run into.