diff --git a/Project.toml b/Project.toml index 4b6233d53..c9ae99c7c 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,6 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" @@ -25,9 +24,11 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" [extensions] DynamicPPLMCMCChainsExt = ["MCMCChains"] +DynamicPPLMCMCEnzymeCoreExt = ["EnzymeCore"] [compat] AbstractMCMC = "5" @@ -54,3 +55,4 @@ julia = "1.6" [extras] MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl new file mode 100644 index 000000000..c0f864009 --- /dev/null +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -0,0 +1,13 @@ +module DynamicPPLEnzymeCoreExt + +if isdefined(Base, :get_extension) + using DynamicPPL: DynamicPPL + using EnzymeCore +else + using ..DynamicPPL: DynamicPPL + using ..EnzymeCore +end + +@inline EnzymeCore.EnzymeRules.inactive_type(::Type{T}) where {T <: SamplingContext} = true + +end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b90381dea..b5c29a34e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -189,6 +189,9 @@ end @require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include( "../ext/DynamicPPLMCMCChainsExt.jl" ) + @require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include( + "../ext/DynamicPPLEnzymeCoreExt.jl" + ) end end diff --git a/src/contexts.jl b/src/contexts.jl index e83da50cd..83da5d929 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -137,9 +137,6 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte context::C end -using EnzymeCore -@inline EnzymeCore.EnzymeRules.inactive_type(::Type{T}) where {T <: SamplingContext} = true - function SamplingContext( rng::Random.AbstractRNG=Random.default_rng(), sampler::AbstractSampler=SampleFromPrior() ) diff --git a/test/Project.toml b/test/Project.toml index 16c793956..878d3c1d1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,6 +7,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" diff --git a/test/contexts.jl b/test/contexts.jl index 9b0427cd0..d04aecb52 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -16,6 +16,8 @@ using DynamicPPL: hasconditioned_nested, getconditioned_nested +using EnzymeCore + # Dummy context to test nested behaviors. struct ParentContext{C<:AbstractContext} <: AbstractContext context::C @@ -252,6 +254,7 @@ end @test SamplingContext(Random.default_rng(), DefaultContext()) == context @test SamplingContext(SampleFromPrior(), DefaultContext()) == context @test SamplingContext(SampleFromPrior(), DefaultContext()) == context + @test EnzymeCore.EnzymeRules.inactive_type(typeof(context)) end @testset "FixedContext" begin