Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: invoke #207

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/interpreter/ir_normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ function normalise!(ir::IRCode, spnames::Vector{Symbol})
inst = splatnew_to_call(inst)
inst = intrinsic_to_function(inst)
inst = lift_getfield_and_others(inst)
inst = invoke_to_invoke_wrapper(inst)
ir.stmts.inst[n] = inst
end
return ir
Expand Down Expand Up @@ -182,3 +183,25 @@ end
__get_arg(x::GlobalRef) = getglobal(x.mod, x.name)
__get_arg(x::QuoteNode) = x.value
__get_arg(x) = x

"""
invoke_wrapper(f, TT, x::Vararg{Any, N}) where {N}

Equivalent to `invoke(f, TT, x...)`, but is a primitive that won't get lowered / inlined
away.
"""
invoke_wrapper(f, TT, x::Vararg{Any, N}) where {N} = invoke(f, TT, x...)

@is_primitive MinimalCtx Tuple{typeof(invoke_wrapper), Vararg}

"""
invoke_to_invoke_wrapper(inst)

Replaces `:call`s to `Core.invoke` with calls to `Tapir.invoke_wrapper`. This done to
prevent such calls being inlined / lowered away during optimisation, as access to them is
required later on.
"""
function invoke_to_invoke_wrapper(inst)
(Meta.isexpr(inst, :call) && inst.args[1] == invoke) || return inst
return Expr(:call, invoke_wrapper, inst.args[2:end]...)
end
47 changes: 45 additions & 2 deletions src/interpreter/ir_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,19 +173,62 @@ end
"""
lookup_ir(interp::AbstractInterpreter, sig::Type{<:Tuple})::Tuple{IRCode, T}

Get the IR unique IR associated to `sig` under `interp`. Throws `ArgumentError`s if there is
Get the IR associated to `sig` under `interp`. Throws `ArgumentError`s if there is
no code found, or if more than one `IRCode` instance returned.

Returns a tuple containing the `IRCode` and its return type.
"""
function lookup_ir(interp::CC.AbstractInterpreter, sig::Type{<:Tuple})

# If `invoke` is the function in question, then do something complicated.
if sig.parameters[1] == typeof(invoke)
return lookup_invoke_ir(interp, sig)
end

# Look up the `IRCode` using the standard mechanism.
output = Base.code_ircode_by_type(sig; interp)
if isempty(output)
throw(ArgumentError("No methods found for signature $sig"))
elseif length(output) > 1
throw(ArgumentError("$(length(output)) methods found for signature $sig"))
end
return only(output)
return only(output)[1]
end

is_invoke_sig(sig) = sig.parameters[1] == typeof(invoke)

function static_sig(sig)
ps = sig.parameters
return is_invoke_sig(sig) ? Tuple{ps[2], ps[3].parameters[1].parameters...} : sig
end

"""
lookup_invoke_ir(interp::CC.AbstractInterpreter, sig::Type{<:Tuple})

Looks up the `IRCode` associated to an `invoke` call. If the first parameter of `sig` is not
`typeof(invoke)` then an `ArgumentError` is thrown. No other error checking is performed.
"""
function lookup_invoke_ir(interp::CC.AbstractInterpreter, sig::Type{<:Tuple})
if sig.parameters[1] != typeof(invoke)
throw(ArgumentError("Expected signature for a call to `invoke`."))
end

# Construct the static signature, and dynamic signature.
ps = sig.parameters
_sig = static_sig(sig)
dynamic_sig = Tuple{ps[2], ps[4:end]...}

# Lookup all methods which could apply to the types provided in the signature, and pick
# the one which `which` says would get applied.
# Base on https://github.com/JuliaLang/julia/blob/v1.10.4/base/reflection.jl#L1485
matches = Base._methods_by_ftype(_sig, #=lim=#-1, Base.get_world_counter())
m = which(_sig)
match = only(filter(_m -> m === _m.method, matches))
meth = Base.func_for_method_checked(match.method, _sig, match.sparams)
(code, _) = Core.Compiler.typeinf_ircode(
interp, meth, dynamic_sig, match.sparams, #=optimize_until=#nothing
)
return code
end

"""
Expand Down
49 changes: 35 additions & 14 deletions src/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -654,39 +654,51 @@ end
# between differing varargs conventions.
#

struct Pullback{Tpb_oc, Tisva<:Val, Tnvargs<:Val}
struct Pullback{Tpb_oc, Tisva<:Val, Tnvargs<:Val, Tis_invoke<:Val}
pb_oc::Tpb_oc
isva::Tisva
nvargs::Tnvargs
is_invoke::Tis_invoke
end

@inline (pb::Pullback)(dy) = __flatten_varargs(pb.isva, pb.pb_oc(dy), pb.nvargs)
@inline function (pb::Pullback)(dy)
return __flatten_varargs(pb.isva, pb.pb_oc(dy), pb.nvargs, pb.is_invoke)
end

struct DerivedRule{Tfwds_oc, Tpb_oc, Tisva<:Val, Tnargs<:Val}
struct DerivedRule{Tfwds_oc, Tpb_oc, Tisva<:Val, Tnargs<:Val, Tis_invoke<:Val}
fwds_oc::Tfwds_oc
pb_oc::Tpb_oc
isva::Tisva
nargs::Tnargs
is_invoke::Tis_invoke
end

@inline function (fwds::DerivedRule{P, Q, S})(args::Vararg{CoDual, N}) where {P, Q, S, N}
uf_args = __unflatten_codual_varargs(fwds.isva, args, fwds.nargs)
pb!! = Pullback(fwds.pb_oc, fwds.isva, nvargs(length(args), fwds.nargs))
uf_args = __unflatten_codual_varargs(fwds.isva, args, fwds.nargs, fwds.is_invoke)
pb!! = Pullback(fwds.pb_oc, fwds.isva, nvargs(length(args), fwds.nargs), fwds.is_invoke)
return fwds.fwds_oc(uf_args...)::CoDual, pb!!
end

@inline nvargs(n_flat, ::Val{nargs}) where {nargs} = Val(n_flat - nargs + 1)

# If isva, inputs (5.0, (4.0, 3.0)) are transformed into (5.0, 4.0, 3.0).
function __flatten_varargs(::Val{isva}, args, ::Val{nvargs}) where {isva, nvargs}
isva || return args
function __flatten_varargs(
::Val{isva}, args, ::Val{nvargs}, ::Val{is_invoke}
) where {isva, nvargs, is_invoke}
isva || (return is_invoke ? __invoke_rdata(args) : args)
last_el = isa(args[end], NoRData) ? ntuple(n -> NoRData(), nvargs) : args[end]
return (args[1:end-1]..., last_el...)
out_args = (args[1:end-1]..., last_el...)
return is_invoke ? __invoke_rdata(out_args) : out_args
end

__invoke_rdata(args) = (NoRData(), args[1], NoRData(), args[2:end]...)

# If isva and nargs=2, then inputs `(CoDual(5.0, 0.0), CoDual(4.0, 0.0), CoDual(3.0, 0.0))`
# are transformed into `(CoDual(5.0, 0.0), CoDual((5.0, 4.0), (0.0, 0.0)))`.
function __unflatten_codual_varargs(::Val{isva}, args, ::Val{nargs}) where {isva, nargs}
function __unflatten_codual_varargs(
::Val{isva}, raw_args, ::Val{nargs}, ::Val{is_invoke}
) where {isva, nargs, is_invoke}
args = is_invoke ? __args_from_invoke_call(raw_args) : raw_args
isva || return args
group_primal = map(primal, args[nargs:end])
if fdata_type(tangent_type(_typeof(group_primal))) == NoFData
Expand All @@ -697,6 +709,8 @@ function __unflatten_codual_varargs(::Val{isva}, args, ::Val{nargs}) where {isva
return (args[1:nargs-1]..., grouped_args)
end

__args_from_invoke_call(args) = (args[2], args[4:end]...)

#
# Rule derivation.
#
Expand All @@ -705,9 +719,9 @@ end
# important for performance in dynamic dispatch, and to ensure that recursion works
# properly.
function rule_type(interp::TapirInterpreter{C}, ::Type{sig}) where {C, sig}
is_primitive(C, sig) && return typeof(rrule!!)
has_hand_written_rule(C, sig) && return typeof(rrule!!)

ir, _ = lookup_ir(interp, sig)
ir = lookup_ir(interp, sig)
Treturn = Base.Experimental.compute_ir_rettype(ir)
isva, _ = is_vararg_sig_and_sparam_names(sig)

Expand All @@ -716,19 +730,22 @@ function rule_type(interp::TapirInterpreter{C}, ::Type{sig}) where {C, sig}
arg_rvs_types = Tuple{map(rdata_type ∘ tangent_type, arg_types)...}
fwds_return_codual = fcodual_type(Treturn)
rvs_return_type = rdata_type(tangent_type(Treturn))

if isconcretetype(fwds_return_codual)
return DerivedRule{
MistyClosure{OpaqueClosure{arg_fwds_types, fwds_return_codual}},
MistyClosure{OpaqueClosure{Tuple{rvs_return_type}, arg_rvs_types}},
Val{isva},
Val{length(ir.argtypes)},
Val{is_invoke_sig(sig)},
}
else
return DerivedRule{
MistyClosure{OpaqueClosure{arg_fwds_types, P}} where {P<:fwds_return_codual},
MistyClosure{OpaqueClosure{Tuple{rvs_return_type}, arg_rvs_types}},
Val{isva},
Val{length(ir.argtypes)},
Val{is_invoke_sig(sig)},
}
end
end
Expand All @@ -742,6 +759,8 @@ function build_rrule(args...; safety_on=false)
return build_rrule(PInterp(), _typeof(TestUtils.__get_primals(args)); safety_on)
end

has_hand_written_rule(C, sig) = is_primitive(C, static_sig(sig))

"""
build_rrule(interp::PInterp{C}, sig::Type{<:Tuple}; safety_on=false) where {C}

Expand All @@ -764,10 +783,10 @@ function build_rrule(
seed_id!()

# If we have a hand-coded rule, just use that.
is_primitive(C, sig) && return (safety_on ? SafeRRule(rrule!!) : rrule!!)
has_hand_written_rule(C, sig) && return (safety_on ? SafeRRule(rrule!!) : rrule!!)

# Grab code associated to the primal.
ir, _ = lookup_ir(interp, sig)
ir = lookup_ir(interp, sig)
Treturn = Base.Experimental.compute_ir_rettype(ir)

# Normalise the IR, and generated BBCode version of it.
Expand Down Expand Up @@ -823,7 +842,9 @@ function build_rrule(
interp.oc_cache[(sig, safety_on)] = (fwds_oc, pb_oc)
end

raw_rule = rule_type(interp, sig)(fwds_oc, pb_oc, Val(isva), Val(num_args(info)))
raw_rule = rule_type(interp, sig)(
fwds_oc, pb_oc, Val(isva), Val(num_args(info)), Val(is_invoke_sig(sig)),
)
return safety_on ? SafeRRule(raw_rule) : raw_rule
end

Expand Down
22 changes: 20 additions & 2 deletions src/rrules/builtins.jl
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,25 @@ function rrule!!(f::CoDual{typeof(getglobal)}, a, b)
return zero_fcodual(getglobal(primal(a), primal(b))), NoPullback(f, a, b)
end

# invoke
# invoke: see https://github.com/compintell/Tapir.jl/issues/206 for more info.
@is_primitive MinimalCtx Tuple{typeof(invoke), Vararg}

@generated function codual_sig(::Type{sig}) where {sig<:Tuple}
return Tuple{map(codual_type, sig.parameters)...}
end

# An rrule!! which can be called if there is an `rrule!!` available for the `invoke`d
# method of rrule.
function rrule!!(
::CoDual{typeof(invoke)}, f::F, ::CoDual{Type{PP}}, args::Vararg{CoDual, N}
) where {F<:CoDual, PP<:Tuple, N}
out, pb!! = invoke(rrule!!, Tuple{F, codual_sig(PP).parameters...}, f, args...)
function invoke_pb!!(dout)
df, dargs... = pb!!(dout)
return NoRData(), df, NoRData(), dargs...
end
return out, invoke_pb!!
end

function rrule!!(f::CoDual{typeof(isa)}, x, T)
return zero_fcodual(isa(primal(x), primal(T))), NoPullback(f, x, T)
Expand Down Expand Up @@ -955,7 +973,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins})
(false, :none, _range, getfield, UInt8, :hash),
(false, :none, _range, getfield, UInt8, :flags),
# getglobal requires compositional testing, because you can't deepcopy a module
# invoke -- NEEDS IMPLEMENTING AND TESTING
# invoke -- tested in s2s_reverse_mode_ad tests
(false, :stability, nothing, isa, 5.0, Float64),
(false, :stability, nothing, isa, 1, Float64),
(false, :stability, nothing, isdefined, MutableFoo(5.0, randn(5)), :sim),
Expand Down
25 changes: 25 additions & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,27 @@ end

test_getfield_of_tuple_of_types(n::Int) = getfield((Float64, Float64), n)

test_for_invoke(x) = 5x

Tapir.is_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(test_for_invoke), Any}}) = true

