From d54d65a61fdb971ef99b36cb8f5c2a02ba4e206b Mon Sep 17 00:00:00 2001 From: Max Date: Fri, 13 Dec 2024 22:05:49 +0100 Subject: [PATCH 1/3] WIP spectral gradients gradients --- src/SpeedyTransforms/spectral_gradients.jl | 10 ++-- test/spectral_transform_ad_rules.jl | 53 ++++++++++++++++++++++ 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/src/SpeedyTransforms/spectral_gradients.jl b/src/SpeedyTransforms/spectral_gradients.jl index d5889c3c6..0bf554643 100644 --- a/src/SpeedyTransforms/spectral_gradients.jl +++ b/src/SpeedyTransforms/spectral_gradients.jl @@ -18,9 +18,9 @@ function curl!( add::Bool=false, kwargs..., ) - # = -(∂λ - ∂θ) or (∂λ - ∂θ), adding or overwriting the output curl - kernel(o, a, b, c) = flipsign ? (add ? o-(a+b-c) : -(a+b-c)) : - (add ? o+(a+b-c) : a+b-c ) + # = -(∂λ - ∂θ) or (∂λ - ∂θ), adding or overwriting the output curl + kernel = flipsign ? (add ? (o, a, b, c) -> o-(a+b-c) : (o, a, b, c) -> -(a+b-c)) : + (add ? (o, a, b, c) -> o+(a+b-c) : (o, a, b, c) -> a+b-c) _divergence!(kernel, curl, v, u, S; kwargs...) # flip u, v -> v, u end @@ -42,8 +42,8 @@ function divergence!( kwargs..., ) # = -(∂λ + ∂θ) or (∂λ + ∂θ), adding or overwriting the output div - kernel(o, a, b, c) = flipsign ? (add ? o-(a-b+c) : -(a-b+c)) : - (add ? o+(a-b+c) : a-b+c ) + kernel = flipsign ? (add ? (o, a, b, c) -> o-(a-b+c) : (o, a, b, c) -> -(a-b+c)) : + (add ? (o, a, b, c) -> o+(a-b+c) : (o, a, b, c) -> a-b+c) _divergence!(kernel, div, u, v, S; kwargs...) end diff --git a/test/spectral_transform_ad_rules.jl b/test/spectral_transform_ad_rules.jl index 4d94f32ac..bcda54ccc 100644 --- a/test/spectral_transform_ad_rules.jl +++ b/test/spectral_transform_ad_rules.jl @@ -184,6 +184,59 @@ end end end end + + @testset "Spectral Gradient Enzyme" begin + for (i_grid, grid_type) in enumerate(grid_types) + + if fd_tests[i_grid] + + spectral_grid = SpectralGrid(Grid=grid_type, trunc=10, nlayers=1, dealiasing=grid_dealiasing[i_grid]) + S = SpectralTransform(spectral_grid, one_more_degree=true) + dS = deepcopy(S) + + u_grid = rand(spectral_grid.Grid{spectral_grid.NF}, spectral_grid.nlat_half, spectral_grid.nlayers) + v_grid = rand(spectral_grid.Grid{spectral_grid.NF}, spectral_grid.nlat_half, spectral_grid.nlayers) + + u = transform(u_grid, S) + v = transform(v_grid, S) + du = zero(u) + dv = zero(v) + + cu = zero(u) + dcu = zero(u) + fill!(dcu, 1+1im) + + # curl test + autodiff(Reverse, curl!, Const, Duplicated(cu, dcu), Duplicated(u, du), Duplicated(v, dv), Duplicated(S, dS)) + + # new seed + dcu2 = zero(dcu) + fill!(dcu2, 1+1im) + + # finite difference comparision, seeded with a one adjoint to get the direct gradient + fd_jvp = FiniteDifferences.j′vp(central_fdm(5,1), x -> curl(x[1],x[2], S), dcu2, (u, v)) + @test isapprox(du, fd_jvp[1][1]) + @test isapprox(dv, fd_jvp[1][2]) + + # div test + + du = zero(u) + dv = zero(v) + div = zero(u) + ddiv = zero(u) + fill!(ddiv, 1+1im) + + autodiff(Reverse, divergence!, Const, Duplicated(div, ddiv), Duplicated(u, du), Duplicated(v, dv), Duplicated(S, dS)) + + ddiv2 = zero(ddiv) + fill!(ddiv, 1+1im) + + fd_jvp = FiniteDifferences.j′vp(central_fdm(5,1), x -> divergence(x[1],x[2], S), ddiv2, (u, v)) + @test isapprox(du, fd_jvp[1][1]) + @test isapprox(dv, fd_jvp[1][2]) + end + end + end end @testset "Complete Transform ChainRules" begin From 83d7b384b839d31edf7ac4b4546c2cb27ca5a19e Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 17 Dec 2024 09:00:24 +0100 Subject: [PATCH 2/3] WIP spectral gradient tests --- test/spectral_transform_ad_rules.jl | 40 +++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/test/spectral_transform_ad_rules.jl b/test/spectral_transform_ad_rules.jl index bcda54ccc..3a3d24a17 100644 --- a/test/spectral_transform_ad_rules.jl +++ b/test/spectral_transform_ad_rules.jl @@ -9,6 +9,9 @@ grid_types = [FullGaussianGrid, OctahedralGaussianGrid] # one full and one reduc grid_dealiasing = [2, 3] fd_tests = [true, true] +i_grid = 1 +grid_type = grid_types[i_grid] + # currenlty there's an issue with EnzymeTestUtils not being able to work with structs with undefined fields like FFT plans # https://github.com/EnzymeAD/Enzyme.jl/issues/1992 # This is a very hacky workaround @@ -234,6 +237,43 @@ end fd_jvp = FiniteDifferences.j′vp(central_fdm(5,1), x -> divergence(x[1],x[2], S), ddiv2, (u, v)) @test isapprox(du, fd_jvp[1][1]) @test isapprox(dv, fd_jvp[1][2]) + + # UV_from_vor! + + u = zero(u) + du = fill!(du, 1+1im) + + v = zero(v) + dv = fill!(dv, 1+1im) + + vor_grid = rand(spectral_grid.Grid{spectral_grid.NF}, spectral_grid.nlat_half, spectral_grid.nlayers) + vor = transform(vor_grid, S) + dvor = zero(vor) + + autodiff(Reverse, SpeedyWeather.SpeedyTransforms.UV_from_vor!, Const, Duplicated(u, du), Duplicated(v, dv), Duplicated(vor, dvor), Duplicated(S, dS)) + + dvor = zero(dvor) + fill!(dvor, 1+1im) + + function uvfvor(vor, S) + u = zero(vor) + v = zero(vor) + SpeedyWeather.SpeedyTransforms.UV_from_vor!(u, v, vor, S) + return () + end + + + fd_jvp = FiniteDifferences.j′vp(central_fdm(5,1), x -> SpeedyWeather.SpeedyTransforms.UV_from_vor!(u, v, x, S), ) + @test isapprox(du, fd_jvp[1][1]) + + # Δ + + + + # ∇ + + + end end end From e5d75f407cac57d442c8cffab0fd8e943b46613b Mon Sep 17 00:00:00 2001 From: Max Date: Tue, 17 Dec 2024 13:50:59 +0100 Subject: [PATCH 3/3] more wip --- test/spectral_transform_ad_rules.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/test/spectral_transform_ad_rules.jl b/test/spectral_transform_ad_rules.jl index 3a3d24a17..b0d60271d 100644 --- a/test/spectral_transform_ad_rules.jl +++ b/test/spectral_transform_ad_rules.jl @@ -259,12 +259,14 @@ end u = zero(vor) v = zero(vor) SpeedyWeather.SpeedyTransforms.UV_from_vor!(u, v, vor, S) - return () + return cat(u, v, dims=2) end + uv_input = cat(u, v, dims=2) + duv_input = zero(uv_input) - fd_jvp = FiniteDifferences.j′vp(central_fdm(5,1), x -> SpeedyWeather.SpeedyTransforms.UV_from_vor!(u, v, x, S), ) - @test isapprox(du, fd_jvp[1][1]) + fd_jvp = FiniteDifferences.j′vp(central_fdm(5,1), x -> uvfvor(x, S), duv_input, vor) + @test isapprox(du, fd_jvp[1]) # Δ