Skip to content

Commit

Permalink
Unify transition also in external samplers (#2030)
Browse files Browse the repository at this point in the history
* Transition

* Revert "Transition"

This reverts commit 71c8097.

* bug

* repeated functions

* move Transition to inference

* default get_stat

* bug

* Update src/inference/Inference.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/inference/Inference.jl

Co-authored-by: David Widmann <[email protected]>

* Update src/inference/Inference.jl

Co-authored-by: David Widmann <[email protected]>

* rest of david changes

* bring back Transition(a,b)

---------

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
JaimeRZP and devmotion authored Jul 10, 2023
1 parent e41f58c commit a1bf714
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 22 deletions.
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
[compat]
AbstractMCMC = "4"
AdvancedHMC = "0.3.0, 0.4"
AdvancedMH = "0.6.8, 0.7"
AdvancedPS = "0.4"
AdvancedVI = "0.2"
BangBang = "0.3"
Expand Down
17 changes: 1 addition & 16 deletions src/contrib/inference/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,13 @@ struct TuringState{S,F}
logdensity::F
end

struct TuringTransition{T,NT<:NamedTuple,F<:AbstractFloat}
θ::T
lp::F
stat::NT
end

function TuringTransition(vi::AbstractVarInfo, t)
theta = tonamedtuple(vi)
lp = getlogp(vi)
return TuringTransition(theta, lp, getstats(t))
end

metadata(t::TuringTransition) = merge((lp = t.lp,), t.stat)
DynamicPPL.getlogp(t::TuringTransition) = t.lp

state_to_turing(f::DynamicPPL.LogDensityFunction, state) = TuringState(state, f)
function transition_to_turing(f::DynamicPPL.LogDensityFunction, transition)
θ = getparams(transition)
varinfo = DynamicPPL.unflatten(f.varinfo, θ)
# TODO: `deepcopy` is overkill; make more efficient.
varinfo = DynamicPPL.invlink!!(deepcopy(varinfo), f.model)
return TuringTransition(varinfo, transition)
return Transition(varinfo, transition)
end

# NOTE: Only thing that depends on the underlying sampler.
Expand Down
11 changes: 6 additions & 5 deletions src/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ end
######################
# Default Transition #
######################
# Default
# Extended in contrib/inference/abstractmcmc.jl
getstats(t) = nothing

struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}}
θ :: T
Expand All @@ -132,10 +135,10 @@ end

Transition(θ, lp) = Transition(θ, lp, nothing)

function Transition(vi::AbstractVarInfo; nt::NamedTuple=NamedTuple())
function Transition(vi::AbstractVarInfo, t=nothing; nt::NamedTuple=NamedTuple())
θ = merge(tonamedtuple(vi), nt)
lp = getlogp(vi)
return Transition(θ, lp, nothing)
return Transition(θ, lp, getstats(t))
end

function metadata(t::Transition)
Expand Down Expand Up @@ -664,9 +667,7 @@ function transitions_from_chain(
model(rng, vi, sampler)

# Convert `VarInfo` into `NamedTuple` and save.
theta = DynamicPPL.tonamedtuple(vi)
lp = Turing.getlogp(vi)
Transition(theta, lp)
Transition(vi)
end

return transitions
Expand Down

0 comments on commit a1bf714

Please sign in to comment.