Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PROTOTYPE] generated batching rules for custom dispatcher ops #578

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Mar 9, 2022

Note: this PR has a bunch of issues, but it shows one potential way of getting "automatically generated" batching rules for custom ops registered to the dispatcher through python.

Let's say you have a custom operator (foo) in python that you've defined a derivative formula for (foo_vjp), and you want to vmap over it. If I run this:

f = vmap(foo)
out = f(x)

Then the chain of calls in the dispatcher will look something like this:

foo
-> foo_vmap (vmap kernel for `foo`, which [somehow] does some batching stuff and redispatches)
-> foo_grad (autograd kernel for `foo`, which adds `foo_vjp` to the autograd graph and redispatches
-> foo_cpu (cpu kernel, which we've actually directly registered our python `foo` function to. call into `foo` in python)

That will work, but it has a downside: it requires you (the user) to write a custom batching rule for your custom op. In theory, we should be able to get the batching rule for free.

One way to "get the batching rule for free" is by running foo(), letting it decompose into whatever aten ops it eventually calls, and running the batching rules on each of those aten ops. There's a problem with that though. If we decompose foo when we run the batching rule, then theres no way to "undo" the decomposition below. Any kernels that we redispatch to will see the "base" ops, instead of the original op:

foo
-> foo_vmap--------------------------------
            |                              |
       base_a                          base_b
       -> base_a_vmap            -> base_b_vmap      // good: we vmap'd over `foo` "for free"
       -> base_a_grad            -> base_b_grad      // bad outcome: we ran autograd on base_a. wanted to run on foo!
       -> base_a_cpu             -> base_b_cpu

How do we get around that? We can't really "undo" the decomposition inside of the call stack... But we could just run "foo" twice: once for the forward pass where we do decompose into the base ops, and run the batching rule on each, and once for the backward pass where we dont decompose, taking care so that:
(1) When we run the forward, we skip autograd
(2) when we run the autograd kernel (to setup the autograd graph), we dont redispatch and run the backend again.

foo
-> foo_vmap
     (1) runs forward (decomposes, skips autograd)
       -> foo_cpu (this dispatches us to the user's python function, which decomposes into base ops)
            |                              |
       base_a                          base_b
       -> base_a_vmap            -> base_b_vmap
       -> base_a_cpu             -> base_b_cpu
     (2) set up autograd graph (doesn't decompose, hits the autograd kernel but "skips" the call to the backend)
-> foo_grad

Known issues:
(1) I'm not sure how composable this is. I haven't thought too hard yet about what would happen if you the logic together with another functionality (e.g. amp or functionalization)
(2) It interacts poorly with DynamicLayer - I left a comment explaining why, but I'm not sure what the best solution is.
(3) I'm hardcoding that the custom ops accept and return a single TensorList argument (but so does the existing code 😛)
(4) I got one very basic test to pass, but there are probably other problems I just haven't run into.

// doesn't play well with DynamicLayer (only one layer of vmap works right now).
// Why? In the generated batching rule, I effectively want to treat it as a "composite kernel",
// and have it run the to the python-defined forward function. But:
// (1) I want to go there through the dispatcher so other functionalities can run (e.g. AMP).
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, maybe I should just be treating this the same way that DynamicLayer already treats composite ops - just directly call into the composite function. That means that stuff like AMP will run on the base ops and not the composite ops, but maybe that's the right behavior (unless the user wants to write a custom "AMP rule" for their op)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(ended up doing this, although there's another issue with disabling autograd that I left a comment about)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants