Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Missing rules for SpecialFunctions #412

Open
gdalle opened this issue Dec 7, 2024 · 3 comments
Open

Missing rules for SpecialFunctions #412

gdalle opened this issue Dec 7, 2024 · 3 comments
Labels
enhancement New feature or request

Comments

@gdalle
Copy link
Collaborator

gdalle commented Dec 7, 2024

Hi there!
I saw this Discourse post benchmarking reverse-mode AD on a code with Bessel functions: https://discourse.julialang.org/t/speeding-up-zygote-autodiff-for-numerical-loop/123515/
I tried Mooncake on it but it was missing a rule for bessely. When I tried implementing it, I got a weird error message (can't reproduce it here since I'm not on my work computer), which I think may be linked to the partial implementation of ChainRulesCore.rrule:
https://github.com/JuliaMath/SpecialFunctions.jl/blob/ed6a36732712a71e99de397cca45a58432f22a0e/ext/SpecialFunctionsChainRulesCoreExt.jl#L97-L103
Wanna take a look? It would be fun to see how close we can get to Enzyme.

@willtebbutt willtebbutt added the enhancement New feature or request label Dec 7, 2024
@willtebbutt
Copy link
Member

willtebbutt commented Dec 7, 2024

Yeah, the problem is that Mooncake can't (currently) handle NotImplemented stuff from ChainRules. In order to be able to do it without risking dropping gradient information, we need to be able to prove that it's safe to do so, and throw an error if it's not safe. At the minute we don't have any mechanism to do this.

This is all to say that the only way to support the bessel functions currently is to finish the implementation off.

@yebai
Copy link
Contributor

yebai commented Dec 10, 2024

@willtebbutt, it is not very clear what is missing here. Can you clarify what ChainRules.NotImplemented does and why Mooncake needs safety checks?

@willtebbutt
Copy link
Member

willtebbutt commented Dec 10, 2024

A good example is something like the method of SpecialFunctions.gamma found here with associated ChainRules.rrule found here.

If you inspect the rule, you will see that (for whatever reason) the authors of SpecialFunctions have not provided the gradient w.r.t. the first argument, and instead return whatever ChainRulesCore.@not_implemented(INCOMPLETE_GAMMA_INFO) produces, which IIRC is a ChainRulesCore.NotImplemented. This is useful in the situation that you actually don't need the gradient w.r.t. the first argument of this function because it is e.g. a hard-coded constant. Again, IIRC, ChainRulesCore.NotImplemented has the property that if you try to do anything with it, you'll get an error.

This works well in Zygote, because Zygote defers accumulating gradients until the rule which consumes them is reached. This means that if no rule consumes the ChainRulesCore.NotImplemented (e.g. because it's the gradient w.r.t. a hard-coded constant), everything works correctly. Notice that there is no risk of dropping gradients incorrectly here -- you'll only not get an error if the result is entirely unused.

In Mooncake we accumulate the results in an eager manner i.e. immediately after the reverse-pass of the rule. This makes the reverse-pass dramatically more simple in general, but has the unfortunate consequence in this case that we can't make use of the NotImplemented feature of ChainRules quite as straightforwardly. Consequently, we need some kind of mechanism to explicitly tell us that a NotImplemented gradient is safe to ignore. In principle we have this information, but I haven't started to think about how we would exploit it in a systematic way -- it feels to me like something that we should resolve in the context of a broader discussion around activity analysis in Mooncake.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants