Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SpeedyTransforms: Differentiable transforms via custom Enzyme rule #589

Merged
merged 74 commits into from
Nov 29, 2024

Conversation

maximilian-gelbrecht
Copy link
Member

@maximilian-gelbrecht maximilian-gelbrecht commented Oct 17, 2024

We want to make the spherical harmonics transforms differentiable.

The complete model is supposed to rely on Enzyme for that, Enzyme can deal with many of our functions already by itself in principle. However, the FFTs are not computed by Julia code, but with FFTW and cuFFT. Transforms are also the most performance critical part of the model, so optimising the gradients of the transforms also makes sense.

Enzyme

  • The Legendre transformation _legendre! Enzyme can handle by itself
  • We do Fourier transforms consecutively on each latitude ring in _fourier!, the forward version uses rfft plans, the inverse brfft plans. _fourier! isn't normalised, the normalisation is done in _legendre!.
  • The FFT needs a rule here, the adjoint of a Fourier transform is the unnormalised inverse Fourier transform (so brfft). For real-valued FFT you have an additional scale in the adjoint reflecting the different coefficients in the formula (there's a two in front of the cos/sine terms)
  • Rules are tested with the test tools from Enzyme and FD comparison for the whole transform

ChainRules/Zygote/…

For uses outside of our model (e.g. Spherical FNOs) it would be great to also have this differentiable by other ADs. These ADs usually don’t support mutation, so that I would define rules for the non-mutating transform(grid, S). I’ll do that a bit later. This might also need some changes to the code.

This PR

  • This PR implements the Enzyme Rule for _fourier! and tests it.
  • As far as I can see it, this rule should also work with the GPU version as soon as we have it with just a tiny adjustment.
  • We could pre-compute the scaling. It’s a relatively small Int valued matrix. That would make the adjoint marginally faster.
  • I also introduce a new extension for FiniteDifferences.jl that makes our data structure compatible with the library
  • I set up a new environment for the tests.
  • Enzyme Custom rules go into another extension
  • Currently I just add Enzyme, EnzymeRules and EnzymeTestUtils to the main env to make testing a bit easier for me while developing this PR, that will be changed before merging this
    • Do we want to do the differentiability test always in the same files as the other tests, or put them in a separate folder?
    • Enzyme (and most of this code) could go in an extension. Is this what we want long-term? Probably yes, but I don’t have a good impression yet how many custom rules we have to define. Hopefully not many though.

Testing

  • I do a manual test comparing with FiniteDifferences of the full spherical harmonics transforms, a test that $$\frac{d}{dx}\mathcal{S}^{-1}(S(x)) = 1$$ and tests with EnzymeTestUtils (which also do a comparison to FD)
  • I'd love to test the rules directly with EnzymeTestUtils as well, but there are several problems currently
    • There's a problem with the FFT plans having an uninitailzied fields. There's a quite hacky way around this though
    • There's a problem with the function having complex valued outputs
    • The tests are stuck, I waited 30 min without a result.
    • I have to look into this a bit further
  • FiniteDifferences tests can take a while. Going forward do we want to put all gradient correctness checks in the regular CI? The four quite simple tests that are currently in this PR already take 5 mins.

Complex Numbers

Differentiating complex numbers is a bit of a topic for itself. Enzyme also has a quite long explanation of their approach in their FAQs (https://enzymead.github.io/Enzyme.jl/stable/faq/#Complex-numbers). I think for us, we can treat the real and imaginary part basically as separate numbers, as they would be for real SPH (for which coefficients with m<0 correspond to our imaginary part (sort of)). I am not quite sure yet if this works well out of the box with Enzyme or if we have to be more carefully somewhere. It's something we should think about.

@maximilian-gelbrecht maximilian-gelbrecht added transform ⬅️ ➡️ Our spherical harmonic transform, grid to spectral, spectral to grid differentiability 🤖 Making the model differentiable via AD labels Oct 17, 2024
@maximilian-gelbrecht
Copy link
Member Author

I had a very quick check if some non-sensical differentiation through a timestep of a barotropic model is possible.

I get an error from Enzyme and the compiler from the _divergence! function. That's already a bit what I suspected. After the transform itself, the spatial gradient functions are the other challenging bit. I'll look into differentiating them again, when we also make steps towards GPU-ifying those bits. That should be next on the agenda after the transforms.

@maximilian-gelbrecht
Copy link
Member Author

maximilian-gelbrecht commented Oct 30, 2024

So, all my direct tests pass. This is moving towards being ready.

Left to do:

  • Figure out why EnzymeTestUtils is stuck
  • Do we want to keep all of these tests in the CI? Right now for one grid type it's 8min. If we add more differentiability tests later, it might get too much. Enzyme is fast, but Finite Differences are slow. That's why it takes so long
  • Clear up some open questions about the usage of Enzyme/AD here and To-Dos along the code

@milankl
Copy link
Member

milankl commented Oct 30, 2024

Do we want to keep all of these tests in the CI? Right now for one grid type it's 8min. If we add more differentiability tests later, it might get too much. Enzyme is fast, but Finite Differences are slow. That's why it takes so long

Hmmm it would be good if the tests aren't adding massive time for all other development but they are obviously crucial for differentiability testing. Not sure how we can achieve that. Could we only keep the test without FiniteDifferences? Use a lower resolution transform? You're in a better decision than me to decide.

@maximilian-gelbrecht
Copy link
Member Author

Ah yes, of course, you are right. So far I was even using 8 layers in the test, which is just a bit stupid.

@maximilian-gelbrecht
Copy link
Member Author

Now it's much faster, I think we can keep them like that in the CI

@maximilian-gelbrecht
Copy link
Member Author

Okay, I feel a bit stupid, haha. EnzymeTestUtils is also not stuck anymore. Before I was letting it run longer than 2 hours, now it finishes in less in 1 second for a lower dimensional problem.

In order the tests of EnzymeTestUtils to pass I had to increase the tolerance a bit from the default (to 1e-4) and increase the accuracy of the FD. For my QG3.jl I had to do the same, so that's not really a surprise to me.

@maximilian-gelbrecht
Copy link
Member Author

The tests pass on my laptop, but not in the CI. I'll investigate this a bit later.

@maximilian-gelbrecht
Copy link
Member Author

maximilian-gelbrecht commented Oct 31, 2024

All tests pass also in the 1.10 CI for the FullGaussianGrid.

For the OctahedralGaussianGrid, it's a bit weird.

  • The FD comparisons of the full transforms that I wrote, pass
  • EnzymeTestUtils passes most tests, especially the relevant ones (derivative of the function wrt to its relevant input)
  • But the test of the derivative wrt the pre-allocated f_north,f_south don't pass. We set those manually to zero, as is recommended by the Enzyme docs because f_north is overwritten by _fourier!, so the derivative wrt to that input is zero. But FiniteDifferences.jl computes a few non-zero entries in these derivative matrices. This doesn't happen for the FullGaussianGrid. It's a bit odd, but we know that those derivatives should be zero, so Enzyme is right. What I imagine what might go wrong is that for reduced grids, we don't fully use the pre-allocated arrays. The parts of the array that we don't overwrite should just stay zero though, so a zero derivative is also correct there, but somehow FiniteDifferences seems to struggle with that and return some non-zero derivatives for those spurious parts of the scratch memory.

I'd wait to merge this until everything works fine with Julia 1.11 as well. It doesn't cause any errors outside of using Enzyme, but still.

@maximilian-gelbrecht maximilian-gelbrecht marked this pull request as ready for review November 6, 2024 14:03
@maximilian-gelbrecht
Copy link
Member Author

maximilian-gelbrecht commented Nov 6, 2024

With the lastest update to Enzyme the tests now pass on 1.11 as well. Hence, I marked this now as a proper PR and not a draft anymore.

Copy link
Contributor

@vchuravy vchuravy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks fine to me.

@michel2323 was looking at custom rules for FFT IIRC

@maximilian-gelbrecht maximilian-gelbrecht merged commit 7e8d01d into main Nov 29, 2024
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
differentiability 🤖 Making the model differentiable via AD transform ⬅️ ➡️ Our spherical harmonic transform, grid to spectral, spectral to grid
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants