Skip to content

Commit

Permalink
Make Expr(:invoke) target be a CodeInstance, not MethodInstance (#54899)
Browse files Browse the repository at this point in the history
This changes our IR representation to use a CodeInstance directly as
the invoke function target to specify the ABI in its entirety, instead
of just the MethodInstance (specifically for the rettype). That allows
removing the lookup call at that point to decide upon the ABI. It is
based around the idea that eventually we now keep track of these
anyways to form a graph of the inferred edge data, for use later in
validation anyways (instead of attempting to invert the backedges graph
in staticdata_utils.c), so we might as well use the same target type
for the :invoke call representation also now.
  • Loading branch information
vtjnash authored Nov 21, 2024
1 parent 859c25a commit c31710a
Show file tree
Hide file tree
Showing 23 changed files with 151 additions and 144 deletions.
11 changes: 6 additions & 5 deletions Compiler/src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,11 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(fun
if mi === nothing || !const_prop_methodinstance_heuristic(interp, mi, arginfo, sv)
csig = get_compileable_sig(method, sig, match.sparams)
if csig !== nothing && (!seenall || csig !== sig) # corresponds to whether the first look already looked at this, so repeating abstract_call_method is not useful
#println(sig, " changed to ", csig, " for ", method)
sp_ = ccall(:jl_type_intersection_with_env, Any, (Any, Any), csig, method.sig)::SimpleVector
if match.sparams === sp_[2]
mresult = abstract_call_method(interp, method, csig, match.sparams, multiple_matches, StmtInfo(false, false), sv)::Future
isready(mresult) || return false # wait for mresult Future to resolve off the callstack before continuing
end
sparams = sp_[2]::SimpleVector
mresult = abstract_call_method(interp, method, csig, sparams, multiple_matches, StmtInfo(false, false), sv)::Future
isready(mresult) || return false # wait for mresult Future to resolve off the callstack before continuing
end
end
end
Expand Down Expand Up @@ -1365,7 +1365,8 @@ function const_prop_call(interp::AbstractInterpreter,
pop!(callstack)
return nothing
end
inf_result.ci_as_edge = codeinst_as_edge(interp, frame)
existing_edge = result.edge
inf_result.ci_as_edge = codeinst_as_edge(interp, frame, existing_edge)
@assert frame.frameid != 0 && frame.cycleid == frame.frameid
@assert frame.parentid == sv.frameid
@assert inf_result.result !== nothing
Expand Down
5 changes: 4 additions & 1 deletion Compiler/src/ssair/EscapeAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,10 @@ end

# escape statically-resolved call, i.e. `Expr(:invoke, ::MethodInstance, ...)`
function escape_invoke!(astate::AnalysisState, pc::Int, args::Vector{Any})
mi = first(args)::MethodInstance
mi = first(args)
if !(mi isa MethodInstance)
mi = (mi::CodeInstance).def # COMBAK get escape info directly from CI instead?
end
first_idx, last_idx = 2, length(args)
add_liveness_changes!(astate, pc, args, first_idx, last_idx)
# TODO inspect `astate.ir.stmts[pc][:info]` and use const-prop'ed `InferenceResult` if available
Expand Down
57 changes: 34 additions & 23 deletions Compiler/src/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct SomeCase
end

struct InvokeCase
invoke::MethodInstance
invoke::Union{CodeInstance,MethodInstance}
effects::Effects
info::CallInfo
end
Expand Down Expand Up @@ -764,19 +764,20 @@ function rewrite_apply_exprargs!(todo::Vector{Pair{Int,Any}},
return new_argtypes
end

function compileable_specialization(mi::MethodInstance, effects::Effects,
function compileable_specialization(code::Union{MethodInstance,CodeInstance}, effects::Effects,
et::InliningEdgeTracker, @nospecialize(info::CallInfo), state::InliningState)
mi = code isa CodeInstance ? code.def : code
mi_invoke = mi
method, atype, sparams = mi.def::Method, mi.specTypes, mi.sparam_vals
if OptimizationParams(state.interp).compilesig_invokes
new_atype = get_compileable_sig(method, atype, sparams)
new_atype === nothing && return nothing
if atype !== new_atype
sp_ = ccall(:jl_type_intersection_with_env, Any, (Any, Any), new_atype, method.sig)::SimpleVector
if sparams === sp_[2]::SimpleVector
mi_invoke = specialize_method(method, new_atype, sparams)
mi_invoke === nothing && return nothing
end
sparams = sp_[2]::SimpleVector
mi_invoke = specialize_method(method, new_atype, sparams)
mi_invoke === nothing && return nothing
code = mi_invoke
end
else
# If this caller does not want us to optimize calls to use their
Expand All @@ -786,8 +787,15 @@ function compileable_specialization(mi::MethodInstance, effects::Effects,
return nothing
end
end
add_inlining_edge!(et, mi_invoke) # to the dispatch lookup
return InvokeCase(mi_invoke, effects, info)
# prefer using a CodeInstance gotten from the cache, since that is where the invoke target should get compiled to normally
# TODO: can this code be gotten directly from inference sometimes?
code = get(code_cache(state), mi_invoke, nothing)
if !isa(code, CodeInstance)
#println("missing code for ", mi_invoke, " for ", mi)
code = mi_invoke
end
add_inlining_edge!(et, code) # to the code and edges
return InvokeCase(code, effects, info)
end

struct InferredResult
Expand Down Expand Up @@ -844,18 +852,18 @@ function resolve_todo(mi::MethodInstance, result::Union{Nothing,InferenceResult,
src = @atomic :monotonic inferred_result.inferred
effects = decode_effects(inferred_result.ipo_purity_bits)
edge = inferred_result
else # there is no cached source available, bail out
else # there is no cached source available for this, but there might be code for the compilation sig
return compileable_specialization(mi, Effects(), et, info, state)
end

# the duplicated check might have been done already within `analyze_method!`, but still
# we need it here too since we may come here directly using a constant-prop' result
if !OptimizationParams(state.interp).inlining || is_stmt_noinline(flag)
return compileable_specialization(edge.def, effects, et, info, state)
return compileable_specialization(edge, effects, et, info, state)
end

src_inlining_policy(state.interp, src, info, flag) ||
return compileable_specialization(edge.def, effects, et, info, state)
return compileable_specialization(edge, effects, et, info, state)

add_inlining_edge!(et, edge)
if inferred_result isa CodeInstance
Expand Down Expand Up @@ -1423,18 +1431,19 @@ end

function semiconcrete_result_item(result::SemiConcreteResult,
@nospecialize(info::CallInfo), flag::UInt32, state::InliningState)
mi = result.edge.def
code = result.edge
mi = code.def
et = InliningEdgeTracker(state)

if (!OptimizationParams(state.interp).inlining || is_stmt_noinline(flag) ||
# For `NativeInterpreter`, `SemiConcreteResult` may be produced for
# a `@noinline`-declared method when it's marked as `@constprop :aggressive`.
# Suppress the inlining here (unless inlining is requested at the callsite).
(is_declared_noinline(mi.def::Method) && !is_stmt_inline(flag)))
return compileable_specialization(mi, result.effects, et, info, state)
return compileable_specialization(code, result.effects, et, info, state)
end
src_inlining_policy(state.interp, result.ir, info, flag) ||
return compileable_specialization(mi, result.effects, et, info, state)
return compileable_specialization(code, result.effects, et, info, state)

add_inlining_edge!(et, result.edge)
preserve_local_sources = OptimizationParams(state.interp).preserve_local_sources
Expand Down Expand Up @@ -1466,7 +1475,7 @@ may_inline_concrete_result(result::ConcreteResult) =
function concrete_result_item(result::ConcreteResult, @nospecialize(info::CallInfo), state::InliningState)
if !may_inline_concrete_result(result)
et = InliningEdgeTracker(state)
return compileable_specialization(result.edge.def, result.effects, et, info, state)
return compileable_specialization(result.edge, result.effects, et, info, state)
end
@assert result.effects === EFFECTS_TOTAL
return ConstantCase(quoted(result.result), result.edge)
Expand Down Expand Up @@ -1522,11 +1531,7 @@ function handle_modifyop!_call!(ir::IRCode, idx::Int, stmt::Expr, info::ModifyOp
match = info.results[1]::MethodMatch
match.fully_covers || return nothing
edge = info.edges[1]
if edge === nothing
edge = specialize_method(match)
else
edge = edge.def
end
edge === nothing && return nothing
case = compileable_specialization(edge, Effects(), InliningEdgeTracker(state), info, state)
case === nothing && return nothing
stmt.head = :invoke_modify
Expand Down Expand Up @@ -1564,8 +1569,11 @@ function handle_finalizer_call!(ir::IRCode, idx::Int, stmt::Expr, info::Finalize
# `Core.Compiler` data structure into the global cache
item1 = cases[1].item
if isa(item1, InliningTodo)
push!(stmt.args, true)
push!(stmt.args, item1.mi)
code = get(code_cache(state), item1.mi, nothing) # COMBAK: this seems like a bad design, can we use stmt_info instead to store the correct info?
if code isa CodeInstance
push!(stmt.args, true)
push!(stmt.args, code)
end
elseif isa(item1, InvokeCase)
push!(stmt.args, false)
push!(stmt.args, item1.invoke)
Expand All @@ -1578,7 +1586,10 @@ end

function handle_invoke_expr!(todo::Vector{Pair{Int,Any}}, ir::IRCode,
idx::Int, stmt::Expr, @nospecialize(info::CallInfo), flag::UInt32, sig::Signature, state::InliningState)
mi = stmt.args[1]::MethodInstance
mi = stmt.args[1]
if !(mi isa MethodInstance)
mi = (mi::CodeInstance).def
end
case = resolve_todo(mi, info, flag, state)
handle_single_case!(todo, ir, idx, stmt, case, false)
return nothing
Expand Down
16 changes: 10 additions & 6 deletions Compiler/src/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,15 @@ end

function abstract_eval_invoke_inst(interp::AbstractInterpreter, inst::Instruction, irsv::IRInterpretationState)
stmt = inst[:stmt]
mi = stmt.args[1]::MethodInstance
world = frame_world(irsv)
mi_cache = WorldView(code_cache(interp), world)
code = get(mi_cache, mi, nothing)
code === nothing && return Pair{Any,Tuple{Bool,Bool}}(nothing, (false, false))
ci = stmt.args[1]
if ci isa MethodInstance
world = frame_world(irsv)
mi_cache = WorldView(code_cache(interp), world)
code = get(mi_cache, ci, nothing)
code === nothing && return Pair{Any,Tuple{Bool,Bool}}(nothing, (false, false))
else
code = ci::CodeInstance
end
argtypes = collect_argtypes(interp, stmt.args[2:end], StatementState(nothing, false), irsv)
argtypes === nothing && return Pair{Any,Tuple{Bool,Bool}}(Bottom, (false, false))
return concrete_eval_invoke(interp, code, argtypes, irsv)
Expand Down Expand Up @@ -160,7 +164,7 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction,
result isa Future && (result = result[])
(; rt, effects) = result
add_flag!(inst, flags_for_effects(effects))
elseif head === :invoke
elseif head === :invoke # COMBAK: || head === :invoke_modifyfield (similar to call, but for args[2:end])
rt, (nothrow, noub) = abstract_eval_invoke_inst(interp, inst, irsv)
if nothrow
add_flag!(inst, IR_FLAG_NOTHROW)
Expand Down
12 changes: 6 additions & 6 deletions Compiler/src/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1302,7 +1302,7 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
# at the end of the intrinsic. Detect that here.
if length(stmt.args) == 4 && stmt.args[4] === nothing
# constant case
elseif length(stmt.args) == 5 && stmt.args[4] isa Bool && stmt.args[5] isa MethodInstance
elseif length(stmt.args) == 5 && stmt.args[4] isa Bool && stmt.args[5] isa Core.CodeInstance
# inlining case
else
continue
Expand Down Expand Up @@ -1522,9 +1522,9 @@ end
# NOTE we resolve the inlining source here as we don't want to serialize `Core.Compiler`
# data structure into the global cache (see the comment in `handle_finalizer_call!`)
function try_inline_finalizer!(ir::IRCode, argexprs::Vector{Any}, idx::Int,
mi::MethodInstance, @nospecialize(info::CallInfo), inlining::InliningState,
code::CodeInstance, @nospecialize(info::CallInfo), inlining::InliningState,
attach_after::Bool)
code = get(code_cache(inlining), mi, nothing)
mi = code.def
et = InliningEdgeTracker(inlining)
if code isa CodeInstance
if use_const_api(code)
Expand Down Expand Up @@ -1671,11 +1671,11 @@ function try_resolve_finalizer!(ir::IRCode, alloc_idx::Int, finalizer_idx::Int,
if inline === nothing
# No code in the function - Nothing to do
else
mi = finalizer_stmt.args[5]::MethodInstance
if inline::Bool && try_inline_finalizer!(ir, argexprs, loc, mi, info, inlining, attach_after)
ci = finalizer_stmt.args[5]::CodeInstance
if inline::Bool && try_inline_finalizer!(ir, argexprs, loc, ci, info, inlining, attach_after)
# the finalizer body has been inlined
else
newinst = add_flag(NewInstruction(Expr(:invoke, mi, argexprs...), Nothing), flag)
newinst = add_flag(NewInstruction(Expr(:invoke, ci, argexprs...), Nothing), flag)
insert_node!(ir, loc, newinst, attach_after)
end
end
Expand Down
8 changes: 6 additions & 2 deletions Compiler/src/ssair/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,14 @@ function print_stmt(io::IO, idx::Int, @nospecialize(stmt), code::Union{IRCode,Co
print(io, ", ")
print(io, stmt.typ)
print(io, ")")
elseif isexpr(stmt, :invoke) && length(stmt.args) >= 2 && isa(stmt.args[1], MethodInstance)
elseif isexpr(stmt, :invoke) && length(stmt.args) >= 2 && isa(stmt.args[1], Union{MethodInstance,CodeInstance})
stmt = stmt::Expr
# TODO: why is this here, and not in Base.show_unquoted
printstyled(io, " invoke "; color = :light_black)
mi = stmt.args[1]::Core.MethodInstance
mi = stmt.args[1]
if !(mi isa Core.MethodInstance)
mi = (mi::Core.CodeInstance).def
end
show_unquoted(io, stmt.args[2], indent)
print(io, "(")
# XXX: this is wrong if `sig` is not a concretetype method
Expand All @@ -110,6 +113,7 @@ function print_stmt(io::IO, idx::Int, @nospecialize(stmt), code::Union{IRCode,Co
end
join(io, (print_arg(i) for i = 3:length(stmt.args)), ", ")
print(io, ")")
# TODO: if we have a CodeInstance, should we print that rettype info here, which may differ (wider or narrower than the ssavaluetypes)
elseif isexpr(stmt, :call) && length(stmt.args) >= 1 && label_dynamic_calls
ft = maybe_argextype(stmt.args[1], code, sptypes)
f = singleton_type(ft)
Expand Down
21 changes: 15 additions & 6 deletions Compiler/src/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -449,9 +449,10 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter)
maybe_validate_code(me.linfo, me.src, "inferred")

# finish populating inference results into the CodeInstance if possible, and maybe cache that globally for use elsewhere
if isdefined(result, :ci) && !limited_ret
if isdefined(result, :ci)
result_type = result.result
@assert !(result_type === nothing || result_type isa LimitedAccuracy)
result_type isa LimitedAccuracy && (result_type = result_type.typ)
@assert !(result_type === nothing)
if isa(result_type, Const)
rettype_const = result_type.val
const_flags = is_result_constabi_eligible(result) ? 0x3 : 0x2
Expand Down Expand Up @@ -760,16 +761,24 @@ function MethodCallResult(::AbstractInterpreter, sv::AbsIntState, method::Method
return MethodCallResult(rt, exct, effects, edge, edgecycle, edgelimited, volatile_inf_result)
end

# allocate a dummy `edge::CodeInstance` to be added by `add_edges!`
function codeinst_as_edge(interp::AbstractInterpreter, sv::InferenceState)
# allocate a dummy `edge::CodeInstance` to be added by `add_edges!`, reusing an existing_edge if possible
# TODO: fill this in fully correctly (currently IPO info such as effects and return types are lost)
function codeinst_as_edge(interp::AbstractInterpreter, sv::InferenceState, @nospecialize existing_edge)
mi = sv.linfo
owner = cache_owner(interp)
min_world, max_world = first(sv.world.valid_worlds), last(sv.world.valid_worlds)
if max_world >= get_world_counter()
max_world = typemax(UInt)
end
edges = Core.svec(sv.edges...)
ci = CodeInstance(mi, owner, Any, Any, nothing, nothing, zero(Int32),
if existing_edge isa CodeInstance
# return an existing_edge, if the existing edge has more restrictions already (more edges and narrower worlds)
if existing_edge.min_world >= min_world &&
existing_edge.max_world <= max_world &&
existing_edge.edges == edges
return existing_edge
end
end
ci = CodeInstance(mi, cache_owner(interp), Any, Any, nothing, nothing, zero(Int32),
min_world, max_world, zero(UInt32), nothing, zero(UInt8), nothing, edges)
if max_world == typemax(UInt)
# if we can record all of the backedges in the global reverse-cache,
Expand Down
16 changes: 8 additions & 8 deletions Compiler/test/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ f29083(;μ,σ) = μ + σ*randn()
g29083() = f29083=2.0=0.1)
let c = code_typed(g29083, ())[1][1].code
# make sure no call to kwfunc remains
@test !any(e->(isa(e,Expr) && (e.head === :invoke && e.args[1].def.name === :kwfunc)), c)
@test !any(e->(isa(e,Expr) && (e.head === :invoke && e.args[1].def.def.name === :kwfunc)), c)
end

@testset "issue #19122: [no]inline of short func. def. with return type annotation" begin
Expand Down Expand Up @@ -723,7 +723,7 @@ mktempdir() do dir
ci, rt = only(code_typed(issue42246))
if any(ci.code) do stmt
Meta.isexpr(stmt, :invoke) &&
stmt.args[1].def.name === nameof(IOBuffer)
stmt.args[1].def.def.name === nameof(IOBuffer)
end
exit(0)
else
Expand Down Expand Up @@ -1797,7 +1797,7 @@ end

isinvokemodify(y) = @nospecialize(x) -> isinvokemodify(y, x)
isinvokemodify(sym::Symbol, @nospecialize(x)) = isinvokemodify(mi->mi.def.name===sym, x)
isinvokemodify(pred::Function, @nospecialize(x)) = isexpr(x, :invoke_modify) && pred(x.args[1]::MethodInstance)
isinvokemodify(pred::Function, @nospecialize(x)) = isexpr(x, :invoke_modify) && pred((x.args[1]::CodeInstance).def)

mutable struct Atomic{T}
@atomic x::T
Expand Down Expand Up @@ -2131,15 +2131,15 @@ let src = code_typed1((Type,)) do x
end
@test count(src.code) do @nospecialize x
isinvoke(:no_compile_sig_invokes, x) &&
(x.args[1]::MethodInstance).specTypes == Tuple{typeof(no_compile_sig_invokes),Any}
(x.args[1]::Core.CodeInstance).def.specTypes == Tuple{typeof(no_compile_sig_invokes),Any}
end == 1
end
let src = code_typed1((Type,); interp=NoCompileSigInvokes()) do x
no_compile_sig_invokes(x)
end
@test count(src.code) do @nospecialize x
isinvoke(:no_compile_sig_invokes, x) &&
(x.args[1]::MethodInstance).specTypes == Tuple{typeof(no_compile_sig_invokes),Type}
(x.args[1]::Core.CodeInstance).def.specTypes == Tuple{typeof(no_compile_sig_invokes),Type}
end == 1
end
# test the union split case
Expand All @@ -2148,19 +2148,19 @@ let src = code_typed1((Union{DataType,UnionAll},)) do x
end
@test count(src.code) do @nospecialize x
isinvoke(:no_compile_sig_invokes, x) &&
(x.args[1]::MethodInstance).specTypes == Tuple{typeof(no_compile_sig_invokes),Any}
(x.args[1]::Core.CodeInstance).def.specTypes == Tuple{typeof(no_compile_sig_invokes),Any}
end == 2
end
let src = code_typed1((Union{DataType,UnionAll},); interp=NoCompileSigInvokes()) do x
no_compile_sig_invokes(x)
end
@test count(src.code) do @nospecialize x
isinvoke(:no_compile_sig_invokes, x) &&
(x.args[1]::MethodInstance).specTypes == Tuple{typeof(no_compile_sig_invokes),DataType}
(x.args[1]::Core.CodeInstance).def.specTypes == Tuple{typeof(no_compile_sig_invokes),DataType}
end == 1
@test count(src.code) do @nospecialize x
isinvoke(:no_compile_sig_invokes, x) &&
(x.args[1]::MethodInstance).specTypes == Tuple{typeof(no_compile_sig_invokes),UnionAll}
(x.args[1]::Core.CodeInstance).def.specTypes == Tuple{typeof(no_compile_sig_invokes),UnionAll}
end == 1
end

Expand Down
Loading

0 comments on commit c31710a

Please sign in to comment.