From a1d1b35153d26ee49d6c3f9b582217a1555d15d2 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 8 Nov 2023 03:06:42 -0600 Subject: [PATCH 1/5] Mark Sampling context as not needing derivatives --- Project.toml | 2 ++ src/contexts.jl | 3 +++ 2 files changed, 5 insertions(+) diff --git a/Project.toml b/Project.toml index 4f4a5ecaa..4b6233d53 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" @@ -38,6 +39,7 @@ Compat = "4" ConstructionBase = "1.5.4" Distributions = "0.25" DocStringExtensions = "0.9" +EnzymeCore = "0.6" LogDensityProblems = "2" MCMCChains = "6" MacroTools = "0.5.6" diff --git a/src/contexts.jl b/src/contexts.jl index 83da5d929..e83da50cd 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -137,6 +137,9 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte context::C end +using EnzymeCore +@inline EnzymeCore.EnzymeRules.inactive_type(::Type{T}) where {T <: SamplingContext} = true + function SamplingContext( rng::Random.AbstractRNG=Random.default_rng(), sampler::AbstractSampler=SampleFromPrior() ) From 9158c3ecc23b1f8e971ea941eb39da656919c274 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 8 Nov 2023 03:06:42 -0600 Subject: [PATCH 2/5] Mark Sampling context as not needing derivatives --- Project.toml | 4 +++- ext/DynamicPPLEnzymeCoreExt.jl | 13 +++++++++++++ src/DynamicPPL.jl | 3 +++ src/contexts.jl | 3 --- test/Project.toml | 1 + test/contexts.jl | 3 +++ 6 files changed, 23 insertions(+), 4 deletions(-) create mode 100644 ext/DynamicPPLEnzymeCoreExt.jl diff --git a/Project.toml b/Project.toml index 4b6233d53..c9ae99c7c 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,6 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" @@ -25,9 +24,11 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [weakdeps] MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" [extensions] DynamicPPLMCMCChainsExt = ["MCMCChains"] +DynamicPPLMCMCEnzymeCoreExt = ["EnzymeCore"] [compat] AbstractMCMC = "5" @@ -54,3 +55,4 @@ julia = "1.6" [extras] MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl new file mode 100644 index 000000000..c0f864009 --- /dev/null +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -0,0 +1,13 @@ +module DynamicPPLEnzymeCoreExt + +if isdefined(Base, :get_extension) + using DynamicPPL: DynamicPPL + using EnzymeCore +else + using ..DynamicPPL: DynamicPPL + using ..EnzymeCore +end + +@inline EnzymeCore.EnzymeRules.inactive_type(::Type{T}) where {T <: SamplingContext} = true + +end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b90381dea..b5c29a34e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -189,6 +189,9 @@ end @require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include( "../ext/DynamicPPLMCMCChainsExt.jl" ) + @require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include( + "../ext/DynamicPPLEnzymeCoreExt.jl" + ) end end diff --git a/src/contexts.jl b/src/contexts.jl index e83da50cd..83da5d929 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -137,9 +137,6 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte context::C end -using EnzymeCore -@inline EnzymeCore.EnzymeRules.inactive_type(::Type{T}) where {T <: SamplingContext} = true - function SamplingContext( rng::Random.AbstractRNG=Random.default_rng(), sampler::AbstractSampler=SampleFromPrior() ) diff --git a/test/Project.toml b/test/Project.toml index 16c793956..878d3c1d1 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -7,6 +7,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" diff --git a/test/contexts.jl b/test/contexts.jl index 9b0427cd0..d04aecb52 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -16,6 +16,8 @@ using DynamicPPL: hasconditioned_nested, getconditioned_nested +using EnzymeCore + # Dummy context to test nested behaviors. struct ParentContext{C<:AbstractContext} <: AbstractContext context::C @@ -252,6 +254,7 @@ end @test SamplingContext(Random.default_rng(), DefaultContext()) == context @test SamplingContext(SampleFromPrior(), DefaultContext()) == context @test SamplingContext(SampleFromPrior(), DefaultContext()) == context + @test EnzymeCore.EnzymeRules.inactive_type(typeof(context)) end @testset "FixedContext" begin From 8ac0f042cf3f8d86a9a9d3fabeba9822db47ed71 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 21 Nov 2023 09:09:44 +0100 Subject: [PATCH 3/5] Fix format --- ext/DynamicPPLEnzymeCoreExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index c0f864009..7af184bdf 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -8,6 +8,6 @@ else using ..EnzymeCore end -@inline EnzymeCore.EnzymeRules.inactive_type(::Type{T}) where {T <: SamplingContext} = true +@inline EnzymeCore.EnzymeRules.inactive_type(::Type{T}) where {T<:SamplingContext} = true end From 80a18b26a8d1b280db6735e0640f206c22eab0f6 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 21 Nov 2023 09:11:07 +0100 Subject: [PATCH 4/5] Fix Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c9ae99c7c..4a095402b 100644 --- a/Project.toml +++ b/Project.toml @@ -28,7 +28,7 @@ EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" [extensions] DynamicPPLMCMCChainsExt = ["MCMCChains"] -DynamicPPLMCMCEnzymeCoreExt = ["EnzymeCore"] +DynamicPPLEnzymeCoreExt = ["EnzymeCore"] [compat] AbstractMCMC = "5" From b3a3757dc223827c476c0b9824a518c6f46b8da4 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 21 Nov 2023 09:26:06 +0100 Subject: [PATCH 5/5] Qualify SamplingContext --- ext/DynamicPPLEnzymeCoreExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index 7af184bdf..f83d6e8f7 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -8,6 +8,6 @@ else using ..EnzymeCore end -@inline EnzymeCore.EnzymeRules.inactive_type(::Type{T}) where {T<:SamplingContext} = true +@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true end