Skip to content

Commit

Permalink
Only use stacks for tangents (#77)
Browse files Browse the repository at this point in the history
* Only use stacks for tangents

* Recomment out localised benchmarking

* Test small array union

* Fix small array union

* Print out test number in DiffTests

* Add failing unit test

* Fix bug in a quick way

* Loosen bound on new test
  • Loading branch information
willtebbutt authored Feb 9, 2024
1 parent 3bbf8fc commit 3cd13a6
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 64 deletions.
93 changes: 60 additions & 33 deletions src/interpreter/reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,65 @@ const BwdsInst = Core.OpaqueClosure{Tuple{Int}, Int}

const CoDualSlot{V} = AbstractSlot{V} where {V<:CoDual}

const CoDualStack{V} = Union{ConstSlot{V}, Stack{V}} where {V<:CoDual}

primal_eltype(::CoDualSlot{CoDual{P, T}}) where {P, T} = P

# Operations on Slots involving CoDuals
struct FwdStack{V<:CoDual, Pslot<:SlotRef, Tstack<:Stack} <: AbstractSlot{V}
primal_slot::Pslot
tangent_stack::Tstack
function FwdStack{CoDual{P, T}}() where {P, T}
Pslot = SlotRef{P}
Tstack = Stack{T}
return new{codual_type(P), Pslot, Tstack}(Pslot(), Tstack())
end
FwdStack{CoDual}() = FwdStack{CoDual{Any, Any}}()
FwdStack(::Pslot) where {P, Pslot<:SlotRef{P}} = FwdStack{codual_type(P)}()
end

function increment_tangent!(x::CoDualSlot{C}, y::CoDualSlot) where {C}
x_val = x[]
x[] = C(primal(x_val), increment!!(tangent(x_val), tangent(y[])))
Base.isassigned(x::FwdStack) = isassigned(x.tangent_stack)

const CoDualStack{V} = Union{ConstSlot{V}, FwdStack{V}} where {V<:CoDual}

function FwdStack(x::C) where {C<:CoDual}
stack = FwdStack{C}()
push!(stack, x)
return stack
end

Base.getindex(x::FwdStack{V}) where {V<:CoDual} = V(x.primal_slot[], x.tangent_stack[])

function Base.push!(x::FwdStack{V}, v::V) where {V<:CoDual}
x.primal_slot[] = primal(v)
push!(x.tangent_stack, tangent(v))
return nothing
end

make_codual_stack(x::SlotRef{P}) where {P} = FwdStack(x)
function make_codual_stack(x::ConstSlot{P}) where {P} # REVISIT THIS TO USE CONSTSLOT AGAIN
stack = FwdStack{codual_type(P)}()
push!(stack, uninit_codual(x[]))
return stack
end

# Operations on Slots involving CoDuals

function increment_tangent!(x::CoDualSlot{C}, t) where {C}
x_val = x[]
x[] = C(primal(x_val), increment!!(tangent(x_val), t))
return nothing
end

function increment_tangent!(x::FwdStack{V}, t) where {V}
x.tangent_stack[] = increment!!(x.tangent_stack[], t)
return nothing
end

## ReturnNode
function build_coinsts(node::ReturnNode, _, _rrule!!, ::Int, ::Int, ::Bool)
return build_coinsts(ReturnNode, _rrule!!.return_slot, _get_slot(node.val, _rrule!!))
end
function build_coinsts(::Type{ReturnNode}, ret_slot::SlotRef{<:CoDual}, val::CoDualStack)
fwds_inst = build_inst(ReturnNode, ret_slot, val)
bwds_inst = @opaque (j::Int) -> (increment_tangent!(val, ret_slot); return j)
bwds_inst = @opaque (j::Int) -> (increment_tangent!(val, tangent(ret_slot[])); return j)
return fwds_inst::FwdsInst, bwds_inst::BwdsInst
end

Expand Down Expand Up @@ -90,7 +124,7 @@ function build_coinsts(
return fwds_inst::FwdsInst, bwds_inst::BwdsInst
end

function transfer_tmp_value!(x::TypedPhiNode{<:Stack}, prev_blk::Int)
function transfer_tmp_value!(x::TypedPhiNode{<:FwdStack}, prev_blk::Int)
map(x.edges, x.values) do edge, v
(edge == prev_blk) && isassigned(v) && (push!(x.ret_slot, x.tmp_slot[]))
end
Expand All @@ -100,15 +134,15 @@ end
function replace_tmp_tangent_from_ret!(x::TypedPhiNode{<:CoDualSlot}, prev_blk::Int)
map(x.edges, x.values) do edge, v
if (edge == prev_blk) && isassigned(v)
replace_tangent!(x.tmp_slot, tangent(pop!(x.ret_slot)))
replace_tangent!(x.tmp_slot, pop!(x.ret_slot.tangent_stack))
end
end
return nothing
end

function increment_predecessor_from_tmp!(x::TypedPhiNode{<:CoDualSlot}, prev_blk::Int)
map(x.edges, x.values) do edge, v
(edge == prev_blk) && isassigned(v) && increment_tangent!(v, x.tmp_slot)
(edge == prev_blk) && isassigned(v) && increment_tangent!(v, tangent(x.tmp_slot[]))
end
return nothing
end
Expand All @@ -128,7 +162,7 @@ function build_coinsts(
return next_blk
end
bwds_inst = @opaque function (j::Int)
increment_tangent!(val, tangent(pop!(ret)))
increment_tangent!(val, pop!(ret.tangent_stack))
return j
end
return fwds_inst::FwdsInst, bwds_inst::BwdsInst
Expand All @@ -141,7 +175,7 @@ function build_coinsts(x::GlobalRef, _, _rrule!!, n::Int, b::Int, is_blk_end::Bo
end
function build_coinsts(::Type{GlobalRef}, x::AbstractSlot, out::CoDualStack, next_blk::Int)
fwds_inst = @opaque (p::Int) -> (push!(out, uninit_codual(x[])); return next_blk)
bwds_inst = @opaque (j::Int) -> (pop!(out); return j)
bwds_inst = @opaque (j::Int) -> (pop!(out.tangent_stack); return j)
return fwds_inst::FwdsInst, bwds_inst::BwdsInst
end

Expand All @@ -153,7 +187,7 @@ function build_coinsts(node, _, _rrule!!, n::Int, b::Int, is_blk_end::Bool)
end
function build_coinsts(::Nothing, x::ConstSlot, out::CoDualStack, next_blk::Int)
fwds_inst = @opaque (p::Int) -> (push!(out, x[]); return next_blk)
bwds_inst = @opaque (j::Int) -> (pop!(out); return j)
bwds_inst = @opaque (j::Int) -> (pop!(out.tangent_stack); return j)
return fwds_inst::FwdsInst, bwds_inst::BwdsInst
end

Expand Down Expand Up @@ -224,9 +258,6 @@ function build_coinsts(ir_inst::Expr, in_f, _rrule!!, n::Int, b::Int, is_blk_end
# Create stack for storing pullbacks.
pb_stack = build_pb_stack(__rrule!!, evaluator, arg_slots)

# Create stack for storing values.
old_vals = Stack{eltype(val_slot)}()

return build_coinsts(
Val(:call), val_slot, arg_slots, evaluator, __rrule!!, pb_stack, next_blk
)
Expand All @@ -243,27 +274,30 @@ function build_coinsts(ir_inst::Expr, in_f, _rrule!!, n::Int, b::Int, is_blk_end
end
end

function build_coinsts(::Val{:boundscheck}, out::CoDualSlot, next_blk::Int)
function build_coinsts(::Val{:boundscheck}, out::CoDualStack, next_blk::Int)
@assert primal_eltype(out) == Bool
fwds_inst::FwdsInst = @opaque (p::Int) -> (push!(out, zero_codual(true)); return next_blk)
bwds_inst::BwdsInst = @opaque (j::Int) -> (pop!(out); return j)
return fwds_inst, bwds_inst
fwds_inst = @opaque (p::Int) -> (push!(out, zero_codual(true)); return next_blk)
bwds_inst = @opaque (j::Int) -> (pop!(out.tangent_stack); return j)
return fwds_inst::FwdsInst, bwds_inst::BwdsInst
end

function replace_tangent!(x::AbstractSlot{<:CoDual{Tx, Tdx}}, new_tangent::Tdx) where {Tx, Tdx}
x_val = x[]
x[] = CoDual(primal(x_val), new_tangent)
x[] = CoDual(primal(x[]), new_tangent)
return nothing
end

function replace_tangent!(x::AbstractSlot{<:CoDual}, new_tangent)
x_val = x[]
x[] = CoDual(primal(x_val), new_tangent)
x[] = CoDual(primal(x[]), new_tangent)
return nothing
end

replace_tangent!(::ConstSlot{<:CoDual}, new_tangent) = nothing

function replace_tangent!(x::FwdStack, new_tangent)
x.tangent_stack[] = new_tangent
return nothing
end

function build_coinsts(
::Val{:call},
out::CoDualStack,
Expand All @@ -288,7 +322,7 @@ function build_coinsts(

function bwds_pass()
pb!! = pop!(pb_stack)
dout = tangent(pop!(out))
dout = pop!(out.tangent_stack)
dargs = tuple_map(set_immutable_to_zero tangent getindex, arg_slots)
new_dargs = pb!!(dout, NoTangent(), dargs...)
map(increment_tangent!, arg_slots, new_dargs[2:end])
Expand Down Expand Up @@ -338,13 +372,6 @@ end
tangent_type(::Type{<:InterpretedFunction}) = NoTangent
tangent_type(::Type{<:DelayedInterpretedFunction}) = NoTangent

make_codual_stack(::SlotRef{P}) where {P} = Stack{codual_type(P)}()
function make_codual_stack(x::ConstSlot{P}) where {P}
stack = Stack{codual_type(P)}()
push!(stack, uninit_codual(x[]))
return stack
end

function make_codual_arginfo(ai::ArgInfo{T, is_vararg}) where {T, is_vararg}
codual_arg_slots = map(make_codual_stack, ai.arg_slots)
return ArgInfo{_typeof(codual_arg_slots), is_vararg}(codual_arg_slots)
Expand Down Expand Up @@ -386,7 +413,7 @@ end
struct InterpretedFunctionRRule{sig<:Tuple, Treturn, Targ_info<:ArgInfo}
return_slot::SlotRef{Treturn}
arg_info::Targ_info
slots::Vector{Stack}
slots::Vector{CoDualStack}
fwds_instructions::Vector{FwdsInst}
bwds_instructions::Vector{BwdsInst}
n_stack::Stack{Int}
Expand Down
17 changes: 17 additions & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1364,6 +1364,12 @@ function test_multi_use_pi_node(x::Base.RefValue{Any})
return v
end

function test_union_of_arrays(x::Vector{Float64}, b::Bool)
y = randn(Xoshiro(1), Float32, 4)
z = b ? x : y
return 2z
end

sr(n) = Xoshiro(n)

function generate_test_functions()
Expand Down Expand Up @@ -1395,6 +1401,16 @@ function generate_test_functions()
(false, :allocs, nothing, phi_const_bool_tester, -5.0),
(false, :allocs, nothing, phi_node_with_undefined_value, true, 4.0),
(false, :allocs, nothing, phi_node_with_undefined_value, false, 4.0),
(
false,
:none,
nothing,
Base._unsafe_getindex,
IndexLinear(),
randn(5),
1,
Base.Slice(Base.OneTo(1)),
), # fun PhiNode example to do with not assigning values
(false, :allocs, nothing, avoid_throwing_path_tester, 5.0),
(false, :allocs, nothing, simple_foreigncall_tester, randn(5)),
(false, :none, nothing, simple_foreigncall_tester_2, randn(6), (2, 3)),
Expand All @@ -1421,6 +1437,7 @@ function generate_test_functions()
(false, :none, nothing, inferred_const_tester, Ref{Any}(nothing)),
(false, :none, (lb=1, ub=1_000), datatype_slot_tester, 1),
(false, :none, (lb=1, ub=1_000), datatype_slot_tester, 2),
(false, :none, (lb=1, ub=100_000_000), test_union_of_arrays, randn(5), true),
(
false,
:none,
Expand Down
3 changes: 2 additions & 1 deletion test/front_matter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ using Taped:
TypedPhiNode,
build_coinsts,
Stack,
_typeof
_typeof,
FwdStack

using .TestUtils:
test_rrule!!,
Expand Down
6 changes: 3 additions & 3 deletions test/integration_testing/diff_tests.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
@testset "diff_tests" begin
interp = Taped.TInterp()
@testset "$f, $(_typeof(x))" for (interface_only, f, x...) in vcat(
@testset "$f, $(_typeof(x))" for (n, (interface_only, f, x...)) in enumerate(vcat(
TestResources.DIFFTESTS_FUNCTIONS[1:31], # SKIPPING SPARSE_LDIV mat2num_4 and softmax due to `_apply_iterate` handling
TestResources.DIFFTESTS_FUNCTIONS[34:66], # SKIPPING SPARSE_LDIV
TestResources.DIFFTESTS_FUNCTIONS[68:89], # SKIPPING SPARSE_LDIV
TestResources.DIFFTESTS_FUNCTIONS[91:end], # SKIPPING SPARSE_LDIV
)
@info "$(_typeof((f, x...)))"
))
@info "$n: $(_typeof((f, x...)))"
TestUtils.test_interpreted_rrule!!(
sr(123456), f, x...;
interp, perf_flag=:none, interface_only=false, is_primitive=false,
Expand Down
Loading

0 comments on commit 3cd13a6

Please sign in to comment.