function Tapir.rrule!!(::CoDual{typeof(test_for_invoke)}, x::CoDual)
test_for_invoke_pb!!(dy) = NoRData(), 5 * dy
return Tapir.zero_fcodual(5 * primal(x)), test_for_invoke_pb!!
end

test_for_invoke(x::Float64) = x

function Tapir.is_primitive(
::Type{MinimalCtx}, ::Type{<:Tuple{typeof(test_for_invoke), Float64}}
)
return false
end

test_for_invoke(x::Float64, y::Float64, z::Float64...) = x + sum(y)

inlinable_invoke_call(x::Float64) = invoke(test_for_invoke, Tuple{Float64}, x)

function generate_test_functions()
return Any[
(false, :allocs, nothing, const_tester),
Expand Down Expand Up @@ -1621,6 +1642,10 @@ function generate_test_functions()
(false, :none, nothing, ArgumentError, "hi"),
(false, :none, nothing, test_small_union, Ref{Union{Float64, Vector{Float64}}}(5.0)),
(false, :none, nothing, test_small_union, Ref{Union{Float64, Vector{Float64}}}([1.0])),
(false, :stability_and_allocs, nothing, invoke, test_for_invoke, Tuple{Any}, 5.0),
(false, :allocs, nothing, invoke, test_for_invoke, Tuple{Float64}, 5.0),
(false, :allocs, nothing, invoke, test_for_invoke, Tuple{Float64, Float64, Float64}, 5.0, 4.0, 3.0),
(false, :allocs, nothing, hvcat, (2, 2), 3.0, 2.0, 0.0, 1.0),
]
end

