Skip to content

Commit

Permalink
optimizer: inline abstract union-split callsite (JuliaLang#44512)
Browse files Browse the repository at this point in the history
Currently the optimizer handles abstract callsite only when there is a
single dispatch candidate (in most cases), and so inlining and static-dispatch
are prohibited when the callsite is union-split (in other word, union-split
happens only when all the dispatch candidates are concrete).

However, there are certain patterns of code (most notably our Julia-level compiler code)
that inherently need to deal with abstract callsite.
The following example is taken from `Core.Compiler` utility:
```julia
julia> @inline isType(@nospecialize t) = isa(t, DataType) && t.name === Type.body.name
isType (generic function with 1 method)

julia> code_typed((Any,)) do x # abstract, but no union-split, successful inlining
           isType(x)
       end |> only
CodeInfo(
1 ─ %1 = (x isa Main.DataType)::Bool
└──      goto #3 if not %1
2 ─ %3 = π (x, DataType)
│   %4 = Base.getfield(%3, :name)::Core.TypeName
│   %5 = Base.getfield(Type{T}, :name)::Core.TypeName
│   %6 = (%4 === %5)::Bool
└──      goto #4
3 ─      goto #4
4 ┄ %9 = φ (#2 => %6, #3 => false)::Bool
└──      return %9
) => Bool

julia> code_typed((Union{Type,Nothing},)) do x # abstract, union-split, unsuccessful inlining
           isType(x)
       end |> only
CodeInfo(
1 ─ %1 = (isa)(x, Nothing)::Bool
└──      goto #3 if not %1
2 ─      goto #4
3 ─ %4 = Main.isType(x)::Bool
└──      goto #4
4 ┄ %6 = φ (#2 => false, #3 => %4)::Bool
└──      return %6
) => Bool
```
(note that this is a limitation of the inlining algorithm, and so any
user-provided hints like callsite inlining annotation doesn't help here)

This commit enables inlining and static dispatch for abstract union-split callsite.
The core idea here is that we can simulate our dispatch semantics by
generating `isa` checks in order of the specialities of dispatch candidates:
```julia
julia> code_typed((Union{Type,Nothing},)) do x # union-split, unsuccessful inlining
                  isType(x)
              end |> only
CodeInfo(
1 ─ %1  = (isa)(x, Nothing)::Bool
└──       goto #3 if not %1
2 ─       goto JuliaLang#9
3 ─ %4  = (isa)(x, Type)::Bool
└──       goto JuliaLang#8 if not %4
4 ─ %6  = π (x, Type)
│   %7  = (%6 isa Main.DataType)::Bool
└──       goto JuliaLang#6 if not %7
5 ─ %9  = π (%6, DataType)
│   %10 = Base.getfield(%9, :name)::Core.TypeName
│   %11 = Base.getfield(Type{T}, :name)::Core.TypeName
│   %12 = (%10 === %11)::Bool
└──       goto JuliaLang#7
6 ─       goto JuliaLang#7
7 ┄ %15 = φ (JuliaLang#5 => %12, JuliaLang#6 => false)::Bool
└──       goto JuliaLang#9
8 ─       Core.throw(ErrorException("fatal error in type inference (type bound)"))::Union{}
└──       unreachable
9 ┄ %19 = φ (#2 => false, JuliaLang#7 => %15)::Bool
└──       return %19
) => Bool
```

Inlining/static-dispatch of abstract union-split callsite will improve
the performance in such situations (and so this commit will improve the
latency of our JIT compilation). Especially, this commit helps us avoid
excessive specializations of `Core.Compiler` code by statically-resolving
`@nospecialize`d callsites, and as the result, the # of precompiled
statements is now reduced from  `2005` ([`master`](f782430)) to `1912` (this commit).

And also, as a side effect, the implementation of our inlining algorithm
gets much simplified now since we no longer need the previous special
handlings for abstract callsites.

One possible drawback would be increased code size.
This change seems to certainly increase the size of sysimage,
but I think these numbers are in an acceptable range:
> [`master`](f782430)
```
❯ du -shk usr/lib/julia/*
17604	usr/lib/julia/corecompiler.ji
194072	usr/lib/julia/sys-o.a
169424	usr/lib/julia/sys.dylib
23784	usr/lib/julia/sys.dylib.dSYM
103772	usr/lib/julia/sys.ji
```

> this commit
```
❯ du -shk usr/lib/julia/*
17512	usr/lib/julia/corecompiler.ji
195588	usr/lib/julia/sys-o.a
170908	usr/lib/julia/sys.dylib
23776	usr/lib/julia/sys.dylib.dSYM
105360	usr/lib/julia/sys.ji
```
  • Loading branch information
aviatesk authored Mar 30, 2022
1 parent f782430 commit 3e1ae47
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 90 deletions.
179 changes: 94 additions & 85 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ function cfg_inline_unionsplit!(ir::IRCode, idx::Int,
push!(from_bbs, length(state.new_cfg_blocks))
# TODO: Right now we unconditionally generate a fallback block
# in case of subtyping errors - This is probably unnecessary.
if i != length(cases) || (!fully_covered || (!params.trust_inference && isdispatchtuple(cases[i].sig)))
if i != length(cases) || (!fully_covered || (!params.trust_inference))
# This block will have the next condition or the final else case
push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx)))
push!(state.new_cfg_blocks[cond_bb].succs, length(state.new_cfg_blocks))
Expand Down Expand Up @@ -313,7 +313,6 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
spec = item.spec::ResolvedInliningSpec
sparam_vals = item.mi.sparam_vals
def = item.mi.def::Method
inline_cfg = spec.ir.cfg
linetable_offset::Int32 = length(linetable)
# Append the linetable of the inlined function to our line table
inlined_at = Int(compact.result[idx][:line])
Expand Down Expand Up @@ -459,6 +458,66 @@ end

