diff --git a/base/Base.jl b/base/Base.jl index 8fd0780133818e..82b1b43822d7b8 100644 --- a/base/Base.jl +++ b/base/Base.jl @@ -397,6 +397,9 @@ include("util.jl") include("asyncmap.jl") +include("compiler/ssair/EscapeAnalysis/EAUtils.jl") +using .EAUtils: code_escapes + # deprecated functions include("deprecated.jl") diff --git a/base/compiler/bootstrap.jl b/base/compiler/bootstrap.jl index 2517b181d28048..75ec987656509f 100644 --- a/base/compiler/bootstrap.jl +++ b/base/compiler/bootstrap.jl @@ -14,7 +14,7 @@ let fs = Any[ # we first create caches for the optimizer, because they contain many loop constructions # and they're better to not run in interpreter even during bootstrapping - run_passes, + analyze_escapes, run_passes, # then we create caches for inference entries typeinf_ext, typeinf, typeinf_edge, ] diff --git a/base/compiler/compiler.jl b/base/compiler/compiler.jl index 5f2f5614ba2091..a0a2432de351ad 100644 --- a/base/compiler/compiler.jl +++ b/base/compiler/compiler.jl @@ -95,6 +95,8 @@ ntuple(f, n) = (Any[f(i) for i = 1:n]...,) # core docsystem include("docs/core.jl") +import Core.Compiler.CoreDocs +Core.atdoc!(CoreDocs.docm) # sorting function sort end diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index 56f23ca7c2b398..d2f61b00644553 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -1,5 +1,45 @@ # This file is a part of Julia. License is MIT: https://julialang.org/license +############# +# constants # +############# + +# The slot has uses that are not statically dominated by any assignment +# This is implied by `SLOT_USEDUNDEF`. +# If this is not set, all the uses are (statically) dominated by the defs. +# In particular, if a slot has `AssignedOnce && !StaticUndef`, it is an SSA. +const SLOT_STATICUNDEF = 1 # slot might be used before it is defined (structurally) +const SLOT_ASSIGNEDONCE = 16 # slot is assigned to only once +const SLOT_USEDUNDEF = 32 # slot has uses that might raise UndefVarError +# const SLOT_CALLED = 64 + +# NOTE make sure to sync the flag definitions below with julia.h and `jl_code_info_set_ir` in method.c + +const IR_FLAG_NULL = 0x00 +# This statement is marked as @inbounds by user. +# Ff replaced by inlining, any contained boundschecks may be removed. +const IR_FLAG_INBOUNDS = 0x01 << 0 +# This statement is marked as @inline by user +const IR_FLAG_INLINE = 0x01 << 1 +# This statement is marked as @noinline by user +const IR_FLAG_NOINLINE = 0x01 << 2 +const IR_FLAG_THROW_BLOCK = 0x01 << 3 +# This statement may be removed if its result is unused. In particular it must +# thus be both pure and effect free. +const IR_FLAG_EFFECT_FREE = 0x01 << 4 + +# known to be always effect-free (in particular nothrow) +const _PURE_BUILTINS = Any[tuple, svec, ===, typeof, nfields] + +# known to be effect-free if the are nothrow +const _PURE_OR_ERROR_BUILTINS = [ + fieldtype, apply_type, isa, UnionAll, + getfield, arrayref, const_arrayref, arraysize, isdefined, Core.sizeof, + Core.kwfunc, Core.ifelse, Core._typevar, (<:) +] + +const TOP_TUPLE = GlobalRef(Core, :tuple) + ##################### # OptimizationState # ##################### @@ -52,7 +92,10 @@ function inlining_policy(interp::AbstractInterpreter, @nospecialize(src), stmt_f return nothing end +function argextype end # imported by EscapeAnalysis include("compiler/ssair/driver.jl") +using .EscapeAnalysis +import .EscapeAnalysis: EscapeState mutable struct OptimizationState linfo::MethodInstance @@ -121,46 +164,6 @@ function ir_to_codeinf!(opt::OptimizationState) return src end -############# -# constants # -############# - -# The slot has uses that are not statically dominated by any assignment -# This is implied by `SLOT_USEDUNDEF`. -# If this is not set, all the uses are (statically) dominated by the defs. -# In particular, if a slot has `AssignedOnce && !StaticUndef`, it is an SSA. -const SLOT_STATICUNDEF = 1 # slot might be used before it is defined (structurally) -const SLOT_ASSIGNEDONCE = 16 # slot is assigned to only once -const SLOT_USEDUNDEF = 32 # slot has uses that might raise UndefVarError -# const SLOT_CALLED = 64 - -# NOTE make sure to sync the flag definitions below with julia.h and `jl_code_info_set_ir` in method.c - -const IR_FLAG_NULL = 0x00 -# This statement is marked as @inbounds by user. -# Ff replaced by inlining, any contained boundschecks may be removed. -const IR_FLAG_INBOUNDS = 0x01 << 0 -# This statement is marked as @inline by user -const IR_FLAG_INLINE = 0x01 << 1 -# This statement is marked as @noinline by user -const IR_FLAG_NOINLINE = 0x01 << 2 -const IR_FLAG_THROW_BLOCK = 0x01 << 3 -# This statement may be removed if its result is unused. In particular it must -# thus be both pure and effect free. -const IR_FLAG_EFFECT_FREE = 0x01 << 4 - -# known to be always effect-free (in particular nothrow) -const _PURE_BUILTINS = Any[tuple, svec, ===, typeof, nfields] - -# known to be effect-free if the are nothrow -const _PURE_OR_ERROR_BUILTINS = [ - fieldtype, apply_type, isa, UnionAll, - getfield, arrayref, const_arrayref, arraysize, isdefined, Core.sizeof, - Core.kwfunc, Core.ifelse, Core._typevar, (<:) -] - -const TOP_TUPLE = GlobalRef(Core, :tuple) - ######### # logic # ######### @@ -514,7 +517,8 @@ function run_passes(ci::CodeInfo, sv::OptimizationState) @timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds) # @timeit "verify 2" verify_ir(ir) @timeit "compact 2" ir = compact!(ir) - @timeit "SROA" ir = sroa_pass!(ir) + nargs = let def = sv.linfo.def; isa(def, Method) ? Int(def.nargs) : 0; end + @timeit "SROA" ir = sroa_pass!(ir, nargs) @timeit "ADCE" ir = adce_pass!(ir) @timeit "type lift" ir = type_lift_pass!(ir) @timeit "compact 3" ir = compact!(ir) diff --git a/base/compiler/ssair/EscapeAnalysis/EAUtils.jl b/base/compiler/ssair/EscapeAnalysis/EAUtils.jl new file mode 100644 index 00000000000000..3046b64424c0dd --- /dev/null +++ b/base/compiler/ssair/EscapeAnalysis/EAUtils.jl @@ -0,0 +1,403 @@ +const EA_AS_PKG = Symbol(@__MODULE__) !== :Base # develop EA as an external package + +module EAUtils + +import ..EA_AS_PKG +if EA_AS_PKG + import ..EscapeAnalysis +else + import Core.Compiler.EscapeAnalysis: EscapeAnalysis + Base.getindex(estate::EscapeAnalysis.EscapeState, @nospecialize(x)) = + Core.Compiler.getindex(estate, x) +end +const EA = EscapeAnalysis +const CC = Core.Compiler + +# entries +# ------- + +@static if EA_AS_PKG +import InteractiveUtils: gen_call_with_extracted_types_and_kwargs + +@doc """ + @code_escapes [options...] f(args...) + +Evaluates the arguments to the function call, determines its types, and then calls +[`code_escapes`](@ref) on the resulting expression. +As with `@code_typed` and its family, any of `code_escapes` keyword arguments can be given +as the optional arguments like `@code_escapes interp=myinterp myfunc(myargs...)`. +""" +macro code_escapes(ex0...) + return gen_call_with_extracted_types_and_kwargs(__module__, :code_escapes, ex0) +end +end # @static if EA_AS_PKG + +""" + code_escapes(f, argtypes=Tuple{}; [world], [interp]) -> result::EscapeResult + code_escapes(tt::Type{<:Tuple}; [world], [interp]) -> result::EscapeResult + +Runs the escape analysis on optimized IR of a generic function call with the given type signature. +Note that the escape analysis runs after inlining, but before any other optimizations. + +```julia +julia> mutable struct SafeRef{T} + x::T + end + +julia> Base.getindex(x::SafeRef) = x.x; + +julia> Base.isassigned(x::SafeRef) = true; + +julia> get′(x) = isassigned(x) ? x[] : throw(x); + +julia> result = code_escapes((String,String,String)) do s1, s2, s3 + r1 = Ref(s1) + r2 = Ref(s2) + r3 = SafeRef(s3) + try + s1 = get′(r1) + ret = sizeof(s1) + catch err + global g = err # will definitely escape `r1` + end + s2 = get′(r2) # still `r2` doesn't escape fully + s3 = get′(r3) # still `r2` doesn't escape fully + return s2, s3 + end +#3(X _2::String, ↑ _3::String, ↑ _4::String) in Main at REPL[7]:2 +2 X 1 ── %1 = %new(Base.RefValue{String}, _2)::Base.RefValue{String} │╻╷╷ Ref +3 *′ │ %2 = %new(Base.RefValue{String}, _3)::Base.RefValue{String} │╻╷╷ Ref +4 ✓′ └─── %3 = %new(SafeRef{String}, _4)::SafeRef{String} │╻╷ SafeRef +5 ◌ 2 ── %4 = \$(Expr(:enter, #8)) │ + ✓′ │ %5 = ϒ (%3)::SafeRef{String} │ + *′ └─── %6 = ϒ (%2)::Base.RefValue{String} │ +6 ◌ 3 ── %7 = Base.isdefined(%1, :x)::Bool │╻╷ get′ + ◌ └─── goto #5 if not %7 ││ + X 4 ── Base.getfield(%1, :x)::String ││╻ getindex + ◌ └─── goto #6 ││ + ◌ 5 ── Main.throw(%1)::Union{} ││ + ◌ └─── unreachable ││ +7 ◌ 6 ── nothing::typeof(Core.sizeof) │╻ sizeof + ◌ │ nothing::Int64 ││ + ◌ └─── \$(Expr(:leave, 1)) │ + ◌ 7 ── goto #10 │ + ✓′ 8 ── %17 = φᶜ (%5)::SafeRef{String} │ + *′ │ %18 = φᶜ (%6)::Base.RefValue{String} │ + ◌ └─── \$(Expr(:leave, 1)) │ + X 9 ── %20 = \$(Expr(:the_exception))::Any │ +9 ◌ │ (Main.g = %20)::Any │ + ◌ └─── \$(Expr(:pop_exception, :(%4)))::Any │ +11 ✓′ 10 ┄ %23 = φ (#7 => %3, #9 => %17)::SafeRef{String} │ + *′ │ %24 = φ (#7 => %2, #9 => %18)::Base.RefValue{String} │ + ◌ │ %25 = Base.isdefined(%24, :x)::Bool ││╻ isassigned + ◌ └─── goto #12 if not %25 ││ + ↑ 11 ─ %27 = Base.getfield(%24, :x)::String │││╻ getproperty + ◌ └─── goto #13 ││ + ◌ 12 ─ Main.throw(%24)::Union{} ││ + ◌ └─── unreachable ││ +12 ↑ 13 ─ %31 = Base.getfield(%23, :x)::String │╻╷╷ get′ +13 ↑ │ %32 = Core.tuple(%27, %31)::Tuple{String, String} │ + ◌ └─── return %32 │ +``` + +The symbols in the side of each call argument and SSA statements represents the following meaning: +- `◌`: this value is not analyzed because escape information of it won't be used anyway (when the object is `isbitstype` for example) +- `✓`: this value never escapes (`has_no_escape(result.state[x])` holds) +- `↑`: this value can escape to the caller via return (`has_return_escape(result.state[x])` holds) +- `X`: this value can escape to somewhere the escape analysis can't reason about like escapes to a global memory (`has_all_escape(result.state[x])` holds) +- `*`: this value's escape state is between the `ReturnEscape` and `AllEscape` in the lattice of [`EscapeInfo`](@ref), e.g. it has unhandled `ThrownEscape` +- `′`: this value has additional field/aliasing information in its `AliasInfo` property + +For testing, escape information of each call argument and SSA value can be inspected programmatically as like: +```julia +julia> result.state[Core.Argument(3)] +ReturnEscape + +julia> result.state[Core.SSAValue(3)] +NoEscape′ +``` +""" +function code_escapes(@nospecialize(args...); + world = get_world_counter(), + interp = Core.Compiler.NativeInterpreter(world)) + interp = EscapeAnalyzer(interp) + results = code_typed(args...; optimize=true, world, interp) + isone(length(results)) || throw(ArgumentError("`code_escapes` only supports single analysis result")) + return EscapeResult(interp.ir, interp.state, interp.linfo) +end + +# AbstractInterpreter +# ------------------- + +# imports +import .CC: + AbstractInterpreter, + NativeInterpreter, + WorldView, + WorldRange, + InferenceParams, + OptimizationParams, + get_world_counter, + get_inference_cache, + lock_mi_inference, + unlock_mi_inference, + add_remark!, + may_optimize, + may_compress, + may_discard_trees, + verbose_stmt_info, + code_cache, + @timeit, + get_inference_cache, + convert_to_ircode, + slot2reg, + compact!, + ssa_inlining_pass!, + sroa_pass!, + adce_pass!, + type_lift_pass!, + JLOptions, + verify_ir, + verify_linetable +# usings +import Core: + CodeInstance, MethodInstance, CodeInfo +import .CC: + OptimizationState, IRCode +import .EA: + analyze_escapes, cache_escapes! + +mutable struct EscapeAnalyzer{State} <: AbstractInterpreter + native::NativeInterpreter + ir::IRCode + state::State + linfo::MethodInstance + EscapeAnalyzer(native::NativeInterpreter) = new{EscapeState}(native) +end + +CC.InferenceParams(interp::EscapeAnalyzer) = InferenceParams(interp.native) +CC.OptimizationParams(interp::EscapeAnalyzer) = OptimizationParams(interp.native) +CC.get_world_counter(interp::EscapeAnalyzer) = get_world_counter(interp.native) + +CC.lock_mi_inference(::EscapeAnalyzer, ::MethodInstance) = nothing +CC.unlock_mi_inference(::EscapeAnalyzer, ::MethodInstance) = nothing + +CC.add_remark!(interp::EscapeAnalyzer, sv, s) = add_remark!(interp.native, sv, s) + +CC.may_optimize(interp::EscapeAnalyzer) = may_optimize(interp.native) +CC.may_compress(interp::EscapeAnalyzer) = may_compress(interp.native) +CC.may_discard_trees(interp::EscapeAnalyzer) = may_discard_trees(interp.native) +CC.verbose_stmt_info(interp::EscapeAnalyzer) = verbose_stmt_info(interp.native) + +CC.get_inference_cache(interp::EscapeAnalyzer) = get_inference_cache(interp.native) + +const GLOBAL_CODE_CACHE = IdDict{MethodInstance,CodeInstance}() +__clear_code_cache!() = empty!(GLOBAL_CODE_CACHE) + +function CC.code_cache(interp::EscapeAnalyzer) + worlds = WorldRange(get_world_counter(interp)) + return WorldView(GlobalCache(), worlds) +end + +struct GlobalCache end + +CC.haskey(wvc::WorldView{GlobalCache}, mi::MethodInstance) = haskey(GLOBAL_CODE_CACHE, mi) + +CC.get(wvc::WorldView{GlobalCache}, mi::MethodInstance, default) = get(GLOBAL_CODE_CACHE, mi, default) + +CC.getindex(wvc::WorldView{GlobalCache}, mi::MethodInstance) = getindex(GLOBAL_CODE_CACHE, mi) + +function CC.setindex!(wvc::WorldView{GlobalCache}, ci::CodeInstance, mi::MethodInstance) + GLOBAL_CODE_CACHE[mi] = ci + add_callback!(mi) # register the callback on invalidation + return nothing +end + +function add_callback!(linfo) + if !isdefined(linfo, :callbacks) + linfo.callbacks = Any[invalidate_cache!] + else + if !any(@nospecialize(cb)->cb===invalidate_cache!, linfo.callbacks) + push!(linfo.callbacks, invalidate_cache!) + end + end + return nothing +end + +function invalidate_cache!(replaced, max_world, depth = 0) + delete!(GLOBAL_CODE_CACHE, replaced) + + if isdefined(replaced, :backedges) + for mi in replaced.backedges + mi = mi::MethodInstance + if !haskey(GLOBAL_CODE_CACHE, mi) + continue # otherwise fall into infinite loop + end + invalidate_cache!(mi, max_world, depth+1) + end + end + return nothing +end + +function CC.optimize(interp::EscapeAnalyzer, opt::OptimizationState, params::OptimizationParams, @nospecialize(result)) + ir = run_passes_with_ea(interp, opt.src, opt) + return CC.finish(interp, opt, params, ir, result) +end + +function run_passes_with_ea(interp::EscapeAnalyzer, ci::CodeInfo, sv::OptimizationState) + @timeit "convert" ir = convert_to_ircode(ci, sv) + @timeit "slot2reg" ir = slot2reg(ir, ci, sv) + # TODO: Domsorting can produce an updated domtree - no need to recompute here + @timeit "compact 1" ir = compact!(ir) + @timeit "Inlining" ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds) + # @timeit "verify 2" verify_ir(ir) + @timeit "compact 2" ir = compact!(ir) + nargs = let def = sv.linfo.def; isa(def, Method) ? Int(def.nargs) : 0; end + local state + try + @timeit "collect escape information" state = analyze_escapes(ir, nargs) + catch err + @info "error happened within `analyze_escapes`, insepct `Main.ir` and `Main.nargs`" + @eval Main (ir = $ir; nargs = $nargs) + rethrow(err) + end + cacheir = Core.Compiler.copy(ir) + # cache this result + cache_escapes!(sv.linfo, state, cacheir) + # return back the result + interp.ir = cacheir + interp.state = state + interp.linfo = sv.linfo + @timeit "SROA" ir = sroa_pass!(ir, nargs) + @timeit "ADCE" ir = adce_pass!(ir) + @timeit "type lift" ir = type_lift_pass!(ir) + @timeit "compact 3" ir = compact!(ir) + if JLOptions().debug_level == 2 + @timeit "verify 3" (verify_ir(ir); verify_linetable(ir.linetable)) + end + return ir +end + +# printing +# -------- + +import Core: Argument, SSAValue +import .CC: widenconst, singleton_type +import .EA: EscapeInfo, EscapeState, ⊑, ⊏ + +# in order to run a whole analysis from ground zero (e.g. for benchmarking, etc.) +__clear_caches!() = (__clear_code_cache!(); EA.__clear_escape_cache!()) + +function get_name_color(x::EscapeInfo, symbol::Bool = false) + getname(x) = string(nameof(x)) + if x === EA.⊥ + name, color = (getname(EA.NotAnalyzed), "◌"), :plain + elseif EA.has_no_escape(x) + name, color = (getname(EA.NoEscape), "✓"), :green + elseif EA.has_all_escape(x) + name, color = (getname(EA.AllEscape), "X"), :red + elseif EA.NoEscape() ⊏ (EA.ignore_thrownescapes ∘ EA.ignore_aliasinfo)(x) ⊑ EA.AllReturnEscape() + name = (getname(EA.ReturnEscape), "↑") + color = EA.has_thrown_escape(x) ? :yellow : :cyan + else + name = (nothing, "*") + color = EA.has_thrown_escape(x) ? :yellow : :bold + end + name = symbol ? last(name) : first(name) + if name !== nothing && !isa(x.AliasInfo, Bool) + name = string(name, "′") + end + return name, color +end + +# pcs = sprint(show, collect(x.EscapeSites); context=:limit=>true) +function Base.show(io::IO, x::EscapeInfo) + name, color = get_name_color(x) + if isnothing(name) + Base.@invoke show(io::IO, x::Any) + else + printstyled(io, name; color) + end +end +function Base.show(io::IO, ::MIME"application/prs.juno.inline", x::EscapeInfo) + name, color = get_name_color(x) + if isnothing(name) + return x # use fancy tree-view + else + printstyled(io, name; color) + end +end + +struct EscapeResult + ir::IRCode + state::EscapeState + linfo::Union{Nothing,MethodInstance} + EscapeResult(ir::IRCode, state::EscapeState, linfo::Union{Nothing,MethodInstance} = nothing) = + new(ir, state, linfo) +end +Base.show(io::IO, result::EscapeResult) = print_with_info(io, result.ir, result.state, result.linfo) +@eval Base.iterate(res::EscapeResult, state=1) = + return state > $(fieldcount(EscapeResult)) ? nothing : (getfield(res, state), state+1) + +# adapted from https://github.com/JuliaDebug/LoweredCodeUtils.jl/blob/4612349432447e868cf9285f647108f43bd0a11c/src/codeedges.jl#L881-L897 +function print_with_info(io::IO, + ir::IRCode, state::EscapeState, linfo::Union{Nothing,MethodInstance}) + # print escape information on SSA values + function preprint(io::IO) + ft = ir.argtypes[1] + f = singleton_type(ft) + if f === nothing + f = widenconst(ft) + end + print(io, f, '(') + for i in 1:state.nargs + arg = state[Argument(i)] + i == 1 && continue + c, color = get_name_color(arg, true) + printstyled(io, c, ' ', '_', i, "::", ir.argtypes[i]; color) + i ≠ state.nargs && print(io, ", ") + end + print(io, ')') + if !isnothing(linfo) + def = linfo.def + printstyled(io, " in ", (isa(def, Module) ? (def,) : (def.module, " at ", def.file, ':', def.line))...; color=:bold) + end + println(io) + end + + # print escape information on SSA values + # nd = ndigits(length(ssavalues)) + function preprint(io::IO, idx::Int) + c, color = get_name_color(state[SSAValue(idx)], true) + # printstyled(io, lpad(idx, nd), ' ', c, ' '; color) + printstyled(io, rpad(c, 2), ' '; color) + end + + print_with_info(preprint, (args...)->nothing, io, ir) +end + +function print_with_info(preprint, postprint, io::IO, ir::IRCode) + io = IOContext(io, :displaysize=>displaysize(io)) + used = Base.IRShow.stmts_used(io, ir) + # line_info_preprinter = Base.IRShow.lineinfo_disabled + line_info_preprinter = function (io::IO, indent::String, idx::Int) + r = Base.IRShow.inline_linfo_printer(ir)(io, indent, idx) + idx ≠ 0 && preprint(io, idx) + return r + end + line_info_postprinter = Base.IRShow.default_expr_type_printer + preprint(io) + bb_idx_prev = bb_idx = 1 + for idx = 1:length(ir.stmts) + preprint(io, idx) + bb_idx = Base.IRShow.show_ir_stmt(io, ir, idx, line_info_preprinter, line_info_postprinter, used, ir.cfg, bb_idx) + postprint(io, idx, bb_idx != bb_idx_prev) + bb_idx_prev = bb_idx + end + max_bb_idx_size = ndigits(length(ir.cfg.blocks)) + line_info_preprinter(io, " "^(max_bb_idx_size + 2), 0) + postprint(io) + return nothing +end + +end # module EAUtils diff --git a/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl b/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl new file mode 100644 index 00000000000000..f1ee937563f0b2 --- /dev/null +++ b/base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl @@ -0,0 +1,1602 @@ +baremodule EscapeAnalysis + +export + analyze_escapes, + cache_escapes!, + getaliases, + isaliased, + has_no_escape, + has_arg_escape, + has_return_escape, + has_thrown_escape, + has_all_escape + +const _TOP_MOD = ccall(:jl_base_relative_to, Any, (Any,), EscapeAnalysis)::Module + +# imports +import ._TOP_MOD: ==, getindex, setindex! +# usings +import Core: + MethodInstance, Const, Argument, SSAValue, PiNode, PhiNode, UpsilonNode, PhiCNode, + ReturnNode, GotoNode, GotoIfNot, SimpleVector, sizeof, ifelse, arrayset, arrayref, + arraysize +import ._TOP_MOD: # Base definitions + @__MODULE__, @eval, @assert, @specialize, @nospecialize, @inbounds, @inline, @noinline, + @label, @goto, !, !==, !=, ≠, +, -, ≤, <, ≥, >, &, |, include, error, missing, copy, + Vector, BitSet, IdDict, IdSet, UnitRange, ∪, ⊆, ∩, :, ∈, ∉, in, length, get, first, last, + isempty, isassigned, pop!, push!, pushfirst!, empty!, max, min, Csize_t +import Core.Compiler: # Core.Compiler specific definitions + isbitstype, isexpr, is_meta_expr_head, println, + IRCode, IR_FLAG_EFFECT_FREE, widenconst, argextype, singleton_type, fieldcount_noerror, + try_compute_field, try_compute_fieldidx, hasintersect, ⊑ as ⊑ₜ, intrinsic_nothrow, + array_builtin_common_typecheck, arrayset_typecheck, setfield!_nothrow, compute_trycatch + +if _TOP_MOD !== Core.Compiler + include(@__MODULE__, "disjoint_set.jl") +else + include(@__MODULE__, "compiler/ssair/EscapeAnalysis/disjoint_set.jl") +end + +const AInfo = BitSet # XXX better to be IdSet{Int}? +struct Indexable + array::Bool + infos::Vector{AInfo} +end +struct Unindexable + array::Bool + info::AInfo +end +function merge_to_unindexable(info::AInfo, infos::Vector{AInfo}) + for i = 1:length(infos) + info = info ∪ infos[i] + end + return info +end +merge_to_unindexable(infos::Vector{AInfo}) = merge_to_unindexable(AInfo(), infos) + +const LivenessSet = BitSet + +""" + x::EscapeInfo + +A lattice for escape information, which holds the following properties: +- `x.Analyzed::Bool`: not formally part of the lattice, only indicates `x` has not been analyzed or not +- `x.ReturnEscape::Bool`: indicates `x` can escape to the caller via return +- `x.ThrownEscape::BitSet`: records SSA statements numbers where `x` can be thrown as exception: + this information will be used by `escape_exception!` to propagate potential escapes via exception +- `x.AliasInfo::Union{Indexable,Unindexable,Bool}`: maintains all possible values + that can be aliased to fields or array elements of `x`: + * `x.AliasInfo === false` indicates the fields/elements of `x` isn't analyzed yet + * `x.AliasInfo === true` indicates the fields/elements of `x` can't be analyzed, + e.g. the type of `x` is not known or is not concrete and thus its fields/elements + can't be known precisely + * `x.AliasInfo::Indexable` records all the possible values that can be aliased to fields/elements of `x` with precise index information + * `x.AliasInfo::Unindexable` records all the possible values that can be aliased to fields/elements of `x` without precise index information +- `x.Liveness::BitSet`: records SSA statement numbers where `x` should be live, e.g. + to be used as a call argument, to be returned to a caller, or preserved for `:foreigncall`. + `0 ∈ x.Liveness` has the special meaning that it's a call argument of the currently analyzed + call frame (and thus it's visible from the caller immediately). +- `x.ArgEscape::Int` (not implemented yet): indicates it will escape to the caller through + `setfield!` on argument(s) + * `-1` : no escape + * `0` : unknown or multiple + * `n` : through argument N + +There are utility constructors to create common `EscapeInfo`s, e.g., +- `NoEscape()`: the bottom(-like) element of this lattice, meaning it won't escape to anywhere +- `AllEscape()`: the topmost element of this lattice, meaning it will escape to everywhere + +`analyze_escapes` will transition these elements from the bottom to the top, +in the same direction as Julia's native type inference routine. +An abstract state will be initialized with the bottom(-like) elements: +- the call arguments are initialized as `ArgEscape()`, whose `Liveness` property includes `0` + to indicate that it is passed as a call argument and visible from a caller immediately +- the other states are initialized as `NotAnalyzed()`, which is a special lattice element that + is slightly lower than `NoEscape`, but at the same time doesn't represent any meaning + other than it's not analyzed yet (thus it's not formally part of the lattice) +""" +struct EscapeInfo + Analyzed::Bool + ReturnEscape::Bool + ThrownEscape::LivenessSet + AliasInfo #::Union{Indexable,Unindexable,Bool} + Liveness::LivenessSet + # TODO: ArgEscape::Int + + function EscapeInfo( + Analyzed::Bool, + ReturnEscape::Bool, + ThrownEscape::LivenessSet, + AliasInfo#=::Union{Indexable,Unindexable,Bool}=#, + Liveness::LivenessSet, + ) + @nospecialize AliasInfo + return new( + Analyzed, + ReturnEscape, + ThrownEscape, + AliasInfo, + Liveness, + ) + end + function EscapeInfo( + x::EscapeInfo, + # non-concrete fields should be passed as default arguments + # in order to avoid allocating non-concrete `NamedTuple`s + AliasInfo#=::Union{Indexable,Unindexable,Bool}=# = x.AliasInfo; + Analyzed::Bool = x.Analyzed, + ReturnEscape::Bool = x.ReturnEscape, + ThrownEscape::LivenessSet = x.ThrownEscape, + Liveness::LivenessSet = x.Liveness, + ) + @nospecialize AliasInfo + return new( + Analyzed, + ReturnEscape, + ThrownEscape, + AliasInfo, + Liveness, + ) + end +end + +# precomputed default values in order to eliminate computations at each callsite +const BOT_THROWN_ESCAPE = LivenessSet() +const TOP_THROWN_ESCAPE = LivenessSet(1:100_000) + +const BOT_ALIAS_INFO = false +const TOP_ALIAS_INFO = true + +const BOT_LIVENESS = LivenessSet() +const TOP_LIVENESS = LivenessSet(0:100_000) +const ARG_LIVENESS = LivenessSet(0) + +# the constructors +NotAnalyzed() = EscapeInfo(false, false, BOT_THROWN_ESCAPE, BOT_ALIAS_INFO, BOT_LIVENESS) # not formally part of the lattice +NoEscape() = EscapeInfo(true, false, BOT_THROWN_ESCAPE, BOT_ALIAS_INFO, BOT_LIVENESS) +ArgEscape() = EscapeInfo(true, false, BOT_THROWN_ESCAPE, TOP_ALIAS_INFO, ARG_LIVENESS) # TODO allow interprocedural alias analysis? +ReturnEscape(pc::Int) = EscapeInfo(true, true, BOT_THROWN_ESCAPE, BOT_ALIAS_INFO, LivenessSet(pc)) +AllReturnEscape() = EscapeInfo(true, true, BOT_THROWN_ESCAPE, BOT_ALIAS_INFO, TOP_LIVENESS) +ThrownEscape(pc::Int) = EscapeInfo(true, false, LivenessSet(pc), BOT_ALIAS_INFO, BOT_LIVENESS) +AllEscape() = EscapeInfo(true, true, TOP_THROWN_ESCAPE, TOP_ALIAS_INFO, TOP_LIVENESS) + +const ⊥, ⊤ = NotAnalyzed(), AllEscape() + +# Convenience names for some ⊑ queries +has_no_escape(x::EscapeInfo) = !x.ReturnEscape && isempty(x.ThrownEscape) +has_arg_escape(x::EscapeInfo) = 0 in x.Liveness +has_return_escape(x::EscapeInfo) = x.ReturnEscape +has_return_escape(x::EscapeInfo, pc::Int) = x.ReturnEscape && pc in x.Liveness +has_thrown_escape(x::EscapeInfo) = !isempty(x.ThrownEscape) +has_thrown_escape(x::EscapeInfo, pc::Int) = pc in x.ThrownEscape +has_all_escape(x::EscapeInfo) = ⊤ ⊑ x + +# utility lattice constructors +ignore_thrownescapes(x::EscapeInfo) = EscapeInfo(x; ThrownEscape=BOT_THROWN_ESCAPE) +ignore_aliasinfo(x::EscapeInfo) = EscapeInfo(x, BOT_ALIAS_INFO) +ignore_liveness(x::EscapeInfo) = EscapeInfo(x; Liveness=BOT_LIVENESS) + +# we need to make sure this `==` operator corresponds to lattice equality rather than object equality, +# otherwise `propagate_changes` can't detect the convergence +x::EscapeInfo == y::EscapeInfo = begin + # fast pass: better to avoid top comparison + x === y && return true + x.Analyzed === y.Analyzed || return false + x.ReturnEscape === y.ReturnEscape || return false + xt, yt = x.ThrownEscape, y.ThrownEscape + if xt === TOP_THROWN_ESCAPE + yt === TOP_THROWN_ESCAPE || return false + elseif yt === TOP_THROWN_ESCAPE + return false # x.ThrownEscape === TOP_THROWN_ESCAPE + else + xt == yt || return false + end + xa, ya = x.AliasInfo, y.AliasInfo + if isa(xa, Bool) + xa === ya || return false + elseif isa(xa, Indexable) + isa(ya, Indexable) || return false + xa.array === ya.array || return false + xa.infos == ya.infos || return false + else + xa = xa::Unindexable + isa(ya, Unindexable) || return false + xa.array === ya.array || return false + xa.info == ya.info || return false + end + xl, yl = x.Liveness, y.Liveness + if xl === TOP_LIVENESS + yl === TOP_LIVENESS || return false + elseif yl === TOP_LIVENESS + return false # x.Liveness === TOP_LIVENESS + else + xl == yl || return false + end + return true +end + +""" + x::EscapeInfo ⊑ y::EscapeInfo -> Bool + +The non-strict partial order over `EscapeInfo`. +""" +x::EscapeInfo ⊑ y::EscapeInfo = begin + # fast pass: better to avoid top comparison + if y === ⊤ + return true + elseif x === ⊤ + return false # return y === ⊤ + elseif x === ⊥ + return true + elseif y === ⊥ + return false # return x === ⊥ + end + x.Analyzed ≤ y.Analyzed || return false + x.ReturnEscape ≤ y.ReturnEscape || return false + xt, yt = x.ThrownEscape, y.ThrownEscape + if xt === TOP_THROWN_ESCAPE + yt !== TOP_THROWN_ESCAPE && return false + elseif yt !== TOP_THROWN_ESCAPE + xt ⊆ yt || return false + end + xa, ya = x.AliasInfo, y.AliasInfo + if isa(xa, Bool) + xa && ya !== true && return false + elseif isa(xa, Indexable) + if isa(ya, Indexable) + xa.array === ya.array || return false + xinfos, yinfos = xa.infos, ya.infos + xn, yn = length(xinfos), length(yinfos) + xn > yn && return false + for i in 1:xn + xinfos[i] ⊆ yinfos[i] || return false + end + elseif isa(ya, Unindexable) + xa.array === ya.array || return false + xinfos, yinfo = xa.infos, ya.info + for i = length(xinfos) + xinfos[i] ⊆ yinfo || return false + end + else + ya === true || return false + end + else + xa = xa::Unindexable + if isa(ya, Unindexable) + xa.array === ya.array || return false + xinfo, yinfo = xa.info, ya.info + xinfo ⊆ yinfo || return false + else + ya === true || return false + end + end + xl, yl = x.Liveness, y.Liveness + if xl === TOP_LIVENESS + yl !== TOP_LIVENESS && return false + elseif yl !== TOP_LIVENESS + xl ⊆ yl || return false + end + return true +end + +""" + x::EscapeInfo ⊏ y::EscapeInfo -> Bool + +The strict partial order over `EscapeInfo`. +This is defined as the irreflexive kernel of `⊏`. +""" +x::EscapeInfo ⊏ y::EscapeInfo = x ⊑ y && !(y ⊑ x) + +""" + x::EscapeInfo ⋤ y::EscapeInfo -> Bool + +This order could be used as a slightly more efficient version of the strict order `⊏`, +where we can safely assume `x ⊑ y` holds. +""" +x::EscapeInfo ⋤ y::EscapeInfo = !(y ⊑ x) + +""" + x::EscapeInfo ⊔ y::EscapeInfo -> EscapeInfo + +Computes the join of `x` and `y` in the partial order defined by `EscapeInfo`. +""" +x::EscapeInfo ⊔ y::EscapeInfo = begin + # fast pass: better to avoid top join + if x === ⊤ || y === ⊤ + return ⊤ + elseif x === ⊥ + return y + elseif y === ⊥ + return x + end + xt, yt = x.ThrownEscape, y.ThrownEscape + if xt === TOP_THROWN_ESCAPE || yt === TOP_THROWN_ESCAPE + ThrownEscape = TOP_THROWN_ESCAPE + elseif xt === BOT_THROWN_ESCAPE + ThrownEscape = yt + elseif yt === BOT_THROWN_ESCAPE + ThrownEscape = xt + else + ThrownEscape = xt ∪ yt + end + xa, ya = x.AliasInfo, y.AliasInfo + if xa === true || ya === true + AliasInfo = true + elseif xa === false + AliasInfo = ya + elseif ya === false + AliasInfo = xa + elseif isa(xa, Indexable) + if isa(ya, Indexable) && xa.array === ya.array + xinfos, yinfos = xa.infos, ya.infos + xn, yn = length(xinfos), length(yinfos) + nmax, nmin = max(xn, yn), min(xn, yn) + infos = Vector{AInfo}(undef, nmax) + for i in 1:nmax + if i > nmin + infos[i] = (xn > yn ? xinfos : yinfos)[i] + else + infos[i] = xinfos[i] ∪ yinfos[i] + end + end + AliasInfo = Indexable(xa.array, infos) + elseif isa(ya, Unindexable) && xa.array === ya.array + xinfos, yinfo = xa.infos, ya.info + info = merge_to_unindexable(yinfo, xinfos) + AliasInfo = Unindexable(xa.array, info) + else + AliasInfo = true # handle conflicting case conservatively + end + else + xa = xa::Unindexable + if isa(ya, Indexable) && xa.array === ya.array + xinfo, yinfos = xa.info, ya.infos + info = merge_to_unindexable(xinfo, yinfos) + AliasInfo = Unindexable(xa.array, info) + elseif isa(ya, Unindexable) && xa.array === ya.array + xinfo, yinfo = xa.info, ya.info + info = xinfo ∪ yinfo + AliasInfo = Unindexable(xa.array, info) + else + AliasInfo = true # handle conflicting case conservatively + end + end + xl, yl = x.Liveness, y.Liveness + if xl === TOP_LIVENESS || yl === TOP_LIVENESS + Liveness = TOP_LIVENESS + elseif xl === BOT_LIVENESS + Liveness = yl + elseif yl === BOT_LIVENESS + Liveness = xl + else + Liveness = xl ∪ yl + end + return EscapeInfo( + x.Analyzed | y.Analyzed, + x.ReturnEscape | y.ReturnEscape, + ThrownEscape, + AliasInfo, + Liveness, + ) +end + +# TODO setup a more effient struct for cache +# which can discard escape information on SSS values and arguments that don't join dispatch signature + +const AliasSet = IntDisjointSet{Int} + +""" + estate::EscapeState + +Extended lattice that maps arguments and SSA values to escape information represented as [`EscapeInfo`](@ref). +Escape information imposed on SSA IR element `x` can be retrieved by `estate[x]`. +""" +struct EscapeState + escapes::Vector{EscapeInfo} + aliasset::AliasSet + nargs::Int +end +function EscapeState(nargs::Int, nstmts::Int) + escapes = EscapeInfo[ + 1 ≤ i ≤ nargs ? ArgEscape() : ⊥ for i in 1:(nargs+nstmts)] + aliaset = AliasSet(nargs+nstmts) + return EscapeState(escapes, aliaset, nargs) +end +function getindex(estate::EscapeState, @nospecialize(x)) + xidx = iridx(x, estate) + return xidx === nothing ? nothing : estate.escapes[xidx] +end +function setindex!(estate::EscapeState, v::EscapeInfo, @nospecialize(x)) + xidx = iridx(x, estate) + if xidx !== nothing + estate.escapes[xidx] = v + end + return estate +end + +""" + iridx(x, estate::EscapeState) -> xidx::Union{Int,Nothing} + +Tries to convert analyzable IR element `x::Union{Argument,SSAValue}` to +its unique identifier number `xidx` that is valid in the analysis context of `estate`. +Returns `nothing` if `x` isn't maintained by `estate` and thus unanalyzable (e.g. `x::GlobalRef`). + +`irval` is the inverse function of `iridx` (not formally), i.e. +`irval(iridx(x::Union{Argument,SSAValue}, state), state) === x`. +""" +function iridx(@nospecialize(x), estate::EscapeState) + if isa(x, Argument) + xidx = x.n + @assert 1 ≤ xidx ≤ estate.nargs "invalid Argument" + elseif isa(x, SSAValue) + xidx = x.id + estate.nargs + else + return nothing + end + return xidx +end + +""" + irval(xidx::Int, estate::EscapeState) -> x::Union{Argument,SSAValue} + +Converts its unique identifier number `xidx` to the original IR element `x::Union{Argument,SSAValue}` +that is analyzable in the context of `estate`. + +`iridx` is the inverse function of `irval` (not formally), i.e. +`iridx(irval(xidx, state), state) === xidx`. +""" +function irval(xidx::Int, estate::EscapeState) + x = xidx > estate.nargs ? SSAValue(xidx-estate.nargs) : Argument(xidx) + return x +end + +function getaliases(x::Union{Argument,SSAValue}, estate::EscapeState) + xidx = iridx(x, estate) + aliases = getaliases(xidx, estate) + aliases === nothing && return nothing + return Union{Argument,SSAValue}[irval(aidx, estate) for aidx in aliases] +end +function getaliases(xidx::Int, estate::EscapeState) + aliasset = estate.aliasset + root = find_root!(aliasset, xidx) + if xidx ≠ root || aliasset.ranks[xidx] > 0 + # the size of this alias set containing `key` is larger than 1, + # collect the entire alias set + aliases = Int[] + for aidx in 1:length(aliasset.parents) + if aliasset.parents[aidx] == root + push!(aliases, aidx) + end + end + return aliases + else + return nothing + end +end + +isaliased(x::Union{Argument,SSAValue}, y::Union{Argument,SSAValue}, estate::EscapeState) = + isaliased(iridx(x, estate), iridx(y, estate), estate) +isaliased(xidx::Int, yidx::Int, estate::EscapeState) = + in_same_set(estate.aliasset, xidx, yidx) + +""" + ArgEscapeInfo(x::EscapeInfo) -> x′::ArgEscapeInfo + +The data structure for caching `x::EscapeInfo` for interprocedural propagation, +which is slightly more efficient than the original `x::EscapeInfo` object. +""" +struct ArgEscapeInfo + AllEscape::Bool + ReturnEscape::Bool + ThrownEscape::Bool + function ArgEscapeInfo(x::EscapeInfo) + x === ⊤ && return new(true, true, true) + ThrownEscape = isempty(x.ThrownEscape) ? false : true + return new(false, x.ReturnEscape, ThrownEscape) + end +end + +""" + cache_escapes!(linfo::MethodInstance, estate::EscapeState, _::IRCode) + +Transforms escape information of `estate` for interprocedural propagation, +and caches it in a global cache that can then be looked up later when +`linfo` callsite is seen again. +""" +function cache_escapes! end + +# when working outside of Core.Compiler, cache as much as information for later inspection and debugging +if _TOP_MOD !== Core.Compiler + struct EscapeCache + cache::Vector{ArgEscapeInfo} + state::EscapeState # preserved just for debugging purpose + ir::IRCode # preserved just for debugging purpose + end + const GLOBAL_ESCAPE_CACHE = IdDict{MethodInstance,EscapeCache}() + function cache_escapes!(linfo::MethodInstance, estate::EscapeState, cacheir::IRCode) + cache = EscapeCache(to_interprocedural(estate), estate, cacheir) + GLOBAL_ESCAPE_CACHE[linfo] = cache + return cache + end + argescapes_from_cache(cache::EscapeCache) = cache.cache +else + const GLOBAL_ESCAPE_CACHE = IdDict{MethodInstance,Vector{ArgEscapeInfo}}() + function cache_escapes!(linfo::MethodInstance, estate::EscapeState, _::IRCode) + cache = to_interprocedural(estate) + GLOBAL_ESCAPE_CACHE[linfo] = cache + return cache + end + argescapes_from_cache(cache::Vector{ArgEscapeInfo}) = cache +end + +function to_interprocedural(estate::EscapeState) + cache = Vector{ArgEscapeInfo}(undef, estate.nargs) + for i = 1:estate.nargs + cache[i] = ArgEscapeInfo(estate.escapes[i]) + end + return cache +end + +__clear_escape_cache!() = empty!(GLOBAL_ESCAPE_CACHE) + +abstract type Change end +struct EscapeChange <: Change + xidx::Int + xinfo::EscapeInfo +end +struct AliasChange <: Change + xidx::Int + yidx::Int +end +struct LivenessChange <: Change + xidx::Int + livepc::Int +end +const Changes = Vector{Change} + +struct AnalysisState + ir::IRCode + estate::EscapeState + changes::Changes +end + +function getinst(ir::IRCode, idx::Int) + nstmts = length(ir.stmts) + if idx ≤ nstmts + return ir.stmts[idx] + else + return ir.new_nodes.stmts[idx - nstmts] + end +end + +""" + analyze_escapes(ir::IRCode, nargs::Int) -> estate::EscapeState + +Analyzes escape information in `ir`. +`nargs` is the number of actual arguments of the analyzed call. +""" +function analyze_escapes(ir::IRCode, nargs::Int) + stmts = ir.stmts + nstmts = length(stmts) + length(ir.new_nodes.stmts) + + # only manage a single state, some flow-sensitivity is encoded as `EscapeInfo` properties + estate = EscapeState(nargs, nstmts) + changes = Changes() # stashes changes that happen at current statement + tryregions = compute_tryregions(ir) + astate = AnalysisState(ir, estate, changes) + + local debug_itr_counter = 0 + while true + local anyupdate = false + + for pc in nstmts:-1:1 + stmt = getinst(ir, pc)[:inst] + + # collect escape information + if isa(stmt, Expr) + head = stmt.head + if head === :call + escape_call!(astate, pc, stmt.args) + elseif head === :invoke + escape_invoke!(astate, pc, stmt.args) + elseif head === :new || head === :splatnew + escape_new!(astate, pc, stmt.args) + elseif head === :(=) + lhs, rhs = stmt.args + if isa(lhs, GlobalRef) # global store + add_escape_change!(astate, rhs, ⊤) + else + unexpected_assignment!(ir, pc) + end + elseif head === :foreigncall + escape_foreigncall!(astate, pc, stmt.args) + elseif head === :throw_undef_if_not # XXX when is this expression inserted ? + add_escape_change!(astate, stmt.args[1], ThrownEscape(pc)) + elseif is_meta_expr_head(head) + # meta expressions doesn't account for any usages + continue + elseif head === :enter || head === :leave || head === :the_exception || head === :pop_exception + # ignore these expressions since escapes via exceptions are handled by `escape_exception!` + # `escape_exception!` conservatively propagates `AllEscape` anyway, + # and so escape information imposed on `:the_exception` isn't computed + continue + elseif head === :static_parameter || # this exists statically, not interested in its escape + head === :copyast || # XXX can this account for some escapes? + head === :undefcheck || # XXX can this account for some escapes? + head === :isdefined || # just returns `Bool`, nothing accounts for any escapes + head === :gc_preserve_begin || # `GC.@preserve` expressions themselves won't be used anywhere + head === :gc_preserve_end # `GC.@preserve` expressions themselves won't be used anywhere + continue + else + for x in stmt.args + add_escape_change!(astate, x, ⊤) + end + end + elseif isa(stmt, ReturnNode) + if isdefined(stmt, :val) + add_escape_change!(astate, stmt.val, ReturnEscape(pc)) + end + elseif isa(stmt, PhiNode) + escape_edges!(astate, pc, stmt.values) + elseif isa(stmt, PiNode) + escape_val_ifdefined!(astate, pc, stmt) + elseif isa(stmt, PhiCNode) + escape_edges!(astate, pc, stmt.values) + elseif isa(stmt, UpsilonNode) + escape_val_ifdefined!(astate, pc, stmt) + elseif isa(stmt, GlobalRef) # global load + add_escape_change!(astate, SSAValue(pc), ⊤) + elseif isa(stmt, SSAValue) + escape_val!(astate, pc, stmt) + elseif isa(stmt, Argument) + escape_val!(astate, pc, stmt) + else # otherwise `stmt` can be GotoNode, GotoIfNot, and inlined values etc. + continue + end + + isempty(changes) && continue + + anyupdate |= propagate_changes!(estate, changes) + + empty!(changes) + end + + tryregions !== nothing && escape_exception!(astate, tryregions) + + debug_itr_counter += 1 + + anyupdate || break + end + + # if debug_itr_counter > 2 + # println("[EA] excessive iteration count found ", debug_itr_counter, " (", singleton_type(ir.argtypes[1]), ")") + # end + + return estate +end + +# propagate changes, and check convergence +function propagate_changes!(estate::EscapeState, changes::Changes) + local anychanged = false + for change in changes + if isa(change, EscapeChange) + anychanged |= propagate_escape_change!(estate, change) + elseif isa(change, LivenessChange) + anychanged |= propagate_liveness_change!(estate, change) + else + change = change::AliasChange + anychanged |= propagate_alias_change!(estate, change) + end + end + return anychanged +end + +@inline propagate_escape_change!(estate::EscapeState, change::EscapeChange) = + propagate_escape_change!(⊔, estate, change) + +# allows this to work as lattice join as well as lattice meet +@inline function propagate_escape_change!(@specialize(op), + estate::EscapeState, change::EscapeChange) + (; xidx, xinfo) = change + anychanged = _propagate_escape_change!(op, estate, xidx, xinfo) + aliases = getaliases(xidx, estate) + if aliases !== nothing + for aidx in aliases + anychanged |= _propagate_escape_change!(op, estate, aidx, xinfo) + end + end + return anychanged +end + +@inline function _propagate_escape_change!(@specialize(op), + estate::EscapeState, xidx::Int, info::EscapeInfo) + old = estate.escapes[xidx] + new = op(old, info) + if old ≠ new + estate.escapes[xidx] = new + return true + end + return false +end + +# propagate Liveness changes separately in order to avoid constructing too many LivenessSet +@inline function propagate_liveness_change!(estate::EscapeState, change::LivenessChange) + (; xidx, livepc) = change + info = estate.escapes[xidx] + Liveness = info.Liveness + Liveness === TOP_LIVENESS && return false + livepc in Liveness && return false + if Liveness === BOT_LIVENESS || Liveness === ARG_LIVENESS + # if this Liveness is a constant, we shouldn't modify it and propagate this change as a new EscapeInfo + Liveness = copy(Liveness) + push!(Liveness, livepc) + estate.escapes[xidx] = EscapeInfo(info; Liveness) + return true + else + # directly modify Liveness property in order to avoid excessive copies + push!(Liveness, livepc) + return true + end +end + +@inline function propagate_alias_change!(estate::EscapeState, change::AliasChange) + (; xidx, yidx) = change + xroot = find_root!(estate.aliasset, xidx) + yroot = find_root!(estate.aliasset, yidx) + if xroot ≠ yroot + union!(estate.aliasset, xroot, yroot) + xinfo = estate.escapes[xidx] + yinfo = estate.escapes[yidx] + xyinfo = xinfo ⊔ yinfo + estate.escapes[xidx] = xyinfo + estate.escapes[yidx] = xyinfo + return true + end + return false +end + +function add_escape_change!(astate::AnalysisState, @nospecialize(x), xinfo::EscapeInfo) + xinfo === ⊥ && return nothing # performance optimization + xidx = iridx(x, astate.estate) + if xidx !== nothing + if !isbitstype(widenconst(argextype(x, astate.ir))) + push!(astate.changes, EscapeChange(xidx, xinfo)) + end + end + return nothing +end + +function add_liveness_change!(astate::AnalysisState, @nospecialize(x), livepc::Int) + xidx = iridx(x, astate.estate) + if xidx !== nothing + if !isbitstype(widenconst(argextype(x, astate.ir))) + push!(astate.changes, LivenessChange(xidx, livepc)) + end + end + return nothing +end + +function add_alias_change!(astate::AnalysisState, @nospecialize(x), @nospecialize(y)) + if isa(x, GlobalRef) + return add_escape_change!(astate, y, ⊤) + elseif isa(y, GlobalRef) + return add_escape_change!(astate, x, ⊤) + end + estate = astate.estate + xidx = iridx(x, estate) + yidx = iridx(y, estate) + if xidx !== nothing && yidx !== nothing && !isaliased(xidx, yidx, astate.estate) + pushfirst!(astate.changes, AliasChange(xidx, yidx)) + end + return nothing +end + +function escape_edges!(astate::AnalysisState, pc::Int, edges::Vector{Any}) + ret = SSAValue(pc) + for i in 1:length(edges) + if isassigned(edges, i) + v = edges[i] + add_alias_change!(astate, ret, v) + end + end +end + +function escape_val_ifdefined!(astate::AnalysisState, pc::Int, x) + if isdefined(x, :val) + escape_val!(astate, pc, x.val) + end +end + +function escape_val!(astate::AnalysisState, pc::Int, @nospecialize(val)) + ret = SSAValue(pc) + add_alias_change!(astate, ret, val) +end + +# NOTE if we don't maintain the alias set that is separated from the lattice state, we can do +# something like below: it essentially incorporates forward escape propagation in our default +# backward propagation, and leads to inefficient convergence that requires more iterations +# # lhs = rhs: propagate escape information of `rhs` to `lhs` +# function escape_alias!(astate::AnalysisState, @nospecialize(lhs), @nospecialize(rhs)) +# if isa(rhs, SSAValue) || isa(rhs, Argument) +# vinfo = astate.estate[rhs] +# else +# return +# end +# add_escape_change!(astate, lhs, vinfo) +# end + +# linear scan to find regions in which potential throws will be caught +function compute_tryregions(ir::IRCode) + tryregions = nothing + for idx in 1:length(ir.stmts) + stmt = ir.stmts[idx][:inst] + if isexpr(stmt, :enter) + tryregions === nothing && (tryregions = UnitRange{Int}[]) + leave_block = stmt.args[1]::Int + leave_pc = first(ir.cfg.blocks[leave_block].stmts) + push!(tryregions, idx:leave_pc) + end + end + for idx in 1:length(ir.new_nodes.stmts) + stmt = ir.new_nodes.stmts[idx][:inst] + @assert !isexpr(stmt, :enter) "try/catch inside new_nodes unsupported" + end + return tryregions +end + +""" + escape_exception!(astate::AnalysisState, tryregions::Vector{UnitRange{Int}}) + +Propagates escapes via exceptions that can happen in `tryregions`. + +Naively it seems enough to propagate escape information imposed on `:the_exception` object, +but actually there are several other ways to access to the exception object such as +`Base.current_exceptions` and manual catch of `rethrow`n object. +For example, escape analysis needs to account for potential escape of the allocated object +via `rethrow_escape!()` call in the example below: +```julia +const Gx = Ref{Any}() +@noinline function rethrow_escape!() + try + rethrow() + catch err + Gx[] = err + end +end +unsafeget(x) = isassigned(x) ? x[] : throw(x) + +code_escapes() do + r = Ref{String}() + try + t = unsafeget(r) + catch err + t = typeof(err) # `err` (which `r` may alias to) doesn't escape here + rethrow_escape!() # `r` can escape here + end + return t +end +``` + +As indicated by the above example, it requires a global analysis in addition to a base escape +analysis to reason about all possible escapes via existing exception interfaces correctly. +For now we conservatively always propagate `AllEscape` to all potentially thrown objects, +since such an additional analysis might not be worthwhile to do given that exception handlings +and error paths usually don't need to be very performance sensitive, and optimizations of +error paths might be very ineffective anyway since they are sometimes "unoptimized" +intentionally for latency reasons. +""" +function escape_exception!(astate::AnalysisState, tryregions::Vector{UnitRange{Int}}) + estate = astate.estate + # NOTE if `:the_exception` is the only way to access the exception, we can do: + # exc = SSAValue(pc) + # excinfo = estate[exc] + excinfo = ⊤ + escapes = estate.escapes + for i in 1:length(escapes) + x = escapes[i] + xt = x.ThrownEscape + xt === TOP_THROWN_ESCAPE && @goto propagate_exception_escape # fast pass + for pc in x.ThrownEscape + for region in tryregions + pc in region && @goto propagate_exception_escape # early break because of AllEscape + end + end + continue + @label propagate_exception_escape + xval = irval(i, estate) + add_escape_change!(astate, xval, excinfo) + end +end + +function escape_invoke!(astate::AnalysisState, pc::Int, args::Vector{Any}) + linfo = first(args)::MethodInstance + cache = get(GLOBAL_ESCAPE_CACHE, linfo, nothing) + if cache === nothing + for i in 2:length(args) + x = args[i] + add_escape_change!(astate, x, ⊤) + end + else + argescapes = argescapes_from_cache(cache) + ret = SSAValue(pc) + retinfo = astate.estate[ret] # escape information imposed on the call statement + method = linfo.def::Method + nargs = Int(method.nargs) + for i in 2:length(args) + arg = args[i] + if i-1 ≤ nargs + argi = i-1 + else # handle isva signature: COMBAK will this be invalid once we take alias information into account ? + argi = nargs + end + arginfo = argescapes[argi] + info = from_interprocedural(arginfo, retinfo, pc) + if arginfo.ReturnEscape + # if this argument can be "returned", in addition to propagating + # the escape information imposed on this call argument within the callee, + # we should also account for possible aliasing of this argument and the returned value + add_escape_change!(astate, arg, info) + add_alias_change!(astate, ret, arg) + else + # if this is simply passed as the call argument, we can just propagate + # the escape information imposed on this call argument within the callee + add_escape_change!(astate, arg, info) + end + end + end +end + +""" + from_interprocedural(arginfo::ArgEscapeInfo, retinfo::EscapeInfo, pc::Int) -> x::EscapeInfo + +Reinterprets the escape information imposed on the call argument which is cached as `arginfo` +in the context of the caller frame, where `retinfo` is the escape information imposed on +the return value and `pc` is the SSA statement number of the return value. +""" +function from_interprocedural(arginfo::ArgEscapeInfo, retinfo::EscapeInfo, pc::Int) + arginfo.AllEscape && return ⊤ + + ThrownEscape = arginfo.ThrownEscape ? LivenessSet(pc) : BOT_THROWN_ESCAPE + + return EscapeInfo( + #=Analyzed=#true, #=ReturnEscape=#false, ThrownEscape, + # FIXME implement interprocedural memory effect-analysis + # currently, this essentially disables the entire field analysis + # it might be okay from the SROA point of view, since we can't remove the allocation + # as far as it's passed to a callee anyway, but still we may want some field analysis + # for e.g. stack allocation or some other IPO optimizations + #=AliasInfo=#TOP_ALIAS_INFO, #=Liveness=#LivenessSet(pc)) +end + +@noinline function unexpected_assignment!(ir::IRCode, pc::Int) + @eval Main (ir = $ir; pc = $pc) + error("unexpected assignment found: inspect `Main.pc` and `Main.pc`") +end + +function escape_new!(astate::AnalysisState, pc::Int, args::Vector{Any}) + obj = SSAValue(pc) + objinfo = astate.estate[obj] + AliasInfo = objinfo.AliasInfo + nargs = length(args) + if isa(AliasInfo, Bool) + @goto conservative_propagation + elseif isa(AliasInfo, Indexable) && !AliasInfo.array + # fields are known precisely: propagate escape information imposed on recorded possibilities to the exact field values + infos = AliasInfo.infos + nf = length(infos) + objinfo = ignore_aliasinfo(objinfo) + for i in 2:nargs + i-1 > nf && break # may happen when e.g. ϕ-node merges values with different types + arg = args[i] + add_alias_escapes!(astate, arg, infos[i-1]) + push!(infos[i-1], -pc) # record def + # propagate the escape information of this object ignoring field information + add_escape_change!(astate, arg, objinfo) + add_liveness_change!(astate, arg, pc) + end + elseif isa(AliasInfo, Unindexable) && !AliasInfo.array + # fields are known partially: propagate escape information imposed on recorded possibilities to all fields values + info = AliasInfo.info + objinfo = ignore_aliasinfo(objinfo) + for i in 2:nargs + arg = args[i] + add_alias_escapes!(astate, arg, info) + push!(info, -pc) # record def + # propagate the escape information of this object ignoring field information + add_escape_change!(astate, arg, objinfo) + add_liveness_change!(astate, arg, pc) + end + else + # this object has been used as array, but it is allocated as struct here (i.e. should throw) + # update obj's field information and just handle this case conservatively + objinfo = escape_unanalyzable_obj!(astate, obj, objinfo) + @label conservative_propagation + # the fields couldn't be analyzed precisely: propagate the entire escape information + # of this object to all its fields as the most conservative propagation + for i in 2:nargs + arg = args[i] + add_escape_change!(astate, arg, objinfo) + add_liveness_change!(astate, arg, pc) + end + end + if !(getinst(astate.ir, pc)[:flag] & IR_FLAG_EFFECT_FREE ≠ 0) + add_thrown_escapes!(astate, pc, args) + end +end + +function add_alias_escapes!(astate::AnalysisState, @nospecialize(v), ainfo::AInfo) + estate = astate.estate + for aidx in ainfo + aidx < 0 && continue # ignore def + x = SSAValue(aidx) # obviously this won't be true once we implement ArgEscape + add_alias_change!(astate, v, x) + end +end + +function escape_unanalyzable_obj!(astate::AnalysisState, @nospecialize(obj), objinfo::EscapeInfo) + objinfo = EscapeInfo(objinfo, TOP_ALIAS_INFO) + add_escape_change!(astate, obj, objinfo) + return objinfo +end + +function add_thrown_escapes!(astate::AnalysisState, pc::Int, args::Vector{Any}, + first_idx::Int = 1, last_idx::Int = length(args)) + info = ThrownEscape(pc) + for i in first_idx:last_idx + add_escape_change!(astate, args[i], info) + end +end + +function add_liveness_changes!(astate::AnalysisState, pc::Int, args::Vector{Any}, + first_idx::Int = 1, last_idx::Int = length(args)) + for i in first_idx:last_idx + arg = args[i] + add_liveness_change!(astate, arg, pc) + end +end + +function add_fallback_changes!(astate::AnalysisState, pc::Int, args::Vector{Any}, + first_idx::Int = 1, last_idx::Int = length(args)) + info = ThrownEscape(pc) + for i in first_idx:last_idx + arg = args[i] + add_escape_change!(astate, arg, info) + add_liveness_change!(astate, arg, pc) + end +end + +# escape every argument `(args[6:length(args[3])])` and the name `args[1]` +# TODO: we can apply a similar strategy like builtin calls to specialize some foreigncalls +function escape_foreigncall!(astate::AnalysisState, pc::Int, args::Vector{Any}) + nargs = length(args) + if nargs < 6 + # invalid foreigncall, just escape everything + for i = 1:length(args) + add_escape_change!(astate, args[i], ⊤) + end + return + end + argtypes = args[3]::SimpleVector + nargs = length(argtypes) + name = args[1] + nn = normalize(name) + if isa(nn, Symbol) + boundserror_ninds = array_resize_info(nn) + if boundserror_ninds !== nothing + boundserror, ninds = boundserror_ninds + escape_array_resize!(boundserror, ninds, astate, pc, args) + return + end + if is_array_copy(nn) + escape_array_copy!(astate, pc, args) + return + elseif is_array_isassigned(nn) + escape_array_isassigned!(astate, pc, args) + return + end + # if nn === :jl_gc_add_finalizer_th + # # TODO add `FinalizerEscape` ? + # end + end + # NOTE array allocations might have been proven as nothrow (https://github.com/JuliaLang/julia/pull/43565) + nothrow = astate.ir.stmts[pc][:flag] & IR_FLAG_EFFECT_FREE ≠ 0 + name_info = nothrow ? ⊥ : ThrownEscape(pc) + add_escape_change!(astate, name, name_info) + add_liveness_change!(astate, name, pc) + for i = 1:nargs + # we should escape this argument if it is directly called, + # otherwise just impose ThrownEscape if not nothrow + if argtypes[i] === Any + arg_info = ⊤ + else + arg_info = nothrow ? ⊥ : ThrownEscape(pc) + end + add_escape_change!(astate, args[5+i], arg_info) + add_liveness_change!(astate, args[5+i], pc) + end + for i = (5+nargs):length(args) + arg = args[i] + add_escape_change!(astate, arg, ⊥) + add_liveness_change!(astate, arg, pc) + end +end + +normalize(@nospecialize x) = isa(x, QuoteNode) ? x.value : x + +function escape_call!(astate::AnalysisState, pc::Int, args::Vector{Any}) + ir = astate.ir + ft = argextype(first(args), ir, ir.sptypes, ir.argtypes) + f = singleton_type(ft) + if isa(f, Core.IntrinsicFunction) + # XXX somehow `:call` expression can creep in here, ideally we should be able to do: + # argtypes = Any[argextype(args[i], astate.ir) for i = 2:length(args)] + argtypes = Any[] + for i = 2:length(args) + arg = args[i] + push!(argtypes, isexpr(arg, :call) ? Any : argextype(arg, ir)) + end + if intrinsic_nothrow(f, argtypes) + add_liveness_changes!(astate, pc, args, 2) + else + add_fallback_changes!(astate, pc, args, 2) + end + return # TODO accounts for pointer operations? + end + result = escape_builtin!(f, astate, pc, args) + if result === missing + # if this call hasn't been handled by any of pre-defined handlers, + # we escape this call conservatively + for i in 2:length(args) + add_escape_change!(astate, args[i], ⊤) + end + add_escape_change!(astate, SSAValue(pc), ⊤) + return + elseif result === true + add_liveness_changes!(astate, pc, args, 2) + return # ThrownEscape is already checked + else + # we escape statements with the `ThrownEscape` property using the effect-freeness + # computed by `stmt_effect_free` invoked within inlining + # TODO throwness ≠ "effect-free-ness" + if getinst(astate.ir, pc)[:flag] & IR_FLAG_EFFECT_FREE ≠ 0 + add_liveness_changes!(astate, pc, args, 2) + else + add_fallback_changes!(astate, pc, args, 2) + end + return + end +end + +escape_builtin!(@nospecialize(f), _...) = return missing + +# safe builtins +escape_builtin!(::typeof(isa), _...) = return false +escape_builtin!(::typeof(typeof), _...) = return false +escape_builtin!(::typeof(sizeof), _...) = return false +escape_builtin!(::typeof(===), _...) = return false +# not really safe, but `ThrownEscape` will be imposed later +escape_builtin!(::typeof(isdefined), _...) = return false +escape_builtin!(::typeof(throw), _...) = return false + +function escape_builtin!(::typeof(ifelse), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) == 4 || return false + f, cond, th, el = args + ret = SSAValue(pc) + condt = argextype(cond, astate.ir) + if isa(condt, Const) && (cond = condt.val; isa(cond, Bool)) + if cond + add_alias_change!(astate, th, ret) + else + add_alias_change!(astate, el, ret) + end + else + add_alias_change!(astate, th, ret) + add_alias_change!(astate, el, ret) + end + return false +end + +function escape_builtin!(::typeof(typeassert), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) == 3 || return false + f, obj, typ = args + ret = SSAValue(pc) + add_alias_change!(astate, ret, obj) + return false +end + +function escape_builtin!(::typeof(tuple), astate::AnalysisState, pc::Int, args::Vector{Any}) + escape_new!(astate, pc, args) + return false +end + +function analyze_fields(ir::IRCode, @nospecialize(typ), @nospecialize(fld)) + nfields = fieldcount_noerror(typ) + if nfields === nothing + return Unindexable(false, AInfo()), 0 + end + if isa(typ, DataType) + fldval = try_compute_field(ir, fld) + fidx = try_compute_fieldidx(typ, fldval) + else + fidx = nothing + end + if fidx === nothing + return Unindexable(false, AInfo()), 0 + end + return Indexable(false, AInfo[AInfo() for _ in 1:nfields]), fidx +end + +function reanalyze_fields(ir::IRCode, AliasInfo::Indexable, @nospecialize(typ), @nospecialize(fld)) + infos = AliasInfo.infos + nfields = fieldcount_noerror(typ) + if nfields === nothing + return Unindexable(false, merge_to_unindexable(infos)), 0 + end + if isa(typ, DataType) + fldval = try_compute_field(ir, fld) + fidx = try_compute_fieldidx(typ, fldval) + else + fidx = nothing + end + if fidx === nothing + return Unindexable(false, merge_to_unindexable(infos)), 0 + end + ninfos = length(infos) + if nfields > ninfos + for _ in 1:(nfields-ninfos) + push!(infos, AInfo()) + end + end + return AliasInfo, fidx +end + +function escape_builtin!(::typeof(getfield), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 3 || return false + ir, estate = astate.ir, astate.estate + obj = args[2] + typ = widenconst(argextype(obj, ir)) + if hasintersect(typ, Module) # global load + add_escape_change!(astate, SSAValue(pc), ⊤) + end + if isa(obj, SSAValue) || isa(obj, Argument) + objinfo = estate[obj] + else + return false + end + AliasInfo = objinfo.AliasInfo + if isa(AliasInfo, Bool) + AliasInfo && @goto conservative_propagation + # the fields of this object haven't been analyzed yet: analyze them now + AliasInfo, fidx = analyze_fields(ir, typ, args[3]) + if isa(AliasInfo, Indexable) + @goto record_indexable_use + else + @goto record_unindexable_use + end + elseif isa(AliasInfo, Indexable) && !AliasInfo.array + AliasInfo, fidx = reanalyze_fields(ir, AliasInfo, typ, args[3]) + isa(AliasInfo, Unindexable) && @goto record_unindexable_use + @label record_indexable_use + push!(AliasInfo.infos[fidx], pc) # record use + objinfo = EscapeInfo(objinfo, AliasInfo) + add_escape_change!(astate, obj, objinfo) + elseif isa(AliasInfo, Unindexable) && !AliasInfo.array + @label record_unindexable_use + push!(AliasInfo.info, pc) # record use + objinfo = EscapeInfo(objinfo, AliasInfo) + add_escape_change!(astate, obj, objinfo) + else + # this object has been used as array, but it is used as struct here (i.e. should throw) + # update obj's field information and just handle this case conservatively + objinfo = escape_unanalyzable_obj!(astate, obj, objinfo) + @label conservative_propagation + # the field couldn't be analyzed precisely: propagate the escape information + # imposed on the return value of this `getfield` call to the object itself + # as the most conservative propagation + ssainfo = estate[SSAValue(pc)] + add_escape_change!(astate, obj, ssainfo) + end + return false +end + +function escape_builtin!(::typeof(setfield!), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 4 || return false + ir, estate = astate.ir, astate.estate + obj = args[2] + val = args[4] + if isa(obj, SSAValue) || isa(obj, Argument) + objinfo = estate[obj] + else + # unanalyzable object (e.g. obj::GlobalRef): escape field value conservatively + add_escape_change!(astate, val, ⊤) + @goto add_thrown_escapes + end + AliasInfo = objinfo.AliasInfo + if isa(AliasInfo, Bool) + AliasInfo && @goto conservative_propagation + # the fields of this object haven't been analyzed yet: analyze them now + typ = widenconst(argextype(obj, ir)) + AliasInfo, fidx = analyze_fields(ir, typ, args[3]) + if isa(AliasInfo, Indexable) + @goto escape_indexable_def + else + @goto escape_unindexable_def + end + elseif isa(AliasInfo, Indexable) && !AliasInfo.array + typ = widenconst(argextype(obj, ir)) + AliasInfo, fidx = reanalyze_fields(ir, AliasInfo, typ, args[3]) + isa(AliasInfo, Unindexable) && @goto escape_unindexable_def + @label escape_indexable_def + add_alias_escapes!(astate, val, AliasInfo.infos[fidx]) + push!(AliasInfo.infos[fidx], -pc) # record def + objinfo = EscapeInfo(objinfo, AliasInfo) + add_escape_change!(astate, obj, objinfo) + # propagate the escape information of this object ignoring field information + add_escape_change!(astate, val, ignore_aliasinfo(objinfo)) + elseif isa(AliasInfo, Unindexable) && !AliasInfo.array + info = AliasInfo.info + @label escape_unindexable_def + add_alias_escapes!(astate, val, AliasInfo.info) + push!(AliasInfo.info, -pc) # record def + objinfo = EscapeInfo(objinfo, AliasInfo) + add_escape_change!(astate, obj, objinfo) + # propagate the escape information of this object ignoring field information + add_escape_change!(astate, val, ignore_aliasinfo(objinfo)) + else + # this object has been used as array, but it is used as struct here (i.e. should throw) + # update obj's field information and just handle this case conservatively + objinfo = escape_unanalyzable_obj!(astate, obj, objinfo) + @label conservative_propagation + # the field couldn't be analyzed: propagate the entire escape information + # of this object to the value being assigned as the most conservative propagation + add_escape_change!(astate, val, objinfo) + end + # also propagate escape information imposed on the return value of this `setfield!` + ssainfo = estate[SSAValue(pc)] + add_escape_change!(astate, val, ssainfo) + # compute the throwness of this setfield! call here since builtin_nothrow doesn't account for that + @label add_thrown_escapes + argtypes = Any[] + for i = 2:length(args) + push!(argtypes, argextype(args[i], ir)) + end + setfield!_nothrow(argtypes) || add_thrown_escapes!(astate, pc, args, 2) + return true +end + +function escape_builtin!(::typeof(arrayref), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 4 || return false + # check potential thrown escapes from this arrayref call + argtypes = Any[argextype(args[i], astate.ir) for i in 2:length(args)] + boundcheckt = argtypes[1] + aryt = argtypes[2] + if !array_builtin_common_typecheck(boundcheckt, aryt, argtypes, 3) + add_thrown_escapes!(astate, pc, args, 2) + end + ary = args[3] + inbounds = isa(boundcheckt, Const) && !boundcheckt.val::Bool + inbounds || add_escape_change!(astate, ary, ThrownEscape(pc)) + # we don't track precise index information about this array and thus don't know what values + # can be referenced here: directly propagate the escape information imposed on the return + # value of this `arrayref` call to the array itself as the most conservative propagation + # but also with updated index information + # TODO enable index analysis when constant values are available? + estate = astate.estate + if isa(ary, SSAValue) || isa(ary, Argument) + aryinfo = estate[ary] + else + return true + end + AliasInfo = aryinfo.AliasInfo + if isa(AliasInfo, Bool) + AliasInfo && @goto conservative_propagation + # the elements of this array haven't been analyzed yet: set AliasInfo now + AliasInfo = Unindexable(true, AInfo()) + @goto record_unindexable_use + elseif isa(AliasInfo, Indexable) && AliasInfo.array + throw("array index analysis unsupported") + elseif isa(AliasInfo, Unindexable) && AliasInfo.array + @label record_unindexable_use + push!(AliasInfo.info, pc) # record use + add_escape_change!(astate, ary, EscapeInfo(aryinfo, AliasInfo)) + else + # this object has been used as struct, but it is used as array here (thus should throw) + # update ary's element information and just handle this case conservatively + aryinfo = escape_unanalyzable_obj!(astate, ary, aryinfo) + @label conservative_propagation + ssainfo = estate[SSAValue(pc)] + add_escape_change!(astate, ary, ssainfo) + end + return true +end + +function escape_builtin!(::typeof(arrayset), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 5 || return false + # check potential escapes from this arrayset call + # NOTE here we essentially only need to account for TypeError, assuming that + # UndefRefError or BoundsError don't capture any of the arguments here + argtypes = Any[argextype(args[i], astate.ir) for i in 2:length(args)] + boundcheckt = argtypes[1] + aryt = argtypes[2] + valt = argtypes[3] + if !(array_builtin_common_typecheck(boundcheckt, aryt, argtypes, 4) && + arrayset_typecheck(aryt, valt)) + add_thrown_escapes!(astate, pc, args, 2) + end + ary = args[3] + val = args[4] + inbounds = isa(boundcheckt, Const) && !boundcheckt.val::Bool + inbounds || add_escape_change!(astate, ary, ThrownEscape(pc)) + # we don't track precise index information about this array and won't record what value + # is being assigned here: directly propagate the escape information of this array to + # the value being assigned as the most conservative propagation + # TODO enable index analysis when constant values are available? + estate = astate.estate + if isa(ary, SSAValue) || isa(ary, Argument) + aryinfo = estate[ary] + else + # unanalyzable object (e.g. obj::GlobalRef): escape field value conservatively + add_escape_change!(astate, val, ⊤) + return true + end + AliasInfo = aryinfo.AliasInfo + if isa(AliasInfo, Bool) + AliasInfo && @goto conservative_propagation + # the elements of this array haven't been analyzed yet: set AliasInfo now + AliasInfo = Unindexable(true, AInfo()) + @goto escape_unindexable_def + elseif isa(AliasInfo, Indexable) && AliasInfo.array + throw("array index analysis unsupported") + elseif isa(AliasInfo, Unindexable) && AliasInfo.array + @label escape_unindexable_def + add_alias_escapes!(astate, val, AliasInfo.info) + push!(AliasInfo.info, -pc) # record def + add_escape_change!(astate, ary, EscapeInfo(aryinfo, AliasInfo)) + # propagate the escape information of this array ignoring elements information + add_escape_change!(astate, val, ignore_aliasinfo(aryinfo)) + else + # this object has been used as struct, but it is used as array here (thus should throw) + # update ary's element information and just handle this case conservatively + aryinfo = escape_unanalyzable_obj!(astate, ary, aryinfo) + @label conservative_propagation + add_escape_change!(astate, val, aryinfo) + end + # also propagate escape information imposed on the return value of this `arrayset` + ssainfo = estate[SSAValue(pc)] + add_escape_change!(astate, ary, ssainfo) + return true +end + +function escape_builtin!(::typeof(arraysize), astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) == 3 || return false + ary = args[2] + dim = args[3] + if !arraysize_typecheck(ary, dim, astate.ir) + add_escape_change!(astate, ary, ThrownEscape(pc)) + add_escape_change!(astate, dim, ThrownEscape(pc)) + end + # NOTE we may still see "arraysize: dimension out of range", but it doesn't capture anything + return true +end + +function arraysize_typecheck(@nospecialize(ary), @nospecialize(dim), ir::IRCode) + aryt = argextype(ary, ir) + aryt ⊑ₜ Array || return false + dimt = argextype(dim, ir) + dimt ⊑ₜ Int || return false + return true +end + +# returns nothing if this isn't array resizing operation, +# otherwise returns true if it can throw BoundsError and false if not +function array_resize_info(name::Symbol) + if name === :jl_array_grow_beg || name === :jl_array_grow_end + return false, 1 + elseif name === :jl_array_del_beg || name === :jl_array_del_end + return true, 1 + elseif name === :jl_array_grow_at || name === :jl_array_del_at + return true, 2 + else + return nothing + end +end + +# NOTE may potentially throw "cannot resize array with shared data" error, +# but just ignore it since it doesn't capture anything +function escape_array_resize!(boundserror::Bool, ninds::Int, + astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 6+ninds || return add_fallback_changes!(astate, pc, args) + ary = args[6] + aryt = argextype(ary, astate.ir) + aryt ⊑ₜ Array || return add_fallback_changes!(astate, pc, args) + for i in 1:ninds + ind = args[i+6] + indt = argextype(ind, astate.ir) + indt ⊑ₜ Integer || return add_fallback_changes!(astate, pc, args) + end + if boundserror + # this array resizing can potentially throw `BoundsError`, impose it now + add_escape_change!(astate, ary, ThrownEscape(pc)) + end + add_liveness_changes!(astate, pc, args, 6) +end + +is_array_copy(name::Symbol) = name === :jl_array_copy + +# FIXME this implementation is very conservative, improve the accuracy and solve broken test cases +function escape_array_copy!(astate::AnalysisState, pc::Int, args::Vector{Any}) + length(args) ≥ 6 || return add_fallback_changes!(astate, pc, args) + ary = args[6] + aryt = argextype(ary, astate.ir) + aryt ⊑ₜ Array || return add_fallback_changes!(astate, pc, args) + if isa(ary, SSAValue) || isa(ary, Argument) + newary = SSAValue(pc) + aryinfo = astate.estate[ary] + newaryinfo = astate.estate[newary] + add_escape_change!(astate, newary, aryinfo) + add_escape_change!(astate, ary, newaryinfo) + end + add_liveness_changes!(astate, pc, args, 6) +end + +is_array_isassigned(name::Symbol) = name === :jl_array_isassigned + +function escape_array_isassigned!(astate::AnalysisState, pc::Int, args::Vector{Any}) + if !array_isassigned_nothrow(args, astate.ir) + add_thrown_escapes!(astate, pc, args) + end + add_liveness_changes!(astate, pc, args, 6) +end + +function array_isassigned_nothrow(args::Vector{Any}, src::IRCode) + # if !validate_foreigncall_args(args, + # :jl_array_isassigned, Cint, svec(Any,Csize_t), 0, :ccall) + # return false + # end + length(args) ≥ 7 || return false + arytype = argextype(args[6], src) + arytype ⊑ₜ Array || return false + idxtype = argextype(args[7], src) + idxtype ⊑ₜ Csize_t || return false + return true +end + +# # COMBAK do we want to enable this (and also backport this to Base for array allocations?) +# import Core.Compiler: Cint, svec +# function validate_foreigncall_args(args::Vector{Any}, +# name::Symbol, @nospecialize(rt), argtypes::SimpleVector, nreq::Int, convension::Symbol) +# length(args) ≥ 5 || return false +# normalize(args[1]) === name || return false +# args[2] === rt || return false +# args[3] === argtypes || return false +# args[4] === vararg || return false +# normalize(args[5]) === convension || return false +# return true +# end + +if isdefined(Core, :ImmutableArray) + +import Core: ImmutableArray, arrayfreeze, mutating_arrayfreeze, arraythaw + +escape_builtin!(::typeof(arrayfreeze), astate::AnalysisState, pc::Int, args::Vector{Any}) = + is_safe_immutable_array_op(Array, astate, args) +escape_builtin!(::typeof(mutating_arrayfreeze), astate::AnalysisState, pc::Int, args::Vector{Any}) = + is_safe_immutable_array_op(Array, astate, args) +escape_builtin!(::typeof(arraythaw), astate::AnalysisState, pc::Int, args::Vector{Any}) = + is_safe_immutable_array_op(ImmutableArray, astate, args) +function is_safe_immutable_array_op(@nospecialize(arytype), astate::AnalysisState, args::Vector{Any}) + length(args) == 2 || return false + argextype(args[2], astate.ir) ⊑ₜ arytype || return false + return true +end + +end # if isdefined(Core, :ImmutableArray) + +# NOTE define fancy package utilities when developing EA as an external package +if _TOP_MOD !== Core.Compiler + include(@__MODULE__, "EAUtils.jl") + using .EAUtils: code_escapes, @code_escapes + export code_escapes, @code_escapes +end + +end # baremodule EscapeAnalysis diff --git a/base/compiler/ssair/EscapeAnalysis/disjoint_set.jl b/base/compiler/ssair/EscapeAnalysis/disjoint_set.jl new file mode 100644 index 00000000000000..915bc214d7c3ce --- /dev/null +++ b/base/compiler/ssair/EscapeAnalysis/disjoint_set.jl @@ -0,0 +1,143 @@ +# A disjoint set implementation adapted from +# https://github.com/JuliaCollections/DataStructures.jl/blob/f57330a3b46f779b261e6c07f199c88936f28839/src/disjoint_set.jl +# under the MIT license: https://github.com/JuliaCollections/DataStructures.jl/blob/master/License.md + +# imports +import ._TOP_MOD: + length, + eltype, + union!, + push! +# usings +import ._TOP_MOD: + OneTo, collect, zero, zeros, one, typemax + +# Disjoint-Set + +############################################################ +# +# A forest of disjoint sets of integers +# +# Since each element is an integer, we can use arrays +# instead of dictionary (for efficiency) +# +# Disjoint sets over other key types can be implemented +# based on an IntDisjointSet through a map from the key +# to an integer index +# +############################################################ + +_intdisjointset_bounds_err_msg(T) = "the maximum number of elements in IntDisjointSet{$T} is $(typemax(T))" + +""" + IntDisjointSet{T<:Integer}(n::Integer) + +A forest of disjoint sets of integers, which is a data structure +(also called a union–find data structure or merge–find set) +that tracks a set of elements partitioned +into a number of disjoint (non-overlapping) subsets. +""" +mutable struct IntDisjointSet{T<:Integer} + parents::Vector{T} + ranks::Vector{T} + ngroups::T +end + +IntDisjointSet(n::T) where {T<:Integer} = IntDisjointSet{T}(collect(OneTo(n)), zeros(T, n), n) +IntDisjointSet{T}(n::Integer) where {T<:Integer} = IntDisjointSet{T}(collect(OneTo(T(n))), zeros(T, T(n)), T(n)) +length(s::IntDisjointSet) = length(s.parents) + +""" + num_groups(s::IntDisjointSet) + +Get a number of groups. +""" +num_groups(s::IntDisjointSet) = s.ngroups +eltype(::Type{IntDisjointSet{T}}) where {T<:Integer} = T + +# find the root element of the subset that contains x +# path compression is implemented here +function find_root_impl!(parents::Vector{T}, x::Integer) where {T<:Integer} + p = parents[x] + @inbounds if parents[p] != p + parents[x] = p = _find_root_impl!(parents, p) + end + return p +end + +# unsafe version of the above +function _find_root_impl!(parents::Vector{T}, x::Integer) where {T<:Integer} + @inbounds p = parents[x] + @inbounds if parents[p] != p + parents[x] = p = _find_root_impl!(parents, p) + end + return p +end + +""" + find_root!(s::IntDisjointSet{T}, x::T) + +Find the root element of the subset that contains an member `x`. +Path compression happens here. +""" +find_root!(s::IntDisjointSet{T}, x::T) where {T<:Integer} = find_root_impl!(s.parents, x) + +""" + in_same_set(s::IntDisjointSet{T}, x::T, y::T) + +Returns `true` if `x` and `y` belong to the same subset in `s`, and `false` otherwise. +""" +in_same_set(s::IntDisjointSet{T}, x::T, y::T) where {T<:Integer} = find_root!(s, x) == find_root!(s, y) + +""" + union!(s::IntDisjointSet{T}, x::T, y::T) + +Merge the subset containing `x` and that containing `y` into one +and return the root of the new set. +""" +function union!(s::IntDisjointSet{T}, x::T, y::T) where {T<:Integer} + parents = s.parents + xroot = find_root_impl!(parents, x) + yroot = find_root_impl!(parents, y) + return xroot != yroot ? root_union!(s, xroot, yroot) : xroot +end + +""" + root_union!(s::IntDisjointSet{T}, x::T, y::T) + +Form a new set that is the union of the two sets whose root elements are +`x` and `y` and return the root of the new set. +Assume `x ≠ y` (unsafe). +""" +function root_union!(s::IntDisjointSet{T}, x::T, y::T) where {T<:Integer} + parents = s.parents + rks = s.ranks + @inbounds xrank = rks[x] + @inbounds yrank = rks[y] + + if xrank < yrank + x, y = y, x + elseif xrank == yrank + rks[x] += one(T) + end + @inbounds parents[y] = x + s.ngroups -= one(T) + return x +end + +""" + push!(s::IntDisjointSet{T}) + +Make a new subset with an automatically chosen new element `x`. +Returns the new element. Throw an `ArgumentError` if the +capacity of the set would be exceeded. +""" +function push!(s::IntDisjointSet{T}) where {T<:Integer} + l = length(s) + l < typemax(T) || throw(ArgumentError(_intdisjointset_bounds_err_msg(T))) + x = l + one(T) + push!(s.parents, x) + push!(s.ranks, zero(T)) + s.ngroups += one(T) + return x +end diff --git a/base/compiler/ssair/driver.jl b/base/compiler/ssair/driver.jl index e54a09fe351b3d..7329dafcb11215 100644 --- a/base/compiler/ssair/driver.jl +++ b/base/compiler/ssair/driver.jl @@ -14,8 +14,10 @@ include("compiler/ssair/basicblock.jl") include("compiler/ssair/domtree.jl") include("compiler/ssair/ir.jl") include("compiler/ssair/slot2ssa.jl") -include("compiler/ssair/passes.jl") include("compiler/ssair/inlining.jl") include("compiler/ssair/verify.jl") include("compiler/ssair/legacy.jl") -#@isdefined(Base) && include("compiler/ssair/show.jl") +function try_compute_field end # imported by EscapeAnalysis +include("compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl") +include("compiler/ssair/passes.jl") +# @isdefined(Base) && include("compiler/ssair/show.jl") diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 51b5ec076f25db..c6913dd077d60a 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -641,7 +641,7 @@ its argument). In a case when all usages are fully eliminated, `struct` allocation may also be erased as a result of succeeding dead code elimination. """ -function sroa_pass!(ir::IRCode) +function sroa_pass!(ir::IRCode, nargs::Int) compact = IncrementalCompact(ir) defuses = nothing # will be initialized once we encounter mutability in order to reduce dynamic allocations lifting_cache = IdDict{Pair{AnySSAValue, Any}, AnySSAValue}() @@ -813,7 +813,7 @@ function sroa_pass!(ir::IRCode) used_ssas = copy(compact.used_ssas) simple_dce!(compact, (x::SSAValue) -> used_ssas[x.id] -= 1) ir = complete(compact) - sroa_mutables!(ir, defuses, used_ssas) + sroa_mutables!(ir, defuses, used_ssas, nargs) return ir else simple_dce!(compact) @@ -821,9 +821,10 @@ function sroa_pass!(ir::IRCode) end end -function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}) +function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse}}, used_ssas::Vector{Int}, nargs::Int) # initialization of domtree is delayed to avoid the expensive computation in many cases local domtree = nothing + estate = analyze_escapes(ir, nargs) for (idx, (intermediaries, defuse)) in defuses intermediaries = collect(intermediaries) # Check if there are any uses we did not account for. If so, the variable @@ -899,6 +900,7 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse end end end + is_load_forwardable(estate[SSAValue(idx)]) || println("[EA] bad EA: ", ir.argtypes[1:nargs], " at ", idx) # Everything accounted for. Go field by field and perform idf: # Compute domtree now, needed below, now that we have finished compacting the IR. # This needs to be after we iterate through the IR with `IncrementalCompact` @@ -957,6 +959,11 @@ function sroa_mutables!(ir::IRCode, defuses::IdDict{Int, Tuple{SPCSet, SSADefUse end end +function is_load_forwardable(x::EscapeAnalysis.EscapeInfo) + AliasInfo = x.AliasInfo + return isa(AliasInfo, EscapeAnalysis.Indexable) && !AliasInfo.array +end + function form_new_preserves(origex::Expr, intermediates::Vector{Int}, new_preserves::Vector{Any}) newex = Expr(:foreigncall) nccallargs = length(origex.args[3]::SimpleVector) diff --git a/base/compiler/utilities.jl b/base/compiler/utilities.jl index e97441495f16b8..9b1106e9649199 100644 --- a/base/compiler/utilities.jl +++ b/base/compiler/utilities.jl @@ -19,6 +19,8 @@ function _any(@nospecialize(f), a) end return false end +any(@nospecialize(f), itr) = _any(f, itr) +any(itr) = _any(identity, itr) function _all(@nospecialize(f), a) for x in a @@ -26,6 +28,8 @@ function _all(@nospecialize(f), a) end return true end +all(@nospecialize(f), itr) = _all(f, itr) +all(itr) = _all(identity, itr) function contains_is(itr, @nospecialize(x)) for y in itr diff --git a/base/exports.jl b/base/exports.jl index c43e66eecb74c1..7d71915909a5b4 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -774,6 +774,7 @@ export # help and reflection code_typed, code_lowered, + code_escapes, fullname, functionloc, isconst, diff --git a/doc/make.jl b/doc/make.jl index 8be3b807400d11..214644f7d74735 100644 --- a/doc/make.jl +++ b/doc/make.jl @@ -148,6 +148,7 @@ DevDocs = [ "devdocs/require.md", "devdocs/inference.md", "devdocs/ssair.md", + "devdocs/EscapeAnalysis/README.md", "devdocs/gc-sa.md", ], "Developing/debugging Julia's C code" => [ diff --git a/doc/src/devdocs/EscapeAnalysis/README.md b/doc/src/devdocs/EscapeAnalysis/README.md new file mode 100644 index 00000000000000..a82ea7196b00e5 --- /dev/null +++ b/doc/src/devdocs/EscapeAnalysis/README.md @@ -0,0 +1,323 @@ +[![CI](https://github.com/aviatesk/EscapeAnalysis.jl/actions/workflows/ci.yml/badge.svg)](https://github.com/aviatesk/EscapeAnalysis.jl/actions/workflows/ci.yml) +[![codecov](https://codecov.io/gh/aviatesk/EscapeAnalysis.jl/branch/master/graph/badge.svg?token=ADEKPZRUJH)](https://codecov.io/gh/aviatesk/EscapeAnalysis.jl) +[![](https://img.shields.io/badge/docs-dev-blue.svg)](https://aviatesk.github.io/EscapeAnalysis.jl/dev/) + +`EscapeAnalysis` is a simple module that collects escape information in +[Julia's SSA optimization IR](@ref Julia-SSA-form-IR) a.k.a. `IRCode`. + +You can give a try to the escape analysis with the convenience entries that +`EscapeAnalysis` exports for testing and debugging purposes: +```@docs +Base.code_escapes +InteractiveUtils.@code_escapes +``` + +## Analysis Design + +### Lattice Design + +`EscapeAnalysis` is implemented as a [data-flow analysis](https://en.wikipedia.org/wiki/Data-flow_analysis) +that works on a lattice of `x::EscapeInfo`, which is composed of the following properties: +- `x.Analyzed::Bool`: not formally part of the lattice, only indicates `x` has not been analyzed or not +- `x.ReturnEscape::BitSet`: records SSA statements where `x` can escape to the caller via return +- `x.ThrownEscape::BitSet`: records SSA statements where `x` can be thrown as exception + (used for the [exception handling](@ref EA-Exception-Handling) described below) +- `x.AliasInfo`: maintains all possible values that can be aliased to fields or array elements of `x` + (used for the [alias analysis](@ref EA-Alias-Analysis) described below) +- `x.ArgEscape::Int` (not implemented yet): indicates it will escape to the caller through + `setfield!` on argument(s) + +These attributes can be combined to create a partial lattice that has a finite height, given +the invariant that an input program has a finite number of statements, which is assured by Julia's semantics. +The clever part of this lattice design is that it enables a simpler implementation of +lattice operations by allowing them to handle each lattice property separately[^LatticeDesign]. + +### Backward Escape Propagation + +This escape analysis implementation is based on the data-flow algorithm described in the paper[^MM02]. +The analysis works on the lattice of `EscapeInfo` and transitions lattice elements from the +bottom to the top until every lattice element gets converged to a fixed point by maintaining +a (conceptual) working set that contains program counters corresponding to remaining SSA +statements to be analyzed. The analysis manages a single global state that tracks +`EscapeInfo` of each argument and SSA statement, but also note that some flow-sensitivity +is encoded as program counters recorded in `EscapeInfo`'s `ReturnEscape` property, +which can be combined with domination analysis later to reason about flow-sensitivity if necessary. + +One distinctive design of this escape analysis is that it is fully _backward_, +i.e. escape information flows _from usages to definitions_. +For example, in the code snippet below, EA first analyzes the statement `return %1` and +imposes `ReturnEscape` on `%1` (corresponding to `obj`), and then it analyzes +`%1 = %new(Base.RefValue{String, _2}))` and propagates the `ReturnEscape` imposed on `%1` +to the call argument `_2` (corresponding to `s`): +```julia +julia> code_escapes((String,)) do s + obj = Ref(s) + return obj + end +#1(↑ _2::String) in Main at REPL[2]:2 +2 ↑ 1 ─ %1 = %new(Base.RefValue{String}, _2)::Base.RefValue{String} │╻╷╷ Ref +3 ◌ └── return %1 │ +``` + +The key observation here is that this backward analysis allows escape information to flow +naturally along the use-def chain rather than control-flow[^BackandForth]. +As a result this scheme enables a simple implementation of escape analysis, +e.g. `PhiNode` for example can be handled simply by propagating escape information +imposed on a `PhiNode` to its predecessor values: +```julia +julia> code_escapes((Bool, String, String)) do cnd, s, t + if cnd + obj = Ref(s) + else + obj = Ref(t) + end + return obj + end + #3(↑ _2::Bool, ↑ _3::String, ↑ _4::String) in Main at REPL[3]:2 +2 ◌ 1 ─ goto #3 if not _2 │ +3 ↑ 2 ─ %2 = %new(Base.RefValue{String}, _3)::Base.RefValue{String} │╻╷╷ Ref + ◌ └── goto #4 │ +5 ↑ 3 ─ %4 = %new(Base.RefValue{String}, _4)::Base.RefValue{String} │╻╷╷ Ref +7 ↑ 4 ┄ %5 = φ (#2 => %2, #3 => %4)::Base.RefValue{String} │ + ◌ └── return %5 │ +``` + +### [Alias Analysis](@id EA-Alias-Analysis) + +`EscapeAnalysis` implements a backward field analysis in order to reason about escapes +imposed on object fields with certain accuracy, +and `x::EscapeInfo`'s `x.AliasInfo` property exists for this purpose. +It records all possible values that can be aliased to fields of `x` at "usage" sites, +and then the escape information of that recorded values are propagated to the actual field values later at "definition" sites. +More specifically, the analysis records a value that may be aliased to a field of object by analyzing `getfield` call, +and then it propagates its escape information to the field when analyzing `%new(...)` expression or `setfield!` call[^Dynamism]. +```julia +julia> mutable struct SafeRef{T} + x::T + end + +julia> Base.getindex(x::SafeRef) = x.x; + +julia> Base.setindex!(x::SafeRef, v) = x.x = v; + +julia> code_escapes((String,)) do s + obj = SafeRef("init") + obj[] = s + v = obj[] + return v + end +#5(↑ _2::String) in Main at REPL[7]:2 +2 ✓′ 1 ─ %1 = %new(SafeRef{String}, "init")::SafeRef{String} │╻╷ SafeRef +3 ◌ │ Base.setfield!(%1, :x, _2)::String │╻╷ setindex! +4 ↑ │ %3 = Base.getfield(%1, :x)::String │╻╷ getindex +5 ◌ └── return %3 │ +``` +In the example above, `ReturnEscape` imposed on `%3` (corresponding to `v`) is _not_ directly +propagated to `%1` (corresponding to `obj`) but rather that `ReturnEscape` is only propagated +to `_2` (corresponding to `s`). Here `%3` is recorded in `%1`'s `AliasInfo` property as +it can be aliased to the first field of `%1`, and then when analyzing `Base.setfield!(%1, :x, _2)::String`, +that escape information is propagated to `_2` but not to `%1`. + +So `EscapeAnalysis` tracks which IR elements can be aliased across a `getfield`-`%new`/`setfield!` chain +in order to analyze escapes of object fields, but actually this alias analysis needs to be +generalized to handle other IR elements as well. This is because in Julia IR the same +object is sometimes represented by different IR elements and so we should make sure that those +different IR elements that actually can represent the same object share the same escape information. +IR elements that return the same object as their operand(s), such as `PiNode` and `typeassert`, +can cause that IR-level aliasing and thus requires escape information imposed on any of such +aliased values to be shared between them. +More interestingly, it is also needed for correctly reasoning about mutations on `PhiNode`. +Let's consider the following example: +```julia +julia> code_escapes((Bool, String,)) do cond, x + if cond + ϕ2 = ϕ1 = SafeRef("foo") + else + ϕ2 = ϕ1 = SafeRef("bar") + end + ϕ2[] = x + y = ϕ1[] + return y + end +#7(↑ _2::Bool, ↑ _3::String) in Main at REPL[8]:2 +2 ◌ 1 ─ goto #3 if not _2 │ +3 ✓′ 2 ─ %2 = %new(SafeRef{String}, "foo")::SafeRef{String} │╻╷ SafeRef + ◌ └── goto #4 │ +5 ✓′ 3 ─ %4 = %new(SafeRef{String}, "bar")::SafeRef{String} │╻╷ SafeRef +7 ✓′ 4 ┄ %5 = φ (#2 => %2, #3 => %4)::SafeRef{String} │ + ✓′ │ %6 = φ (#2 => %2, #3 => %4)::SafeRef{String} │ + ◌ │ Base.setfield!(%5, :x, _3)::String │╻ setindex! +8 ↑ │ %8 = Base.getfield(%6, :x)::String │╻╷ getindex +9 ◌ └── return %8 │ +``` +`ϕ1 = %5` and `ϕ2 = %6` are aliased and thus `ReturnEscape` imposed on `%8 = Base.getfield(%6, :x)::String` (corresponding to `y = ϕ1[]`) +needs to be propagated to `Base.setfield!(%5, :x, _3)::String` (corresponding to `ϕ2[] = x`). +In order for such escape information to be propagated correctly, the analysis should recognize that +the _predecessors_ of `ϕ1` and `ϕ2` can be aliased as well and equalize their escape information. + +One interesting property of such aliasing information is that it is not known at "usage" site +but can only be derived at "definition" site (as aliasing is conceptually equivalent to assignment), +and thus it doesn't naturally fit in a backward analysis. In order to efficiently propagate escape +information between related values, EscapeAnalysis.jl uses an approach inspired by the escape +analysis algorithm explained in an old JVM paper[^JVM05]. That is, in addition to managing +escape lattice elements, the analysis also maintains an "equi"-alias set, a disjoint set of +aliased arguments and SSA statements. The alias set manages values that can be aliased to +each other and allows escape information imposed on any of such aliased values to be equalized +between them. + +Lastly, this scheme of alias/field analysis can also be generalized to analyze array operations. +`EscapeAnalysis` currently reasons about escapes imposed on array elements using +an imprecise version of the field analysis described above, where `AliasInfo` doesn't +try to track precise array index but rather simply records all possible values that can be +aliased any elements of the array. + +### [Exception Handling](@id EA-Exception-Handling) + +It would be also worth noting how `EscapeAnalysis` handles possible escapes via exceptions. +Naively it seems enough to propagate escape information imposed on `:the_exception` object to +all values that may be thrown in a corresponding `try` block. +But there are actually several other ways to access to the exception object in Julia, +such as `Base.current_exceptions` and manual catch of `rethrow`n object. +For example, escape analysis needs to account for potential escape of `r` in the example below: +```julia +julia> const Gx = Ref{Any}(); + +julia> @noinline function rethrow_escape!() + try + rethrow() + catch err + Gx[] = err + end + end; + +julia> get′(x) = isassigned(x) ? x[] : throw(x); + +julia> code_escapes() do + r = Ref{String}() + local t + try + t = get′(r) + catch err + t = typeof(err) # `err` (which `r` aliases to) doesn't escape here + rethrow_escape!() # but `r` escapes here + end + return t + end +#9() in Main at REPL[12]:2 +2 X 1 ── %1 = %new(Base.RefValue{String})::Base.RefValue{String} │╻╷ Ref +4 ◌ 2 ── %2 = $(Expr(:enter, #8)) │ +5 ◌ 3 ── %3 = Base.isdefined(%1, :x)::Bool │╻╷ get′ + ◌ └─── goto #5 if not %3 ││ + ↑ 4 ── %5 = Base.getfield(%1, :x)::String ││╻ getindex + ◌ └─── goto #6 ││ + ◌ 5 ── Main.throw(%1)::Union{} ││ + ◌ └─── unreachable ││ + ◌ 6 ── $(Expr(:leave, 1)) │ + ◌ 7 ── goto #10 │ + ◌ 8 ── $(Expr(:leave, 1)) │ + ◌ 9 ── %12 = $(Expr(:the_exception))::Any │ +7 ↑ │ %13 = Main.typeof(%12)::DataType │ +8 ◌ │ invoke Main.rethrow_escape!()::Any │ + ◌ └─── $(Expr(:pop_exception, :(%2)))::Any │ +10 ↑ 10 ┄ %16 = φ (#7 => %5, #9 => %13)::Union{DataType, String} │ + ◌ └─── return %16 │ +``` + +It requires a global analysis in order to correctly reason about all possible escapes via +existing exception interfaces. For now we always propagate the topmost escape information to +all potentially thrown objects conservatively, since such an additional analysis might not be +worthwhile to do given that exception handling and error path usually don't need to be +very performance sensitive, and also optimizations of error paths might be very ineffective anyway +since they are often even "unoptimized" intentionally for latency reasons. + +`x::EscapeInfo`'s `x.ThrownEscape` property records SSA statements where `x` can be thrown as an exception. +Using this information `EscapeAnalysis` can propagate possible escapes via exceptions limitedly +to only those may be thrown in each `try` region: +```julia +julia> result = code_escapes((String,String)) do s1, s2 + r1 = Ref(s1) + r2 = Ref(s2) + local ret + try + s1 = get′(r1) + ret = sizeof(s1) + catch err + global g = err # will definitely escape `r1` + end + s2 = get′(r2) # still `r2` doesn't escape fully + return s2 + end +#11(X _2::String, ↑ _3::String) in Main at REPL[13]:2 +2 X 1 ── %1 = %new(Base.RefValue{String}, _2)::Base.RefValue{String} │╻╷╷ Ref +3 *′ └─── %2 = %new(Base.RefValue{String}, _3)::Base.RefValue{String} │╻╷╷ Ref +5 ◌ 2 ── %3 = $(Expr(:enter, #8)) │ + *′ └─── %4 = ϒ (%2)::Base.RefValue{String} │ +6 ◌ 3 ── %5 = Base.isdefined(%1, :x)::Bool │╻╷ get′ + ◌ └─── goto #5 if not %5 ││ + X 4 ── Base.getfield(%1, :x)::String ││╻ getindex + ◌ └─── goto #6 ││ + ◌ 5 ── Main.throw(%1)::Union{} ││ + ◌ └─── unreachable ││ +7 ◌ 6 ── nothing::typeof(Core.sizeof) │╻ sizeof + ◌ │ nothing::Int64 ││ + ◌ └─── $(Expr(:leave, 1)) │ + ◌ 7 ── goto #10 │ + *′ 8 ── %15 = φᶜ (%4)::Base.RefValue{String} │ + ◌ └─── $(Expr(:leave, 1)) │ + X 9 ── %17 = $(Expr(:the_exception))::Any │ +9 ◌ │ (Main.g = %17)::Any │ + ◌ └─── $(Expr(:pop_exception, :(%3)))::Any │ +11 *′ 10 ┄ %20 = φ (#7 => %2, #9 => %15)::Base.RefValue{String} │ + ◌ │ %21 = Base.isdefined(%20, :x)::Bool ││╻ isassigned + ◌ └─── goto #12 if not %21 ││ + ↑ 11 ─ %23 = Base.getfield(%20, :x)::String │││╻ getproperty + ◌ └─── goto #13 ││ + ◌ 12 ─ Main.throw(%20)::Union{} ││ + ◌ └─── unreachable ││ +12 ◌ 13 ─ return %23 │ +``` + +## Analysis Usage + +When using `EscapeAnalysis` in Julia's high-level compilation pipeline, we can run +`analyze_escapes(ir::IRCode) -> estate::EscapeState` to analyze escape information of each SSA-IR element in `ir`. + +Note that it should be most effective if `analyze_escapes` runs after inlining, +as `EscapeAnalysis`'s interprocedural escape information handling is limited at this moment. + +Since the computational cost of `analyze_escapes` is not that cheap, +it is more ideal if it runs once and succeeding optimization passes incrementally update + the escape information upon IR transformation. + +```@docs +Core.Compiler.EscapeAnalysis.analyze_escapes +Core.Compiler.EscapeAnalysis.EscapeState +Core.Compiler.EscapeAnalysis.EscapeInfo +Core.Compiler.EscapeAnalysis.cache_escapes! +``` + +[^LatticeDesign]: Our type inference implementation takes the alternative approach, + where each lattice property is represented by a special lattice element type object. + It turns out that it started to complicate implementations of the lattice operations + mainly because it often requires conversion rules between each lattice element type object. + And we are working on [overhauling our type inference lattice implementation](https://github.com/JuliaLang/julia/pull/42596) + with `EscapeInfo`-like lattice design. + +[^MM02]: _A Graph-Free approach to Data-Flow Analysis_. + Markas Mohnen, 2002, April. + . + +[^BackandForth]: Our type inference algorithm in contrast is implemented as a forward analysis, + because type information usually flows from "definition" to "usage" and it is more + natural and effective to propagate such information in a forward way. + +[^Dynamism]: In some cases, however, object fields can't be analyzed precisely. + For example, object may escape to somewhere `EscapeAnalysis` can't account for possible memory effects on it, + or fields of the objects simply can't be known because of the lack of type information. + In such cases `AliasInfo` property is raised to the topmost element within its own lattice order, + and it causes succeeding field analysis to be conservative and escape information imposed on + fields of an unanalyzable object to be propagated to the object itself. + +[^JVM05]: _Escape Analysis in the Context of Dynamic Compilation and Deoptimization_. + Thomas Kotzmann and Hanspeter Mössenböck, 2005, June. + . diff --git a/doc/src/devdocs/llvm.md b/doc/src/devdocs/llvm.md index 1e983949ea0b67..840822f1360045 100644 --- a/doc/src/devdocs/llvm.md +++ b/doc/src/devdocs/llvm.md @@ -28,7 +28,7 @@ The difference between an intrinsic and a builtin is that a builtin is a first c that can be used like any other Julia function. An intrinsic can operate only on unboxed data, and therefore its arguments must be statically typed. -### Alias Analysis +### [Alias Analysis](@id LLVM-Alias-Analysis) Julia currently uses LLVM's [Type Based Alias Analysis](https://llvm.org/docs/LangRef.html#tbaa-metadata). To find the comments that document the inclusion relationships, look for `static MDNode*` in diff --git a/stdlib/InteractiveUtils/src/InteractiveUtils.jl b/stdlib/InteractiveUtils/src/InteractiveUtils.jl index 1df8a2ca8f93bb..76edec0d0cc574 100644 --- a/stdlib/InteractiveUtils/src/InteractiveUtils.jl +++ b/stdlib/InteractiveUtils/src/InteractiveUtils.jl @@ -6,7 +6,7 @@ Base.Experimental.@optlevel 1 export apropos, edit, less, code_warntype, code_llvm, code_native, methodswith, varinfo, versioninfo, subtypes, supertypes, @which, @edit, @less, @functionloc, @code_warntype, - @code_typed, @code_lowered, @code_llvm, @code_native, @time_imports, clipboard + @code_typed, @code_lowered, @code_llvm, @code_native, @code_escapes, @time_imports, clipboard import Base.Docs.apropos diff --git a/stdlib/InteractiveUtils/src/macros.jl b/stdlib/InteractiveUtils/src/macros.jl index 0a1fd848b02083..659111b9b2bf25 100644 --- a/stdlib/InteractiveUtils/src/macros.jl +++ b/stdlib/InteractiveUtils/src/macros.jl @@ -208,7 +208,7 @@ macro which(ex0::Symbol) return :(which($__module__, $ex0)) end -for fname in [:code_warntype, :code_llvm, :code_native] +for fname in [:code_warntype, :code_llvm, :code_native, :code_escapes] @eval begin macro ($fname)(ex0...) gen_call_with_extracted_types_and_kwargs(__module__, $(Expr(:quote, fname)), ex0) @@ -350,6 +350,16 @@ See also: [`code_native`](@ref), [`@code_llvm`](@ref), [`@code_typed`](@ref) and """ :@code_native +""" + @code_escapes [options...] f(args...) + +Evaluates the arguments to the function call, determines its types, and then calls +[`code_escapes`](@ref) on the resulting expression. +As with `@code_typed` and its family, any of `code_escapes` keyword arguments can be given +as the optional arguments like `@code_escpase interp=myinterp myfunc(myargs...)`. +""" +:@code_escapes + """ @time_imports diff --git a/test/choosetests.jl b/test/choosetests.jl index e00aedffdd42e0..0415481a49b44a 100644 --- a/test/choosetests.jl +++ b/test/choosetests.jl @@ -142,7 +142,7 @@ function choosetests(choices = []) filtertests!(tests, "subarray") filtertests!(tests, "compiler", ["compiler/inference", "compiler/validation", "compiler/ssair", "compiler/irpasses", "compiler/codegen", - "compiler/inline", "compiler/contextual"]) + "compiler/inline", "compiler/contextual", "compiler/EscapeAnalysis/EscapeAnalysis"]) filtertests!(tests, "stdlib", STDLIBS) # do ambiguous first to avoid failing if ambiguities are introduced by other tests filtertests!(tests, "ambiguous") diff --git a/test/compiler/EscapeAnalysis/EscapeAnalysis.jl b/test/compiler/EscapeAnalysis/EscapeAnalysis.jl new file mode 100644 index 00000000000000..001b513a25e3f8 --- /dev/null +++ b/test/compiler/EscapeAnalysis/EscapeAnalysis.jl @@ -0,0 +1,2052 @@ +@isdefined(EA_AS_PKG) || include(normpath(@__DIR__, "setup.jl")) + +@testset "basics" begin + let # arg return + result = code_escapes((Any,)) do a # return to caller + return nothing + end + @test has_arg_escape(result.state[Argument(2)]) + # return + result = code_escapes((Any,)) do a + return a + end + i = only(findall(isreturn, result.ir.stmts.inst)) + @test has_arg_escape(result.state[Argument(1)]) # self + @test !has_return_escape(result.state[Argument(1)], i) # self + @test has_arg_escape(result.state[Argument(2)]) # a + @test has_return_escape(result.state[Argument(2)], i) # a + end + let # global store + result = code_escapes((Any,)) do a + global aa = a + nothing + end + @test has_all_escape(result.state[Argument(2)]) + end + let # global load + result = code_escapes() do + global gr + return gr + end + i = only(findall(has_return_escape, map(i->result.state[SSAValue(i)], 1:length(result.ir.stmts)))) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # global store / load (https://github.com/aviatesk/EscapeAnalysis.jl/issues/56) + result = code_escapes((Any,)) do s + global v + v = s + return v + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + end + let # :gc_preserve_begin / :gc_preserve_end + result = code_escapes((String,)) do s + m = SafeRef(s) + GC.@preserve m begin + return nothing + end + end + i = findfirst(isT(SafeRef{String}), result.ir.stmts.type) # find allocation statement + @test !isnothing(i) + @test has_no_escape(result.state[SSAValue(i)]) + end + let # :isdefined + result = code_escapes((String, Bool, )) do a, b + if b + s = Ref(a) + end + return @isdefined(s) + end + i = findfirst(isT(Base.RefValue{String}), result.ir.stmts.type) # find allocation statement + @test !isnothing(i) + @test has_no_escape(result.state[SSAValue(i)]) + end + let # ϕ-node + result = code_escapes((Bool,Any,Any)) do cond, a, b + c = cond ? a : b # ϕ(a, b) + return c + end + @assert any(@nospecialize(x)->isa(x, Core.PhiNode), result.ir.stmts.inst) + i = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], i) # a + @test has_return_escape(result.state[Argument(4)], i) # b + end + let # π-node + result = code_escapes((Any,)) do a + if isa(a, Regex) # a::π(Regex) + return a + end + return nothing + end + @assert any(@nospecialize(x)->isa(x, Core.PiNode), result.ir.stmts.inst) + @test any(findall(isreturn, result.ir.stmts.inst)) do i + has_return_escape(result.state[Argument(2)], i) + end + end + let # φᶜ-node / ϒ-node + result = code_escapes((Any,String)) do a, b + local x::String + try + x = a + catch err + x = b + end + return x + end + @assert any(@nospecialize(x)->isa(x, Core.PhiCNode), result.ir.stmts.inst) + @assert any(@nospecialize(x)->isa(x, Core.UpsilonNode), result.ir.stmts.inst) + i = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], i) + @test has_return_escape(result.state[Argument(3)], i) + end + let # branching + result = code_escapes((Any,Bool,)) do a, c + if c + return nothing # a doesn't escape in this branch + else + return a # a escapes to a caller + end + end + @test has_return_escape(result.state[Argument(2)]) + end + let # loop + result = code_escapes((Int,)) do n + c = SafeRef{Bool}(false) + while n > 0 + rand(Bool) && return c + end + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)]) + end + let # try/catch + result = code_escapes((Any,)) do a + try + nothing + catch err + return a # return escape + end + end + @test has_return_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do a + try + nothing + finally + return a # return escape + end + end + @test has_return_escape(result.state[Argument(2)]) + end + let # :foreigncall + result = code_escapes((Any,)) do x + ccall(:some_ccall, Any, (Any,), x) + end + @test has_all_escape(result.state[Argument(2)]) + end +end + +let # simple allocation + result = code_escapes((Bool,)) do c + mm = SafeRef{Bool}(c) # just allocated, never escapes + return mm[] ? nothing : 1 + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(i)]) +end + +@testset "inter-procedural" begin + # FIXME currently we can't prove the effect-freeness of `getfield(RefValue{String}, :x)` + # because of this check https://github.com/JuliaLang/julia/blob/94b9d66b10e8e3ebdb268e4be5f7e1f43079ad4e/base/compiler/tfuncs.jl#L745 + # and thus it leads to the following two broken tests + let result = @eval Module() begin + @noinline broadcast_NoEscape(a) = (broadcast(identity, a); nothing) + $code_escapes() do + broadcast_NoEscape(Ref("Hi")) + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test_broken has_no_escape(result.state[SSAValue(i)]) + end + let result = @eval Module() begin + @noinline broadcast_NoEscape2(b) = broadcast(identity, b) + $code_escapes() do + broadcast_NoEscape2(Ref("Hi")) + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test_broken has_no_escape(result.state[SSAValue(i)]) + end + let result = @eval Module() begin + @noinline f_GlobalEscape_a(a) = (global globala = a) # obvious escape + $code_escapes() do + f_GlobalEscape_a(Ref("Hi")) + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)]) && has_thrown_escape(result.state[SSAValue(i)]) + end + # if we can't determine the matching method statically, we should be conservative + let result = @eval Module() $code_escapes((Ref{Any},)) do a + may_exist(a) + end + @test has_all_escape(result.state[Argument(2)]) + end + let result = @eval Module() begin + @noinline broadcast_NoEscape(a) = (broadcast(identity, a); nothing) + $code_escapes((Ref{Any},)) do a + Base.@invokelatest broadcast_NoEscape(a) + end + end + @test has_all_escape(result.state[Argument(2)]) + end + + # handling of simple union-split (just exploit the inliner's effort) + let T = Union{Int,Nothing} + result = @eval Module() begin + @noinline unionsplit_NoEscape_a(a) = string(nothing) + @noinline unionsplit_NoEscape_a(a::Int) = a + 10 + $code_escapes(($T,)) do x + s = $SafeRef{$T}(x) + unionsplit_NoEscape_a(s[]) + return nothing + end + end + inds = findall(isT(SafeRef{T}), result.ir.stmts.type) # find allocation statement + @assert !isempty(inds) + for i in inds + @test has_no_escape(result.state[SSAValue(i)]) + end + end + + # appropriate conversion of inter-procedural context + # https://github.com/aviatesk/EscapeAnalysis.jl/issues/7 + let M = Module() + @eval M @noinline f_NoEscape_a(a) = (println("prevent inlining"); Base.inferencebarrier(nothing)) + + result = @eval M $code_escapes() do + a = Ref("foo") # shouldn't be "return escape" + b = f_NoEscape_a(a) + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + + result = @eval M $code_escapes() do + a = Ref("foo") # still should be "return escape" + b = f_NoEscape_a(a) + return a + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + end + + # should propagate escape information imposed on return value to the aliased call argument + let result = @eval Module() begin + @noinline f_ReturnEscape_a(a) = (println("prevent inlining"); a) + $code_escapes() do + obj = Ref("foo") # should be "return escape" + ret = f_ReturnEscape_a(obj) + return ret # alias of `obj` + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + end + let result = @eval Module() begin + @noinline f_NoReturnEscape_a(a) = (println("prevent inlining"); identity("hi")) + $code_escapes() do + obj = Ref("foo") # better to not be "return escape" + ret = f_NoReturnEscape_a(obj) + return ret # must not alias to `obj` + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + end +end + +@testset "builtins" begin + let # throw + r = code_escapes((Any,)) do a + throw(a) + end + @test has_thrown_escape(r.state[Argument(2)]) + end + + let # implicit throws + r = code_escapes((Any,)) do a + getfield(a, :may_not_field) + end + @test has_thrown_escape(r.state[Argument(2)]) + + r = code_escapes((Any,)) do a + sizeof(a) + end + @test has_thrown_escape(r.state[Argument(2)]) + end + + let # :=== + result = code_escapes((Bool, String)) do cond, s + m = cond ? SafeRef(s) : nothing + c = m === nothing + return c + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(i)]) + end + + let # sizeof + ary = [0,1,2] + result = @eval code_escapes() do + ary = $(QuoteNode(ary)) + sizeof(ary) + end + i = only(findall(isT(Core.Const(ary)), result.ir.stmts.type)) + @test has_no_escape(result.state[SSAValue(i)]) + end + + let # ifelse + result = code_escapes((Bool,)) do c + r = ifelse(c, Ref("yes"), Ref("no")) + return r + end + inds = findall(isnew, result.ir.stmts.inst) + @assert !isempty(inds) + for i in inds + @test has_return_escape(result.state[SSAValue(i)]) + end + end + let # ifelse (with constant condition) + result = code_escapes() do + r = ifelse(true, Ref("yes"), Ref(nothing)) + return r + end + for i in 1:length(result.ir.stmts) + if isnew(result.ir.stmts.inst[i]) && isT(Base.RefValue{String})(result.ir.stmts.type[i]) + @test has_return_escape(result.state[SSAValue(i)]) + elseif isnew(result.ir.stmts.inst[i]) && isT(Base.RefValue{Nothing})(result.ir.stmts.type[i]) + @test has_no_escape(result.state[SSAValue(i)]) + end + end + end + + let # typeassert + result = code_escapes((Any,)) do x + y = x::String + return y + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test !has_all_escape(result.state[Argument(2)]) + end + + let # isdefined + result = code_escapes((Any,)) do x + isdefined(x, :foo) ? x : throw("undefined") + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test !has_all_escape(result.state[Argument(2)]) + + result = code_escapes((Module,)) do m + isdefined(m, 10) # throws + end + @test has_thrown_escape(result.state[Argument(2)]) + end +end + +@testset "flow-sensitivity" begin + # ReturnEscape + let result = code_escapes((Bool,)) do cond + r = Ref("foo") + if cond + return cond + end + return r + end + i = only(findall(isnew, result.ir.stmts.inst)) + rts = findall(isreturn, result.ir.stmts.inst) + @assert length(rts) == 2 + @test count(rt->has_return_escape(result.state[SSAValue(i)], rt), rts) == 1 + end + let result = code_escapes((Bool,)) do cond + r = Ref("foo") + cnt = 0 + while rand(Bool) + cnt += 1 + rand(Bool) && return r + end + rand(Bool) && return r + return cnt + end + i = only(findall(isnew, result.ir.stmts.inst)) + rts = findall(isreturn, result.ir.stmts.inst) # return statement + @assert length(rts) == 3 + @test count(rt->has_return_escape(result.state[SSAValue(i)], rt), rts) == 2 + end +end + +@testset "escape through exceptions" begin + M = @eval Module() begin + unsafeget(x) = isassigned(x) ? x[] : throw(x) + @noinline function rethrow_escape!() + try + rethrow() + catch err + Gx[] = err + end + end + @noinline function current_exceptions_escape!() + excs = Base.current_exceptions() + Gx[] = excs + end + const Gx = Ref{Any}() + @__MODULE__ + end + + let # simple: return escape + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + ret = sizeof(s) + catch err + ret = err + end + return ret + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)]) + end + + let # simple: global escape + result = @eval M $code_escapes() do + r = Ref{String}() + local ret # prevent DCE + try + s = unsafeget(r) + ret = sizeof(s) + catch err + global g = err + end + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + + let # account for possible escapes via nested throws + result = @eval M $code_escapes() do + r = Ref{String}() + try + try + unsafeget(r) + catch err1 + throw(err1) + end + catch err2 + Gx[] = err2 + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `rethrow` + result = @eval M $code_escapes() do + r = Ref{String}() + try + try + unsafeget(r) + catch err1 + rethrow(err1) + end + catch err2 + Gx[] = err2 + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `rethrow` + result = @eval M $code_escapes() do + try + r = Ref{String}() + unsafeget(r) + catch + rethrow_escape!() + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `rethrow` + result = @eval M $code_escapes() do + local t + try + r = Ref{String}() + t = unsafeget(r) + catch err + t = typeof(err) + rethrow_escape!() + end + return t + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `Base.current_exceptions` + result = @eval M $code_escapes() do + try + r = Ref{String}() + unsafeget(r) + catch + Gx[] = Base.current_exceptions() + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + let # account for possible escapes via `Base.current_exceptions` + result = @eval M $code_escapes() do + try + r = Ref{String}() + unsafeget(r) + catch + current_exceptions_escape!() + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + end + + let # contextual: escape information imposed on `err` shouldn't propagate to `r2`, but only to `r1` + result = @eval M $code_escapes() do + r1 = Ref{String}() + r2 = Ref{String}() + local ret + try + s1 = unsafeget(r1) + ret = sizeof(s1) + catch err + global g = err + end + s2 = unsafeget(r2) + return s2, r2 + end + is = findall(isnew, result.ir.stmts.inst) + @test length(is) == 2 + i1, i2 = is + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i1)]) + @test !has_all_escape(result.state[SSAValue(i2)]) + @test has_return_escape(result.state[SSAValue(i2)], r) + end + + # XXX test cases below are currently broken because of the technical reason described in `escape_exception!` + + let # limited propagation: exception is caught within a frame => doesn't escape to a caller + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + ret = sizeof(s) + catch + ret = nothing + end + return ret + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)], r) + end + let # sequential: escape information imposed on `err1` and `err2 should propagate separately + result = @eval M $code_escapes() do + r1 = Ref{String}() + r2 = Ref{String}() + local ret + try + s1 = unsafeget(r1) + ret = sizeof(s1) + catch err1 + global g = err1 + end + try + s2 = unsafeget(r2) + ret = sizeof(s2) + catch err2 + ret = err2 + end + return ret + end + is = findall(isnew, result.ir.stmts.inst) + @test length(is) == 2 + i1, i2 = is + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i1)]) + @test has_return_escape(result.state[SSAValue(i2)], r) + @test_broken !has_all_escape(result.state[SSAValue(i2)]) + end + let # nested: escape information imposed on `inner` shouldn't propagate to `s` + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + try + ret = sizeof(s) + catch inner + return inner + end + catch outer + ret = nothing + end + return ret + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)]) + end + let # merge: escape information imposed on `err1` and `err2 should be merged + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + ret = sizeof(s) + catch err1 + return err1 + end + try + s = unsafeget(r) + ret = sizeof(s) + catch err2 + return err2 + end + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + rs = findall(isreturn, result.ir.stmts.inst) + @test_broken !has_all_escape(result.state[SSAValue(i)]) + for r in rs + @test has_return_escape(result.state[SSAValue(i)], r) + end + end + let # no exception handling: should keep propagating the escape + result = @eval M $code_escapes() do + r = Ref{String}() + local ret + try + s = unsafeget(r) + ret = sizeof(s) + finally + if !@isdefined(ret) + ret = 42 + end + end + return ret + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[SSAValue(i)], r) + end +end + +@testset "field analysis / alias analysis" begin + # escaped allocations + # ------------------- + + # escaped object should escape its fields as well + let result = code_escapes((Any,)) do a + global g = SafeRef{Any}(a) + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + @test has_all_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do a + global g = (a,) + nothing + end + i = only(findall(issubT(Tuple), result.ir.stmts.type)) + @test has_all_escape(result.state[SSAValue(i)]) + @test has_all_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do a + o0 = SafeRef{Any}(a) + global g = SafeRef(o0) + nothing + end + is = findall(isnew, result.ir.stmts.inst) + @test length(is) == 2 + i0, i1 = is + @test has_all_escape(result.state[SSAValue(i0)]) + @test has_all_escape(result.state[SSAValue(i1)]) + @test has_all_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do a + t0 = (a,) + global g = (t0,) + nothing + end + inds = findall(issubT(Tuple), result.ir.stmts.type) + @assert length(inds) == 2 + for i in inds; @test has_all_escape(result.state[SSAValue(i)]); end + @test has_all_escape(result.state[Argument(2)]) + end + # global escape through `setfield!` + let result = code_escapes((Any,)) do a + r = SafeRef{Any}(:init) + global g = r + r[] = a + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + @test has_all_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,Any)) do a, b + r = SafeRef{Any}(a) + global g = r + r[] = b + nothing + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test has_all_escape(result.state[SSAValue(i)]) + @test has_all_escape(result.state[Argument(2)]) # a + @test has_all_escape(result.state[Argument(3)]) # b + end + let result = @eval EATModule() begin + const Rx = SafeRef{String}("Rx") + $code_escapes((String,)) do s + Rx[] = s + Core.sizeof(Rx[]) + end + end + @test has_all_escape(result.state[Argument(2)]) + end + let result = @eval EATModule() begin + const Rx = SafeRef{String}("Rx") + $code_escapes((String,)) do s + setfield!(Rx, :x, s) + Core.sizeof(Rx[]) + end + end + @test has_all_escape(result.state[Argument(2)]) + end + let M = EATModule() + @eval M module ___xxx___ + import ..SafeRef + const Rx = SafeRef("Rx") + end + result = @eval M begin + $code_escapes((String,)) do s + rx = getfield(___xxx___, :Rx) + rx[] = s + nothing + end + end + @test has_all_escape(result.state[Argument(2)]) + end + + # field escape + # ------------ + + # field escape should propagate to :new arguments + let result = code_escapes((String,)) do a + o = SafeRef(a) + f = o[] + return f + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + let result = code_escapes((String,)) do a + t = (a,) + f = t[1] + return f + end + i = only(findall(iscall((result.ir, tuple)), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + let result = code_escapes((String, String)) do a, b + obj = SafeRefs(a, b) + fld1 = obj[1] + fld2 = obj[2] + return (fld1, fld2) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test is_load_forwardable(result.state[SSAValue(i)]) + end + + # field escape should propagate to `setfield!` argument + let result = code_escapes((String,)) do a + o = SafeRef("foo") + o[] = a + f = o[] + return f + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + # propagate escape information imposed on return value of `setfield!` call + let result = code_escapes((String,)) do a + obj = SafeRef("foo") + return (obj[] = a) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + + # nested allocations + let result = code_escapes((String,)) do a + o1 = SafeRef(a) + o2 = SafeRef(o1) + return o2[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + for i in 1:length(result.ir.stmts) + if isnew(result.ir.stmts.inst[i]) && isT(SafeRef{String})(result.ir.stmts.type[i]) + @test has_return_escape(result.state[SSAValue(i)], r) + elseif isnew(result.ir.stmts.inst[i]) && isT(SafeRef{SafeRef{String}})(result.ir.stmts.type[i]) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + end + let result = code_escapes((String,)) do a + o1 = (a,) + o2 = (o1,) + return o2[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + for i in 1:length(result.ir.stmts) + if isnew(result.ir.stmts.inst[i]) && isT(Tuple{String})(result.ir.stmts.type[i]) + @test has_return_escape(result.state[SSAValue(i)], r) + elseif isnew(result.ir.stmts.inst[i]) && isT(Tuple{Tuple{String}})(result.ir.stmts.type[i]) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + end + let result = code_escapes((String,)) do a + o1 = SafeRef(a) + o2 = SafeRef(o1) + o1′ = o2[] + a′ = o1′[] + return a′ + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + let result = code_escapes() do + o1 = SafeRef("foo") + o2 = SafeRef(o1) + return o2 + end + r = only(findall(isreturn, result.ir.stmts.inst)) + for i in findall(isnew, result.ir.stmts.inst) + @test has_return_escape(result.state[SSAValue(i)], r) + end + end + let result = code_escapes() do + o1 = SafeRef("foo") + o2′ = SafeRef(nothing) + o2 = SafeRef{SafeRef}(o2′) + o2[] = o1 + return o2 + end + r = only(findall(isreturn, result.ir.stmts.inst)) + findall(1:length(result.ir.stmts)) do i + if isnew(result.ir.stmts[i][:inst]) + t = result.ir.stmts[i][:type] + return t === SafeRef{String} || # o1 + t === SafeRef{SafeRef} # o2 + end + return false + end |> x->foreach(x) do i + @test has_return_escape(result.state[SSAValue(i)], r) + end + end + let result = code_escapes((String,)) do x + broadcast(identity, Ref(x)) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + + # ϕ-node allocations + let result = code_escapes((Bool,Any,Any)) do cond, x, y + if cond + ϕ = SafeRef{Any}(x) + else + ϕ = SafeRef{Any}(y) + end + return ϕ[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # x + @test has_return_escape(result.state[Argument(4)], r) # y + i = only(findall(isϕ, result.ir.stmts.inst)) + @test is_load_forwardable(result.state[SSAValue(i)]) + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + let result = code_escapes((Bool,Any,Any)) do cond, x, y + if cond + ϕ2 = ϕ1 = SafeRef{Any}(x) + else + ϕ2 = ϕ1 = SafeRef{Any}(y) + end + return ϕ1[], ϕ2[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # x + @test has_return_escape(result.state[Argument(4)], r) # y + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + # when ϕ-node merges values with different types + let result = code_escapes((Bool,String,String,String)) do cond, x, y, z + local out + if cond + ϕ = SafeRef(x) + out = ϕ[] + else + ϕ = SafeRefs(z, y) + end + return @isdefined(out) ? out : throw(ϕ) + end + r = only(findall(isreturn, result.ir.stmts.inst)) + t = only(findall(iscall((result.ir, throw)), result.ir.stmts.inst)) + ϕ = only(findall(isT(Union{SafeRef{String},SafeRefs{String,String}}), result.ir.stmts.type)) + @test has_return_escape(result.state[Argument(3)], r) # x + @test !has_return_escape(result.state[Argument(4)], r) # y + @test has_return_escape(result.state[Argument(5)], r) # z + @test has_thrown_escape(result.state[SSAValue(ϕ)], t) + end + + # alias analysis + # -------------- + + # alias via getfield & Expr(:new) + let result = code_escapes((String,)) do s + r = SafeRef(s) + return r[] + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test isaliased(Argument(2), val, result.state) + @test !isaliased(Argument(2), SSAValue(i), result.state) + end + let result = code_escapes((String,)) do s + r1 = SafeRef(s) + r2 = SafeRef(r1) + return r2[] + end + i1, i2 = findall(isnew, result.ir.stmts.inst) + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test !isaliased(SSAValue(i1), SSAValue(i2), result.state) + @test isaliased(SSAValue(i1), val, result.state) + @test !isaliased(SSAValue(i2), val, result.state) + end + let result = code_escapes((String,)) do s + r1 = SafeRef(s) + r2 = SafeRef(r1) + return r2[][] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test isaliased(Argument(2), val, result.state) + for i in findall(isnew, result.ir.stmts.inst) + @test !isaliased(SSAValue(i), val, result.state) + end + end + let result = @eval EATModule() begin + const Rx = SafeRef("Rx") + $code_escapes((String,)) do s + r = SafeRef(Rx) + rx = r[] # rx aliased to Rx + rx[] = s + nothing + end + end + i = findfirst(isnew, result.ir.stmts.inst) + @test has_all_escape(result.state[Argument(2)]) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + # alias via getfield & setfield! + let result = code_escapes((String,)) do s + r = Ref{String}() + r[] = s + return r[] + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test isaliased(Argument(2), val, result.state) + @test !isaliased(Argument(2), SSAValue(i), result.state) + end + let result = code_escapes((String,)) do s + r1 = Ref(s) + r2 = Ref{Base.RefValue{String}}() + r2[] = r1 + return r2[] + end + i1, i2 = findall(isnew, result.ir.stmts.inst) + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test !isaliased(SSAValue(i1), SSAValue(i2), result.state) + @test isaliased(SSAValue(i1), val, result.state) + @test !isaliased(SSAValue(i2), val, result.state) + end + let result = code_escapes((String,)) do s + r1 = Ref{String}() + r2 = Ref{Base.RefValue{String}}() + r2[] = r1 + r1[] = s + return r2[][] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test isaliased(Argument(2), val, result.state) + for i in findall(isnew, result.ir.stmts.inst) + @test !isaliased(SSAValue(i), val, result.state) + end + result = code_escapes((String,)) do s + r1 = Ref{String}() + r2 = Ref{Base.RefValue{String}}() + r1[] = s + r2[] = r1 + return r2[][] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test isaliased(Argument(2), val, result.state) + for i in findall(isnew, result.ir.stmts.inst) + @test !isaliased(SSAValue(i), val, result.state) + end + end + let result = @eval EATModule() begin + const Rx = SafeRef("Rx") + $code_escapes((SafeRef{String}, String,)) do _rx, s + r = SafeRef(_rx) + r[] = Rx + rx = r[] # rx aliased to Rx + rx[] = s + nothing + end + end + i = findfirst(isnew, result.ir.stmts.inst) + @test has_all_escape(result.state[Argument(3)]) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + # alias via typeassert + let result = code_escapes((Any,)) do a + r = a::String + return r + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test has_return_escape(result.state[Argument(2)], r) # a + @test isaliased(Argument(2), val, result.state) # a <-> r + end + let result = code_escapes((Any,)) do a + global g + (g::SafeRef{Any})[] = a + nothing + end + @test has_all_escape(result.state[Argument(2)]) + end + # alias via ifelse + let result = code_escapes((Bool,Any,Any)) do c, a, b + r = ifelse(c, a, b) + return r + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test has_return_escape(result.state[Argument(3)], r) # a + @test has_return_escape(result.state[Argument(4)], r) # b + @test !isaliased(Argument(2), val, result.state) # c r + @test isaliased(Argument(3), val, result.state) # a <-> r + @test isaliased(Argument(4), val, result.state) # b <-> r + end + let result = @eval EATModule() begin + const Lx, Rx = SafeRef("Lx"), SafeRef("Rx") + $code_escapes((Bool,String,)) do c, a + r = ifelse(c, Lx, Rx) + r[] = a + nothing + end + end + @test has_all_escape(result.state[Argument(3)]) # a + end + # alias via ϕ-node + let result = code_escapes((Bool,String)) do cond, x + if cond + ϕ2 = ϕ1 = SafeRef("foo") + else + ϕ2 = ϕ1 = SafeRef("bar") + end + ϕ2[] = x + return ϕ1[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test has_return_escape(result.state[Argument(3)], r) # x + @test isaliased(Argument(3), val, result.state) # x + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + let result = code_escapes((Bool,Bool,String)) do cond1, cond2, x + if cond1 + ϕ2 = ϕ1 = SafeRef("foo") + else + ϕ2 = ϕ1 = SafeRef("bar") + end + cond2 && (ϕ2[] = x) + return ϕ1[] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + val = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test has_return_escape(result.state[Argument(4)], r) # x + @test isaliased(Argument(4), val, result.state) # x + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + end + # alias via π-node + let result = code_escapes((Any,)) do x + if isa(x, String) + return x + end + throw("error!") + end + r = only(findall(isreturn, result.ir.stmts.inst)) + rval = (result.ir.stmts.inst[r]::ReturnNode).val::SSAValue + @test has_return_escape(result.state[Argument(2)], r) # x + @test isaliased(Argument(2), rval, result.state) + end + let result = code_escapes((String,)) do x + global g + l = g + if isa(l, SafeRef{String}) + l[] = x + end + nothing + end + @test has_all_escape(result.state[Argument(2)]) # x + end + + # dynamic semantics + # ----------------- + + # conservatively handle untyped objects + let result = @eval code_escapes((Any,Any,)) do T, x + obj = $(Expr(:new, :T, :x)) + end + t = only(findall(isnew, result.ir.stmts.inst)) + @test #=T=# has_thrown_escape(result.state[Argument(2)], t) # T + @test #=x=# has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = @eval code_escapes((Any,Any,Any,Any)) do T, x, y, z + obj = $(Expr(:new, :T, :x, :y)) + return getfield(obj, :x) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test #=x=# has_return_escape(result.state[Argument(3)], r) + @test #=y=# has_return_escape(result.state[Argument(4)], r) + @test #=z=# !has_return_escape(result.state[Argument(5)], r) + end + let result = @eval code_escapes((Any,Any,Any,Any)) do T, x, y, z + obj = $(Expr(:new, :T, :x)) + setfield!(obj, :x, y) + return getfield(obj, :x) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test #=x=# has_return_escape(result.state[Argument(3)], r) + @test #=y=# has_return_escape(result.state[Argument(4)], r) + @test #=z=# !has_return_escape(result.state[Argument(5)], r) + end + + # conservatively handle unknown field: + # all fields should be escaped, but the allocation itself doesn't need to be escaped + let result = code_escapes((String, Symbol)) do a, fld + obj = SafeRef(a) + return getfield(obj, fld) + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, String, Symbol)) do a, b, fld + obj = SafeRefs(a, b) + return getfield(obj, fld) # should escape both `a` and `b` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, String, Int)) do a, b, idx + obj = SafeRefs(a, b) + return obj[idx] # should escape both `a` and `b` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, String, Symbol)) do a, b, fld + obj = SafeRefs("a", "b") + setfield!(obj, fld, a) + return obj[2] # should escape `a` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test !has_return_escape(result.state[Argument(3)], r) # b + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, Symbol)) do a, fld + obj = SafeRefs("a", "b") + setfield!(obj, fld, a) + return obj[1] # this should escape `a` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + let result = code_escapes((String, String, Int)) do a, b, idx + obj = SafeRefs("a", "b") + obj[idx] = a + return obj[2] # should escape `a` + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # a + @test !has_return_escape(result.state[Argument(3)], r) # b + @test !is_load_forwardable(result.state[SSAValue(i)]) # obj + end + + # interprocedural + # --------------- + + let result = @eval EATModule() begin + @noinline getx(obj) = obj[] + $code_escapes((String,)) do a + obj = SafeRef(a) + fld = getx(obj) + return fld + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) + # NOTE we can't scalar replace `obj`, but still we may want to stack allocate it + @test_broken is_load_forwardable(result.state[SSAValue(i)]) + end + + # TODO interprocedural field analysis + let result = code_escapes((SafeRef{String},)) do s + s[] = "bar" + global g = s[] + nothing + end + @test_broken !has_all_escape(result.state[Argument(2)]) + end + + # TODO flow-sensitivity? + # ---------------------- + + let result = code_escapes((Any,Any)) do a, b + r = SafeRef{Any}(a) + r[] = b + return r[] + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test is_load_forwardable(result.state[SSAValue(i)]) + end + let result = code_escapes((Any,Any)) do a, b + r = SafeRef{Any}(:init) + r[] = a + r[] = b + return r[] + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_return_escape(result.state[Argument(2)], r) # a + @test has_return_escape(result.state[Argument(3)], r) # b + @test is_load_forwardable(result.state[SSAValue(i)]) + end + let result = code_escapes((Any,Any,Bool)) do a, b, cond + r = SafeRef{Any}(:init) + if cond + r[] = a + return r[] + else + r[] = b + return nothing + end + end + i = only(findall(isnew, result.ir.stmts.inst)) + @test is_load_forwardable(result.state[SSAValue(i)]) + r = only(findall(result.ir.stmts.inst) do @nospecialize x + isreturn(x) && isa(x.val, Core.SSAValue) + end) + @test has_return_escape(result.state[Argument(2)], r) # a + @test_broken !has_return_escape(result.state[Argument(3)], r) # b + end + + # handle conflicting field information correctly + let result = code_escapes((Bool,String,String,)) do cnd, baz, qux + if cnd + o = SafeRef("foo") + else + o = SafeRefs("bar", baz) + r = getfield(o, 2) + end + if cnd + o = o::SafeRef + setfield!(o, 1, qux) + r = getfield(o, 1) + end + r + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # baz + @test has_return_escape(result.state[Argument(4)], r) # qux + for new in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(new)]) + end + end + let result = code_escapes((Bool,String,String,)) do cnd, baz, qux + if cnd + o = SafeRefs("foo", "bar") + r = setfield!(o, 2, baz) + else + o = SafeRef(qux) + end + if !cnd + o = o::SafeRef + r = getfield(o, 1) + end + r + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(3)], r) # baz + @test has_return_escape(result.state[Argument(4)], r) # qux + end + + # foreigncall should disable field analysis + let result = code_escapes((Any,Nothing,Int,UInt)) do t, mt, lim, world + ambig = false + min = Ref{UInt}(typemin(UInt)) + max = Ref{UInt}(typemax(UInt)) + has_ambig = Ref{Int32}(0) + mt = ccall(:jl_matching_methods, Any, + (Any, Any, Cint, Cint, UInt, Ptr{UInt}, Ptr{UInt}, Ref{Int32}), + t, mt, lim, ambig, world, min, max, has_ambig)::Union{Array{Any,1}, Bool} + return mt, has_ambig[] + end + for i in findall(isnew, result.ir.stmts.inst) + @test !is_load_forwardable(result.state[SSAValue(i)]) + end + end +end + +# demonstrate the power of our field / alias analysis with a realistic end to end example +abstract type AbstractPoint{T} end +mutable struct MPoint{T} <: AbstractPoint{T} + x::T + y::T +end +add(a::P, b::P) where P<:AbstractPoint = P(a.x + b.x, a.y + b.y) +function compute(T, ax, ay, bx, by) + a = T(ax, ay) + b = T(bx, by) + for i in 0:(100000000-1) + a = add(add(a, b), b) + end + a.x, a.y +end +function compute(a, b) + for i in 0:(100000000-1) + a = add(add(a, b), b) # unreplaceable, since it can be the call argument + end + a.x, a.y +end +function compute!(a, b) + for i in 0:(100000000-1) + a′ = add(add(a, b), b) + a.x = a′.x + a.y = a′.y + end +end +let result = @code_escapes compute(MPoint, 1+.5im, 2+.5im, 2+.25im, 4+.75im) + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end +end +let result = @code_escapes compute(MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) + for i in findall(1:length(result.ir.stmts)) do i + isϕ(result.ir.stmts[i][:inst]) && isT(MPoint{ComplexF64})(result.ir.stmts[i][:type]) + end + @test !is_load_forwardable(result.state[SSAValue(i)]) + end +end +let result = @code_escapes compute!(MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) + for i in findall(isnew, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end + for i in findall(isϕ, result.ir.stmts.inst) + @test is_load_forwardable(result.state[SSAValue(i)]) + end +end + +@testset "array primitives" begin + inbounds = Base.JLOptions().check_bounds == 0 + + # arrayref + let result = code_escapes((Vector{String},Int)) do xs, i + s = Base.arrayref(true, xs, i) + return s + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test has_thrown_escape(result.state[Argument(2)]) # xs + @test !has_return_escape(result.state[Argument(3)], r) # i + end + let result = code_escapes((Vector{String},Int)) do xs, i + s = Base.arrayref(false, xs, i) + return s + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test !has_thrown_escape(result.state[Argument(2)]) # xs + @test !has_return_escape(result.state[Argument(3)], r) # i + end + inbounds && let result = code_escapes((Vector{String},Int)) do xs, i + s = @inbounds xs[i] + return s + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test !has_thrown_escape(result.state[Argument(2)]) # xs + @test !has_return_escape(result.state[Argument(3)], r) # i + end + let result = code_escapes((Vector{String},Bool)) do xs, i + c = Base.arrayref(true, xs, i) # TypeError will happen here + return c + end + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((String,Int)) do xs, i + c = Base.arrayref(true, xs, i) # TypeError will happen here + return c + end + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((AbstractVector{String},Int)) do xs, i + c = Base.arrayref(true, xs, i) # TypeError may happen here + return c + end + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((Vector{String},Any)) do xs, i + c = Base.arrayref(true, xs, i) # TypeError may happen here + return c + end + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + + # arrayset + let result = code_escapes((Vector{String},String,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) + return xs + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test has_thrown_escape(result.state[Argument(2)]) # xs + @test has_return_escape(result.state[Argument(3)], r) # x + end + let result = code_escapes((Vector{String},String,Int,)) do xs, x, i + Base.arrayset(false, xs, x, i) + return xs + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test !has_thrown_escape(result.state[Argument(2)]) # xs + @test has_return_escape(result.state[Argument(3)], r) # x + end + inbounds && let result = code_escapes((Vector{String},String,Int,)) do xs, x, i + @inbounds xs[i] = x + return xs + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[Argument(2)], r) # xs + @test !has_thrown_escape(result.state[Argument(2)]) # xs + @test has_return_escape(result.state[Argument(3)], r) # x + end + let result = code_escapes((String,String,String,)) do s, t, u + xs = Vector{String}(undef, 3) + Base.arrayset(true, xs, s, 1) + Base.arrayset(true, xs, t, 2) + Base.arrayset(true, xs, u, 3) + return xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + for i in 2:result.state.nargs + @test has_return_escape(result.state[Argument(i)], r) + end + end + let result = code_escapes((Vector{String},String,Bool,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError will happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + @test has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = code_escapes((String,String,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError will happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs::String + @test has_thrown_escape(result.state[Argument(3)], t) # x::String + end + let result = code_escapes((AbstractVector{String},String,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError may happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + @test has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = code_escapes((Vector{String},AbstractString,Int,)) do xs, x, i + Base.arrayset(true, xs, x, i) # TypeError may happen here + return xs + end + t = only(findall(iscall((result.ir, Base.arrayset)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + @test has_thrown_escape(result.state[Argument(3)], t) # x + end + + # arrayref and arrayset + let result = code_escapes() do + a = Vector{Vector{Any}}(undef, 1) + b = Any[] + a[1] = b + return a[1] + end + r = only(findall(isreturn, result.ir.stmts.inst)) + ai = only(findall(result.ir.stmts.inst) do @nospecialize x + isarrayalloc(x) && x.args[2] === Vector{Vector{Any}} + end) + bi = only(findall(result.ir.stmts.inst) do @nospecialize x + isarrayalloc(x) && x.args[2] === Vector{Any} + end) + @test !has_return_escape(result.state[SSAValue(ai)], r) + @test has_return_escape(result.state[SSAValue(bi)], r) + end + let result = code_escapes() do + a = Vector{Vector{Any}}(undef, 1) + b = Any[] + a[1] = b + return a + end + r = only(findall(isreturn, result.ir.stmts.inst)) + ai = only(findall(result.ir.stmts.inst) do @nospecialize x + isarrayalloc(x) && x.args[2] === Vector{Vector{Any}} + end) + bi = only(findall(result.ir.stmts.inst) do @nospecialize x + isarrayalloc(x) && x.args[2] === Vector{Any} + end) + @test has_return_escape(result.state[SSAValue(ai)], r) + @test has_return_escape(result.state[SSAValue(bi)], r) + end + let result = code_escapes((Vector{Any},String,Int,Int)) do xs, s, i, j + x = SafeRef(s) + xs[i] = x + xs[j] # potential error + end + i = only(findall(isnew, result.ir.stmts.inst)) + t = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(3)], t) # s + @test has_thrown_escape(result.state[SSAValue(i)], t) # x + end + + # arraysize + let result = code_escapes((Vector{Any},)) do xs + Core.arraysize(xs, 1) + end + t = only(findall(iscall((result.ir, Core.arraysize)), result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) + end + let result = code_escapes((Vector{Any},Int,)) do xs, dim + Core.arraysize(xs, dim) + end + t = only(findall(iscall((result.ir, Core.arraysize)), result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) + end + let result = code_escapes((Any,)) do xs + Core.arraysize(xs, 1) + end + t = only(findall(iscall((result.ir, Core.arraysize)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) + end + + # arraylen + let result = code_escapes((Vector{Any},)) do xs + Base.arraylen(xs) + end + t = only(findall(iscall((result.ir, Base.arraylen)), result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((String,)) do xs + Base.arraylen(xs) + end + t = only(findall(iscall((result.ir, Base.arraylen)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + let result = code_escapes((Vector{Any},)) do xs + Base.arraylen(xs, 1) + end + t = only(findall(iscall((result.ir, Base.arraylen)), result.ir.stmts.inst)) + @test has_thrown_escape(result.state[Argument(2)], t) # xs + end + + # array resizing + # without BoundsErrors + let result = code_escapes((Vector{Any},String)) do xs, x + @ccall jl_array_grow_beg(xs::Any, 2::UInt)::Cvoid + xs[1] = x + xs + end + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) # xs + @test !has_thrown_escape(result.state[Argument(3)], t) # x + end + let result = code_escapes((Vector{Any},String)) do xs, x + @ccall jl_array_grow_end(xs::Any, 2::UInt)::Cvoid + xs[1] = x + xs + end + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test !has_thrown_escape(result.state[Argument(2)], t) # xs + @test !has_thrown_escape(result.state[Argument(3)], t) # x + end + # with possible BoundsErrors + let result = code_escapes((String,)) do x + xs = Any[1,2,3] + xs[3] = x + @ccall jl_array_del_beg(xs::Any, 2::UInt)::Cvoid # can potentially throw + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + let result = code_escapes((String,)) do x + xs = Any[1,2,3] + xs[1] = x + @ccall jl_array_del_end(xs::Any, 2::UInt)::Cvoid # can potentially throw + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + let result = code_escapes((String,)) do x + xs = Any[x] + @ccall jl_array_grow_at(xs::Any, 1::UInt, 2::UInt)::Cvoid # can potentially throw + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + let result = code_escapes((String,)) do x + xs = Any[x] + @ccall jl_array_del_at(xs::Any, 1::UInt, 2::UInt)::Cvoid # can potentially throw + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + inbounds && let result = code_escapes((String,)) do x + xs = @inbounds Any[x] + @ccall jl_array_del_at(xs::Any, 1::UInt, 2::UInt)::Cvoid # can potentially throw + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + t = only(findall(isarrayresize, result.ir.stmts.inst)) + @test has_thrown_escape(result.state[SSAValue(i)], t) # xs + @test has_thrown_escape(result.state[Argument(2)], t) # x + end + + # array copy + let result = code_escapes((Vector{Any},)) do xs + return copy(xs) + end + i = only(findall(isarraycopy, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + @test_broken !has_return_escape(result.state[Argument(2)], r) + end + let result = code_escapes((String,)) do s + xs = String[s] + xs′ = copy(xs) + return xs′[1] + end + i1 = only(findall(isarrayalloc, result.ir.stmts.inst)) + i2 = only(findall(isarraycopy, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i1)]) + @test !has_return_escape(result.state[SSAValue(i2)]) + @test has_return_escape(result.state[Argument(2)], r) # s + end + let result = code_escapes((Vector{Any},)) do xs + xs′ = copy(xs) + return xs′[1] # may potentially throw BoundsError, should escape `xs` conservatively (i.e. escape its elements) + end + i = only(findall(isarraycopy, result.ir.stmts.inst)) + ref = only(findall(iscall((result.ir, Base.arrayref)), result.ir.stmts.inst)) + ret = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_thrown_escape(result.state[SSAValue(i)], ref) + @test_broken !has_return_escape(result.state[SSAValue(i)], ret) + @test has_thrown_escape(result.state[Argument(2)], ref) + @test has_return_escape(result.state[Argument(2)], ret) + end + let result = code_escapes((String,)) do s + xs = Vector{String}(undef, 1) + xs[1] = s + xs′ = copy(xs) + length(xs′) > 2 && throw(xs′) + return xs′ + end + i1 = only(findall(isarrayalloc, result.ir.stmts.inst)) + i2 = only(findall(isarraycopy, result.ir.stmts.inst)) + t = only(findall(iscall((result.ir, throw)), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test_broken !has_thrown_escape(result.state[SSAValue(i1)], t) + @test_broken !has_return_escape(result.state[SSAValue(i1)], r) + @test has_thrown_escape(result.state[SSAValue(i2)], t) + @test has_return_escape(result.state[SSAValue(i2)], r) + @test has_thrown_escape(result.state[Argument(2)], t) + @test has_return_escape(result.state[Argument(2)], r) + end + + # isassigned + let result = code_escapes((Vector{Any},Int)) do xs, i + return isassigned(xs, i) + end + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[Argument(2)], r) + @test !has_thrown_escape(result.state[Argument(2)]) + end +end + +# demonstrate array primitive support with a realistic end to end example +let result = code_escapes((Int,String,)) do n,s + xs = String[] + for i in 1:n + push!(xs, s) + end + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + Base.JLOptions().check_bounds ≠ 0 && @test has_thrown_escape(result.state[SSAValue(i)]) + @test has_return_escape(result.state[Argument(3)], r) # s + Base.JLOptions().check_bounds ≠ 0 && @test has_thrown_escape(result.state[Argument(3)]) # s +end +let result = code_escapes((Int,String,)) do n,s + xs = String[] + for i in 1:n + pushfirst!(xs, s) + end + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) # xs + @test has_thrown_escape(result.state[SSAValue(i)]) # xs + @test has_return_escape(result.state[Argument(3)], r) # s + @test has_thrown_escape(result.state[Argument(3)]) # s +end +let result = code_escapes((String,String,String)) do s, t, u + xs = String[] + resize!(xs, 3) + xs[1] = s + xs[1] = t + xs[1] = u + xs + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test has_return_escape(result.state[SSAValue(i)], r) + @test has_thrown_escape(result.state[SSAValue(i)]) # xs + @test has_return_escape(result.state[Argument(2)], r) # s + @test has_return_escape(result.state[Argument(3)], r) # t + @test has_return_escape(result.state[Argument(4)], r) # u +end + +@static if isdefined(Core, :ImmutableArray) + +import Core: ImmutableArray, arrayfreeze, mutating_arrayfreeze, arraythaw + +@testset "ImmutableArray" begin + # arrayfreeze + let result = code_escapes((Vector{Any},)) do xs + arrayfreeze(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Vector,)) do xs + arrayfreeze(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do xs + arrayfreeze(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((ImmutableArray{Any,1},)) do xs + arrayfreeze(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes() do + xs = Any[] + arrayfreeze(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(1)]) + end + + # mutating_arrayfreeze + let result = code_escapes((Vector{Any},)) do xs + mutating_arrayfreeze(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Vector,)) do xs + mutating_arrayfreeze(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do xs + mutating_arrayfreeze(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((ImmutableArray{Any,1},)) do xs + mutating_arrayfreeze(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes() do + xs = Any[] + mutating_arrayfreeze(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(1)]) + end + + # arraythaw + let result = code_escapes((ImmutableArray{Any,1},)) do xs + arraythaw(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((ImmutableArray,)) do xs + arraythaw(xs) + end + @test !has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Any,)) do xs + arraythaw(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes((Vector{Any},)) do xs + arraythaw(xs) + end + @test has_thrown_escape(result.state[Argument(2)]) + end + let result = code_escapes() do + xs = ImmutableArray(Any[]) + arraythaw(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test has_no_escape(result.state[SSAValue(1)]) + end +end + +# demonstrate some arrayfreeze optimizations +# !has_return_escape(ary) means ary is eligible for arrayfreeze to mutating_arrayfreeze optimization +let result = code_escapes((Int,)) do n + xs = collect(1:n) + ImmutableArray(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)]) +end +let result = code_escapes((Vector{Float64},)) do xs + ys = sin.(xs) + ImmutableArray(ys) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)]) +end +let result = code_escapes((Vector{Pair{Int,String}},)) do xs + n = maximum(first, xs) + ys = Vector{String}(undef, n) + for (i, s) in xs + ys[i] = s + end + ImmutableArray(xs) + end + i = only(findall(isarrayalloc, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)]) +end + +end # @static if isdefined(Core, :ImmutableArray) + +# demonstrate a simple type level analysis can sometimes improve the analysis accuracy +# by compensating the lack of yet unimplemented analyses +@testset "special-casing bitstype" begin + let result = code_escapes((Nothing,)) do a + global bb = a + end + @test !(has_all_escape(result.state[Argument(2)])) + end + + let result = code_escapes((Int,)) do a + o = SafeRef(a) + f = o[] + return f + end + i = only(findall(isnew, result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[SSAValue(i)], r) + end + + # an escaped tuple stmt will not propagate to its Int argument (since `Int` is of bitstype) + let result = code_escapes((Int,Any,)) do a, b + t = tuple(a, b) + return t + end + i = only(findall(iscall((result.ir, tuple)), result.ir.stmts.inst)) + r = only(findall(isreturn, result.ir.stmts.inst)) + @test !has_return_escape(result.state[Argument(2)], r) + @test has_return_escape(result.state[Argument(3)], r) + end +end + +# # TODO implement a finalizer elision pass +# mutable struct WithFinalizer +# v +# function WithFinalizer(v) +# x = new(v) +# f(t) = @async println("Finalizing $t.") +# return finalizer(x, x) +# end +# end +# make_m(v = 10) = MyMutable(v) +# function simple(cond) +# m = make_m() +# if cond +# # println(m.v) +# return nothing # <= insert `finalize` call here +# end +# return m +# end + +@static @isdefined(EA_AS_PKG) && @testset "code quality" begin + using JET + + # assert that our main routine are free from (unnecessary) runtime dispatches + + function function_filter(@nospecialize(ft)) + ft === typeof(Core.Compiler.widenconst) && return false # `widenconst` is very untyped, ignore + ft === typeof(EscapeAnalysis.escape_builtin!) && return false # `escape_builtin!` is very untyped, ignore + return true + end + target_modules = (EscapeAnalysis,) + test_opt(only(methods(EscapeAnalysis.analyze_escapes)).sig; + function_filter, + target_modules, + # skip_nonconcrete_calls=false, + ) + for m in methods(EscapeAnalysis.escape_builtin!) + Base._methods_by_ftype(m.sig, 1, Base.get_world_counter()) === false && continue + test_opt(m.sig; + function_filter, + target_modules, + # skip_nonconcrete_calls=false, + ) + end +end diff --git a/test/compiler/EscapeAnalysis/setup.jl b/test/compiler/EscapeAnalysis/setup.jl new file mode 100644 index 00000000000000..4ae10529c41509 --- /dev/null +++ b/test/compiler/EscapeAnalysis/setup.jl @@ -0,0 +1,89 @@ +using Test +if @isdefined(EA_AS_PKG) + import EscapeAnalysis: code_escapes, @code_escapes + using EscapeAnalysis +else + using Core.Compiler.EscapeAnalysis + import Base: code_escapes + import InteractiveUtils: @code_escapes +end +import Core: Argument, SSAValue, ReturnNode + +@static if isdefined(Core.Compiler, :alloc_array_ndims) + import Core.Compiler: alloc_array_ndims +else + function alloc_array_ndims(name::Symbol) + if name === :jl_alloc_array_1d + return 1 + elseif name === :jl_alloc_array_2d + return 2 + elseif name === :jl_alloc_array_3d + return 3 + elseif name === :jl_new_array + return 0 + end + return nothing + end +end + +isT(T) = (@nospecialize x) -> x === T +issubT(T) = (@nospecialize x) -> x <: T +isreturn(@nospecialize x) = isa(x, Core.ReturnNode) && isdefined(x, :val) +isthrow(@nospecialize x) = Meta.isexpr(x, :call) && Core.Compiler.is_throw_call(x) +isnew(@nospecialize x) = Meta.isexpr(x, :new) +isϕ(@nospecialize x) = isa(x, Core.PhiNode) +function with_normalized_name(@nospecialize(f), @nospecialize(x)) + if Meta.isexpr(x, :foreigncall) + name = x.args[1] + nn = EscapeAnalysis.normalize(name) + return isa(nn, Symbol) && f(nn) + end + return false +end +isarrayalloc(@nospecialize x) = with_normalized_name(nn->!isnothing(alloc_array_ndims(nn)), x) +isarrayresize(@nospecialize x) = with_normalized_name(nn->!isnothing(EscapeAnalysis.array_resize_info(nn)), x) +isarraycopy(@nospecialize x) = with_normalized_name(nn->EscapeAnalysis.is_array_copy(nn), x) +import Core.Compiler: argextype, singleton_type +const EMPTY_SPTYPES = Any[] +iscall(y) = @nospecialize(x) -> iscall(y, x) +function iscall((ir, f), @nospecialize(x)) + return iscall(x) do @nospecialize x + Core.Compiler.singleton_type(Core.Compiler.argextype(x, ir, EMPTY_SPTYPES)) === f + end +end +iscall(pred::Function, @nospecialize(x)) = Meta.isexpr(x, :call) && pred(x.args[1]) + +""" + is_load_forwardable(x::EscapeInfo) -> Bool + +Queries if `x` is elibigle for store-to-load forwarding optimization. +""" +function is_load_forwardable(x::EscapeAnalysis.EscapeInfo) + AliasInfo = x.AliasInfo + AliasInfo === false && return true # allows this query to work for immutables since we don't impose escape on them + # NOTE technically we also need to check `!has_thrown_escape(x)` here as well, + # but we can also do equivalent check during forwarding + return isa(AliasInfo, EscapeAnalysis.Indexable) && !AliasInfo.array +end + +let setup_ex = quote + mutable struct SafeRef{T} + x::T + end + Base.getindex(s::SafeRef) = getfield(s, 1) + Base.setindex!(s::SafeRef, x) = setfield!(s, 1, x) + + mutable struct SafeRefs{S,T} + x1::S + x2::T + end + Base.getindex(s::SafeRefs, idx::Int) = getfield(s, idx) + Base.setindex!(s::SafeRefs, x, idx::Int) = setfield!(s, idx, x) + end + global function EATModule(setup_ex = setup_ex) + M = Module() + Core.eval(M, setup_ex) + return M + end + Core.eval(@__MODULE__, setup_ex) +end diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index f355a42d7b08cf..cbbf4375541e1d 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -407,7 +407,7 @@ let m = Meta.@lower 1 + 1 src.ssaflags = fill(Int32(0), nstmts) ir = Core.Compiler.inflate_ir(src, Any[], Any[Any, Any]) @test Core.Compiler.verify_ir(ir) === nothing - ir = @test_nowarn Core.Compiler.sroa_pass!(ir) + ir = @test_nowarn Core.Compiler.sroa_pass!(ir, 0) @test Core.Compiler.verify_ir(ir) === nothing end