Skip to content

Commit

Permalink
aot: move jl_insert_backedges to Julia side
Browse files Browse the repository at this point in the history
With #56447, the dependency between `jl_insert_backedges`
and method insertion has been eliminated, allowing `jl_insert_backedges`
to be performed after loading. As a result, it is now possible to move
`jl_insert_backedges` to the Julia side.

Currently this commit simply moves the implementation without adding
any new features.
  • Loading branch information
aviatesk committed Nov 19, 2024
1 parent af9e6e3 commit 2c8035a
Show file tree
Hide file tree
Showing 7 changed files with 316 additions and 371 deletions.
9 changes: 4 additions & 5 deletions Compiler/src/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -488,14 +488,13 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter)
end

# record the backedges
function store_backedges(caller::CodeInstance, edges::Vector{Any})
function store_backedges(caller::CodeInstance, edges::Union{Vector{Any},SimpleVector})
isa(caller.def.def, Method) || return # don't add backedges to toplevel method instance
for itr in BackedgeIterator(edges)
callee = itr.caller
for (; callee, sig) in BackedgeIterator(edges)
if isa(callee, MethodInstance)
ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), callee, itr.sig, caller)
ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), callee, sig, caller)
else
ccall(:jl_method_table_add_backedge, Cvoid, (Any, Any, Any), callee, itr.sig, caller)
ccall(:jl_method_table_add_backedge, Cvoid, (Any, Any, Any), callee, sig, caller)
end
end
nothing
Expand Down
18 changes: 9 additions & 9 deletions Compiler/src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,11 @@ is_no_constprop(method::Union{Method,CodeInfo}) = method.constprop == 0x02
"""
BackedgeIterator(backedges::Vector{Any})
Return an iterator over a list of backedges. Iteration returns `(sig, caller)` elements,
Return an iterator over a list of backedges. Iteration returns `(sig, callee)` elements,
which will be one of the following:
- `BackedgePair(nothing, caller::MethodInstance)`: a call made by ordinary inferable dispatch
- `BackedgePair(invokesig::Type, caller::MethodInstance)`: a call made by `invoke(f, invokesig, args...)`
- `BackedgePair(nothing, callee::MethodInstance)`: a call made by ordinary inferable dispatch
- `BackedgePair(invokesig::Type, callee::MethodInstance)`: a call made by `invoke(f, invokesig, args...)`
- `BackedgePair(specsig::Type, mt::MethodTable)`: an abstract call
# Examples
Expand All @@ -234,22 +234,22 @@ julia> callyou(2.0)
julia> mi = which(callme, (Any,)).specializations
MethodInstance for callme(::Float64)
julia> @eval Core.Compiler for (; sig, caller) in BackedgeIterator(Main.mi.backedges)
julia> @eval Core.Compiler for (; sig, callee) in BackedgeIterator(Main.mi.backedges)
println(sig)
println(caller)
println(callee)
end
nothing
callyou(Float64) from callyou(Any)
```
"""
struct BackedgeIterator
backedges::Vector{Any}
struct BackedgeIterator{Vec}
backedges::Vec
end

struct BackedgePair
sig # ::Union{Nothing,Type}
caller::Union{MethodInstance,MethodTable}
BackedgePair(@nospecialize(sig), caller::Union{MethodInstance,MethodTable}) = new(sig, caller)
callee::Union{MethodInstance,MethodTable}
BackedgePair(@nospecialize(sig), callee::Union{MethodInstance,MethodTable}) = new(sig, callee)
end

