Skip to content

Commit

Permalink
invoke attempt 2 (#212)
Browse files Browse the repository at this point in the history
* Some work

* Test LazyZeroRData constructors more carefully

* Improve docstring for new method of lookup_ir

* Improve code formatting to make intent clear

* Rename sig to sig_and_mi where required

* More legitimate primitive detection

* More appropriate name for vararg and sparam finder

* Improve formatting and docstring

* Bump patch
  • Loading branch information
willtebbutt authored Aug 2, 2024
1 parent 7a5c03e commit 4ea0e1d
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 51 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Tapir"
uuid = "07d77754-e150-4737-8c94-cd238a1fb45b"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.2.28"
version = "0.2.29"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
4 changes: 2 additions & 2 deletions src/fwds_rvs_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -737,8 +737,8 @@ end
# zero element and use it later. L is the precise type of `LazyZeroRData` that you wish to
# construct -- very occassionally you need complete control over this, but don't want to
# figure out for yourself whether or not construction can be performed lazily.
@inline function lazy_zero_rdata(::Type{L}, p::P) where {L<:LazyZeroRData, P}
return L(can_produce_zero_rdata_from_type(P) ? nothing : zero_rdata(p))
@inline function lazy_zero_rdata(::Type{L}, p::P) where {S, L<:LazyZeroRData{S}, P}
return L(can_produce_zero_rdata_from_type(S) ? nothing : zero_rdata(p))
end

# If type parameters for `LazyZeroRData` are not provided, use the defaults.
Expand Down
2 changes: 1 addition & 1 deletion src/interpreter/ir_normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ static parameter names have been translated into either types, or `:static_param
expressions.
Unfortunately, the static parameter names are not retained in `IRCode`, and the `Method`
from which the `IRCode` is derived must be consulted. `Tapir.is_vararg_sig_and_sparam_names`
from which the `IRCode` is derived must be consulted. `Tapir.is_vararg_and_sparam_names`
provides a convenient way to do this.
"""
function normalise!(ir::IRCode, spnames::Vector{Symbol})
Expand Down
13 changes: 10 additions & 3 deletions src/interpreter/ir_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,13 @@ function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true)
end

"""
lookup_ir(interp::AbstractInterpreter, sig::Type{<:Tuple})::Tuple{IRCode, T}
lookup_ir(
interp::AbstractInterpreter,
sig_or_mi::Union{Type{<:Tuple}, Core.MethodInstance},
)::Tuple{IRCode, T}
Get the IR unique IR associated to `sig` under `interp`. Throws `ArgumentError`s if there is
no code found, or if more than one `IRCode` instance returned.
Get the IR unique IR associated to `sig_or_mi` under `interp`. Throws `ArgumentError`s if
there is no code found, or if more than one `IRCode` instance returned.
Returns a tuple containing the `IRCode` and its return type.
"""
Expand All @@ -188,6 +191,10 @@ function lookup_ir(interp::CC.AbstractInterpreter, sig::Type{<:Tuple})
return only(output)
end

function lookup_ir(interp::CC.AbstractInterpreter, mi::Core.MethodInstance)
return CC.typeinf_ircode(interp, mi.def, mi.specTypes, mi.sparam_vals, nothing)
end

"""
is_reachable_return_node(x::ReturnNode)
Expand Down
53 changes: 29 additions & 24 deletions src/interpreter/s2s_reverse_mode_ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,8 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo)
raw_rule = if is_primitive(context_type(info.interp), sig)
rrule!! # intrinsic / builtin / thing we provably have rule for
elseif is_invoke
LazyDerivedRule(info.interp, sig, info.safety_on) # Static dispatch
mi = stmt.args[1]::Core.MethodInstance
LazyDerivedRule(info.interp, mi, info.safety_on) # Static dispatch
else
DynamicDerivedRule(info.interp, info.safety_on) # Dynamic dispatch
end
Expand Down Expand Up @@ -701,15 +702,18 @@ end
# Rule derivation.
#

_is_primitive(C::Type, mi::Core.MethodInstance) = is_primitive(C, mi.specTypes)
_is_primitive(C::Type, sig::Type) = is_primitive(C, sig)

# Compute the concrete type of the rule that will be returned from `build_rrule`. This is
# important for performance in dynamic dispatch, and to ensure that recursion works
# properly.
function rule_type(interp::TapirInterpreter{C}, ::Type{sig}) where {C, sig}
is_primitive(C, sig) && return typeof(rrule!!)
function rule_type(interp::TapirInterpreter{C}, sig_or_mi) where {C}
_is_primitive(C, sig_or_mi) && return typeof(rrule!!)

ir, _ = lookup_ir(interp, sig)
ir, _ = lookup_ir(interp, sig_or_mi)
Treturn = Base.Experimental.compute_ir_rettype(ir)
isva, _ = is_vararg_sig_and_sparam_names(sig)
isva, _ = is_vararg_and_sparam_names(sig_or_mi)

arg_types = map(_type, ir.argtypes)
arg_fwds_types = Tuple{map(fcodual_type, arg_types)...}
Expand Down Expand Up @@ -743,35 +747,35 @@ function build_rrule(args...; safety_on=false)
end

"""
build_rrule(interp::PInterp{C}, sig::Type{<:Tuple}; safety_on=false) where {C}
build_rrule(interp::PInterp{C}, sig_or_mi; safety_on=false) where {C}
Returns a `DerivedRule` which is an `rrule!!` for `sig` in context `C`. See the docstring
Returns a `DerivedRule` which is an `rrule!!` for `sig_or_mi` in context `C`. See the docstring
for `rrule!!` for more info.
If `safety_on` is `true`, then all calls to rules are replaced with calls to `SafeRRule`s.
"""
function build_rrule(
interp::PInterp{C}, sig::Type{<:Tuple}; safety_on=false, silence_safety_messages=true
interp::PInterp{C}, sig_or_mi; safety_on=false, silence_safety_messages=true
) where {C}

# If we're compiling in safe mode, let the user know by default.
if !silence_safety_messages && safety_on
@info "Compiling rule for $sig in safe mode. Disable for best performance."
@info "Compiling rule for $sig_or_mi in safe mode. Disable for best performance."
end

# Reset id count. This ensures that the IDs generated are the same each time this
# function runs.
seed_id!()

# If we have a hand-coded rule, just use that.
is_primitive(C, sig) && return (safety_on ? SafeRRule(rrule!!) : rrule!!)
_is_primitive(C, sig_or_mi) && return (safety_on ? SafeRRule(rrule!!) : rrule!!)

# Grab code associated to the primal.
ir, _ = lookup_ir(interp, sig)
ir, _ = lookup_ir(interp, sig_or_mi)
Treturn = Base.Experimental.compute_ir_rettype(ir)

# Normalise the IR, and generated BBCode version of it.
isva, spnames = is_vararg_sig_and_sparam_names(sig)
isva, spnames = is_vararg_and_sparam_names(sig_or_mi)
ir = normalise!(ir, spnames)
primal_ir = BBCode(ir)

Expand All @@ -791,8 +795,8 @@ function build_rrule(

# If we've already derived the OpaqueClosures and info, do not re-derive, just create a
# copy and pass in new shared data.
if haskey(interp.oc_cache, (sig, safety_on))
existing_fwds_oc, existing_pb_oc = interp.oc_cache[(sig, safety_on)]
if haskey(interp.oc_cache, (sig_or_mi, safety_on))
existing_fwds_oc, existing_pb_oc = interp.oc_cache[(sig_or_mi, safety_on)]
fwds_oc = replace_captures(existing_fwds_oc, shared_data)
pb_oc = replace_captures(existing_pb_oc, shared_data)
else
Expand All @@ -801,7 +805,7 @@ function build_rrule(

optimised_fwds_ir = optimise_ir!(optimise_ir!(IRCode(fwds_ir); do_inline=true))
optimised_pb_ir = optimise_ir!(optimise_ir!(IRCode(pb_ir); do_inline=true))
# @show sig
# @show sig_or_mi
# @show Treturn
# @show safety_on
# display(ir)
Expand All @@ -820,10 +824,10 @@ function build_rrule(
OpaqueClosure(optimised_pb_ir, shared_data...; do_compile=true),
optimised_pb_ir,
)
interp.oc_cache[(sig, safety_on)] = (fwds_oc, pb_oc)
interp.oc_cache[(sig_or_mi, safety_on)] = (fwds_oc, pb_oc)
end

raw_rule = rule_type(interp, sig)(fwds_oc, pb_oc, Val(isva), Val(num_args(info)))
raw_rule = rule_type(interp, sig_or_mi)(fwds_oc, pb_oc, Val(isva), Val(num_args(info)))
return safety_on ? SafeRRule(raw_rule) : raw_rule
end

Expand Down Expand Up @@ -1230,7 +1234,7 @@ function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any, N}) where {N}
end

#=
LazyDerivedRule(interp, sig, safety_on::Bool)
LazyDerivedRule(interp, mi::Core.MethodInstance, safety_on::Bool)
For internal use only.
Expand All @@ -1242,19 +1246,20 @@ If `safety_on` is `true`, then the rule constructed will be a `SafeRRule`. This
when debugging, but should usually be switched off for production code as it (in general)
incurs some runtime overhead.
=#
mutable struct LazyDerivedRule{sig, Tinterp<:TapirInterpreter, Trule}
mutable struct LazyDerivedRule{Tinterp<:TapirInterpreter, Trule}
interp::Tinterp
safety_on::Bool
mi::Core.MethodInstance
rule::Trule
function LazyDerivedRule(interp::A, ::Type{sig}, safety_on::Bool) where {A, sig}
rt = safety_on ? SafeRRule{rule_type(interp, sig)} : rule_type(interp, sig)
return new{sig, A, rt}(interp, safety_on)
function LazyDerivedRule(interp::A, mi::Core.MethodInstance, safety_on::Bool) where {A}
rt = rule_type(interp, mi)
return new{A, safety_on ? SafeRRule{rt} : rt}(interp, safety_on, mi)
end
end

function (rule::LazyDerivedRule{sig})(args::Vararg{Any, N}) where {N, sig}
function (rule::LazyDerivedRule)(args::Vararg{Any, N}) where {N}
if !isdefined(rule, :rule)
rule.rule = build_rrule(rule.interp, sig; safety_on=rule.safety_on)
rule.rule = build_rrule(rule.interp, rule.mi; safety_on=rule.safety_on)
end
return rule.rule(args...)
end
2 changes: 2 additions & 0 deletions src/tangents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ tangent_type(::Type{Nothing}) = NoTangent

tangent_type(::Type{Expr}) = NoTangent

tangent_type(::Type{Core.TypeofVararg}) = NoTangent

tangent_type(::Type{SimpleVector}) = Vector{Any}

tangent_type(::Type{P}) where {P<:Union{UInt8, UInt16, UInt32, UInt64, UInt128}} = NoTangent
Expand Down
15 changes: 15 additions & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,18 @@ end

test_getfield_of_tuple_of_types(n::Int) = getfield((Float64, Float64), n)

test_for_invoke(x) = 5x

inlinable_invoke_call(x::Float64) = invoke(test_for_invoke, Tuple{Float64}, x)

vararg_test_for_invoke(n::Tuple{Int, Int}, x...) = sum(x) + n[1]

function inlinable_vararg_invoke_call(
rows::Tuple{Vararg{Int}}, n1::N, ns::Vararg{N}
) where {N}
return invoke(vararg_test_for_invoke, Tuple{typeof(rows), Vararg{N}}, rows, n1, ns...)
end

function generate_test_functions()
return Any[
(false, :allocs, nothing, const_tester),
Expand Down Expand Up @@ -1621,6 +1633,9 @@ function generate_test_functions()
(false, :none, nothing, ArgumentError, "hi"),
(false, :none, nothing, test_small_union, Ref{Union{Float64, Vector{Float64}}}(5.0)),
(false, :none, nothing, test_small_union, Ref{Union{Float64, Vector{Float64}}}([1.0])),
(false, :allocs, nothing, inlinable_invoke_call, 5.0),
(false, :none, nothing, inlinable_vararg_invoke_call, (2, 2), 5.0, 4.0, 3.0, 2.0),
(false, :none, nothing, hvcat, (2, 2), 3.0, 2.0, 0.0, 1.0),
]
end

Expand Down
28 changes: 21 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,19 +114,33 @@ The usual function `map` doesn't enforce this for `Array`s.
end

#=
is_vararg_sig_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}}
is_vararg_and_sparam_names(m::Method)
Returns a 2-tuple. The first element is true if the method associated to `sig` is a vararg
method, and false if not. The second element contains all of the names of the static
parameters associated to said method.
Returns a 2-tuple. The first element is true if `m` is a vararg method, and false if not.
The second element contains the names of the static parameters associated to `m`.
=#
function is_vararg_sig_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}}
is_vararg_and_sparam_names(m::Method) = m.isva, sparam_names(m)

#=
is_vararg_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}}
Finds the method associated to `sig`, and calls `is_vararg_and_sparam_names` on it.
=#
function is_vararg_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}}
world = Base.get_world_counter()
min = Base.RefValue{UInt}(typemin(UInt))
max = Base.RefValue{UInt}(typemax(UInt))
ms = Base._methods_by_ftype(sig, nothing, -1, world, true, min, max, Ptr{Int32}(C_NULL))::Vector
m = only(ms).method
return m.isva, sparam_names(m)
return is_vararg_and_sparam_names(only(ms).method)
end

#=
is_vararg_and_sparam_names(mi::Core.MethodInstance)
Calls `is_vararg_and_sparam_names` on `mi.def::Method`.
=#
function is_vararg_and_sparam_names(mi::Core.MethodInstance)::Tuple{Bool, Vector{Symbol}}
return is_vararg_and_sparam_names(mi.def)
end

# Returns the names of all of the static parameters in `m`.
Expand Down
31 changes: 18 additions & 13 deletions test/fwds_rvs_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,25 @@ end
@testset "lazy construction checks" begin
# Check that lazy construction is in fact lazy for some cases where performance
# really matters -- floats, things with no rdata, etc.
@testset "$p" for (p, fully_lazy) in Any[
(5, true),
(Int32(5), true),
(5.0, true),
(5f0, true),
(Float16(5.0), true),
(StructFoo(5.0), false),
(StructFoo(5.0, randn(4)), false),
(Bool, true),
(Tapir.TestResources.StableFoo, true),
@testset "$p" for (P, p, fully_lazy) in Any[
(Int, 5, true),
(Int32, Int32(5), true),
(Float64, 5.0, true),
(Float32, 5f0, true),
(Float16, Float16(5.0), true),
(StructFoo, StructFoo(5.0), false),
(StructFoo, StructFoo(5.0, randn(4)), false),
(Type{Bool}, Bool, true),
(Type{Tapir.TestResources.StableFoo}, Tapir.TestResources.StableFoo, true),
(Tuple{Float64, Float64}, (5.0, 4.0), true),
(Tuple{Float64, Vararg{Float64}}, (5.0, 4.0, 3.0), false),
]
@test fully_lazy == Base.issingletontype(typeof(lazy_zero_rdata(p)))
@inferred Tapir.instantiate(lazy_zero_rdata(p))
@test typeof(lazy_zero_rdata(p)) == Tapir.lazy_zero_rdata_type(_typeof(p))
L = Tapir.lazy_zero_rdata_type(P)
@test fully_lazy == Base.issingletontype(typeof(lazy_zero_rdata(L, p)))
if isconcretetype(P)
@inferred Tapir.instantiate(lazy_zero_rdata(L, p))
end
@test typeof(lazy_zero_rdata(L, p)) == Tapir.lazy_zero_rdata_type(P)
end
@test isa(
lazy_zero_rdata(Tapir.TestResources.StableFoo),
Expand Down

2 comments on commit 4ea0e1d

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/112277

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.29 -m "<description of version>" 4ea0e1d5ebe50ef21dc5078f74073141aad928c0
git push origin v0.2.29

Please sign in to comment.