Skip to content

Commit

Permalink
Merge branch 'main' into wct/cache-rules
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Feb 17, 2024
2 parents d73da17 + aea04bf commit 7c2e51a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/interpreter/reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ function build_coinsts(ir_inst::Expr, P, in_f, _rrule!!, n::Int, b::Int, is_blk_
arg_slots = map(arg -> _get_slot(arg, _rrule!!), (__args..., ))

# Construct signature, and determine how the rrule is to be computed.
primal_sig = _typeof(map(primal get_codual, arg_slots))
primal_sig = Tuple{map(arg -> eltype(_get_slot(arg, in_f)), (__args..., ))...}
evaluator = get_evaluator(in_f.ctx, primal_sig, in_f.interp, is_invoke)
__rrule!! = get_rrule!!_evaluator(evaluator)

Expand Down Expand Up @@ -320,7 +320,7 @@ function rrule!!(_f::CoDual{<:DelayedInterpretedFunction{C, F}}, args::CoDual...
f = primal(_f)
s = _typeof(map(primal, args))
if is_primitive(C, s)
return rrule!!(zero_codual(f.f), args...)
return rrule!!(zero_codual(_eval), args...)
else
in_f = InterpretedFunction(f.ctx, s, f.interp)
return build_rrule!!(in_f)(zero_codual(in_f), args...)
Expand Down
25 changes: 25 additions & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1376,6 +1376,30 @@ function test_union_of_types(x::Ref{Union{Type{Float64}, Type{Int}}})
return x[]
end

# Only one of these is a primitive. Lots of methods to prevent the compiler from
# over-specialising.
@noinline edge_case_tester(x::Float64) = 5x
@noinline edge_case_tester(x::Any) = 5.0
@noinline edge_case_tester(x::Float32) = 6.0
@noinline edge_case_tester(x::Int) = 10
@noinline edge_case_tester(x::String) = "hi"
@is_primitive MinimalCtx Tuple{typeof(edge_case_tester), Float64}
function Taped.rrule!!(::CoDual{typeof(edge_case_tester)}, x::CoDual{Float64})
edge_case_tester_pb!!(dy, df, dx) = df, dx + 5 * dy
return CoDual(5 * primal(x), 0.0), edge_case_tester_pb!!
end

# To test the edge case properly, call this with x = Any[5.0, false]
function test_primitive_dynamic_dispatch(x::Vector{Any})
i = 0
y = 0.0
while i < 2
i += 1
y += edge_case_tester(x[i])
end
return y
end

sr(n) = Xoshiro(n)

function generate_test_functions()
Expand All @@ -1402,6 +1426,7 @@ function generate_test_functions()
(false, :none, nothing, type_unstable_tester, Ref{Any}(5.0)),
(false, :none, nothing, type_unstable_tester_2, Ref{Real}(5.0)),
(false, :none, (lb=1, ub=1000), type_unstable_tester_3, Ref{Any}(5.0)),
(false, :none, (lb=1, ub=1000), test_primitive_dynamic_dispatch, Any[5.0, false]),
(false, :none, nothing, type_unstable_function_eval, Ref{Any}(sin), 5.0),
(false, :allocs, nothing, phi_const_bool_tester, 5.0),
(false, :allocs, nothing, phi_const_bool_tester, -5.0),
Expand Down

0 comments on commit 7c2e51a

Please sign in to comment.