Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with new DPPL version #1900

Merged
merged 37 commits into from
Nov 12, 2022
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
b3973a0
use unflatten in evaluation of LogDensityFunction
torfjelde Jul 7, 2022
b377ea1
make AD-related functions able to take AbstractVarInfo
torfjelde Jul 7, 2022
977377a
use unflatten where appropriate
torfjelde Jul 7, 2022
276a39b
updated Gibbs
torfjelde Jul 7, 2022
bf8ec74
updated HMC
torfjelde Jul 7, 2022
9d41506
move to using BangBang versions of link and invlink
torfjelde Jul 18, 2022
0ead42f
use link!!
torfjelde Jul 22, 2022
9b8e937
update tests to be compatible with new DynamicPPL.TestUtils
torfjelde Jul 23, 2022
e5d7168
updated deps for tests
torfjelde Jul 23, 2022
4857211
Merge branch 'tor/dppl-bump' into tor/unflatten
torfjelde Jul 23, 2022
205c8c3
fixed tests for ESS
torfjelde Jul 23, 2022
4962e1d
upper-bound distributions in tests because otherwise depwarns will
torfjelde Jul 23, 2022
e260720
Merge branch 'tor/dppl-bump' into tor/unflatten
torfjelde Jul 23, 2022
a2d73d3
replace link! with link!!, etc.
torfjelde Jul 25, 2022
b28b22f
added Setfield and updated optimization stuff
torfjelde Jul 25, 2022
3310eee
updated the contrib to use link!!, etc.
torfjelde Jul 27, 2022
3178bab
updated AD tests
torfjelde Jul 27, 2022
901211b
Merge branch 'master' into tor/unflatten
torfjelde Sep 7, 2022
6660a39
Merge branch 'master' into tor/unflatten
torfjelde Nov 1, 2022
b2139c3
updated DPPL versions
torfjelde Nov 4, 2022
13e445d
removed usage of deprecated inv
torfjelde Nov 7, 2022
66e773a
made some function signatures more restrictive
torfjelde Nov 8, 2022
574ef2d
Update src/inference/mh.jl
torfjelde Nov 9, 2022
9c08dda
Merge branch 'tor/unflatten' of github.com:TuringLang/Turing.jl into …
torfjelde Nov 10, 2022
c056b01
fixed MH sampler
torfjelde Nov 10, 2022
6b264dc
Merge branch 'master' into tor/unflatten
torfjelde Nov 10, 2022
67975b6
increase atol for certain tests to make them pass on MacOS
torfjelde Nov 10, 2022
5159ad0
reduce atol for a MH test
torfjelde Nov 11, 2022
4d5396d
disable emcee tests for now
torfjelde Nov 11, 2022
a0255b8
Update Project.toml
torfjelde Nov 11, 2022
130dbad
further reductions in atol to make tests pass
torfjelde Nov 11, 2022
e634b20
Merge branch 'tor/unflatten' of github.com:TuringLang/Turing.jl into …
torfjelde Nov 11, 2022
39dc618
Update test/runtests.jl
yebai Nov 11, 2022
286dbc0
Update mh.jl
yebai Nov 11, 2022
e5db993
restrict ForwardDiff for tests to avoid issue with cholesky
torfjelde Nov 12, 2022
0ab3dd5
Merge branch 'tor/unflatten' of github.com:TuringLang/Turing.jl into …
torfjelde Nov 12, 2022
ab69a21
increased number of samples and lowered atol for MH tests
torfjelde Nov 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -45,7 +46,7 @@ DataStructures = "0.18"
Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicPPL = "0.20"
DynamicPPL = "0.21"
EllipticalSliceSampling = "0.5, 1"
ForwardDiff = "0.10.3"
Libtask = "0.6.7, 0.7"
Expand All @@ -55,6 +56,7 @@ NamedArrays = "0.9"
Reexport = "0.2, 1"
Requires = "0.5, 1.0"
SciMLBase = "1.37.1"
Setfield = "0.8"
SpecialFunctions = "0.7.2, 0.8, 0.9, 0.10, 1, 2"
StatsBase = "0.32, 0.33"
StatsFuns = "0.8, 0.9, 1"
Expand Down
3 changes: 2 additions & 1 deletion src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ struct LogDensityFunction{V,M,S,C}
end