function iterate(iter::BackedgeIterator, i::Int=1)
Expand Down
293 changes: 290 additions & 3 deletions base/loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1280,10 +1280,12 @@ function _include_from_serialized(pkg::PkgId, path::String, ocachepath::Union{No
sv = try
if ocachepath !== nothing
@debug "Loading object cache file $ocachepath for $(repr("text/plain", pkg))"
ccall(:jl_restore_package_image_from_file, Any, (Cstring, Any, Cint, Cstring, Cint), ocachepath, depmods, false, pkg.name, ignore_native)
ccall(:jl_restore_package_image_from_file, Any, (Cstring, Any, Cint, Cstring, Cint),
ocachepath, depmods, #=completeinfo=#false, pkg.name, ignore_native)
else
@debug "Loading cache file $path for $(repr("text/plain", pkg))"
ccall(:jl_restore_incremental, Any, (Cstring, Any, Cint, Cstring), path, depmods, false, pkg.name)
ccall(:jl_restore_incremental, Any, (Cstring, Any, Cint, Cstring),
path, depmods, #=completeinfo=#false, pkg.name)
end
finally
lock(require_lock)
Expand All @@ -1292,6 +1294,10 @@ function _include_from_serialized(pkg::PkgId, path::String, ocachepath::Union{No
return sv
end

edges = sv[3]::Vector{Any}
ext_edges = sv[4]::Union{Nothing,Vector{Any}}
insert_backedges(edges, ext_edges)

restored = register_restored_modules(sv, pkg, path)

for M in restored
Expand Down Expand Up @@ -1413,6 +1419,287 @@ function register_restored_modules(sv::SimpleVector, pkg::PkgId, path::String)
return restored
end

# Restore backedges to external targets
# `edges` = [caller1, ...], the list of worklist-owned code instances internally
# `ext_ci_list` = [caller1, ...], the list of worklist-owned code instances externally
function insert_backedges(edges::Vector{Any}, ext_ci_list::Union{Nothing,Vector{Any}})
# determine which CodeInstance objects are still valid in our image
# to enable any applicable new codes
stack = CodeInstance[]
visiting = IdDict{CodeInstance,Int}()
_insert_backedges(edges, stack, visiting)
if ext_ci_list !== nothing
_insert_backedges(ext_ci_list, stack, visiting, #=external=#true)
end
end

function _insert_backedges(edges::Vector{Any}, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}, external::Bool=false)
for i = 1:length(edges)
codeinst = edges[i]::CodeInstance
verify_method_graph(codeinst, stack, visiting)
minvalid = codeinst.min_world
maxvalid = codeinst.max_world
if maxvalid minvalid
if get_world_counter() == maxvalid
# if this callee is still valid, add all the backedges
Core.Compiler.store_backedges(codeinst, codeinst.edges)
end
if get_world_counter() == maxvalid
maxvalid = typemax(UInt)
@atomic codeinst.max_world = maxvalid
end
if external
caller = codeinst.def
@assert isdefined(codeinst, :inferred) # See #53586, #53109
inferred = @ccall jl_rettype_inferred(
codeinst.owner::Any, caller::Any, minvalid::UInt, maxvalid::UInt)::Any
if inferred !== nothing
# We already got a code instance for this world age range from
# somewhere else - we don't need this one.
else
@ccall jl_mi_cache_insert(caller::Any, codeinst::Any)::Cvoid
end
end
end
end
end

function verify_method_graph(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int})
@assert isempty(stack)
@assert isempty(visiting)
child_cycle, minworld, maxworld = verify_method(codeinst, stack, visiting)
@assert child_cycle == 0
@assert isempty(stack)
empty!(visiting)
if Threads.maxthreadid() == 1 # a different thread might simultaneously come to a different, but equally valid, alternative result
@assert maxworld == 0 || codeinst.min_world == minworld
@assert codeinst.max_world == maxworld
end
end

const WORLD_AGE_REVALIDATION_SENTINEL::UInt = 1 # needs to sync with staticdata.c
const _jl_debug_method_invalidation = Ref{Union{Nothing,Vector{Any}}}(nothing)
debug_method_invalidation(onoff::Bool) =
_jl_debug_method_invalidation[] = onoff ? Any[] : nothing

# Test all edges relevant to a method:
# - Visit the entire call graph, starting from edges[idx] to determine if that method is valid
# - Implements Tarjan's SCC (strongly connected components) algorithm, simplified to remove the count variable
# and slightly modified with an early termination option once the computation reaches its minimum
function verify_method(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int})
world = codeinst.min_world
let max_valid2 = codeinst.max_world
if max_valid2 WORLD_AGE_REVALIDATION_SENTINEL
return 0, world, max_valid2
end
end
current_world = get_world_counter()
local minworld::UInt, maxworld::UInt = 1, current_world
@assert codeinst.def.def isa Method
if haskey(visiting, codeinst)
return visiting[codeinst], minworld, maxworld
end
push!(stack, codeinst)
depth = length(stack)
visiting[codeinst] = depth
# TODO JL_TIMING(VERIFY_IMAGE, VERIFY_Methods)
callees = codeinst.edges
# verify current edges
j = 1
while j length(callees)
local min_valid2::UInt, max_valid2::UInt
edge = callees[j]
@assert !(edge isa Method) # `Method`-edge isn't allowed for the optimized one-edge format
if edge isa CodeInstance
edge = edge.def
end
if edge isa MethodInstance
sig = typeintersect((edge.def::Method).sig, edge.specTypes) # TODO??
min_valid2, max_valid2, matches = verify_call(sig, callees, j, 1, world)
j += 1
sig = nothing
elseif edge isa Int
sig = callees[j+1]
min_valid2, max_valid2, matches = verify_call(sig, callees, j+2, edge, world)
j += 2 + edge
edge = sig
elseif edge isa Core.MethodTable # skip the legacy edge (missing backedge)
j += 2
continue
else
callee = callees[j+1]
if callee isa CodeInstance
callee = callee.def
end
if callee isa MethodInstance
meth = callee.def::Method
else
meth = callee::Method
end
min_valid2, max_valid2 = verify_invokesig(edge, meth, world)
matches = nothing
j += 2
end
if minworld < min_valid2
minworld = min_valid2
end
if maxworld > max_valid2
maxworld = max_valid2
end
invalidations = _jl_debug_method_invalidation[]
if max_valid2 typemax(UInt) && invalidations !== nothing
push!(invalidations, edge, "insert_backedges_callee", codeinst, matches)
end
if max_valid2 == 0 && invalidations === nothing
break
end
end
# verify recursive edges (if valid, or debugging)
cycle = depth
cause = codeinst
if maxworld 0 || _jl_debug_method_invalidation[] !== nothing
for j = 1:length(callees)
local min_valid2::UInt, max_valid2::UInt
edge = callees[j]
if !(edge isa CodeInstance)
continue
end
callee = edge
child_cycle, min_valid2, max_valid2 = verify_method(callee, stack, visiting)
if minworld < min_valid2
minworld = min_valid2
end
if minworld > max_valid2
max_valid2 = 0
end
if maxworld > max_valid2
cause = callee
maxworld = max_valid2
end
if max_valid2 == 0
# found what we were looking for, so terminate early
break
elseif child_cycle 0 && child_cycle < cycle
# record the cycle will resolve at depth "cycle"
cycle = child_cycle;
end
end
end
if maxworld 0 && cycle depth
return cycle, minworld, maxworld
end
# If we are the top of the current cycle, now mark all other parts of
# our cycle with what we found.
# Or if we found a failed edge, also mark all of the other parts of the
# cycle as also having a failed edge.
while length(stack) depth
child = pop!(stack)
if Threads.maxthreadid() == 1 # a different thread might simultaneously come to a different, but equally valid, alternative result
@assert child.max_world == WORLD_AGE_REVALIDATION_SENTINEL
@assert minworld child.min_world
end
if maxworld 0
@atomic child.min_world = minworld
end
@atomic child.max_world = maxworld
@assert visiting[codeinst] == length(stack) + 1
delete!(visiting, codeinst)
invalidations = _jl_debug_method_invalidation[]
if invalidations !== nothing && maxworld < current_world
push!(invalidations, child, "verify_methods", cause)
end
end
return 0, minworld, maxworld
end

function verify_call(@nospecialize(sig), expecteds::SimpleVector, i::Int, n::Int, world::UInt)
# verify that these edges intersect with the same methods as before
lim = _jl_debug_method_invalidation[] !== nothing ? Int(typemax(Int32)) : n
minworld = RefValue{UInt}(1)
maxworld = RefValue{UInt}(typemax(UInt))
has_ambig = RefValue{Int32}(0)
result = _methods_by_ftype(sig, nothing, lim, world, #=ambig=#false, minworld, maxworld, has_ambig)
if result === nothing
maxworld[] = 0
else
# setdiff!(result, expected)
if length(result) n
maxworld[] = 0
end
ins = 0
for k = 1:length(result)
match = result[k]::Core.MethodMatch
local found = false
for j = 1:n
t = expecteds[i+j-1]
if t isa CodeInstance
t = t.def
end
if t isa Method
meth = t
else
meth = (t::MethodInstance).def::Method
end
if match.method == meth
found = true
break
end
end
if !found
# intersection has a new method or a method was
# deleted--this is now probably no good, just invalidate
# everything about it now
maxworld[] = 0
if _jl_debug_method_invalidation[] === nothing
break
end
ins += 1
result[ins] = match
end
end
if maxworld[] typemax(UInt) && _jl_debug_method_invalidation[] !== nothing
resize!(result, ins)
end
end
return minworld[], maxworld[], result
end

function verify_invokesig(@nospecialize(invokesig), expected::Method, world::UInt)
@assert invokesig isa Type
local minworld::UInt, maxworld::UInt
if invokesig === expected.sig
# the invoke match is `expected` for `expected->sig`, unless `expected` is invalid
minworld = expected.primary_world
maxworld = expected.deleted_world
@assert minworld world
if maxworld < world
maxworld = 0
end
else
minworld = 1
maxworld = typemax(UInt)
mt = get_methodtable(expected)
if mt === nothing
maxworld = 0
else
min_valid = RefValue{UInt}(minworld)
max_valid = RefValue{UInt}(maxworld)
matches = @ccall jl_gf_invoke_lookup_worlds(
invokesig::Any, mt::Any, world::UInt, min_valid::Ptr{Csize_t}, max_valid::Ptr{Csize_t})::Any
minworld, maxworld = min_valid[], max_valid[]
if matches === nothing
maxworld = 0
else
matches = matches::Core.MethodMatch
if matches.method != expected
maxworld = 0
end
end
end
end
return minworld, maxworld
end

function run_module_init(mod::Module, i::Int=1)
# `i` informs ordering for the `@time_imports` report formatting
if TIMING_IMPORTS[] == 0
Expand Down Expand Up @@ -4190,7 +4477,7 @@ function precompile(@nospecialize(argt::Type))
end

# Variants that work for `invoke`d calls for which the signature may not be sufficient
precompile(mi::Core.MethodInstance, world::UInt=get_world_counter()) =
precompile(mi::MethodInstance, world::UInt=get_world_counter()) =
(ccall(:jl_compile_method_instance, Cvoid, (Any, Ptr{Cvoid}, UInt), mi, C_NULL, world); return true)

"""
Expand Down
2 changes: 1 addition & 1 deletion base/pcre.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ THREAD_MATCH_CONTEXTS::Vector{Ptr{Cvoid}} = [C_NULL]
PCRE_COMPILE_LOCK = nothing

_tid() = Int(ccall(:jl_threadid, Int16, ())) + 1
_mth() = Int(Core.Intrinsics.atomic_pointerref(cglobal(:jl_n_threads, Cint), :acquire))
_mth() = Base.Threads.maxthreadid()

function get_local_match_context()
tid = _tid()
Expand Down
Loading

0 comments on commit 2c8035a

Please sign in to comment.