From 1b8be01ebb6556449211a551ece15a4bf4c9d6fb Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Wed, 22 Dec 2021 02:05:57 +0900 Subject: [PATCH] =?UTF-8?q?optimizer:=20enable=20SROA=20of=20mutable=20?= =?UTF-8?q?=CF=86-nodes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit allows elimination of mutable φ-node (and its predecessor mutables allocations). As an contrived example, it allows this `mutable_ϕ_elim(::String, ::Vector{String})` to run without any allocations at all: ```julia function mutable_ϕ_elim(x, xs) r = Ref(x) for x in xs r = Ref(x) end return r[] end let xs = String[string(gensym()) for _ in 1:100] mutable_ϕ_elim("init", xs) @test @allocated(mutable_ϕ_elim("init", xs)) == 0 end ``` This mutable ϕ-node elimination is still limited though. Most notably, the current implementation doesn't work if a mutable allocation forms multiple ϕ-nodes, since we check allocation eliminability (i.e. escapability) by counting usages counts and thus it's hard to reason about multiple ϕ-nodes at a time. For example, currently mutable allocations involved in cases like below will still not be eliminated: ```julia code_typed((Bool,String,String),) do cond, x, y if cond ϕ2 = ϕ1 = Ref(x) else ϕ2 = ϕ1 = Ref(y) end ϕ1[], ϕ2[] end \# more realistic example mutable struct Point{T} x::T y::T end add(a::Point, b::Point) = Point(a.x + b.x, a.y + b.y) function compute(a::Point{ComplexF64}, b::Point{ComplexF64}) for i in 0:(100000000-1) a = add(add(a, b), b) end a.x, a.y end ``` I'd say this limitation should be addressed by first introducing a better abstraction for reasoning escape information. More specifically, I'd like introduce EscapeAnalysis.jl into Julia base first, and then gradually adapt it to improve our SROA pass, since EA will allow us to reason about all escape information imposed on whatever object more easily and should help us get rid of the complexities of our current SROA implementation. For now, I'd like to get in this enhancement even though it has the limitation elaborated above, as far as this commit doesn't introduce latency problem (which is unlikely). --- base/compiler/ssair/passes.jl | 217 ++++++++++++++++++++++++++-------- test/compiler/irpasses.jl | 184 +++++++++++++++++++++++++++- 2 files changed, 348 insertions(+), 53 deletions(-) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index a1853dbff2f73b..a06a44b60f56c4 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -100,9 +100,22 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I end end -# even when the allocation contains an uninitialized field, we try an extra effort to check -# if this load at `idx` have any "safe" `setfield!` calls that define the field function has_safe_def( + ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, + newidx::Int, fidx::Int) + newexpr = ir[SSAValue(newidx)]::Expr + + fidx + 1 ≤ length(newexpr.args) && return true # assured to have a safe definition for all usages + + # even when the allocation contains an uninitialized field, we try an extra effort to + # check if all loads have "safe" `setfield!` calls that define the uninitialized field + for use in du.uses + has_safe_def_for_uninitialized_field(ir, domtree, allblocks, du, newidx, use) || return false + end + return true # shuold be safe +end + +function has_safe_def_for_uninitialized_field( ir::IRCode, domtree::DomTree, allblocks::Vector{Int}, du::SSADefUse, newidx::Int, idx::Int) def, _, _ = find_def_for_use(ir, domtree, allblocks, du, idx) @@ -207,14 +220,15 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA end function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#), - @nospecialize(typeconstraint)) - callback = function (@nospecialize(pi), @nospecialize(idx)) - if isa(pi, PiNode) - typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ)) + @nospecialize(typeconstraint), @nospecialize(callback = nothing)) + newcallback = function (@nospecialize(x), @nospecialize(idx)) + if isa(x, PiNode) + typeconstraint = typeintersect(typeconstraint, widenconst(x.typ)) end + callback === nothing || callback(x, idx) return false end - def = simple_walk(compact, defssa, callback) + def = simple_walk(compact, defssa, newcallback) return Pair{Any, Any}(def, typeconstraint) end @@ -224,7 +238,9 @@ end Starting at `val` walk use-def chains to get all the leaves feeding into this `val` (pruning those leaves rules out by path conditions). """ -function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint)) +function walk_to_defs(compact::IncrementalCompact, + @nospecialize(defssa), @nospecialize(typeconstraint), + @nospecialize(callback = nothing)) visited_phinodes = AnySSAValue[] isa(defssa, AnySSAValue) || return Any[defssa], visited_phinodes def = compact[defssa] @@ -260,7 +276,7 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe val = OldSSAValue(val.id) end if isa(val, AnySSAValue) - new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint) + new_def, new_constraint = simple_walk_constraint(compact, val, typeconstraint, callback) if isa(new_def, AnySSAValue) if !haskey(visited_constraints, new_def) push!(worklist_defs, new_def) @@ -721,10 +737,10 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true) continue end if defuses === nothing - defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() + defuses = IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}() end - mid, defuse = get!(defuses, defidx) do - SPCSet(), SSADefUse() + mid, defuse, phidefs = get!(defuses, defidx) do + SPCSet(), SSADefUse(), PhiDefs(nothing) end push!(defuse.ccall_preserve_uses, idx) union!(mid, intermediaries) @@ -779,16 +795,29 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true) # Mutable stuff here isa(def, SSAValue) || continue if defuses === nothing - defuses = IdDict{Int, Tuple{SPCSet, SSADefUse}}() + defuses = IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}() end - mid, defuse = get!(defuses, def.id) do - SPCSet(), SSADefUse() + mid, defuse, phidefs = get!(defuses, def.id) do + SPCSet(), SSADefUse(), PhiDefs(nothing) end if is_setfield push!(defuse.defs, idx) else push!(defuse.uses, idx) end + defval = compact[def] + if isa(defval, PhiNode) + phicallback = function (@nospecialize(x), @nospecialize(ssa)) + push!(intermediaries, ssa.id) + return false + end + defs, _ = walk_to_defs(compact, def, struct_typ, phicallback) + if _any(@nospecialize(d)->!isa(d, SSAValue), defs) + delete!(defuses, def.id) + continue + end + phidefs[] = Int[(def::SSAValue).id for def in defs] + end union!(mid, intermediaries) end continue @@ -848,8 +877,14 @@ function sroa_pass!(ir::IRCode, optional_opts::Bool = true) end end +# TODO: +# - run mutable SROA on the same IR as when we collect information about mutable allocations +# - simplify and improve the eliminability check below using an escape analysis + +const PhiDefs = RefValue{Union{Nothing,Vector{Int}}} + function sroa_mutables!(ir::IRCode, - defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}, + defuses::IdDict{Int, Tuple{SPCSet, SSADefUse, PhiDefs}}, used_ssas::Vector{Int}, nested_loads::NestedLoads) # Compute domtree, needed below, now that we have finished compacting the IR. # This needs to be after we iterate through the IR with `IncrementalCompact` @@ -859,36 +894,58 @@ function sroa_mutables!(ir::IRCode, nested_mloads = NestedLoads() # tracks nested `getfield(getfield(...)::Mutable, ...)::Mutable` local any_eliminated = false # NOTE eliminate from innermost definitions, so that we can track elimination of nested `getfield` - for (idx, (intermediaries, defuse)) in sort!(collect(defuses); by=first, rev=true) + for (idx, (intermediaries, defuse, phidefs)) in sort!(collect(defuses); by=first, rev=true) intermediaries = collect(intermediaries) + phidefs = phidefs[] # Check if there are any uses we did not account for. If so, the variable # escapes and we cannot eliminate the allocation. This works, because we're guaranteed # not to include any intermediaries that have dead uses. As a result, missing uses will only ever # show up in the nuses_total count. - nleaves = length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses) + nleaves = count_leaves(defuse) + if phidefs !== nothing + # if this defines ϕ, we also track leaves of all predecessors as well + # FIXME this doesn't work when any predecessor is used by another ϕ-node + for pidx in phidefs + haskey(defuses, pidx) || continue + pdefuse = defuses[pidx][2] + nleaves += count_leaves(pdefuse) + end + end nuses = 0 for idx in intermediaries nuses += used_ssas[idx] end - nuses_total = used_ssas[idx] + nuses - length(intermediaries) + nuses -= length(intermediaries) + nuses_total = used_ssas[idx] + nuses + if phidefs !== nothing + for pidx in phidefs + # NOTE we don't need to accout for intermediates for this predecessor here, + # since they are already included in intermediates of this ϕ-node + # FIXME this doesn't work when any predecessor is used by another ϕ-node + nuses_total += used_ssas[pidx] - 1 # substract usage count from ϕ-node itself + end + end nleaves == nuses_total || continue # Find the type for this allocation defexpr = ir[SSAValue(idx)] - isa(defexpr, Expr) || continue - if !isexpr(defexpr, :new) - if is_known_call(defexpr, getfield, ir) - val = defexpr.args[2] - if isa(val, SSAValue) - struct_typ = unwrap_unionall(widenconst(argextype(val, ir))) - if ismutabletype(struct_typ) - record_nested_load!(nested_mloads, idx) - end + if isa(defexpr, Expr) + if !isexpr(defexpr, :new) + maybe_record_nested_load!(nested_mloads, ir, idx) + continue + end + elseif isa(defexpr, PhiNode) + phidefs === nothing && continue + for pidx in phidefs + pexpr = ir[SSAValue(pidx)] + if !isexpr(pexpr, :new) + maybe_record_nested_load!(nested_mloads, ir, pidx) + @goto skip end end + else continue end - newidx = idx - typ = ir.stmts[newidx][:type] + typ = ir.stmts[idx][:type] if isa(typ, UnionAll) typ = unwrap_unionall(typ) end @@ -900,25 +957,29 @@ function sroa_mutables!(ir::IRCode, fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)] all_forwarded = true for use in defuse.uses - stmt = ir[SSAValue(use)] # == `getfield` call - # We may have discovered above that this use is dead - # after the getfield elim of immutables. In that case, - # it would have been deleted. That's fine, just ignore - # the use in that case. - if stmt === nothing + eliminable = check_use_eliminability!(fielddefuse, ir, use, typ) + if eliminable === nothing + # We may have discovered above that this use is dead + # after the getfield elim of immutables. In that case, + # it would have been deleted. That's fine, just ignore + # the use in that case. all_forwarded = false continue + elseif !eliminable + @goto skip end - field = try_compute_fieldidx_stmt(ir, stmt::Expr, typ) - field === nothing && @goto skip - push!(fielddefuse[field].uses, use) end for def in defuse.defs - stmt = ir[SSAValue(def)]::Expr # == `setfield!` call - field = try_compute_fieldidx_stmt(ir, stmt, typ) - field === nothing && @goto skip - isconst(typ, field) && @goto skip # we discovered an attempt to mutate a const field, which must error - push!(fielddefuse[field].defs, def) + check_def_eliminability!(fielddefuse, ir, def, typ) || @goto skip + end + if phidefs !== nothing + for pidx in phidefs + haskey(defuses, pidx) || continue + pdefuse = defuses[pidx][2] + for pdef in pdefuse.defs + check_def_eliminability!(fielddefuse, ir, pdef, typ) || @goto skip + end + end end # Check that the defexpr has defined values for all the fields # we're accessing. In the future, we may want to relax this, @@ -929,15 +990,24 @@ function sroa_mutables!(ir::IRCode, for fidx in 1:ndefuse du = fielddefuse[fidx] isempty(du.uses) && continue - push!(du.defs, newidx) + if phidefs === nothing + push!(du.defs, idx) + else + for pidx in phidefs + push!(du.defs, pidx) + end + end ldu = compute_live_ins(ir.cfg, du) phiblocks = isempty(ldu.live_in_bbs) ? Int[] : iterated_dominance_frontier(ir.cfg, ldu, domtree) allblocks = sort(vcat(phiblocks, ldu.def_bbs)) blocks[fidx] = phiblocks, allblocks - if fidx + 1 > length(defexpr.args) - for use in du.uses - has_safe_def(ir, domtree, allblocks, du, newidx, use) || @goto skip + if phidefs !== nothing + # check if all predecessors have safe definitions + for pidx in phidefs + has_safe_def(ir, domtree, allblocks, du, pidx, fidx) || @goto skip end + else + has_safe_def(ir, domtree, allblocks, du, idx, fidx) || @goto skip end end # Everything accounted for. Go field by field and perform idf @@ -977,9 +1047,16 @@ function sroa_mutables!(ir::IRCode, end end end - for stmt in du.defs - stmt == newidx && continue - ir[SSAValue(stmt)] = nothing + if isa(defexpr, PhiNode) + ir[SSAValue(idx)] = nothing + for pidx in phidefs::Vector{Int} + used_ssas[pidx] -= 1 + end + else + for stmt in du.defs + stmt == idx && continue + ir[SSAValue(stmt)] = nothing + end end end preserve_uses === nothing && continue @@ -987,7 +1064,7 @@ function sroa_mutables!(ir::IRCode, # this means all ccall preserves have been replaced with forwarded loads # so we can potentially eliminate the allocation, otherwise we must preserve # the whole allocation. - push!(intermediaries, newidx) + push!(intermediaries, idx) end # Insert the new preserves for (use, new_preserves) in preserve_uses @@ -1003,6 +1080,42 @@ function sroa_mutables!(ir::IRCode, end end +count_leaves(defuse::SSADefUse) = + length(defuse.uses) + length(defuse.defs) + length(defuse.ccall_preserve_uses) + +function maybe_record_nested_load!(nested_mloads::NestedLoads, ir::IRCode, idx::Int) + defexpr = ir[SSAValue(idx)] + if is_known_call(defexpr, getfield, ir) + val = defexpr.args[2] + if isa(val, SSAValue) + struct_typ = unwrap_unionall(widenconst(argextype(val, ir))) + if ismutabletype(struct_typ) + record_nested_load!(nested_mloads, idx) + end + end + end +end + +function check_use_eliminability!(fielddefuse::Vector{SSADefUse}, + ir::IRCode, useidx::Int, struct_typ::DataType) + stmt = ir[SSAValue(useidx)] # == `getfield` call + stmt === nothing && return nothing + field = try_compute_fieldidx_stmt(ir, stmt::Expr, struct_typ) + field === nothing && return false + push!(fielddefuse[field].uses, useidx) + return true +end + +function check_def_eliminability!(fielddefuse::Vector{SSADefUse}, + ir::IRCode, defidx::Int, struct_typ::DataType) + stmt = ir[SSAValue(defidx)]::Expr # == `setfield!` call + field = try_compute_fieldidx_stmt(ir, stmt, struct_typ) + field === nothing && return false + isconst(struct_typ, field) && return false # we discovered an attempt to mutate a const field, which must error + push!(fielddefuse[field].defs, defidx) + return true +end + function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any}) newex = Expr(:foreigncall) nccallargs = length(origex.args[3]::SimpleVector) diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index 2d1d85e3df97fd..5dcad9e6e83f03 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -230,7 +230,7 @@ let src = code_typed1((Any,Any,Any)) do x, y, z end end # FIXME? in order to handle nested mutable `getfield` calls, we run SROA iteratively until -# any nested mutable `getfield` calls become no longer eliminatable: +# any nested mutable `getfield` calls become no longer eliminable: # it's probably not the most efficient option and we may want to introduce some sort of # alias analysis and eliminates all the loads at once. # mutable(immutable(...)) case @@ -308,6 +308,188 @@ let # NOTE `sroa_mutables!` eliminate from innermost definitions, so that it sho @test !any(isnew, src.code) end +# ϕ-allocation elimination +# ------------------------ +mutable struct MutableSome + x::Any + MutableSome(@nospecialize x) = new(x) + MutableSome() = new() +end +Base.getindex(s::MutableSome) = s.x +Base.setindex!(s::MutableSome, @nospecialize x) = s.x = x +@testset "mutable ϕ-allocation elimination" begin + # safe cases + let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome(y) + end + ϕ[] + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 + end + let src = code_typed1((Bool,Bool,Any,Any,Any)) do cond1, cond2, x, y, z + if cond1 + ϕ = MutableSome(x) + elseif cond2 + ϕ = MutableSome(y) + else + ϕ = MutableSome(z) + end + ϕ[] + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(4) in x.values && + #=y=# Core.Argument(5) in x.values && + #=z=# Core.Argument(6) in x.values + end == 1 + end + let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome(y) + ϕ[] = z + end + ϕ[] + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=z=# Core.Argument(5) in x.values + end == 1 + end + let src = code_typed1((Bool,Any,Any,)) do cond, x, y + if cond + ϕ = MutableSome(x) + out1 = ϕ[] + else + ϕ = MutableSome(y) + out1 = ϕ[] + end + out2 = ϕ[] + out1, out2 + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 2 + end + let src = code_typed1((Bool,Any,Any,Any)) do cond, x, y, z + if cond + ϕ = MutableSome(x) + out1 = ϕ[] + else + ϕ = MutableSome(y) + out1 = ϕ[] + ϕ[] = z + end + out2 = ϕ[] + out1, out2 + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=z=# Core.Argument(5) in x.values + end == 1 + end + + # unsafe cases + let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome(y) + end + some_escape(ϕ) + ϕ[] + end + @test count(isnew, src.code) == 2 + end + let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = MutableSome(x) + some_escape(ϕ) + else + ϕ = MutableSome(y) + end + ϕ[] + end + @test count(isnew, src.code) == 2 + end + let src = code_typed1((Bool,Any,)) do cond, x + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome() + end + ϕ[] + end + @test count(isnew, src.code) == 2 + end + let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ = MutableSome(x) + else + ϕ = MutableSome() + ϕ[] = y + end + ϕ[] + end + @test !any(isnew, src.code) + @test count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 + end + + # FIXME allocation forming multiple ϕ + let src = code_typed1((Bool,Any,Any)) do cond, x, y + if cond + ϕ2 = ϕ1 = MutableSome(x) + else + ϕ2 = ϕ1 = MutableSome(y) + end + ϕ1[], ϕ2[] + end + @test_broken !any(isnew, src.code) + @test_broken count(src.code) do @nospecialize x + isa(x, Core.PhiNode) && + #=x=# Core.Argument(3) in x.values && + #=y=# Core.Argument(4) in x.values + end == 1 + end +end +function mutable_ϕ_elim(x, xs) + r = Ref(x) + for x in xs + r = Ref(x) + end + return r[] +end +let xs = String[string(gensym()) for _ in 1:100] + mutable_ϕ_elim("init", xs) + @test @allocated(mutable_ϕ_elim("init", xs)) == 0 +end + # should work nicely with inlining to optimize away a complicated case # adapted from http://wiki.luajit.org/Allocation-Sinking-Optimization#implementation%5B struct Point