From eadb803472adc60665af7e3d2e1e4b42a557516e Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Wed, 21 Feb 2024 19:23:42 +0000 Subject: [PATCH 1/3] Sort out small unions --- src/codual.jl | 1 + src/interpreter/reverse_mode_ad.jl | 26 +++++++++++++++++--------- src/stack.jl | 1 + src/tangents.jl | 3 +-- src/test_utils.jl | 10 ++++++++++ test/codual.jl | 4 ++++ 6 files changed, 34 insertions(+), 11 deletions(-) diff --git a/src/codual.jl b/src/codual.jl index e9ea687f5..8114350bf 100644 --- a/src/codual.jl +++ b/src/codual.jl @@ -33,6 +33,7 @@ Shorthand for `CoDual{P, tangent_type(P}}` when `P` is concrete, equal to `CoDua """ function codual_type(::Type{P}) where {P} P == DataType && return CoDual + P isa Union && return Union{codual_type(P.a), codual_type(P.b)} return isconcretetype(P) ? CoDual{P, tangent_type(P)} : CoDual end diff --git a/src/interpreter/reverse_mode_ad.jl b/src/interpreter/reverse_mode_ad.jl index 213885563..310e65956 100644 --- a/src/interpreter/reverse_mode_ad.jl +++ b/src/interpreter/reverse_mode_ad.jl @@ -11,14 +11,21 @@ const BwdsInst = Core.OpaqueClosure{Tuple{Int}, Int} const RuleSlot{V} = Union{SlotRef{V}, ConstSlot{V}} where {V<:Tuple{CoDual, Ref}} -primal_type(::AbstractSlot{<:Tuple{<:CoDual{P}, <:Any}}) where {P} = @isdefined(P) ? P : Any -primal_type(::AbstractSlot{<:Tuple{<:CoDual, <:Any}}) = Any +__primal_type(::Type{<:Tuple{<:CoDual{P}, <:Any}}) where {P} = @isdefined(P) ? P : Any +function __primal_type(::Type{P}) where {P<:Tuple{<:CoDual, <:Any}} + P isa Union && return Union{__primal_type(P.a), __primal_type(P.b)} + return Any +end -function make_rule_slot(::SlotRef{P}, ::Any) where {P} - return SlotRef{Tuple{codual_type(P), tangent_ref_type_ub(P)}}() +primal_type(::AbstractSlot{P}) where {P} = __primal_type(P) + +function rule_slot_type(::Type{P}) where {P} + P isa Union && return Union{rule_slot_type(P.a), rule_slot_type(P.b)} + return Tuple{codual_type(P), tangent_ref_type_ub(P)} end -function make_rule_slot(::SlotRef{P}, ::PhiNode) where {P} - return SlotRef{Tuple{codual_type(P), tangent_ref_type_ub(P)}}() + +function make_rule_slot(::SlotRef{P}, ::Any) where {P} + return SlotRef{rule_slot_type(P)}() end function make_rule_slot(x::ConstSlot{P}, ::Any) where {P} cd = uninit_codual(x[]) @@ -112,17 +119,18 @@ function build_coinsts( tangent_stack_stack = make_tangent_ref_stack(tangent_ref_type_ub(primal_type(val))) make_fwds(v) = R(primal(v), tangent(v)) - fwds_inst = @opaque function (p::Int) + function fwds_run() v, tangent_stack = val[] push!(my_tangent_stack, tangent(v)) push!(tangent_stack_stack, tangent_stack) ret[] = (make_fwds(v), top_ref(my_tangent_stack)) return next_blk end - bwds_inst = @opaque function (j::Int) + fwds_inst = @opaque (p::Int) -> fwds_run() + function bwds_run() increment_ref!(pop!(tangent_stack_stack), pop!(my_tangent_stack)) - return j end + bwds_inst = @opaque (j::Int) -> (bwds_run(); return j) return fwds_inst::FwdsInst, bwds_inst::BwdsInst end diff --git a/src/stack.jl b/src/stack.jl index 7e2737a80..77f0e7a2c 100644 --- a/src/stack.jl +++ b/src/stack.jl @@ -113,6 +113,7 @@ __array_ref_type(::Type{P}) where {P} = Base.RefArray{P, Vector{P}, Nothing} function tangent_ref_type_ub(::Type{P}) where {P} P === DataType && return Ref T = tangent_type(P) + T isa Union && return Union{tangent_ref_type_ub(T.a), tangent_ref_type_ub(T.b)} T === NoTangent && return NoTangentRef return isconcretetype(P) ? __array_ref_type(T) : Ref end diff --git a/src/tangents.jl b/src/tangents.jl index d4e57507a..8fef76147 100644 --- a/src/tangents.jl +++ b/src/tangents.jl @@ -212,8 +212,7 @@ end )) # If the type is a Union, then take the union type of its arguments. - # P isa Union && return Union{tangent_type(P.a), tangent_type(P.b)} - P isa Union && return Any + P isa Union && return Union{tangent_type(P.a), tangent_type(P.b)} # If the type is itself abstract, it's tangent could be anything. # The same goes for if the type has any undetermined type parameters. diff --git a/src/test_utils.jl b/src/test_utils.jl index 689ac61d5..51a119c48 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -1406,6 +1406,15 @@ sr(n) = Xoshiro(n) return a < b ? a * b : test_self_reference(b, a) end +# Copied over from https://github.com/TuringLang/Turing.jl/issues/1140 +function _sum(x) + z = 0 + for i in eachindex(x) + z += x[i] + end + return z +end + function generate_test_functions() return Any[ (false, :allocs, nothing, const_tester), @@ -1580,6 +1589,7 @@ function generate_test_functions() randn(sr(2), 700, 500), randn(sr(3), 300, 700), ), + (false, :none, nothing, _sum, randn(1024)), ] end diff --git a/test/codual.jl b/test/codual.jl index 21411666e..1f6a1afcf 100644 --- a/test/codual.jl +++ b/test/codual.jl @@ -8,4 +8,8 @@ @test codual_type(Real) == CoDual @test codual_type(Any) == CoDual @test codual_type(Type{UnitRange{Int}}) == CoDual{Type{UnitRange{Int}}, NoTangent} + @test(==( + codual_type(Union{Float64, Int}), + Union{CoDual{Float64, Float64}, CoDual{Int, NoTangent}}, + )) end From 2aeeea87fd0785d2f462ba1bfac001e360859e72 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 22 Feb 2024 17:11:30 +0000 Subject: [PATCH 2/3] Make array example work --- src/interpreter/reverse_mode_ad.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/interpreter/reverse_mode_ad.jl b/src/interpreter/reverse_mode_ad.jl index 4b3f25ee0..1cfc53583 100644 --- a/src/interpreter/reverse_mode_ad.jl +++ b/src/interpreter/reverse_mode_ad.jl @@ -20,7 +20,6 @@ end primal_type(::AbstractSlot{P}) where {P} = __primal_type(P) function rule_slot_type(::Type{P}) where {P} - P isa Union && return Union{rule_slot_type(P.a), rule_slot_type(P.b)} return Tuple{codual_type(P), tangent_ref_type_ub(P)} end From a871dc9e817c26e2db5df0c18ed61cf43882a32c Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 22 Feb 2024 17:47:41 +0000 Subject: [PATCH 3/3] Loosen bound further --- src/stack.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/stack.jl b/src/stack.jl index 77f0e7a2c..7e2737a80 100644 --- a/src/stack.jl +++ b/src/stack.jl @@ -113,7 +113,6 @@ __array_ref_type(::Type{P}) where {P} = Base.RefArray{P, Vector{P}, Nothing} function tangent_ref_type_ub(::Type{P}) where {P} P === DataType && return Ref T = tangent_type(P) - T isa Union && return Union{tangent_ref_type_ub(T.a), tangent_ref_type_ub(T.b)} T === NoTangent && return NoTangentRef return isconcretetype(P) ? __array_ref_type(T) : Ref end