Improve stacktraces by using custom tag for ForwardDiff (#1841)
* Improve stacktraces by using custom tag for ForwardDiff

* Fix typo

* Additional fixes

* Simplify code and define `LogDensityFunction`

* A bit simpler

* Add tests
devmotion authored Jun 21, 2022
1 parent 8adfa22 commit 9951638
Showing 7 changed files with 104 additions and 78 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.21.5"
version = "0.21.6"

AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -11,6 +11,7 @@ AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand Down Expand Up @@ -42,6 +43,7 @@ AdvancedVI = "0.1"
BangBang = "0.3"
Bijectors = "0.8, 0.9, 0.10"
DataStructures = "0.18"
DiffResults = "1"
Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8"
Expand Down
19 changes: 19 additions & 0 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,25 @@ function setprogress!(progress::Bool)
return progress

# Log density function
struct LogDensityFunction{V,M,S,C}

function (f::LogDensityFunction)(θ::AbstractVector)
return getlogp(last(DynamicPPL.evaluate!!(f.model, VarInfo(f.varinfo, f.sampler, θ), f.sampler, f.context)))

# Standard tag: Improves stacktraces
# Ref:
struct TuringTag end

# Allow Turing tag in gradient etc. calls of the log density function
ForwardDiff.checktag(::Type{ForwardDiff.Tag{TuringTag, V}}, ::LogDensityFunction, ::AbstractArray{V}) where {V} = true

# Random probability measures.
Expand Down
1 change: 1 addition & 0 deletions src/essential/Essential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using StatsFuns: logsumexp, softmax
using Requires

import AdvancedPS
import DiffResults
import ZygoteRules

Expand Down
86 changes: 41 additions & 45 deletions src/essential/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,18 @@ function setchunksize(chunk_size::Int)

abstract type ADBackend end
struct ForwardDiffAD{chunk} <: ADBackend end
struct ForwardDiffAD{chunk,standardtag} <: ADBackend end

# Use standard tag if not specified otherwise
ForwardDiffAD{N}() where {N} = ForwardDiffAD{N,true}()

getchunksize(::Type{<:ForwardDiffAD{chunk}}) where chunk = chunk
getchunksize(::Type{<:Sampler{Talg}}) where Talg = getchunksize(Talg)
getchunksize(::Type{SampleFromPrior}) = CHUNKSIZE[]

standardtag(::ForwardDiffAD{<:Any,true}) = true
standardtag(::ForwardDiffAD) = false

struct TrackerAD <: ADBackend end
struct ZygoteAD <: ADBackend end

Expand Down Expand Up @@ -95,59 +102,54 @@ Compute the value of the log joint of `θ` and its gradient for the model
specified by `(vi, sampler, model)` using `backend` for AD, e.g. `ForwardDiffAD{N}()` uses `ForwardDiff.jl` with chunk size `N`, `TrackerAD()` uses `Tracker.jl` and `ZygoteAD()` uses `Zygote.jl`.
function gradient_logp(
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
# Define function to compute log joint.
logp_old = getlogp(vi)
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
new_vi = last(DynamicPPL.evaluate!!(model, new_vi, sampler, ctx))
logp = getlogp(new_vi)
# Don't need to capture the resulting `vi` since this is only
# needed if `vi` is mutable.
setlogp!!(vi, ForwardDiff.value(logp))
return logp
# Define log density function.
f = Turing.LogDensityFunction(vi, model, sampler, context)

# Set chunk size and do ForwardMode.
chunk_size = getchunksize(typeof(sampler))
# Define configuration for ForwardDiff.
tag = if standardtag(ad)
ForwardDiff.Tag(Turing.TuringTag(), eltype(θ))
ForwardDiff.Tag(f, eltype(θ))
chunk_size = getchunksize(typeof(ad))
config = if chunk_size == 0
ForwardDiff.GradientConfig(f, θ)
ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(θ), tag)
ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size))
ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size), tag)
∂l∂θ = ForwardDiff.gradient!(similar(θ), f, θ, config)
l = getlogp(vi)
setlogp!!(vi, logp_old)

return l, ∂l∂θ
# Obtain both value and gradient of the log density function.
out = DiffResults.GradientResult(θ)
ForwardDiff.gradient!(out, f, θ, config)
logp = DiffResults.value(out)
∂logp∂θ = DiffResults.gradient(out)

return logp, ∂logp∂θ
function gradient_logp(
sampler::AbstractSampler = SampleFromPrior(),
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
T = typeof(getlogp(vi))

# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
new_vi = last(DynamicPPL.evaluate!!(model, new_vi, sampler, ctx))
return getlogp(new_vi)
# Define log density function.
f = Turing.LogDensityFunction(vi, model, sampler, context)

# Compute forward and reverse passes.
# Compute forward pass and pullback.
l_tracked, ȳ = Tracker.forward(f, θ)
# Remove tracking info from variables in model (because mutable state).
l::T, ∂l∂θ::typeof(θ) =,[1])

# Remove tracking info.
l::typeof(getlogp(vi)) =
∂l∂θ::typeof(θ) =

