diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 7fcaa79a468d5..c6760d01a61e8 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -10,7 +10,7 @@ end du::SSADefUse This struct keeps track of all uses of some mutable struct allocated in the current function: -- `du.uses::Vector{Int}` are all instances of `getfield` on the struct +- `du.uses::Vector{Int}` are all instances of `getfield` / `isdefined` on the struct - `du.defs::Vector{Int}` are all instances of `setfield!` on the struct The terminology refers to the uses/defs of the "slot bundle" that the mutable struct represents. @@ -27,7 +27,10 @@ struct SSADefUse end SSADefUse() = SSADefUse(Int[], Int[], Int[]) -compute_live_ins(cfg::CFG, du::SSADefUse) = compute_live_ins(cfg, du.defs, du.uses) +function compute_live_ins(cfg::CFG, du::SSADefUse) + # filter out `isdefined` usages + return compute_live_ins(cfg, du.defs, filter(>(0), du.uses)) +end # assume `stmt == getfield(obj, field, ...)` or `stmt == setfield!(obj, field, val, ...)` try_compute_field_stmt(ir::Union{IncrementalCompact,IRCode}, stmt::Expr) = @@ -725,7 +728,7 @@ function sroa_pass!(ir::IRCode) for ((_, idx), stmt) in compact # check whether this statement is `getfield` / `setfield!` (or other "interesting" statement) isa(stmt, Expr) || continue - is_setfield = false + is_setfield = is_isdefined = false field_ordering = :unspecified if is_known_call(stmt, setfield!, compact) 4 <= length(stmt.args) <= 5 || continue @@ -741,6 +744,13 @@ function sroa_pass!(ir::IRCode) field_ordering = argextype(stmt.args[4], compact) widenconst(field_ordering) === Bool && (field_ordering = :unspecified) end + elseif is_known_call(stmt, isdefined, compact) + 3 <= length(stmt.args) <= 4 || continue + is_isdefined = true + if length(stmt.args) == 4 + field_ordering = argextype(stmt.args[4], compact) + widenconst(field_ordering) === Bool && (field_ordering = :unspecified) + end elseif isexpr(stmt, :foreigncall) nccallargs = length(stmt.args[3]::SimpleVector) preserved = Int[] @@ -795,13 +805,11 @@ function sroa_pass!(ir::IRCode) lift_comparison!(===, compact, idx, stmt, lifting_cache) elseif is_known_call(stmt, isa, compact) lift_comparison!(isa, compact, idx, stmt, lifting_cache) - elseif is_known_call(stmt, isdefined, compact) - lift_comparison!(isdefined, compact, idx, stmt, lifting_cache) end continue end - # analyze this `getfield` / `setfield!` call + # analyze this `getfield` / `isdefined` / `setfield!` call field = try_compute_field_stmt(compact, stmt) field === nothing && continue @@ -812,10 +820,15 @@ function sroa_pass!(ir::IRCode) if isa(struct_typ, Union) && struct_typ <: Tuple struct_typ = unswitchtupleunion(struct_typ) end + if isa(struct_typ, Union) && is_isdefined + lift_comparison!(isdefined, compact, idx, stmt, lifting_cache) + continue + end isa(struct_typ, DataType) || continue struct_typ.name.atomicfields == C_NULL || continue # TODO: handle more - if !(field_ordering === :unspecified || (field_ordering isa Const && field_ordering.val === :not_atomic)) + if !((field_ordering === :unspecified) || + (field_ordering isa Const && field_ordering.val === :not_atomic)) continue end @@ -836,6 +849,8 @@ function sroa_pass!(ir::IRCode) mid, defuse = get!(defuses, def.id, (SPCSet(), SSADefUse())) if is_setfield push!(defuse.defs, idx) + elseif is_isdefined + push!(defuse.uses, -idx) else push!(defuse.uses, idx) end @@ -844,6 +859,8 @@ function sroa_pass!(ir::IRCode) continue elseif is_setfield continue # invalid `setfield!` call, but just ignore here + elseif is_isdefined + continue # TODO? end # perform SROA on immutable structs here on @@ -927,9 +944,9 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse typ = typ::DataType # Partition defuses by field fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)] - all_forwarded = true + all_eliminated = all_forwarded = true for use in defuse.uses - stmt = ir[SSAValue(use)][:inst] # == `getfield` call + stmt = ir[SSAValue(abs(use))][:inst] # == `getfield`/`isdefined` call # We may have discovered above that this use is dead # after the getfield elim of immutables. In that case, # it would have been deleted. That's fine, just ignore @@ -969,7 +986,15 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse blocks[fidx] = phiblocks, allblocks if fidx + 1 > length(defexpr.args) for use in du.uses - has_safe_def(ir, get_domtree(), allblocks, du, newidx, use) || @goto skip + if use > 0 # == `getfield` use + has_safe_def(ir, get_domtree(), allblocks, du, newidx, use) || @goto skip + else # == `isdefined` use + if has_safe_def(ir, get_domtree(), allblocks, du, newidx, -use) + ir[SSAValue(-use)][:inst] = true + else + all_eliminated = false + end + end end end end @@ -991,8 +1016,12 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse NewInstruction(PhiNode(), ftyp)) end # Now go through all uses and rewrite them - for stmt in du.uses - ir[SSAValue(stmt)][:inst] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt) + for use in du.uses + if use > 0 # == `getfield` use + ir[SSAValue(use)][:inst] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use) + else # == `isdefined` use + continue # already rewritten if possible + end end if !isbitstype(ftyp) if preserve_uses !== nothing @@ -1010,6 +1039,10 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse end end end + all_eliminated || continue + # all "usages" (i.e. `getfield` and `isdefined` calls) are eliminated, + # now eliminate "definitions" (`setfield!`) calls + # (NOTE the allocation itself will be eliminated by DCE pass later) for stmt in du.defs stmt == newidx && continue ir[SSAValue(stmt)][:inst] = nothing diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 587d4402f554a..d441c7ebc4889 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -425,6 +425,44 @@ let src = code_typed1() do @test count(isnew, src.code) == 1 end +# isdefined elimination +# --------------------- + +let src = code_typed1((Any,)) do a + r = Ref{Any}() + r[] = a + if isassigned(r) + return r[] + end + return nothing + end + @test is_scalar_replaced(src) +end + +let src = code_typed1((Bool, Any,)) do cnd, a + r = Ref{Any}() + if cnd + r[] = a # this `setfield!` shouldn't be eliminated + end + return isassigned(r) + end + @test count(isnew, src.code) == 1 + @test count(iscall((src, setfield!)), src.code) == 1 +end + +callit(f, args...) = f(args...) +function isdefined_elim() + local arr::Vector{Any} + callit() do + arr = Any[] + end + return arr +end +let src = code_typed1(isdefined_elim) + @test is_scalar_replaced(src) +end +@test isdefined_elim() == Any[] + # comparison lifting # ==================