const FATAL_TYPE_BOUND_ERROR = ErrorException("fatal error in type inference (type bound)")

"""
ir_inline_unionsplit!
The core idea of this function is to simulate the dispatch semantics by generating
(flat) `isa`-checks corresponding to the signatures of union-split dispatch candidates,
and then inline their bodies into each `isa`-conditional block.
This `isa`-based virtual dispatch requires few pre-conditions to hold in order to simulate
the actual semantics correctly.
The first one is that these dispatch candidates need to be processed in order of their specificity,
and the corresponding `isa`-checks should reflect the method specificities, since now their
signatures are not necessarily concrete.
For example, given the following definitions:
f(x::Int) = ...
f(x::Number) = ...
f(x::Any) = ...
and a callsite:
f(x::Any)
then a correct `isa`-based virtual dispatch would be:
if isa(x, Int)
[inlined/resolved f(x::Int)]
elseif isa(x, Number)
[inlined/resolved f(x::Number)]
else # implies `isa(x, Any)`, which fully covers this call signature,
# otherwise we need to insert a fallback dynamic dispatch case also
[inlined/resolved f(x::Any)]
end
Fortunately, `ml_matches` should already sorted them in that way, except cases when there is
any ambiguity, from which we already bail out at this point.
Another consideration is type equality constraint from type variables: the `isa`-checks are
not enough to simulate the dispatch semantics in cases like:
Given a definition:
g(x::T, y::T) where T<:Integer = ...
transform a callsite:
g(x::Any, y::Any)
into the optimized form:
if isa(x, Integer) && isa(y, Integer)
[inlined/resolved g(x::Integer, y::Integer)]
else
g(x, y) # fallback dynamic dispatch
end
But again, we should already bail out from such cases at this point, essentially by
excluding cases where `case.sig::UnionAll`.
In short, here we can process the dispatch candidates in order, assuming we haven't changed
their order somehow somewhere up to this point.
"""
function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
argexprs::Vector{Any}, linetable::Vector{LineInfoNode},
(; fully_covered, atype, cases, bbs)::UnionSplit,
Expand All @@ -468,17 +527,17 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
join_bb = bbs[end]
pn = PhiNode()
local bb = compact.active_result_bb
@assert length(bbs) >= length(cases)
for i in 1:length(cases)
ncases = length(cases)
@assert length(bbs) >= ncases
for i = 1:ncases
ithcase = cases[i]
mtype = ithcase.sig::DataType # checked within `handle_cases!`
case = ithcase.item
next_cond_bb = bbs[i]
cond = true
nparams = fieldcount(atype)
@assert nparams == fieldcount(mtype)
if i != length(cases) || !fully_covered ||
(!params.trust_inference && isdispatchtuple(cases[i].sig))
if i != ncases || !fully_covered || !params.trust_inference
for i = 1:nparams
a, m = fieldtype(atype, i), fieldtype(mtype, i)
# If this is always true, we don't need to check for it
Expand Down Expand Up @@ -535,7 +594,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
bb += 1
# We're now in the fall through block, decide what to do
if fully_covered
if !params.trust_inference && isdispatchtuple(cases[end].sig)
if !params.trust_inference
e = Expr(:call, GlobalRef(Core, :throw), FATAL_TYPE_BOUND_ERROR)
insert_node_here!(compact, NewInstruction(e, Union{}, line))
insert_node_here!(compact, NewInstruction(ReturnNode(), Union{}, line))
Expand All @@ -558,7 +617,7 @@ function batch_inline!(todo::Vector{Pair{Int, Any}}, ir::IRCode, linetable::Vect
state = CFGInliningState(ir)
for (idx, item) in todo
if isa(item, UnionSplit)
cfg_inline_unionsplit!(ir, idx, item::UnionSplit, state, params)
cfg_inline_unionsplit!(ir, idx, item, state, params)
else
item = item::InliningTodo
spec = item.spec::ResolvedInliningSpec
Expand Down Expand Up @@ -1172,12 +1231,8 @@ function analyze_single_call!(
sig::Signature, state::InliningState, todo::Vector{Pair{Int, Any}})
argtypes = sig.argtypes
cases = InliningCase[]
local only_method = nothing # keep track of whether there is one matching method
local meth::MethodLookupResult
local any_fully_covered = false
local handled_all_cases = true
local any_covers_full = false
local revisit_idx = nothing

for i in 1:length(infos)
meth = infos[i].results
if meth.ambig
Expand All @@ -1188,66 +1243,20 @@ function analyze_single_call!(
# No applicable methods; try next union split
handled_all_cases = false
continue
else
if length(meth) == 1 && only_method !== false
if only_method === nothing
only_method = meth[1].method
elseif only_method !== meth[1].method
only_method = false
end
else
only_method = false
end
end
for (j, match) in enumerate(meth)
any_covers_full |= match.fully_covers
if !isdispatchtuple(match.spec_types)
if !match.fully_covers
handled_all_cases = false
continue
end
if revisit_idx === nothing
revisit_idx = (i, j)
else
handled_all_cases = false
revisit_idx = nothing
end
else
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases)
end
for match in meth
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true)
any_fully_covered |= match.fully_covers
end
end

