Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache Rules in AbstractInterpreter #80

Merged
merged 4 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions src/interpreter/reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,9 @@ end

function build_rrule!!(in_f::InterpretedFunction{sig}) where {sig}

# If we've already constructed this interpreted function, just return it.
sig in keys(in_f.interp.in_f_rrule_cache) && return in_f.interp.in_f_rrule_cache[sig]

return_slot = SlotRef{codual_type(eltype(in_f.return_slot))}()
return_tangent_slot = SlotRef{tangent_type(eltype(in_f.return_slot))}()
arg_info = make_codual_arginfo(in_f.arg_info)
Expand Down Expand Up @@ -472,16 +475,19 @@ function build_rrule!!(in_f::InterpretedFunction{sig}) where {sig}
# Set PhiNodes.
make_phi_instructions!(in_f, __rrule!!)

in_f.interp.in_f_rrule_cache[sig] = __rrule!!

return __rrule!!
end

struct InterpretedFunctionPb{Tret_tangent<:SlotRef, Targ_info, Tbwds_f, V}
struct InterpretedFunctionPb{Tret_tangent<:SlotRef, Targ_info, Tbwds_f, V, Q}
j::Int
bwds_instructions::Tbwds_f
ret_tangent::Tret_tangent
n_stack::Stack{Int}
arg_info::Targ_info
arg_tangent_stacks::V
arg_tangent_stack_refs::Q
end

function (in_f_rrule!!::InterpretedFunctionRRule{sig})(
Expand All @@ -503,15 +509,18 @@ function (in_f_rrule!!::InterpretedFunctionRRule{sig})(
n = 1
j = length(n_stack)

# Get references to top of tangent stacks for use on reverse-pass.
arg_tangent_stack_refs = map(top_ref, arg_tangent_stacks)

# Run instructions until done.
while next_block != -1
push!(n_stack, n)
if !isassigned(in_f_rrule!!.fwds_instructions, n)
fwds, bwds = generate_coinstructions(in_f, in_f_rrule!!, n)
in_f_rrule!!.fwds_instructions[n] = fwds
in_f_rrule!!.bwds_instructions[n] = bwds
end
next_block = in_f_rrule!!.fwds_instructions[n](prev_block)
push!(n_stack, n)
if next_block == 0
n += 1
elseif next_block > 0
Expand All @@ -530,6 +539,7 @@ function (in_f_rrule!!::InterpretedFunctionRRule{sig})(
n_stack,
arg_info,
arg_tangent_stacks,
arg_tangent_stack_refs,
)
return return_val, interpreted_function_pb!!
end
Expand All @@ -538,8 +548,8 @@ function (if_pb!!::InterpretedFunctionPb)(dout, ::NoTangent, dargs::Vararg{Any,

# Update the output cotangent value to whatever is provided.
if_pb!!.ret_tangent[] = dout
tangent_stacks = if_pb!!.arg_tangent_stacks
set_tangent_stacks!(tangent_stacks, dargs, if_pb!!.arg_info)
tangent_stack_refs = if_pb!!.arg_tangent_stack_refs # this can go when we refactor
set_tangent_stacks!(tangent_stack_refs, dargs, if_pb!!.arg_info)

# Run the instructions in reverse. Present assumes linear instruction ordering.
n_stack = if_pb!!.n_stack
Expand All @@ -550,7 +560,7 @@ function (if_pb!!::InterpretedFunctionPb)(dout, ::NoTangent, dargs::Vararg{Any,
end

# Return resulting tangents from slots.
return NoTangent(), assemble_dout(tangent_stacks, if_pb!!.arg_info)...
return NoTangent(), assemble_dout(if_pb!!.arg_tangent_stacks, if_pb!!.arg_info)...
end

function set_tangent_stacks!(tangent_stacks, dargs, ai::ArgInfo{<:Any, is_va}) where {is_va}
Expand Down
6 changes: 6 additions & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,10 @@ end

sr(n) = Xoshiro(n)

@noinline function test_self_reference(a, b)
return a < b ? a * b : test_self_reference(b, a)
end

function generate_test_functions()
return Any[
(false, :allocs, nothing, const_tester),
Expand Down Expand Up @@ -1477,6 +1481,8 @@ function generate_test_functions()
test_union_of_types,
Ref{Union{Type{Float64}, Type{Int}}}(Float64),
),
(false, :allocs, nothing, test_self_reference, 1.1, 1.5),
(false, :allocs, nothing, test_self_reference, 1.5, 1.1),
(
false,
:none,
Expand Down
Loading