function (f::LogDensityFunction)(θ::AbstractVector)
return getlogp(last(DynamicPPL.evaluate!!(f.model, VarInfo(f.varinfo, f.sampler, θ), f.sampler, f.context)))
vi_new = DynamicPPL.unflatten(f.varinfo, f.sampler, θ)
return getlogp(last(DynamicPPL.evaluate!!(f.model, vi_new, f.sampler, f.context)))
end

# LogDensityProblems interface
Expand Down
12 changes: 6 additions & 6 deletions src/contrib/inference/dynamichmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ function DynamicPPL.initialstep(
)
# Ensure that initial sample is in unconstrained space.
if !DynamicPPL.islinked(vi, spl)
DynamicPPL.link!(vi, spl)
model(rng, vi, spl)
vi = DynamicPPL.link!!(vi, spl, model)
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remind me why we run the model here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And why we use the SamplingContext?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was done to update the logp after link!. But now link!! also includes the log-absdet-jacobian correction, so we could potentially drop it.

And regarding why we use the SamplingContext: didn't want to make changes to something I wasn't 100% certain would have no side-effects. But yes, I also think this is no longer necessary.

The question is if we should make these changes now or in another PR.

Copy link
Member

@devmotion devmotion Nov 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe safer to move these questions to a separate issue and potentially address them in a separate PR 👍

end

# Define log-density function.
Expand All @@ -79,8 +79,8 @@ function DynamicPPL.initialstep(
Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q)

# Update the variables.
vi[spl] = Q.q
DynamicPPL.setlogp!!(vi, Q.ℓq)
vi = DynamicPPL.setindex!!(vi, Q.q, spl)
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)

# Create first sample and state.
sample = Transition(vi)
Expand Down Expand Up @@ -109,8 +109,8 @@ function AbstractMCMC.step(
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)

# Update the variables.
vi[spl] = Q.q
DynamicPPL.setlogp!!(vi, Q.ℓq)
vi = DynamicPPL.setindex!!(vi, Q.q, spl)
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to redefine the log-density function since l.vi is potentially a different object?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oooo good point! Though, in that case we should probably just do away with the state.vi and use state.logdensity.varinfo, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I guess. It seems state.vi is not even used in the DynamicHMC calls above but only state.logendensity.varinfo implicitly. I assume currently the values in the state.logdensity.varinfo are not completely crucial though since in the DynamicHMC calls its values are updated anyway.


# Create next sample and state.
sample = Transition(vi)
Expand Down
16 changes: 8 additions & 8 deletions src/contrib/inference/sghmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ function DynamicPPL.initialstep(
)
# Transform the samples to unconstrained space and compute the joint log probability.
if !DynamicPPL.islinked(vi, spl)
DynamicPPL.link!(vi, spl)
model(rng, vi, spl)
vi = DynamicPPL.link!!(vi, spl, model)
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here?

end