atype = argtypes_to_type(argtypes)
if handled_all_cases && revisit_idx !== nothing
# If there's only one case that's not a dispatchtuple, we can
# still unionsplit by visiting all the other cases first.
# This is useful for code like:
# foo(x::Int) = 1
# foo(@nospecialize(x::Any)) = 2
# where we where only a small number of specific dispatchable
# cases are split off from an ::Any typed fallback.
(i, j) = revisit_idx
match = infos[i].results[j]
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true)
elseif length(cases) == 0 && only_method isa Method
# if the signature is fully covered and there is only one applicable method,
# we can try to inline it even if the signature is not a dispatch tuple.
# -- But don't try it if we already tried to handle the match in the revisit_idx
# case, because that'll (necessarily) be the same method.
if length(infos) > 1
(metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any),
atype, only_method.sig)::SimpleVector
match = MethodMatch(metharg, methsp::SimpleVector, only_method, true)
else
@assert length(meth) == 1
match = meth[1]
end
handle_match!(match, argtypes, flag, state, cases, true) || return nothing
any_covers_full = handled_all_cases = match.fully_covers
if !handled_all_cases
# if we've not seen all candidates, union split is valid only for dispatch tuples
filter!(case::InliningCase->isdispatchtuple(case.sig), cases)
end

handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params)
handle_cases!(ir, idx, stmt, argtypes_to_type(argtypes), cases,
handled_all_cases & any_fully_covered, todo, state.params)
end

# similar to `analyze_single_call!`, but with constant results
Expand All @@ -1258,8 +1267,8 @@ function handle_const_call!(
(; call, results) = cinfo
infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches
cases = InliningCase[]
local any_fully_covered = false
local handled_all_cases = true
local any_covers_full = false
local j = 0
for i in 1:length(infos)
meth = infos[i].results
Expand All @@ -1275,42 +1284,39 @@ function handle_const_call!(
for match in meth
j += 1
result = results[j]
any_covers_full |= match.fully_covers
any_fully_covered |= match.fully_covers
if isa(result, ConstResult)
case = const_result_item(result, state)
push!(cases, InliningCase(result.mi.specTypes, case))
elseif isa(result, InferenceResult)
handled_all_cases &= handle_inf_result!(result, argtypes, flag, state, cases)
handled_all_cases &= handle_inf_result!(result, argtypes, flag, state, cases, true)
else
@assert result === nothing
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases)
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true)
end
end
end

# if the signature is fully covered and there is only one applicable method,
# we can try to inline it even if the signature is not a dispatch tuple
atype = argtypes_to_type(argtypes)
if length(cases) == 0
length(results) == 1 || return nothing
result = results[1]
isa(result, InferenceResult) || return nothing
handle_inf_result!(result, argtypes, flag, state, cases, true) || return nothing
spec_types = cases[1].sig
any_covers_full = handled_all_cases = atype <: spec_types
if !handled_all_cases
# if we've not seen all candidates, union split is valid only for dispatch tuples
filter!(case::InliningCase->isdispatchtuple(case.sig), cases)
end

handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params)
handle_cases!(ir, idx, stmt, argtypes_to_type(argtypes), cases,
handled_all_cases & any_fully_covered, todo, state.params)
end

function handle_match!(
match::MethodMatch, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
cases::Vector{InliningCase}, allow_abstract::Bool = false)
spec_types = match.spec_types
allow_abstract || isdispatchtuple(spec_types) || return false
# we may see duplicated dispatch signatures here when a signature gets widened
# during abstract interpretation: for the purpose of inlining, we can just skip
# processing this dispatch candidate
_any(case->case.sig === spec_types, cases) && return true
item = analyze_method!(match, argtypes, flag, state)
item === nothing && return false
_any(case->case.sig === spec_types, cases) && return true
push!(cases, InliningCase(spec_types, item))
return true
end
Expand Down Expand Up @@ -1346,7 +1352,9 @@ function handle_cases!(ir::IRCode, idx::Int, stmt::Expr, @nospecialize(atype),
handle_single_case!(ir, idx, stmt, cases[1].item, todo, params)
elseif length(cases) > 0
isa(atype, DataType) || return nothing
all(case::InliningCase->isa(case.sig, DataType), cases) || return nothing
for case in cases
isa(case.sig, DataType) || return nothing
end
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
end
return nothing
Expand Down Expand Up @@ -1442,7 +1450,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)

analyze_single_call!(ir, idx, stmt, infos, flag, sig, state, todo)
end
todo

return todo
end

function linear_inline_eligible(ir::IRCode)
Expand Down
2 changes: 1 addition & 1 deletion base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module Sort
import ..@__MODULE__, ..parentmodule
const Base = parentmodule(@__MODULE__)
using .Base.Order
using .Base: copymutable, LinearIndices, length, (:),
using .Base: copymutable, LinearIndices, length, (:), iterate,
eachindex, axes, first, last, similar, zip, OrdinalRange,
AbstractVector, @inbounds, AbstractRange, @eval, @inline, Vector, @noinline,
AbstractMatrix, AbstractUnitRange, isless, identity, eltype, >, <, <=, >=, |, +, -, *, !,
Expand Down
Loading

0 comments on commit 3e1ae47

Please sign in to comment.