Skip to content

Commit

Permalink
Mark Sampling context as not needing derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Nov 19, 2023
1 parent a1d1b35 commit 9158c3e
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 4 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand All @@ -25,9 +24,11 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

[extensions]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLMCMCEnzymeCoreExt = ["EnzymeCore"]

[compat]
AbstractMCMC = "5"
Expand All @@ -54,3 +55,4 @@ julia = "1.6"

[extras]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
13 changes: 13 additions & 0 deletions ext/DynamicPPLEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module DynamicPPLEnzymeCoreExt

if isdefined(Base, :get_extension)
using DynamicPPL: DynamicPPL
using EnzymeCore
else
using ..DynamicPPL: DynamicPPL
using ..EnzymeCore
end

@inline EnzymeCore.EnzymeRules.inactive_type(::Type{T}) where {T <: SamplingContext} = true

Check warning on line 11 in ext/DynamicPPLEnzymeCoreExt.jl

View workflow job for this annotation

GitHub Actions / format

[JuliaFormatter] reported by reviewdog 🐶 Raw Output: ext/DynamicPPLEnzymeCoreExt.jl:11:-@inline EnzymeCore.EnzymeRules.inactive_type(::Type{T}) where {T <: SamplingContext} = true ext/DynamicPPLEnzymeCoreExt.jl:11:+@inline EnzymeCore.EnzymeRules.inactive_type(::Type{T}) where {T<:SamplingContext} = true

end
3 changes: 3 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ end
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
"../ext/DynamicPPLMCMCChainsExt.jl"
)
@require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include(
"../ext/DynamicPPLEnzymeCoreExt.jl"
)
end
end

Expand Down
3 changes: 0 additions & 3 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte
context::C
end

using EnzymeCore
@inline EnzymeCore.EnzymeRules.inactive_type(::Type{T}) where {T <: SamplingContext} = true

function SamplingContext(
rng::Random.AbstractRNG=Random.default_rng(), sampler::AbstractSampler=SampleFromPrior()
)
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Expand Down
3 changes: 3 additions & 0 deletions test/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ using DynamicPPL:
hasconditioned_nested,
getconditioned_nested

using EnzymeCore

# Dummy context to test nested behaviors.
struct ParentContext{C<:AbstractContext} <: AbstractContext
context::C
Expand Down Expand Up @@ -252,6 +254,7 @@ end
@test SamplingContext(Random.default_rng(), DefaultContext()) == context
@test SamplingContext(SampleFromPrior(), DefaultContext()) == context
@test SamplingContext(SampleFromPrior(), DefaultContext()) == context
@test EnzymeCore.EnzymeRules.inactive_type(typeof(context))
end

@testset "FixedContext" begin
Expand Down

0 comments on commit 9158c3e

Please sign in to comment.