-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SpeedyTransforms: Differentiable transforms via custom Enzyme rule #589
Conversation
I had a very quick check if some non-sensical differentiation through a timestep of a barotropic model is possible. I get an error from Enzyme and the compiler from the |
So, all my direct tests pass. This is moving towards being ready. Left to do:
|
Hmmm it would be good if the tests aren't adding massive time for all other development but they are obviously crucial for differentiability testing. Not sure how we can achieve that. Could we only keep the test without FiniteDifferences? Use a lower resolution transform? You're in a better decision than me to decide. |
Ah yes, of course, you are right. So far I was even using 8 layers in the test, which is just a bit stupid. |
Now it's much faster, I think we can keep them like that in the CI |
Okay, I feel a bit stupid, haha. EnzymeTestUtils is also not stuck anymore. Before I was letting it run longer than 2 hours, now it finishes in less in 1 second for a lower dimensional problem. In order the tests of EnzymeTestUtils to pass I had to increase the tolerance a bit from the default (to |
The tests pass on my laptop, but not in the CI. I'll investigate this a bit later. |
All tests pass also in the 1.10 CI for the For the
I'd wait to merge this until everything works fine with Julia 1.11 as well. It doesn't cause any errors outside of using Enzyme, but still. |
With the lastest update to Enzyme the tests now pass on 1.11 as well. Hence, I marked this now as a proper PR and not a draft anymore. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fine to me.
@michel2323 was looking at custom rules for FFT IIRC
We want to make the spherical harmonics transforms differentiable.
The complete model is supposed to rely on Enzyme for that, Enzyme can deal with many of our functions already by itself in principle. However, the FFTs are not computed by Julia code, but with FFTW and cuFFT. Transforms are also the most performance critical part of the model, so optimising the gradients of the transforms also makes sense.
Enzyme
_legendre!
Enzyme can handle by itself_fourier!
, the forward version usesrfft
plans, the inversebrfft
plans._fourier!
isn't normalised, the normalisation is done in_legendre!
.brfft
). For real-valued FFT you have an additional scale in the adjoint reflecting the different coefficients in the formula (there's a two in front of the cos/sine terms)ChainRules/Zygote/…
For uses outside of our model (e.g. Spherical FNOs) it would be great to also have this differentiable by other ADs. These ADs usually don’t support mutation, so that I would define rules for the non-mutating
transform(grid, S)
. I’ll do that a bit later. This might also need some changes to the code.This PR
_fourier!
and tests it.Currently I just add Enzyme, EnzymeRules and EnzymeTestUtils to the main env to make testing a bit easier for me while developing this PR, that will be changed before merging thisEnzyme (and most of this code) could go in an extension. Is this what we want long-term? Probably yes, but I don’t have a good impression yet how many custom rules we have to define. Hopefully not many though.Testing
I'd love to test the rules directly with EnzymeTestUtils as well, but there are several problems currentlyI have to look into this a bit furtherFiniteDifferences tests can take a while. Going forward do we want to put all gradient correctness checks in the regular CI? The four quite simple tests that are currently in this PR already take 5 mins.Complex Numbers
Differentiating complex numbers is a bit of a topic for itself. Enzyme also has a quite long explanation of their approach in their FAQs (https://enzymead.github.io/Enzyme.jl/stable/faq/#Complex-numbers). I think for us, we can treat the real and imaginary part basically as separate numbers, as they would be for real SPH (for which coefficients with m<0 correspond to our imaginary part (sort of)). I am not quite sure yet if this works well out of the box with Enzyme or if we have to be more carefully somewhere. It's something we should think about.