Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spectral Gradients Gradients #638

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/SpeedyTransforms/spectral_gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ function curl!(
add::Bool=false,
kwargs...,
)
# = -(∂λ - ∂θ) or (∂λ - ∂θ), adding or overwriting the output curl
kernel(o, a, b, c) = flipsign ? (add ? o-(a+b-c) : -(a+b-c)) :
(add ? o+(a+b-c) : a+b-c )
# = -(∂λ - ∂θ) or (∂λ - ∂θ), adding or overwriting the output curl
kernel = flipsign ? (add ? (o, a, b, c) -> o-(a+b-c) : (o, a, b, c) -> -(a+b-c)) :
(add ? (o, a, b, c) -> o+(a+b-c) : (o, a, b, c) -> a+b-c)
_divergence!(kernel, curl, v, u, S; kwargs...) # flip u, v -> v, u
end

Expand All @@ -42,8 +42,8 @@ function divergence!(
kwargs...,
)
# = -(∂λ + ∂θ) or (∂λ + ∂θ), adding or overwriting the output div
kernel(o, a, b, c) = flipsign ? (add ? o-(a-b+c) : -(a-b+c)) :
(add ? o+(a-b+c) : a-b+c )
kernel = flipsign ? (add ? (o, a, b, c) -> o-(a-b+c) : (o, a, b, c) -> -(a-b+c)) :
(add ? (o, a, b, c) -> o+(a-b+c) : (o, a, b, c) -> a-b+c)
_divergence!(kernel, div, u, v, S; kwargs...)
end

Expand Down
95 changes: 95 additions & 0 deletions test/spectral_transform_ad_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ grid_types = [FullGaussianGrid, OctahedralGaussianGrid] # one full and one reduc
grid_dealiasing = [2, 3]
fd_tests = [true, true]

i_grid = 1
grid_type = grid_types[i_grid]

# currenlty there's an issue with EnzymeTestUtils not being able to work with structs with undefined fields like FFT plans
# https://github.com/EnzymeAD/Enzyme.jl/issues/1992
# This is a very hacky workaround
Expand Down Expand Up @@ -184,6 +187,98 @@ end
end
end
end

@testset "Spectral Gradient Enzyme" begin
for (i_grid, grid_type) in enumerate(grid_types)

if fd_tests[i_grid]

spectral_grid = SpectralGrid(Grid=grid_type, trunc=10, nlayers=1, dealiasing=grid_dealiasing[i_grid])
S = SpectralTransform(spectral_grid, one_more_degree=true)
dS = deepcopy(S)

u_grid = rand(spectral_grid.Grid{spectral_grid.NF}, spectral_grid.nlat_half, spectral_grid.nlayers)
v_grid = rand(spectral_grid.Grid{spectral_grid.NF}, spectral_grid.nlat_half, spectral_grid.nlayers)

u = transform(u_grid, S)
v = transform(v_grid, S)
du = zero(u)
dv = zero(v)

cu = zero(u)
dcu = zero(u)
fill!(dcu, 1+1im)

# curl test
autodiff(Reverse, curl!, Const, Duplicated(cu, dcu), Duplicated(u, du), Duplicated(v, dv), Duplicated(S, dS))

# new seed
dcu2 = zero(dcu)
fill!(dcu2, 1+1im)

# finite difference comparision, seeded with a one adjoint to get the direct gradient
fd_jvp = FiniteDifferences.j′vp(central_fdm(5,1), x -> curl(x[1],x[2], S), dcu2, (u, v))
@test isapprox(du, fd_jvp[1][1])
@test isapprox(dv, fd_jvp[1][2])

# div test

du = zero(u)
dv = zero(v)
div = zero(u)
ddiv = zero(u)
fill!(ddiv, 1+1im)

autodiff(Reverse, divergence!, Const, Duplicated(div, ddiv), Duplicated(u, du), Duplicated(v, dv), Duplicated(S, dS))

ddiv2 = zero(ddiv)
fill!(ddiv, 1+1im)

fd_jvp = FiniteDifferences.j′vp(central_fdm(5,1), x -> divergence(x[1],x[2], S), ddiv2, (u, v))
@test isapprox(du, fd_jvp[1][1])
@test isapprox(dv, fd_jvp[1][2])

# UV_from_vor!

u = zero(u)
du = fill!(du, 1+1im)

v = zero(v)
dv = fill!(dv, 1+1im)

vor_grid = rand(spectral_grid.Grid{spectral_grid.NF}, spectral_grid.nlat_half, spectral_grid.nlayers)
vor = transform(vor_grid, S)
dvor = zero(vor)

autodiff(Reverse, SpeedyWeather.SpeedyTransforms.UV_from_vor!, Const, Duplicated(u, du), Duplicated(v, dv), Duplicated(vor, dvor), Duplicated(S, dS))

dvor = zero(dvor)
fill!(dvor, 1+1im)

function uvfvor(vor, S)
u = zero(vor)
v = zero(vor)
SpeedyWeather.SpeedyTransforms.UV_from_vor!(u, v, vor, S)
return cat(u, v, dims=2)
end

uv_input = cat(u, v, dims=2)
duv_input = zero(uv_input)

fd_jvp = FiniteDifferences.j′vp(central_fdm(5,1), x -> uvfvor(x, S), duv_input, vor)
@test isapprox(du, fd_jvp[1])

# Δ



# ∇



end
end
end
end

@testset "Complete Transform ChainRules" begin
Expand Down
Loading