Skip to content

Commit

Permalink
Improve stacktraces by using custom tag for ForwardDiff (#1841)
Browse files Browse the repository at this point in the history
* Improve stacktraces by using custom tag for ForwardDiff

* Fix typo

* Additional fixes

* Simplify code and define `LogDensityFunction`

* A bit simpler

* Add tests
  • Loading branch information
devmotion authored Jun 21, 2022
1 parent 8adfa22 commit 9951638
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 78 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.21.5"
version = "0.21.6"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -11,6 +11,7 @@ AdvancedVI = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand Down Expand Up @@ -42,6 +43,7 @@ AdvancedVI = "0.1"
BangBang = "0.3"
Bijectors = "0.8, 0.9, 0.10"
DataStructures = "0.18"
DiffResults = "1"
Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8"
Expand Down
19 changes: 19 additions & 0 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,25 @@ function setprogress!(progress::Bool)
return progress
end

# Log density function
struct LogDensityFunction{V,M,S,C}
varinfo::V
model::M
sampler::S
context::C
end

function (f::LogDensityFunction)(θ::AbstractVector)
return getlogp(last(DynamicPPL.evaluate!!(f.model, VarInfo(f.varinfo, f.sampler, θ), f.sampler, f.context)))
end

# Standard tag: Improves stacktraces
# Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
struct TuringTag end

# Allow Turing tag in gradient etc. calls of the log density function
ForwardDiff.checktag(::Type{ForwardDiff.Tag{TuringTag, V}}, ::LogDensityFunction, ::AbstractArray{V}) where {V} = true

# Random probability measures.
include("stdlib/distributions.jl")
include("stdlib/RandomMeasures.jl")
Expand Down
1 change: 1 addition & 0 deletions src/essential/Essential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using StatsFuns: logsumexp, softmax
using Requires

import AdvancedPS
import DiffResults
import ZygoteRules

include("container.jl")
Expand Down
86 changes: 41 additions & 45 deletions src/essential/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,18 @@ function setchunksize(chunk_size::Int)
end

abstract type ADBackend end
struct ForwardDiffAD{chunk} <: ADBackend end
struct ForwardDiffAD{chunk,standardtag} <: ADBackend end

# Use standard tag if not specified otherwise
ForwardDiffAD{N}() where {N} = ForwardDiffAD{N,true}()

getchunksize(::Type{<:ForwardDiffAD{chunk}}) where chunk = chunk
getchunksize(::Type{<:Sampler{Talg}}) where Talg = getchunksize(Talg)
getchunksize(::Type{SampleFromPrior}) = CHUNKSIZE[]

standardtag(::ForwardDiffAD{<:Any,true}) = true
standardtag(::ForwardDiffAD) = false

struct TrackerAD <: ADBackend end
struct ZygoteAD <: ADBackend end

Expand Down Expand Up @@ -95,59 +102,54 @@ Compute the value of the log joint of `θ` and its gradient for the model
specified by `(vi, sampler, model)` using `backend` for AD, e.g. `ForwardDiffAD{N}()` uses `ForwardDiff.jl` with chunk size `N`, `TrackerAD()` uses `Tracker.jl` and `ZygoteAD()` uses `Zygote.jl`.
"""
function gradient_logp(
::ForwardDiffAD,
ad::ForwardDiffAD,
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Model,
sampler::AbstractSampler=SampleFromPrior(),
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
)
# Define function to compute log joint.
logp_old = getlogp(vi)
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
new_vi = last(DynamicPPL.evaluate!!(model, new_vi, sampler, ctx))
logp = getlogp(new_vi)
# Don't need to capture the resulting `vi` since this is only
# needed if `vi` is mutable.
setlogp!!(vi, ForwardDiff.value(logp))
return logp
end
# Define log density function.
f = Turing.LogDensityFunction(vi, model, sampler, context)

# Set chunk size and do ForwardMode.
chunk_size = getchunksize(typeof(sampler))
# Define configuration for ForwardDiff.
tag = if standardtag(ad)
ForwardDiff.Tag(Turing.TuringTag(), eltype(θ))
else
ForwardDiff.Tag(f, eltype(θ))
end
chunk_size = getchunksize(typeof(ad))
config = if chunk_size == 0
ForwardDiff.GradientConfig(f, θ)
ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(θ), tag)
else
ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size))
ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size), tag)
end
∂l∂θ = ForwardDiff.gradient!(similar(θ), f, θ, config)
l = getlogp(vi)
setlogp!!(vi, logp_old)

return l, ∂l∂θ
# Obtain both value and gradient of the log density function.
out = DiffResults.GradientResult(θ)
ForwardDiff.gradient!(out, f, θ, config)
logp = DiffResults.value(out)
∂logp∂θ = DiffResults.gradient(out)

return logp, ∂logp∂θ
end
function gradient_logp(
::TrackerAD,
θ::AbstractVector{<:Real},
vi::VarInfo,
model::Model,
sampler::AbstractSampler = SampleFromPrior(),
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
)
T = typeof(getlogp(vi))

# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
new_vi = last(DynamicPPL.evaluate!!(model, new_vi, sampler, ctx))
return getlogp(new_vi)
end
# Define log density function.
f = Turing.LogDensityFunction(vi, model, sampler, context)

# Compute forward and reverse passes.
# Compute forward pass and pullback.
l_tracked, ȳ = Tracker.forward(f, θ)
# Remove tracking info from variables in model (because mutable state).
l::T, ∂l∂θ::typeof(θ) = Tracker.data(l_tracked), Tracker.data((1)[1])

# Remove tracking info.
l::typeof(getlogp(vi)) = Tracker.data(l_tracked)
∂l∂θ::typeof(θ) = Tracker.data(only((1)))

return l, ∂l∂θ
end
Expand All @@ -160,18 +162,12 @@ function gradient_logp(
sampler::AbstractSampler = SampleFromPrior(),
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
)
T = typeof(getlogp(vi))

# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
new_vi = last(DynamicPPL.evaluate!!(model, new_vi, sampler, context))
return getlogp(new_vi)
end
# Define log density function.
f = Turing.LogDensityFunction(vi, model, sampler, context)

# Compute forward and reverse passes.
l::T, ȳ = ZygoteRules.pullback(f, θ)
∂l∂θ::typeof(θ) = (1)[1]
# Compute forward pass and pullback.
l::typeof(getlogp(vi)), ȳ = ZygoteRules.pullback(f, θ)
∂l∂θ::typeof(θ) = only((1))

return l, ∂l∂θ
end
Expand Down
39 changes: 14 additions & 25 deletions src/essential/compat/reversediff.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using .ReverseDiff: compile, GradientTape
using .ReverseDiff.DiffResults: GradientResult

struct ReverseDiffAD{cache} <: ADBackend end
const RDCache = Ref(false)
Expand All @@ -22,26 +21,20 @@ function gradient_logp(
sampler::AbstractSampler = SampleFromPrior(),
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
)
T = typeof(getlogp(vi))
# Define log density function.
f = Turing.LogDensityFunction(vi, model, sampler, context)

# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
model(new_vi, sampler, context)
return getlogp(new_vi)
end
# Obtain both value and gradient of the log density function.
tp, result = taperesult(f, θ)
ReverseDiff.gradient!(result, tp, θ)
l = DiffResults.value(result)
l∂θ::typeof(θ) = DiffResults.gradient(result)
logp = DiffResults.value(result)
logp∂θ = DiffResults.gradient(result)

return l, ∂l∂θ
return logp, ∂logp∂θ
end

tape(f, x) = GradientTape(f, x)
function taperesult(f, x)
return tape(f, x), GradientResult(x)
end
taperesult(f, x) = (tape(f, x), DiffResults.GradientResult(x))

@require Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" @eval begin
setrdcache(::Val{true}) = RDCache[] = true
Expand All @@ -58,20 +51,16 @@ end
sampler::AbstractSampler = SampleFromPrior(),
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
)
T = typeof(getlogp(vi))
# Define log density function.
f = Turing.LogDensityFunction(vi, model, sampler, context)

# Specify objective function.
function f(θ)
new_vi = VarInfo(vi, sampler, θ)
model(new_vi, sampler, context)
return getlogp(new_vi)
end
# Obtain both value and gradient of the log density function.
ctp, result = memoized_taperesult(f, θ)
ReverseDiff.gradient!(result, ctp, θ)
l = DiffResults.value(result)
l∂θ = DiffResults.gradient(result)
logp = DiffResults.value(result)
logp∂θ = DiffResults.gradient(result)

return l, ∂l∂θ
return logp, ∂logp∂θ
end

# This makes sure we generate a single tape per Turing model and sampler
Expand All @@ -85,7 +74,7 @@ end
end
memoized_taperesult(f, x) = memoized_taperesult(RDTapeKey(f, x))
Memoization.@memoize Dict function memoized_taperesult(k::RDTapeKey)
return compiledtape(k.f, k.x), GradientResult(k.x)
return compiledtape(k.f, k.x), DiffResults.GradientResult(k.x)
end
compiledtape(f, x) = compile(GradientTape(f, x))
end
12 changes: 12 additions & 0 deletions test/essential/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,16 @@
@test Turing.CHUNKSIZE[] == 0
@test Turing.AdvancedVI.CHUNKSIZE[] == 0
end

@testset "tag" begin
@test Turing.ADBackend(Val(:forwarddiff))() === Turing.ForwardDiffAD{Turing.CHUNKSIZE[],true}()
for chunksize in (0, 1, 10)
ad = Turing.ForwardDiffAD{chunksize}()
@test ad === Turing.ForwardDiffAD{chunksize,true}()
@test Turing.Essential.standardtag(ad)
for standardtag in (false, 0, 1)
@test !Turing.Essential.standardtag(Turing.ForwardDiffAD{chunksize,standardtag}())
end
end
end
end
21 changes: 14 additions & 7 deletions test/test_utils/ad_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,20 @@ function test_model_ad(model, f, syms::Vector{Symbol})
end
end

spl = SampleFromPrior()
_, ∇E = gradient_logp(ForwardDiffAD{1}(), vi[spl], vi, model)
grad_Turing = sort(∇E)
# Compute primal.
x = vec(vnvals)
logp = f(x)

# Call ForwardDiff's AD
grad_FWAD = sort(ForwardDiff.gradient(f, vec(vnvals)))
# Call ForwardDiff's AD directly.
grad_FWAD = sort(ForwardDiff.gradient(f, x))

# Compare result
@test grad_Turing grad_FWAD atol=1e-9
# Compare with `gradient_logp`.
z = vi[SampleFromPrior()]
for chunksize in (0, 1, 10), standardtag in (true, false, 0, 3)
l, ∇E = gradient_logp(ForwardDiffAD{chunksize, standardtag}(), z, vi, model)

# Compare result
@test l logp
@test sort(∇E) grad_FWAD atol=1e-9
end
end

2 comments on commit 9951638

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/62810

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.21.6 -m "<description of version>" 9951638916a4372bef31f20c723224d49e1f4124
git push origin v0.21.6

Please sign in to comment.