Expand Down
4 changes: 3 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,9 @@ function is_vararg_sig_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}}
world = Base.get_world_counter()
min = Base.RefValue{UInt}(typemin(UInt))
max = Base.RefValue{UInt}(typemax(UInt))
ms = Base._methods_by_ftype(sig, nothing, -1, world, true, min, max, Ptr{Int32}(C_NULL))::Vector
ms = Base._methods_by_ftype(
static_sig(sig), nothing, -1, world, true, min, max, Ptr{Int32}(C_NULL)
)::Vector
m = only(ms).method
return m.isva, sparam_names(m)
end
Expand Down
19 changes: 19 additions & 0 deletions test/interpreter/ir_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ module IRUtilsGlobalRefs
const __x_2 = 5.0
__x_3::Float64 = 5.0
const __x_4::Float64 = 5.0

foo(x) = x
foo(x::Float64) = 5x
end

@testset "ir_utils" begin
Expand Down Expand Up @@ -52,4 +55,20 @@ end
@test Tapir.inc_args(IDGotoIfNot(Argument(1), id)) == IDGotoIfNot(Argument(2), id)
@test Tapir.inc_args(IDGotoNode(id)) == IDGotoNode(id)
end
@testset "lookup_invoke_ir" begin

# Bail out if a non-invoke signature is passed in.
@test_throws(
ArgumentError,
Tapir.lookup_invoke_ir(Tapir.TapirInterpreter(), Tuple{typeof(sin), Float64}),
)

ir = Tapir.lookup_invoke_ir(
Tapir.TapirInterpreter(),
Tuple{typeof(invoke), typeof(IRUtilsGlobalRefs.foo), Type{Tuple{Any}}, Float64},
)
oc = Core.OpaqueClosure(ir; do_compile=true)
@test oc(5.0) == 5.0
@test invoke(IRUtilsGlobalRefs.foo, Tuple{Any}, 5.0) == 5.0
end
end
Loading