From a1bf71461aa687cc3ba4d97116db7e1ed161108e Mon Sep 17 00:00:00 2001 From: Jaime RZ Date: Mon, 10 Jul 2023 14:41:32 +0100 Subject: [PATCH] Unify transition also in external samplers (#2030) * Transition * Revert "Transition" This reverts commit 71c809700ca896e5158aa1f1c143dd8d6a5e348a. * bug * repeated functions * move Transition to inference * default get_stat * bug * Update src/inference/Inference.jl Co-authored-by: David Widmann * Update src/inference/Inference.jl Co-authored-by: David Widmann * Update src/inference/Inference.jl Co-authored-by: David Widmann * rest of david changes * bring back Transition(a,b) --------- Co-authored-by: David Widmann --- Project.toml | 1 - src/contrib/inference/abstractmcmc.jl | 17 +---------------- src/inference/Inference.jl | 11 ++++++----- 3 files changed, 7 insertions(+), 22 deletions(-) diff --git a/Project.toml b/Project.toml index 198739423..95fa40a22 100644 --- a/Project.toml +++ b/Project.toml @@ -39,7 +39,6 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [compat] AbstractMCMC = "4" AdvancedHMC = "0.3.0, 0.4" -AdvancedMH = "0.6.8, 0.7" AdvancedPS = "0.4" AdvancedVI = "0.2" BangBang = "0.3" diff --git a/src/contrib/inference/abstractmcmc.jl b/src/contrib/inference/abstractmcmc.jl index d0f9c9038..19411ac99 100644 --- a/src/contrib/inference/abstractmcmc.jl +++ b/src/contrib/inference/abstractmcmc.jl @@ -3,28 +3,13 @@ struct TuringState{S,F} logdensity::F end -struct TuringTransition{T,NT<:NamedTuple,F<:AbstractFloat} - θ::T - lp::F - stat::NT -end - -function TuringTransition(vi::AbstractVarInfo, t) - theta = tonamedtuple(vi) - lp = getlogp(vi) - return TuringTransition(theta, lp, getstats(t)) -end - -metadata(t::TuringTransition) = merge((lp = t.lp,), t.stat) -DynamicPPL.getlogp(t::TuringTransition) = t.lp - state_to_turing(f::DynamicPPL.LogDensityFunction, state) = TuringState(state, f) function transition_to_turing(f::DynamicPPL.LogDensityFunction, transition) θ = getparams(transition) varinfo = DynamicPPL.unflatten(f.varinfo, θ) # TODO: `deepcopy` is overkill; make more efficient. varinfo = DynamicPPL.invlink!!(deepcopy(varinfo), f.model) - return TuringTransition(varinfo, transition) + return Transition(varinfo, transition) end # NOTE: Only thing that depends on the underlying sampler. diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 4e1f35254..73f190dcc 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -123,6 +123,9 @@ end ###################### # Default Transition # ###################### +# Default +# Extended in contrib/inference/abstractmcmc.jl +getstats(t) = nothing struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}} θ :: T @@ -132,10 +135,10 @@ end Transition(θ, lp) = Transition(θ, lp, nothing) -function Transition(vi::AbstractVarInfo; nt::NamedTuple=NamedTuple()) +function Transition(vi::AbstractVarInfo, t=nothing; nt::NamedTuple=NamedTuple()) θ = merge(tonamedtuple(vi), nt) lp = getlogp(vi) - return Transition(θ, lp, nothing) + return Transition(θ, lp, getstats(t)) end function metadata(t::Transition) @@ -664,9 +667,7 @@ function transitions_from_chain( model(rng, vi, sampler) # Convert `VarInfo` into `NamedTuple` and save. - theta = DynamicPPL.tonamedtuple(vi) - lp = Turing.getlogp(vi) - Transition(theta, lp) + Transition(vi) end return transitions