Skip to content

Commit

Permalink
Fix CI (#1992)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored May 16, 2023
1 parent 316fc3b commit 6ee75cc
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions test/essential/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,21 @@
x = map(x->Float64(x), vi[SampleFromPrior()])

trackerℓ = LogDensityProblemsAD.ADgradient(TrackerAD(), ℓ)
@test trackerℓ isa LogDensityProblemsAD.LogDensityProblemsADTrackerExt.TrackerGradientLogDensity
if isdefined(Base, :get_extension)
@test trackerℓ isa Base.get_extension(LogDensityProblemsAD, :LogDensityProblemsADTrackerExt).TrackerGradientLogDensity
else
@test trackerℓ isa LogDensityProblemsAD.LogDensityProblemsADTrackerExt.TrackerGradientLogDensity
end
@test trackerℓ.===
∇E1 = LogDensityProblems.logdensity_and_gradient(trackerℓ, x)[2]
@test sort(∇E1) grad_FWAD atol=1e-9

zygoteℓ = LogDensityProblemsAD.ADgradient(ZygoteAD(), ℓ)
@test zygoteℓ isa LogDensityProblemsAD.LogDensityProblemsADZygoteExt.ZygoteGradientLogDensity
if isdefined(Base, :get_extension)
@test zygoteℓ isa Base.get_extension(LogDensityProblemsAD, :LogDensityProblemsADZygoteExt).ZygoteGradientLogDensity
else
@test zygoteℓ isa LogDensityProblemsAD.LogDensityProblemsADZygoteExt.ZygoteGradientLogDensity
end
@test zygoteℓ.===
∇E2 = LogDensityProblems.logdensity_and_gradient(zygoteℓ, x)[2]
@test sort(∇E2) grad_FWAD atol=1e-9
Expand Down

0 comments on commit 6ee75cc

Please sign in to comment.