diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 2a337dd703db8..272131098bd23 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -99,9 +99,22 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I end end -# even when the allocation contains an uninitialized field, we try an extra effort to check -# if this load at `idx` have any "safe" `setfield!` calls that define the field function has_safe_def( + ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, + newidx::Int, fidx::Int) + newexpr = ir[SSAValue(newidx)]::Expr + + fidx + 1 ≤ length(newexpr.args) && return true # assured to have a safe definition for all usages + + # even when the allocation contains an uninitialized field, we try an extra effort to + # check if all loads have "safe" `setfield!` calls that define the uninitialized field + for use in du.uses + has_safe_def_for_uninitialized_field(ir, domtree, allblocks, du, newidx, use) || return false + end + return true # shuold be safe +end + +function has_safe_def_for_uninitialized_field( ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, newidx::Int, idx::Int) def, _, _ = find_def_for_use(ir, domtree, allblocks, du, idx) @@ -206,14 +219,15 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA end function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#), - @nospecialize(typeconstraint)) - callback = function (@nospecialize(pi), @nospecialize(idx)) - if isa(pi, PiNode) - typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ)) + @nospecialize(typeconstraint), @nospecialize(callback = nothing)) + newcallback = function (@nospecialize(x), @nospecialize(idx)) + if isa(x, PiNode) + typeconstraint = typeintersect(typeconstraint, widenconst(x.typ)) end + callback === nothing || callback(x, idx) return false end - def = simple_walk(compact, defssa, callback) + def = simple_walk(compact, defssa, newcallback) return Pair{Any, Any}(def, typeconstraint) end @@ -223,7 +237,9 @@ end Starting at `val` walk use-def chains to get all the leaves feeding into this `val` (pruning those leaves rules out by path conditions). """ -function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint)) +function walk_to_defs(compact::IncrementalCompact, + @nospecialize(defssa), @nospecialize(typeconstraint), + @nospecialize(callback = nothing)) visited_phinodes = AnySSAValue[] isa(defssa, AnySSAValue) || return Any[defssa], visited_phinodes def = compact[defssa] @@ -259,7 +275,7 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe val = OldSSAValue(val.id) end if isa(val, AnySSAValue) - new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint) + new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint, callback) if isa(new_def, AnySSAValue) if !haskey(visited_constraints, new_def) push!(worklist_defs, new_def) @@ -720,10 +736,10 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true) continue end if defuses === nothing - defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() + defuses = IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}() end - mid, defuse = get!(defuses, defidx) do - SPCSet(), SSADefUse() + mid, defuse, phidefs = get!(defuses, defidx) do + SPCSet(), SSADefUse(), PhiDefs(nothing) end push!(defuse.ccall_preserve_uses, idx) union!(mid, intermediaries) @@ -778,16 +794,29 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true) # Mutable stuff here isa(def, SSAValue) || continue if defuses === nothing - defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() + defuses = IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}() end - mid, defuse = get!(defuses, def.id) do - SPCSet(), SSADefUse() + mid, defuse, phidefs = get!(defuses, def.id) do + SPCSet(), SSADefUse(), PhiDefs(nothing) end if is_setfield push!(defuse.defs, idx) else push!(defuse.uses, idx) end + defval = compact[def] + if isa(defval, PhiNode) + phicallback = function (@nospecialize(x), @nospecialize(ssa)) + push!(intermediaries, ssa.id) + return false + end + defs, _ = walk_to_defs(compact, def, struct_typ, phicallback) + if _any(@nospecialize(d)->!isa(d, SSAValue), defs) + delete!(defuses, def.id) + continue + end + phidefs[] = Int[(def::SSAValue).id for def in defs] + end union!(mid, intermediaries) end continue @@ -847,8 +876,14 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true) end end +# TODO: +# - run mutable SROA on the same IR as when we collect information about mutable allocations +# - simplify and improve the eliminability check below using an escape analysis + +const PhiDefs = RefValue{Union{Nothing,Vector{Int}}} + function sroa_mutables!(ir::IRCode, - defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}, + defuses::IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}, used_ssas::Vector{Int}, nested_loads::NestedLoads) # Compute domtree, needed below, now that we have finished compacting the IR. # This needs to be after we iterate through the IR with `IncrementalCompact` @@ -858,36 +893,58 @@ function sroa_mutables!(ir::IRCode, nested_mloads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable` local any_eliminated = false # NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield` - for (idx, (intermediaries, defuse)) in sort!(collect(defuses); by=first, rev=true) + for (idx, (intermediaries, defuse, phidefs)) in sort!(collect(defuses); by=first, rev=true) intermediaries = collect(intermediaries) + phidefs = phidefs[] # Check if there are any uses we did not account for. If so, the variable # escapes and we cannot eliminate the allocation. This works, because we're guaranteed # not to include any intermediaries that have dead uses. As a result, missing uses will only ever # show up in the nuses_total count. - nleaves = length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses) + nleaves = count_leaves(defuse) + if phidefs !== nothing + # if this defines ϕ, we also track leaves of all predecessors as well + # FIXME this doesn't work when any predecessor is used by another ϕ-node + for pidx in phidefs + haskey(defuses, pidx) || continue + pdefuse = defuses[pidx][2] + nleaves += count_leaves(pdefuse) + end + end nuses = 0 for idx in intermediaries nuses += used_ssas[idx] end - nuses_total = used_ssas[idx] + nuses - length(intermediaries) + nuses -= length(intermediaries) + nuses_total = used_ssas[idx] + nuses + if phidefs !== nothing + for pidx in phidefs + # NOTE we don't need to accout for intermediates for this predecessor here, + # since they are already included in intermediates of this ϕ-node + # FIXME this doesn't work when any predecessor is used by another ϕ-node + nuses_total += used_ssas[pidx] - 1 # substract usage count from ϕ-node itself + end + end nleaves == nuses_total || continue # Find the type for this allocation defexpr = ir[SSAValue(idx)] - isa(defexpr, Expr) || continue - if !isexpr(defexpr, :new) - if is_known_call(defexpr, getfield, ir) - val = defexpr.args[2] - if isa(val, SSAValue) - struct_typ = unwrap_unionall(widenconst(argextype(val, ir))) - if ismutabletype(struct_typ) - record_nested_load!(nested_mloads, idx) - end + if isa(defexpr, Expr) + if !isexpr(defexpr, :new) + maybe_record_nested_load!(nested_mloads, ir, idx) + continue + end + elseif isa(defexpr, PhiNode) + phidefs === nothing && continue + for pidx in phidefs + pexpr = ir[SSAValue(pidx)] + if !isexpr(pexpr, :new) + maybe_record_nested_load!(nested_mloads, ir, pidx) + @goto skip end end + else continue end - newidx = idx - typ = ir.stmts[newidx][:type] + typ = ir.stmts[idx][:type] if isa(typ, UnionAll) typ = unwrap_unionall(typ) end @@ -899,25 +956,29 @@ function sroa_mutables!(ir::IRCode, fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)] all_forwarded = true for use in defuse.uses - stmt = ir[SSAValue(use)] # == `getfield` call - # We may have discovered above that this use is dead - # after the getfield elim of immutables. In that case, - # it would have been deleted. That's fine, just ignore - # the use in that case. - if stmt === nothing + eliminable = check_use_eliminability!(fielddefuse, ir, use, typ) + if eliminable === nothing + # We may have discovered above that this use is dead + # after the getfield elim of immutables. In that case, + # it would have been deleted. That's fine, just ignore + # the use in that case. all_forwarded = false continue + elseif !eliminable + @goto skip end - field = try_compute_fieldidx_stmt(ir, stmt::Expr, typ) - field === nothing && @goto skip - push!(fielddefuse[field].uses, use) end for def in defuse.defs - stmt = ir[SSAValue(def)]::Expr # == `setfield!` call - field = try_compute_fieldidx_stmt(ir, stmt, typ) - field === nothing && @goto skip - isconst(typ, field) && @goto skip # we discovered an attempt to mutate a const field, which must error - push!(fielddefuse[field].defs, def) + check_def_eliminability!(fielddefuse, ir, def, typ) || @goto skip + end + if phidefs !== nothing + for pidx in phidefs + haskey(defuses, pidx) || continue + pdefuse = defuses[pidx][2] + for pdef in pdefuse.defs + check_def_eliminability!(fielddefuse, ir, pdef, typ) || @goto skip + end + end end # Check that the defexpr has defined values for all the fields # we're accessing. In the future, we may want to relax this, @@ -928,15 +989,24 @@ function sroa_mutables!(ir::IRCode, for fidx in 1:ndefuse du = fielddefuse[fidx] isempty(du.uses) && continue - push!(du.defs, newidx) + if phidefs === nothing + push!(du.defs, idx) + else + for pidx in phidefs + push!(du.defs, pidx) + end + end ldu = compute_live_ins(ir.cfg, du) phiblocks = isempty(ldu.live_in_bbs) ? Int[] : iterated_dominance_frontier(ir.cfg, ldu, domtree) allblocks = sort(vcat(phiblocks, ldu.def_bbs)) blocks[fidx] = phiblocks, allblocks - if fidx + 1 > length(defexpr.args) - for use in du.uses - has_safe_def(ir, domtree, allblocks, du, newidx, use) || @goto skip + if phidefs !== nothing + # check if all predecessors have safe definitions + for pidx in phidefs + has_safe_def(ir, domtree, allblocks, du, pidx, fidx) || @goto skip end + else + has_safe_def(ir, domtree, allblocks, du, idx, fidx) || @goto skip end end # Everything accounted for. Go field by field and perform idf @@ -976,9 +1046,16 @@ function sroa_mutables!(ir::IRCode, end end end - for stmt in du.defs - stmt == newidx && continue - ir[SSAValue(stmt)] = nothing + if isa(defexpr, PhiNode) + ir[SSAValue(idx)] = nothing + for pidx in phidefs::Vector{Int} + used_ssas[pidx] -= 1 + end + else + for stmt in du.defs + stmt == idx && continue + ir[SSAValue(stmt)] = nothing + end end end preserve_uses === nothing && continue @@ -986,7 +1063,7 @@ function sroa_mutables!(ir::IRCode, # this means all ccall preserves have been replaced with forwarded loads # so we can potentially eliminate the allocation, otherwise we must preserve # the whole allocation. - push!(intermediaries, newidx) + push!(intermediaries, idx) end # Insert the new preserves for (use, new_preserves) in preserve_uses @@ -1002,6 +1079,42 @@ function sroa_mutables!(ir::IRCode, end end +count_leaves(defuse::SSADefUse) = + length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses) + +function maybe_record_nested_load!(nested_mloads::NestedLoads, ir::IRCode, idx::Int) + defexpr = ir[SSAValue(idx)] + if is_known_call(defexpr, getfield, ir) + val = defexpr.args[2] + if isa(val, SSAValue) + struct_typ = unwrap_unionall(widenconst(argextype(val, ir))) + if ismutabletype(struct_typ) + record_nested_load!(nested_mloads, idx) + end + end + end +end + +function check_use_eliminability!(fielddefuse::Vector{SSADefUse}, + ir::IRCode, useidx::Int, struct_typ::DataType) + stmt = ir[SSAValue(useidx)] # == `getfield` call + stmt === nothing && return nothing + field = try_compute_fieldidx_stmt(ir, stmt::Expr, struct_typ) + field === nothing && return false + push!(fielddefuse[field].uses, useidx) + return true +end + +function check_def_eliminability!(fielddefuse::Vector{SSADefUse}, + ir::IRCode, defidx::Int, struct_typ::DataType) + stmt = ir[SSAValue(defidx)]::Expr # == `setfield!` call + field = try_compute_fieldidx_stmt(ir, stmt, struct_typ) + field === nothing && return false + isconst(struct_typ, field) && return false # we discovered an attempt to mutate a const field, which must error + push!(fielddefuse[field].defs, defidx) + return true +end + function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any}) newex = Expr(:foreigncall) nccallargs = length(origex.args[3]::SimpleVector) diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 2d1d85e3df97f..5dcad9e6e83f0 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -230,7 +230,7 @@ let src = code_typed1((Any,Any,Any)) do x, y, z end end # FIXME? in order to handle nested mutable `getfield` calls, we run SROA iteratively until -# any nested mutable `getfield` calls become no longer eliminatable: +# any nested mutable `getfield` calls become no longer eliminable: # it's probably not the most efficient option and we may want to introduce some sort of # alias analysis and eliminates all the loads at once. # mutable(immutable(...)) case @@ -308,6 +308,188 @@ let # NOTE `sroa_mutables!` eliminate from innermost definitions, so that it sho @test !any(isnew, src.code) end +# ϕ-allocation elimination +# ------------------------ +mutable struct MutableSome + x::Any + MutableSome(@nospecialize x) = new(x) + MutableSome() = new() +end +Base.getindex(s::MutableSome) = s.x +Base.setindex!(s::MutableSome, @nospecialize x) = s.x = x +@testset "mutable ϕ-allocation elimination" begin + # safe cases + let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome(y) + end + ϕ[] + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 + end + let src = code_typed1((Bool,Bool,Any,Any,Any)) do cond1, cond2, x, y, z + if cond1 + ϕ = MutableSome(x) + elseif cond2 + ϕ = MutableSome(y) + else + ϕ = MutableSome(z) + end + ϕ[] + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(4) in x.values && + #=y=# Core.Argument(5) in x.values && + #=z=# Core.Argument(6) in x.values + end == 1 + end + let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome(y) + ϕ[] = z + end + ϕ[] + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=z=# Core.Argument(5) in x.values + end == 1 + end + let src = code_typed1((Bool,Any,Any,)) do cond, x, y + if cond + ϕ = MutableSome(x) + out1 = ϕ[] + else + ϕ = MutableSome(y) + out1 = ϕ[] + end + out2 = ϕ[] + out1, out2 + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 2 + end + let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z + if cond + ϕ = MutableSome(x) + out1 = ϕ[] + else + ϕ = MutableSome(y) + out1 = ϕ[] + ϕ[] = z + end + out2 = ϕ[] + out1, out2 + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=z=# Core.Argument(5) in x.values + end == 1 + end + + # unsafe cases + let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome(y) + end + some_escape(ϕ) + ϕ[] + end + @test count(isnew, src.code) == 2 + end + let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = MutableSome(x) + some_escape(ϕ) + else + ϕ = MutableSome(y) + end + ϕ[] + end + @test count(isnew, src.code) == 2 + end + let src = code_typed1((Bool,Any,)) do cond, x + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome() + end + ϕ[] + end + @test count(isnew, src.code) == 2 + end + let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome() + ϕ[] = y + end + ϕ[] + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 + end + + # FIXME allocation forming multiple ϕ + let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ2 = ϕ1 = MutableSome(x) + else + ϕ2 = ϕ1 = MutableSome(y) + end + ϕ1[], ϕ2[] + end + @test_broken !any(isnew, src.code) + @test_broken count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 + end +end +function mutable_ϕ_elim(x, xs) + r = Ref(x) + for x in xs + r = Ref(x) + end + return r[] +end +let xs = String[string(gensym()) for _ in 1:100] + mutable_ϕ_elim("init", xs) + @test @allocated(mutable_ϕ_elim("init", xs)) == 0 +end + # should work nicely with inlining to optimize away a complicated case # adapted from http://wiki.luajit.org/Allocation-Sinking-Optimization#implementation%5B struct Point