Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward-mode S2S autograd compiler to derive rules for some low-level functions #337

Open
yebai opened this issue Nov 1, 2024 · 7 comments
Labels
enhancement New feature or request low priority

Comments

@yebai
Copy link
Contributor

yebai commented Nov 1, 2024

Due to their lack of backward pass, forward-mode autograd often has considerably different implementation properties than reverse-mode autograd. Given its different performance tradeoffs, I wonder whether forward-mode transformation could be more friendly for autograd compilers than reverse-mode (like Mooncake/Zygote), or at least compensate for some extreme cases of reverse-mode autograd.

For example, the sum_1000 example is a vector-input scalar-output function, which is a perfect example for forward-mode autograd but likely hard (at least requiring significantly more compiler optimization efforts) for the reverse-mode compiler to work well. This advantage goes further if we have chunk-mode forward-mode autograd.

I am talking about the source-to-source approach for both forward- and reverse-mode autograd implementations.

@willtebbutt
Copy link
Member

For example, the sum_1000 example is a vector-input scalar-output function, which is a perfect example for forward-mode autograd but likely hard (at least requiring significantly more compiler optimization efforts) for the reverse-mode compiler to work well. This advantage goes further if we have chunk-mode forward-mode autograd.

I don't believe that this is correct. The fact that this is a many-input single-output function means that it is precisely the wrong kind of function to target with forwards-mode AD, no?

Due to their lack of backward pass, forward-mode autograd often has considerably different implementation properties than reverse-mode autograd. Given its different performance tradeoffs, I wonder whether forward-mode transformation could be more friendly for autograd compilers than reverse-mode (like Mooncake/Zygote), or at least compensate for some extreme cases of reverse-mode autograd.

I do agree that it's considerably easier to produce a high-quality source-to-source forwards-mode AD than it is a high-quality reverse-mode AD. I think it's something we should consider doing at some point, but I don't think it's something to do right now.

@willtebbutt willtebbutt added enhancement New feature or request low priority labels Nov 4, 2024
@yebai
Copy link
Contributor Author

yebai commented Nov 4, 2024

I don't believe that this is correct. The fact that this is a many-input single-output function means that it is precisely the wrong kind of function to target with forwards-mode AD, no?

You are correct here. I was thinking that chunk-mode forward mode could handle a large number of small-dimensional vector-input scalar/vector-output functions. This is still helpful but won't address the general problem.

I think it's something we should consider doing at some point, but I don't think it's something to do right now.

I wrote the issue to start some discussions.

EDIT: another interesting question is whether the forward-mode autograd compiler allows us to handle more Julia language features. For example, the try-catch-end block discovered recently in #326 (comment)

@willtebbutt
Copy link
Member

willtebbutt commented Nov 4, 2024

EDIT: another interesting question is whether the forward-mode autograd compiler allows us to handle more Julia language features. For example, the try-catch-end block discovered recently in #326 (comment)

It's possible that it would, but I think there's probably more that can be done of the reverse-mode side of things to extend our current functionality to support try-catch-end blocks which contain Upsilon / PhiCNodes which don't wind up throwing, as is the case in the linked example, and another that was mentioned in #31 (comment) .

I wrote the issue to start some discussions.

Fair enough. Certainly, I think it's true that having forwards-mode which composes with reverse-mode would be nice -- I would quite like be able to compute Hessian-vector products (and, by extension, Hessians) by doing forwards-mode over reverse-mode.

In terms of what would need to be done:

  • The tangent type infrastructure that we have for reverse-mode could be used as-is, with no modification. We would, however, need to consider how best to extend it to make it possible to handle batches of tangents (i.e. to allow batch-mode forwards-mode).
  • Our testing infrastructure would be largely unchanged -- just requiring extension. Moreover, having a forwards-mode AD to include in the consistency tests would actually be quite helpful, because checking that finite differences, reverse-mode and forwards-mode AD all give roughly the same answer is a more powerful check than just comparing finite differences with reverse-mode. In fact, all of our existing tests for reverse-mode could be straightforwardly recycled to use with forwards-mode -- we would just extend test_rule to test both forwards-mode and reverse-mode AD at the same time. Moreover, we could generate a lot of additional test cases for both our forwards-mode and reverse-mode implementations very straightforwardly by also doing forwards-mode over reverse-mode and reverse-mode over forwards mode in the test suite.
  • We would need rules. This is a bit tedious, but we'd get to import roughly the same amount of stuff from ChainRules as we do for reverse-mode, so I don't anticipate this being too much of a pain.
  • The IR transformation required for forwards-mode is much simpler than the one required for reverse-mode, so this should be straightforward. It might even be possible to do it directly using the Core.Compiler infrastructure for IRCode manipulation, and avoid having to use any of the abstractions I've built up for working with large changes to the basic block structure. This is because the basic block structure of the code needed to do forwards-mode AD is identical to that of the primal code, unlike that of reverse-mode.

@yebai yebai changed the title Forward-mode S2S autograd compiler to derive rules for scalar-output functions Forward-mode S2S autograd compiler to derive rules for some low-level functions Nov 4, 2024
@yebai
Copy link
Contributor Author

yebai commented Nov 6, 2024

We would need rules. This is a bit tedious, but we'd get to import roughly the same amount of stuff from ChainRules as we do for reverse mode, so I don't anticipate this being too much of a pain.

Now that we can freely compose forward and reverse modes, e.g. forward over reverse or reverse over forward, is it possible to use the reverse-mode rules for forward-mode autograd here?

@willtebbutt
Copy link
Member

In general no, because you would have to run the rule N times, where N is the dimension of the output for each forward diff call, so you would have terrible performance. In scalar cases you might be able to get away with it, but to be honest they're not the cases that are hard to write rules for anyway.

@yebai
Copy link
Contributor Author

yebai commented Nov 6, 2024

This paper might be of interest.

Decomposing reverse-mode automatic differentiation
Roy Frostig, Matthew J. Johnson, Dougal Maclaurin, Adam Paszke, Alexey Radul

We decompose reverse-mode automatic differentiation into (forward-mode) linearization followed by transposition. Doing so isolates the essential difference between forward- and reverse-mode AD, and simplifies their joint implementation. In particular, once forward-mode AD rules are defined for every primitive operation in a source language, only linear primitives require an additional transposition rule in order to arrive at a complete reverse-mode AD implementation. This is how reverse-mode AD is written in JAX and Dex.

https://arxiv.org/abs/2105.09469

@willtebbutt
Copy link
Member

I'm familiar with this paper. It's a great way to frame AD, and very nicely explains what's going on. Mooncake's docs essentially frame it in the same way, in fact I'm pretty sure that we reference their follow up paper, we just don't break the 1-1 mapping between linearisation and transposition (read: computing the Frechet derivative and finding its adjoint operator).

I've always been a bit sceptical about the claim that you need to implement many fewer "transpose" rules than you do reverse-rules, because there are surprisingly many linear operators in a language and I'm reasonably sure that you wouldn't decompose many of the more monolithic functions (e.g. cholesky factorisation) down into a seqence of simpler linear transformations at the linearisation step, but would in fact wind up with a single "linearised cholesky" operator.

That being said, I've also not dug into it in any great depth, so it might be worth me revisiting this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request low priority
Projects
None yet
Development

No branches or pull requests

2 participants