# Compute initial sample and state.
Expand Down Expand Up @@ -90,8 +90,8 @@ function AbstractMCMC.step(
newv = (1 - α) .* v .+ η .* grad .+ sqrt(2 * η * α) .* randn(rng, eltype(v), length(v))

# Save new variables and recompute log density.
vi[spl] = θ
model(rng, vi, spl)
vi = DynamicPPL.setindex!!(vi, θ, spl)
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here


# Compute next sample and state.
sample = Transition(vi)
Expand Down Expand Up @@ -209,8 +209,8 @@ function DynamicPPL.initialstep(
)
# Transform the samples to unconstrained space and compute the joint log probability.
if !DynamicPPL.islinked(vi, spl)
DynamicPPL.link!(vi, spl)
model(rng, vi, spl)
vi = DynamicPPL.link!!(vi, spl, model)
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

end

# Create first sample and state.
Expand Down Expand Up @@ -238,8 +238,8 @@ function AbstractMCMC.step(
θ .+= (stepsize / 2) .* grad .+ sqrt(stepsize) .* randn(rng, eltype(θ), length(θ))

# Save new variables and recompute log density.
vi[spl] = θ
model(rng, vi, spl)
vi = DynamicPPL.setindex!!(vi, θ, spl)
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a new logdensity function?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And is sampling needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not need the SamplingContext in case this is used as part of a gibbs kernel?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in, in that case spl should only be dealing with a subset of the variables.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not have a sampler-specific evaluation context? It's mainly because we just evaluate (it seems) that I think passing around rng seems unintuitive.


# Compute next sample and state.
sample = SGLDTransition(vi, stepsize)
Expand Down
8 changes: 4 additions & 4 deletions src/inference/emcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function AbstractMCMC.step(
ArgumentError("initial parameters have to be specified for each walker")
)
vis = map(vis, init_params) do vi, init
vi = DynamicPPL.initialize_parameters!!(vi, init, spl)
vi = DynamicPPL.initialize_parameters!!(vi, init, spl, model)

# Update log joint probability.
last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromPrior()))
Expand All @@ -57,7 +57,7 @@ function AbstractMCMC.step(
state = EmceeState(
vis[1],
map(vis) do vi
DynamicPPL.link!(vi, spl)
vi = DynamicPPL.link!!(vi, spl, model)
AMH.Transition(vi[spl], getlogp(vi))
end
)
Expand All @@ -82,9 +82,9 @@ function AbstractMCMC.step(
# Compute the next transition and state.
transition = map(states) do _state
vi = setindex!!(vi, _state.params, spl)
DynamicPPL.invlink!(vi, spl)
vi = DynamicPPL.invlink!!(vi, spl, model)
t = Transition(tonamedtuple(vi), _state.lp)
DynamicPPL.link!(vi, spl)
vi = DynamicPPL.link!!(vi, spl, model)
return t
end
newstate = EmceeState(vi, states)
Expand Down
2 changes: 1 addition & 1 deletion src/inference/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ function DynamicPPL.initialstep(
states = map(samplers) do local_spl
# Recompute `vi.logp` if needed.
if local_spl.selector.rerun
model(rng, vi, local_spl)
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, local_spl)))
end

# Compute initial state.
Expand Down
22 changes: 12 additions & 10 deletions src/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ function DynamicPPL.initialstep(
kwargs...
)
# Transform the samples to unconstrained space and compute the joint log probability.
link!(vi, spl)
vi = last(DynamicPPL.evaluate!!(model, rng, vi, spl))
vi = link!!(vi, spl, model)

# Extract parameters.
theta = vi[spl]
Expand All @@ -173,8 +172,8 @@ function DynamicPPL.initialstep(
# and its gradient are finite.
if init_params === nothing
while !isfinite(z)
# NOTE: This will sample in the unconstrained space.
vi = last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromUniform()))
link!(vi, spl)
theta = vi[spl]