return l, ∂l∂θ
Expand All @@ -160,18 +162,12 @@ function gradient_logp(
sampler::AbstractSampler = SampleFromPrior(),
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
T = typeof(getlogp(vi))

# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
new_vi = last(DynamicPPL.evaluate!!(model, new_vi, sampler, context))
return getlogp(new_vi)
# Define log density function.
f = Turing.LogDensityFunction(vi, model, sampler, context)

# Compute forward and reverse passes.
l::T, ȳ = ZygoteRules.pullback(f, θ)
∂l∂θ::typeof(θ) = (1)[1]
# Compute forward pass and pullback.
l::typeof(getlogp(vi)), ȳ = ZygoteRules.pullback(f, θ)
∂l∂θ::typeof(θ) = only((1))

return l, ∂l∂θ
Expand Down
39 changes: 14 additions & 25 deletions src/essential/compat/reversediff.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using .ReverseDiff: compile, GradientTape
using .ReverseDiff.DiffResults: GradientResult

struct ReverseDiffAD{cache} <: ADBackend end
const RDCache = Ref(false)
Expand All @@ -22,26 +21,20 @@ function gradient_logp(
sampler::AbstractSampler = SampleFromPrior(),
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
T = typeof(getlogp(vi))
# Define log density function.
f = Turing.LogDensityFunction(vi, model, sampler, context)

# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
model(new_vi, sampler, context)
return getlogp(new_vi)
# Obtain both value and gradient of the log density function.
tp, result = taperesult(f, θ)
ReverseDiff.gradient!(result, tp, θ)
l = DiffResults.value(result)
l∂θ::typeof(θ) = DiffResults.gradient(result)
logp = DiffResults.value(result)
logp∂θ = DiffResults.gradient(result)

return l, ∂l∂θ
return logp, ∂logp∂θ

tape(f, x) = GradientTape(f, x)
function taperesult(f, x)
return tape(f, x), GradientResult(x)
taperesult(f, x) = (tape(f, x), DiffResults.GradientResult(x))

@require Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" @eval begin
setrdcache(::Val{true}) = RDCache[] = true
Expand All @@ -58,20 +51,16 @@ end
sampler::AbstractSampler = SampleFromPrior(),
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
T = typeof(getlogp(vi))
# Define log density function.
f = Turing.LogDensityFunction(vi, model, sampler, context)

# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
model(new_vi, sampler, context)
return getlogp(new_vi)
# Obtain both value and gradient of the log density function.
ctp, result = memoized_taperesult(f, θ)
ReverseDiff.gradient!(result, ctp, θ)
l = DiffResults.value(result)
l∂θ = DiffResults.gradient(result)
logp = DiffResults.value(result)
logp∂θ = DiffResults.gradient(result)

return l, ∂l∂θ
return logp, ∂logp∂θ

# This makes sure we generate a single tape per Turing model and sampler
Expand All @@ -85,7 +74,7 @@ end
memoized_taperesult(f, x) = memoized_taperesult(RDTapeKey(f, x))
Memoization.@memoize Dict function memoized_taperesult(k::RDTapeKey)
return compiledtape(k.f, k.x), GradientResult(k.x)
return compiledtape(k.f, k.x), DiffResults.GradientResult(k.x)
compiledtape(f, x) = compile(GradientTape(f, x))
12 changes: 12 additions & 0 deletions test/essential/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,16 @@
@test Turing.CHUNKSIZE[] == 0
@test Turing.AdvancedVI.CHUNKSIZE[] == 0

@testset "tag" begin
@test Turing.ADBackend(Val(:forwarddiff))() === Turing.ForwardDiffAD{Turing.CHUNKSIZE[],true}()
for chunksize in (0, 1, 10)
ad = Turing.ForwardDiffAD{chunksize}()
@test ad === Turing.ForwardDiffAD{chunksize,true}()
@test Turing.Essential.standardtag(ad)
for standardtag in (false, 0, 1)
@test !Turing.Essential.standardtag(Turing.ForwardDiffAD{chunksize,standardtag}())
21 changes: 14 additions & 7 deletions test/test_utils/ad_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,20 @@ function test_model_ad(model, f, syms::Vector{Symbol})

spl = SampleFromPrior()
_, ∇E = gradient_logp(ForwardDiffAD{1}(), vi[spl], vi, model)
grad_Turing = sort(∇E)
# Compute primal.
x = vec(vnvals)
logp = f(x)

# Call ForwardDiff's AD
grad_FWAD = sort(ForwardDiff.gradient(f, vec(vnvals)))
# Call ForwardDiff's AD directly.
grad_FWAD = sort(ForwardDiff.gradient(f, x))

# Compare result
@test grad_Turing grad_FWAD atol=1e-9
# Compare with `gradient_logp`.
z = vi[SampleFromPrior()]
for chunksize in (0, 1, 10), standardtag in (true, false, 0, 3)
l, ∇E = gradient_logp(ForwardDiffAD{chunksize, standardtag}(), z, vi, model)

# Compare result
@test l logp
@test sort(∇E) grad_FWAD atol=1e-9

2 comments on commit 9951638

