Skip to content

Commit

Permalink
Stacks Not Slots (#76)
Browse files Browse the repository at this point in the history
* Refactor eltype for slots slightly

* Add setindex and getindex to stack

* Move to stacks in situe

* Tidy up a bit
  • Loading branch information
willtebbutt authored Feb 8, 2024
1 parent 0927aec commit 3bbf8fc
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 113 deletions.
11 changes: 3 additions & 8 deletions src/interpreter/interpreted_function.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# Special types to represent data in an IRCode and a InterpretedFunction.

abstract type AbstractSlot{T} end

"""
SlotRef{T}()
Expand All @@ -26,7 +24,6 @@ end
Base.getindex(x::SlotRef) = getfield(x, :x)
Base.setindex!(x::SlotRef, val) = setfield!(x, :x, val)
Base.isassigned(x::SlotRef) = isdefined(x, :x)
Base.eltype(::SlotRef{T}) where {T} = T
Base.copy(x::SlotRef{T}) where {T} = isassigned(x) ? SlotRef{T}(x[]) : SlotRef{T}()

"""
Expand All @@ -44,7 +41,6 @@ end
Base.getindex(x::ConstSlot) = getfield(x, :x)
Base.setindex!(::ConstSlot, val) = nothing
Base.isassigned(::ConstSlot) = true
Base.eltype(::ConstSlot{T}) where {T} = T
Base.copy(x::ConstSlot{T}) where {T} = ConstSlot{T}(x[])

"""
Expand All @@ -68,7 +64,6 @@ TypedGlobalRef(mod::Module, name::Symbol) = TypedGlobalRef(GlobalRef(mod, name))
Base.getindex(x::TypedGlobalRef{T}) where {T} = getglobal(x.mod, x.name)::T
Base.setindex!(x::TypedGlobalRef, val) = setglobal!(x.mod, x.name, val)
Base.isassigned(::TypedGlobalRef) = true
Base.eltype(::TypedGlobalRef{T}) where {T} = T

#=
Returns either a `ConstSlot` or a `TypedGlobalRef`, both of which are `AbstractSlot`s.
Expand Down Expand Up @@ -121,8 +116,8 @@ end

## PhiNode

struct TypedPhiNode{Tr<:AbstractSlot, Te<:Tuple, Tv<:Tuple}
tmp_slot::Tr
struct TypedPhiNode{Tr<:AbstractSlot, Tt<:AbstractSlot, Te<:Tuple, Tv<:Tuple}
tmp_slot::Tt
ret_slot::Tr
edges::Te
values::Tv
Expand Down Expand Up @@ -153,7 +148,7 @@ function build_typed_phi_nodes(ir_insts::Vector{PhiNode}, in_f, n_first::Int)
end
T = eltype(ret_slot)
values_vec = map(n -> _init[n] isa UndefRef ? SlotRef{T}() : _init[n], eachindex(_init))
return TypedPhiNode(copy(ret_slot), ret_slot, edges, (values_vec..., ))
return TypedPhiNode(SlotRef{T}(), ret_slot, edges, (values_vec..., ))
end
end

Expand Down
137 changes: 63 additions & 74 deletions src/interpreter/reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,19 @@ const BwdsInst = Core.OpaqueClosure{Tuple{Int}, Int}

const CoDualSlot{V} = AbstractSlot{V} where {V<:CoDual}

const CoDualStack{V} = Union{ConstSlot{V}, Stack{V}} where {V<:CoDual}

primal_eltype(::CoDualSlot{CoDual{P, T}}) where {P, T} = P

# Operations on Slots involving CoDuals

function increment_tangent!(x::SlotRef{C}, y::CoDualSlot) where {C}
function increment_tangent!(x::CoDualSlot{C}, y::CoDualSlot) where {C}
x_val = x[]
x[] = C(primal(x_val), increment!!(tangent(x_val), tangent(y[])))
return nothing
end

increment_tangent!(x::ConstSlot{<:CoDual}, new_tangent) = nothing

# Non slot version of increment_tangent!
function increment_tangent!(x::SlotRef{C}, t) where {C}
function increment_tangent!(x::CoDualSlot{C}, t) where {C}
x_val = x[]
x[] = C(primal(x_val), increment!!(tangent(x_val), t))
return nothing
Expand All @@ -34,7 +33,7 @@ end
function build_coinsts(node::ReturnNode, _, _rrule!!, ::Int, ::Int, ::Bool)
return build_coinsts(ReturnNode, _rrule!!.return_slot, _get_slot(node.val, _rrule!!))
end
function build_coinsts(::Type{ReturnNode}, ret_slot::SlotRef{<:CoDual}, val::CoDualSlot)
function build_coinsts(::Type{ReturnNode}, ret_slot::SlotRef{<:CoDual}, val::CoDualStack)
fwds_inst = build_inst(ReturnNode, ret_slot, val)
bwds_inst = @opaque (j::Int) -> (increment_tangent!(val, ret_slot); return j)
return fwds_inst::FwdsInst, bwds_inst::BwdsInst
Expand All @@ -50,7 +49,7 @@ end
function build_coinsts(x::GotoIfNot, _, _rrule!!, ::Int, b::Int, is_blk_end::Bool)
return build_coinsts(GotoIfNot, x.dest, b + 1, _get_slot(x.cond, _rrule!!))
end
function build_coinsts(::Type{GotoIfNot}, dest::Int, next_blk::Int, cond::CoDualSlot)
function build_coinsts(::Type{GotoIfNot}, dest::Int, next_blk::Int, cond::CoDualStack)
fwds_inst = @opaque (p::Int) -> primal(cond[]) ? next_blk : dest
bwds_inst = @opaque (j::Int) -> j
return fwds_inst::FwdsInst, bwds_inst::BwdsInst
Expand All @@ -61,20 +60,13 @@ end
function build_coinsts(ir_insts::Vector{PhiNode}, _, _rrule!!, n_first::Int, b::Int, is_blk_end::Bool)
nodes = (build_typed_phi_nodes(ir_insts, _rrule!!, n_first)..., )
next_blk = _standard_next_block(is_blk_end, b)
return build_coinsts(Vector{PhiNode}, nodes, next_blk, make_stacks(nodes)...)
end

function make_stacks(nodes::NTuple{N, TypedPhiNode}) where {N}
ret_stacks = map(n -> Stack{eltype(n.ret_slot)}(), nodes)
prev_blks = Stack{Int}()
return ret_stacks, prev_blks
return build_coinsts(Vector{PhiNode}, nodes, next_blk, Stack{Int}())
end

function build_coinsts(
::Type{Vector{PhiNode}},
nodes::NTuple{N, TypedPhiNode},
next_blk::Int,
ret_stacks,
prev_stack::Stack{Int},
) where {N}

Expand All @@ -86,27 +78,31 @@ function build_coinsts(
fwds_inst = @opaque function (p::Int)
push!(prev_stack, p) # record the preceding block
map(Base.Fix2(store_tmp_value!, p), nodes) # transfer new value into tmp slots
map(log_ret_value!, nodes, ret_stacks) # log old value
map(transfer_tmp_value!, nodes) # transfer new value from tmp slots into ret slots
map(Base.Fix2(transfer_tmp_value!, p), nodes) # transfer new value from tmp slots into ret slots
return next_blk
end
bwds_inst = @opaque function (j::Int)
p = pop!(prev_stack) # get the index of the previous block
map(replace_tmp_tangent_from_ret!, nodes) # transfer data from ret slots to tmp
map((n, r) -> (!isempty(r)) && (n.ret_slot[] = pop!(r)), nodes, ret_stacks) # restore ret slots to previous state
map(Base.Fix2(increment_predecessor_from_tmp!, p), nodes)
map(Base.Fix2(replace_tmp_tangent_from_ret!, p), nodes) # transfer data from ret slots to tmp
map(Base.Fix2(increment_predecessor_from_tmp!, p), nodes) # inc tangents
return j
end
return fwds_inst::FwdsInst, bwds_inst::BwdsInst
end

function log_ret_value!(x::TypedPhiNode{<:CoDualSlot}, ret_stack)
isassigned(x.ret_slot) && push!(ret_stack, x.ret_slot[])
function transfer_tmp_value!(x::TypedPhiNode{<:Stack}, prev_blk::Int)
map(x.edges, x.values) do edge, v
(edge == prev_blk) && isassigned(v) && (push!(x.ret_slot, x.tmp_slot[]))
end
return nothing
end

function replace_tmp_tangent_from_ret!(x::TypedPhiNode{<:CoDualSlot})
isassigned(x.ret_slot) && replace_tangent!(x.tmp_slot, tangent(x.ret_slot[]))
function replace_tmp_tangent_from_ret!(x::TypedPhiNode{<:CoDualSlot}, prev_blk::Int)
map(x.edges, x.values) do edge, v
if (edge == prev_blk) && isassigned(v)
replace_tangent!(x.tmp_slot, tangent(pop!(x.ret_slot)))
end
end
return nothing
end

Expand All @@ -121,23 +117,18 @@ end
function build_coinsts(x::PiNode, _, _rrule!!, n::Int, b::Int, is_blk_end::Bool)
val = _get_slot(x.val, _rrule!!)
ret = _rrule!!.slots[n]
old_vals = Stack{eltype(ret)}()
return build_coinsts(PiNode, val, ret, old_vals, _standard_next_block(is_blk_end, b))
return build_coinsts(PiNode, val, ret, _standard_next_block(is_blk_end, b))
end
function build_coinsts(
::Type{PiNode}, val::CoDualSlot{V}, ret::CoDualSlot{R}, old_vals::Stack, next_blk::Int,
) where {V, R}
::Type{PiNode}, val::CoDualStack, ret::CoDualStack{R}, next_blk::Int
) where {R}
make_fwds(v) = R(primal(v), tangent(v))
fwds_inst = @opaque function (p::Int)
isassigned(ret) && push!(old_vals, ret[])
ret[] = make_fwds(val[])
push!(ret, make_fwds(val[]))
return next_blk
end
bwds_inst = @opaque function (j::Int)
increment_tangent!(val, ret)
if !isempty(old_vals)
ret[] = pop!(old_vals)
end
increment_tangent!(val, tangent(pop!(ret)))
return j
end
return fwds_inst::FwdsInst, bwds_inst::BwdsInst
Expand All @@ -148,9 +139,9 @@ function build_coinsts(x::GlobalRef, _, _rrule!!, n::Int, b::Int, is_blk_end::Bo
next_blk = _standard_next_block(is_blk_end, b)
return build_coinsts(GlobalRef, _globalref_to_slot(x), _rrule!!.slots[n], next_blk)
end
function build_coinsts(::Type{GlobalRef}, x::AbstractSlot, out::CoDualSlot, next_blk::Int)
fwds_inst = @opaque (p::Int) -> (out[] = uninit_codual(x[]); return next_blk)
bwds_inst = @opaque (j::Int) -> j
function build_coinsts(::Type{GlobalRef}, x::AbstractSlot, out::CoDualStack, next_blk::Int)
fwds_inst = @opaque (p::Int) -> (push!(out, uninit_codual(x[])); return next_blk)
bwds_inst = @opaque (j::Int) -> (pop!(out); return j)
return fwds_inst::FwdsInst, bwds_inst::BwdsInst
end

Expand All @@ -160,9 +151,9 @@ function build_coinsts(node, _, _rrule!!, n::Int, b::Int, is_blk_end::Bool)
next_blk = _standard_next_block(is_blk_end, b)
return build_coinsts(nothing, x, _rrule!!.slots[n], next_blk)
end
function build_coinsts(::Nothing, x::ConstSlot{<:CoDual}, out::CoDualSlot, next_blk::Int)
fwds_inst = @opaque (p::Int) -> (out[] = x[]; return next_blk)
bwds_inst = @opaque (j::Int) -> j
function build_coinsts(::Nothing, x::ConstSlot, out::CoDualStack, next_blk::Int)
fwds_inst = @opaque (p::Int) -> (push!(out, x[]); return next_blk)
bwds_inst = @opaque (j::Int) -> (pop!(out); return j)
return fwds_inst::FwdsInst, bwds_inst::BwdsInst
end

Expand Down Expand Up @@ -237,7 +228,7 @@ function build_coinsts(ir_inst::Expr, in_f, _rrule!!, n::Int, b::Int, is_blk_end
old_vals = Stack{eltype(val_slot)}()

return build_coinsts(
Val(:call), val_slot, arg_slots, evaluator, __rrule!!, old_vals, pb_stack, next_blk
Val(:call), val_slot, arg_slots, evaluator, __rrule!!, pb_stack, next_blk
)
elseif ir_inst.head in [
:code_coverage_effect, :gc_preserve_begin, :gc_preserve_end, :loopinfo,
Expand All @@ -254,18 +245,18 @@ end

function build_coinsts(::Val{:boundscheck}, out::CoDualSlot, next_blk::Int)
@assert primal_eltype(out) == Bool
fwds_inst::FwdsInst = @opaque (p::Int) -> (out[] = zero_codual(true); return next_blk)
bwds_inst::BwdsInst = @opaque (j::Int) -> j
fwds_inst::FwdsInst = @opaque (p::Int) -> (push!(out, zero_codual(true)); return next_blk)
bwds_inst::BwdsInst = @opaque (j::Int) -> (pop!(out); return j)
return fwds_inst, bwds_inst
end

function replace_tangent!(x::SlotRef{<:CoDual{Tx, Tdx}}, new_tangent::Tdx) where {Tx, Tdx}
function replace_tangent!(x::AbstractSlot{<:CoDual{Tx, Tdx}}, new_tangent::Tdx) where {Tx, Tdx}
x_val = x[]
x[] = CoDual(primal(x_val), new_tangent)
return nothing
end

function replace_tangent!(x::SlotRef{<:CoDual}, new_tangent)
function replace_tangent!(x::AbstractSlot{<:CoDual}, new_tangent)
x_val = x[]
x[] = CoDual(primal(x_val), new_tangent)
return nothing
Expand All @@ -275,21 +266,18 @@ replace_tangent!(::ConstSlot{<:CoDual}, new_tangent) = nothing

function build_coinsts(
::Val{:call},
out::CoDualSlot,
arg_slots::NTuple{N, CoDualSlot} where {N},
out::CoDualStack,
arg_slots::NTuple{N, CoDualStack} where {N},
evaluator::Teval,
__rrule!!::Trrule!!,
old_vals::Stack,
pb_stack::Stack,
next_blk::Int,
) where {Teval, Trrule!!}

function fwds_pass()
isassigned(out) && push!(old_vals, out[])
args = tuple_map(getindex, arg_slots)
z_ev = zero_codual(evaluator)
_out, pb!! = __rrule!!(z_ev, args...)
out[] = _out
_out, pb!! = __rrule!!(zero_codual(evaluator), args...)
push!(out, _out)
push!(pb_stack, pb!!)
return nothing
end
Expand All @@ -299,15 +287,11 @@ function build_coinsts(
end

function bwds_pass()
dout = tangent(out[])
dargs = tuple_map(set_immutable_to_zero tangent getindex, arg_slots)
pb!! = pop!(pb_stack)
tmp = pb!!(dout, NoTangent(), dargs...)
new_dargs = tmp[2:end]
map(increment_tangent!, arg_slots, new_dargs)
if !isempty(old_vals)
out[] = pop!(old_vals) # restore old state.
end
dout = tangent(pop!(out))
dargs = tuple_map(set_immutable_to_zero tangent getindex, arg_slots)
new_dargs = pb!!(dout, NoTangent(), dargs...)
map(increment_tangent!, arg_slots, new_dargs[2:end])
return nothing
end
bwds_inst = @opaque function (j::Int)
Expand Down Expand Up @@ -354,16 +338,15 @@ end
tangent_type(::Type{<:InterpretedFunction}) = NoTangent
tangent_type(::Type{<:DelayedInterpretedFunction}) = NoTangent

# Pre-allocate for AD-related instructions and quantities.
function make_codual_slot(::SlotRef{P}) where {P}
return SlotRef{codual_type(P)}()
# return isconcretetype(P) ? SlotRef{CoDual{P, tangent_type(P)}}() : SlotRef{CoDual}()
make_codual_stack(::SlotRef{P}) where {P} = Stack{codual_type(P)}()
function make_codual_stack(x::ConstSlot{P}) where {P}
stack = Stack{codual_type(P)}()
push!(stack, uninit_codual(x[]))
return stack
end

make_codual_slot(x::ConstSlot{P}) where {P} = ConstSlot{codual_type(P)}(uninit_codual(x[]))

function make_codual_arginfo(ai::ArgInfo{T, is_vararg}) where {T, is_vararg}
codual_arg_slots = map(make_codual_slot, ai.arg_slots)
codual_arg_slots = map(make_codual_stack, ai.arg_slots)
return ArgInfo{_typeof(codual_arg_slots), is_vararg}(codual_arg_slots)
end

Expand All @@ -390,13 +373,20 @@ function load_rrule_args!(ai::ArgInfo{T, is_vararg}, args::Tuple) where {T, is_v
end

# Load the arguments into `ai.arg_slots`.
return __load_args!(ai.arg_slots, refined_args)
return __push_args!(ai.arg_slots, refined_args)
end

@generated function __push_args!(arg_slots::Tuple, args::Tuple)
Ts = args.parameters
loaders = map(n -> :(push!(arg_slots[$n], args[$n])), eachindex(Ts))
return Expr(:block, loaders..., :(return nothing))
end


struct InterpretedFunctionRRule{sig<:Tuple, Treturn, Targ_info<:ArgInfo}
return_slot::SlotRef{Treturn}
arg_info::Targ_info
slots::Vector{CoDualSlot}
slots::Vector{Stack}
fwds_instructions::Vector{FwdsInst}
bwds_instructions::Vector{BwdsInst}
n_stack::Stack{Int}
Expand Down Expand Up @@ -428,7 +418,7 @@ function make_phi_instructions!(
foreach(n -> (ir.stmts.inst[n] isa PhiNode) && push!(phi_node_inds, n), bb.stmts)
isempty(phi_node_inds) && continue

# Make a single instruction which runs all of the PhiNodes "simulataneously".
# Make a single instruction which runs all of the PhiNodes "simultaneously".
# Specifically, this instruction runs all of the phi nodes, storing the results of
# this into temporary storage, then writing from the temporary slots to the
# final slots. This has the effect of ensuring that phi nodes that depend on other
Expand All @@ -452,14 +442,14 @@ end

function build_rrule!!(in_f::InterpretedFunction{sig}) where {sig}

return_slot = make_codual_slot(in_f.return_slot)
return_slot = SlotRef{codual_type(eltype(in_f.return_slot))}()
arg_info = make_codual_arginfo(in_f.arg_info)

# Construct rrule!! for in_f.
__rrule!! = InterpretedFunctionRRule{sig, eltype(return_slot), _typeof(arg_info)}(
return_slot,
arg_info,
map(make_codual_slot, in_f.slots), # SlotRefs
map(make_codual_stack, in_f.slots), # SlotRefs
Vector{FwdsInst}(undef, length(in_f.instructions)), # fwds_instructions
Vector{BwdsInst}(undef, length(in_f.instructions)), # bwds_instructions
Stack{Int}(),
Expand Down Expand Up @@ -544,8 +534,7 @@ end

@generated function __load_tangents!(arg_slots::Tuple, dargs::Tuple)
Ts = dargs.parameters
ns = filter(n -> !Base.issingletontype(Ts[n]), eachindex(Ts))
loaders = map(n -> :(replace_tangent!(arg_slots[$n], dargs[$n])), ns)
loaders = map(n -> :(replace_tangent!(arg_slots[$n], dargs[$n])), eachindex(Ts))
return Expr(:block, loaders..., :(return nothing))
end

Expand Down
Loading

0 comments on commit 3bbf8fc

Please sign in to comment.