diff --git a/Project.toml b/Project.toml index 4f4a5ecaa..4a095402b 100644 --- a/Project.toml +++ b/Project.toml @@ -24,9 +24,11 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" [extensions] DynamicPPLMCMCChainsExt = ["MCMCChains"] +DynamicPPLEnzymeCoreExt = ["EnzymeCore"] [compat] AbstractMCMC = "5" @@ -38,6 +40,7 @@ Compat = "4" ConstructionBase = "1.5.4" Distributions = "0.25" DocStringExtensions = "0.9" +EnzymeCore = "0.6" LogDensityProblems = "2" MCMCChains = "6" MacroTools = "0.5.6" @@ -52,3 +55,4 @@ julia = "1.6" [extras] MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl new file mode 100644 index 000000000..f83d6e8f7 --- /dev/null +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -0,0 +1,13 @@ +module DynamicPPLEnzymeCoreExt + +if isdefined(Base, :get_extension) + using DynamicPPL: DynamicPPL + using EnzymeCore +else + using ..DynamicPPL: DynamicPPL + using ..EnzymeCore +end + +@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true + +end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b90381dea..b5c29a34e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -189,6 +189,9 @@ end @require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include( "../ext/DynamicPPLMCMCChainsExt.jl" ) + @require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include( + "../ext/DynamicPPLEnzymeCoreExt.jl" + ) end end diff --git a/test/Project.toml b/test/Project.toml index 16c793956..878d3c1d1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,6 +7,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" diff --git a/test/contexts.jl b/test/contexts.jl index 9b0427cd0..d04aecb52 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -16,6 +16,8 @@ using DynamicPPL: hasconditioned_nested, getconditioned_nested +using EnzymeCore + # Dummy context to test nested behaviors. struct ParentContext{C<:AbstractContext} <: AbstractContext context::C @@ -252,6 +254,7 @@ end @test SamplingContext(Random.default_rng(), DefaultContext()) == context @test SamplingContext(SampleFromPrior(), DefaultContext()) == context @test SamplingContext(SampleFromPrior(), DefaultContext()) == context + @test EnzymeCore.EnzymeRules.inactive_type(typeof(context)) end @testset "FixedContext" begin