hamiltonian = AHMC.Hamiltonian(metric, logπ, ∂logπ∂θ)
Expand Down Expand Up @@ -210,10 +209,10 @@ function DynamicPPL.initialstep(

# Update `vi` based on acceptance
if t.stat.is_accept
vi = setindex!!(vi, t.z.θ, spl)
vi = DynamicPPL.unflatten(vi, spl, t.z.θ)
vi = setlogp!!(vi, t.stat.log_density)
else
vi = setindex!!(vi, theta, spl)
vi = DynamicPPL.unflatten(vi, spl, theta)
vi = setlogp!!(vi, log_density_old)
end

Expand Down Expand Up @@ -252,7 +251,7 @@ function AbstractMCMC.step(
# Update variables
vi = state.vi
if t.stat.is_accept
vi = setindex!!(vi, t.z.θ, spl)
vi = DynamicPPL.unflatten(vi, spl, t.z.θ)
vi = setlogp!!(vi, t.stat.log_density)
end

Expand Down Expand Up @@ -532,8 +531,9 @@ function HMCState(
kwargs...
)
# Link everything if needed.
if !islinked(vi, spl)
link!(vi, spl)
waslinked = islinked(vi, spl)
if !waslinked
vi = link!!(vi, spl, model)
end

# Get the initial log pdf and gradient functions.
Expand Down Expand Up @@ -562,8 +562,10 @@ function HMCState(
# Generate a phasepoint. Replaced during sample_init!
h, t = AHMC.sample_init(rng, h, θ_init) # this also ensure AHMC has the same dim as θ.

# Unlink everything.
invlink!(vi, spl)
# Unlink everything, if it was indeed linked before.
if waslinked
vi = invlink!!(vi, spl, model)
end

return HMCState(vi, 0, 0, kernel.τ, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z)
end
37 changes: 16 additions & 21 deletions src/inference/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,11 @@ end

Places the values of a `NamedTuple` into the relevant places of a `VarInfo`.
"""
function set_namedtuple!(vi::VarInfo, nt::NamedTuple)
function set_namedtuple!(vi::DynamicPPL.VarInfoOrThreadSafeVarInfo, nt::NamedTuple)
# TODO: Replace this with something like
# for vn in keys(vi)
# vi = DynamicPPL.setindex!!(vi, get(nt, vn))
# end
for (n, vals) in pairs(nt)
vns = vi.metadata[n].vns
nvns = length(vns)
Expand Down Expand Up @@ -245,6 +249,7 @@ This variant uses the `set_namedtuple!` function to update the `VarInfo`.
const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,S,DynamicPPL.DefaultContext}

function (f::MHLogDensityFunction)(x::NamedTuple)
# TODO: Make this work with immutable `f.varinfo` too.
sampler = f.sampler
vi = f.varinfo

Expand Down Expand Up @@ -286,14 +291,14 @@ function reconstruct(
end

"""
dist_val_tuple(spl::Sampler{<:MH}, vi::AbstractVarInfo)
dist_val_tuple(spl::Sampler{<:MH}, vi::VarInfo)

Return two `NamedTuples`.

The first `NamedTuple` has symbols as keys and distributions as values.
The second `NamedTuple` has model symbols as keys and their stored values as values.
"""
function dist_val_tuple(spl::Sampler{<:MH}, vi::AbstractVarInfo)
function dist_val_tuple(spl::Sampler{<:MH}, vi::DynamicPPL.VarInfoOrThreadSafeVarInfo)
vns = _getvns(vi, spl)
dt = _dist_tuple(spl.alg.proposals, vi, vns)
vt = _val_tuple(vi, vns)
Expand Down Expand Up @@ -349,15 +354,12 @@ function should_link(
return true
end

function maybe_link!(varinfo, sampler, proposal)
if should_link(varinfo, sampler, proposal)
link!(varinfo, sampler)
end
return nothing
function maybe_link!!(varinfo, sampler, proposal, model)
return should_link(varinfo, sampler, proposal) ? link!!(varinfo, sampler, model) : varinfo
end

# Make a proposal if we don't have a covariance proposal matrix (the default).
function propose!(
function propose!!(
rng::AbstractRNG,
vi::AbstractVarInfo,
model::Model,
Expand All @@ -378,13 +380,11 @@ function propose!(
# TODO: Make this compatible with immutable `VarInfo`.
# Update the values in the VarInfo.
set_namedtuple!(vi, trans.params)
setlogp!!(vi, trans.lp)

return vi
return setlogp!!(vi, trans.lp)
end

# Make a proposal if we DO have a covariance proposal matrix.
function propose!(
function propose!!(
rng::AbstractRNG,
vi::AbstractVarInfo,
model::Model,
Expand All @@ -403,12 +403,7 @@ function propose!(
densitymodel = AMH.DensityModel(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()))
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)

# TODO: Make this compatible with immutable `VarInfo`.
# Update the values in the VarInfo.
setindex!!(vi, trans.params, spl)
setlogp!!(vi, trans.lp)

return vi
return setlogp!!(DynamicPPL.unflatten(vi, spl, trans.params), trans.lp)
end

function DynamicPPL.initialstep(
Expand All @@ -420,7 +415,7 @@ function DynamicPPL.initialstep(
)
# If we're doing random walk with a covariance matrix,
# just link everything before sampling.
maybe_link!(vi, spl, spl.alg.proposals)
vi = maybe_link!!(vi, spl, spl.alg.proposals, model)

return Transition(vi), vi
end
Expand All @@ -435,7 +430,7 @@ function AbstractMCMC.step(
# Cases:
# 1. A covariance proposal matrix
# 2. A bunch of NamedTuples that specify the proposal space
propose!(rng, vi, model, spl, spl.alg.proposals)
vi = propose!!(rng, vi, model, spl, spl.alg.proposals)

return Transition(vi), vi
end
Expand Down
Loading