Skip to content

Commit

Permalink
aot: move jl_insert_backedges to Julia side
Browse files Browse the repository at this point in the history
With #56447, the dependency between `jl_insert_backedges`
and method insertion has been eliminated, allowing `jl_insert_backedges`
to be performed after loading. As a result, it is now possible to move
`jl_insert_backedges` to the Julia side.

Currently this commit simply moves the implementation without adding
any new features.
  • Loading branch information
aviatesk committed Dec 10, 2024
1 parent 7192df7 commit 1d190d1
Show file tree
Hide file tree
Showing 5 changed files with 312 additions and 370 deletions.
301 changes: 298 additions & 3 deletions base/loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1280,10 +1280,12 @@ function _include_from_serialized(pkg::PkgId, path::String, ocachepath::Union{No
sv = try
if ocachepath !== nothing
@debug "Loading object cache file $ocachepath for $(repr("text/plain", pkg))"
ccall(:jl_restore_package_image_from_file, Any, (Cstring, Any, Cint, Cstring, Cint), ocachepath, depmods, false, pkg.name, ignore_native)
ccall(:jl_restore_package_image_from_file, Any, (Cstring, Any, Cint, Cstring, Cint),
ocachepath, depmods, #=completeinfo=#false, pkg.name, ignore_native)
else
@debug "Loading cache file $path for $(repr("text/plain", pkg))"
ccall(:jl_restore_incremental, Any, (Cstring, Any, Cint, Cstring), path, depmods, false, pkg.name)
ccall(:jl_restore_incremental, Any, (Cstring, Any, Cint, Cstring),
path, depmods, #=completeinfo=#false, pkg.name)
end
finally
lock(require_lock)
Expand All @@ -1292,6 +1294,10 @@ function _include_from_serialized(pkg::PkgId, path::String, ocachepath::Union{No
return sv
end

edges = sv[3]::Vector{Any}
ext_edges = sv[4]::Union{Nothing,Vector{Any}}
insert_backedges(edges, ext_edges)

restored = register_restored_modules(sv, pkg, path)

for M in restored
Expand Down Expand Up @@ -1413,6 +1419,295 @@ function register_restored_modules(sv::SimpleVector, pkg::PkgId, path::String)
return restored
end

# Restore backedges to external targets
# `edges` = [caller1, ...], the list of worklist-owned code instances internally
# `ext_ci_list` = [caller1, ...], the list of worklist-owned code instances externally
function insert_backedges(edges::Vector{Any}, ext_ci_list::Union{Nothing,Vector{Any}})
# determine which CodeInstance objects are still valid in our image
# to enable any applicable new codes
stack = CodeInstance[]
visiting = IdDict{CodeInstance,Int}()
_insert_backedges(edges, stack, visiting)
if ext_ci_list !== nothing
_insert_backedges(ext_ci_list, stack, visiting, #=external=#true)
end
end

function _insert_backedges(edges::Vector{Any}, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}, external::Bool=false)
for i = 1:length(edges)
codeinst = edges[i]::CodeInstance
verify_method_graph(codeinst, stack, visiting)
minvalid = codeinst.min_world
maxvalid = codeinst.max_world
if maxvalid minvalid
if get_world_counter() == maxvalid
# if this callee is still valid, add all the backedges
Core.Compiler.store_backedges(codeinst, codeinst.edges)
end
if get_world_counter() == maxvalid
maxvalid = typemax(UInt)
@atomic codeinst.max_world = maxvalid
end
if external
caller = codeinst.def
@assert isdefined(codeinst, :inferred) # See #53586, #53109
inferred = @ccall jl_rettype_inferred(
codeinst.owner::Any, caller::Any, minvalid::UInt, maxvalid::UInt)::Any
if inferred !== nothing
# We already got a code instance for this world age range from
# somewhere else - we don't need this one.
else
@ccall jl_mi_cache_insert(caller::Any, codeinst::Any)::Cvoid
end
end
end
end
end

function verify_method_graph(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int})
@assert isempty(stack)
@assert isempty(visiting)
child_cycle, minworld, maxworld = verify_method(codeinst, stack, visiting)
@assert child_cycle == 0
@assert isempty(stack)
empty!(visiting)
if Threads.maxthreadid() == 1 # a different thread might simultaneously come to a different, but equally valid, alternative result
@assert maxworld == 0 || codeinst.min_world == minworld
@assert codeinst.max_world == maxworld
end
end

const WORLD_AGE_REVALIDATION_SENTINEL::UInt = 1 # needs to sync with staticdata.c
const _jl_debug_method_invalidation = Ref{Union{Nothing,Vector{Any}}}(nothing)
debug_method_invalidation(onoff::Bool) =
_jl_debug_method_invalidation[] = onoff ? Any[] : nothing

# Test all edges relevant to a method:
# - Visit the entire call graph, starting from edges[idx] to determine if that method is valid
# - Implements Tarjan's SCC (strongly connected components) algorithm, simplified to remove the count variable
# and slightly modified with an early termination option once the computation reaches its minimum
function verify_method(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int})
world = codeinst.min_world
let max_valid2 = codeinst.max_world
if max_valid2 WORLD_AGE_REVALIDATION_SENTINEL
return 0, world, max_valid2
end
end
current_world = get_world_counter()
local minworld::UInt, maxworld::UInt = 1, current_world
@assert codeinst.def.def isa Method
if haskey(visiting, codeinst)
return visiting[codeinst], minworld, maxworld
end
push!(stack, codeinst)
depth = length(stack)
visiting[codeinst] = depth
# TODO JL_TIMING(VERIFY_IMAGE, VERIFY_Methods)
callees = codeinst.edges
# verify current edges
if isempty(callees)
# quick return: no edges to verify (though we probably shouldn't have gotten here from WORLD_AGE_REVALIDATION_SENTINEL)
elseif maxworld == unsafe_load(cglobal(:jl_require_world, UInt))
# if no new worlds were allocated since serializing the base module, then no new validation is worth doing right now either
minworld = maxworld
else
j = 1
while j length(callees)
local min_valid2::UInt, max_valid2::UInt
edge = callees[j]
@assert !(edge isa Method) # `Method`-edge isn't allowed for the optimized one-edge format
if edge isa CodeInstance
edge = edge.def
end
if edge isa MethodInstance
sig = typeintersect((edge.def::Method).sig, edge.specTypes) # TODO??
min_valid2, max_valid2, matches = verify_call(sig, callees, j, 1, world)
j += 1
sig = nothing
elseif edge isa Int
sig = callees[j+1]
min_valid2, max_valid2, matches = verify_call(sig, callees, j+2, edge, world)
j += 2 + edge
edge = sig
else
callee = callees[j+1]
if callee isa Core.MethodTable # skip the legacy edge (missing backedge)
j += 2
continue
end
if callee isa CodeInstance
callee = callee.def
end
if callee isa MethodInstance
meth = callee.def::Method
else
meth = callee::Method
end
min_valid2, max_valid2 = verify_invokesig(edge, meth, world)
matches = nothing
j += 2
end
if minworld < min_valid2
minworld = min_valid2
end
if maxworld > max_valid2
maxworld = max_valid2
end
invalidations = _jl_debug_method_invalidation[]
if max_valid2 typemax(UInt) && invalidations !== nothing
push!(invalidations, edge, "insert_backedges_callee", codeinst, matches)
end
if max_valid2 == 0 && invalidations === nothing
break
end
end
end
# verify recursive edges (if valid, or debugging)
cycle = depth
cause = codeinst
if maxworld 0 || _jl_debug_method_invalidation[] !== nothing
for j = 1:length(callees)
local min_valid2::UInt, max_valid2::UInt
edge = callees[j]
if !(edge isa CodeInstance)
continue
end
callee = edge
child_cycle, min_valid2, max_valid2 = verify_method(callee, stack, visiting)
if minworld < min_valid2
minworld = min_valid2
end
if minworld > max_valid2
max_valid2 = 0
end
if maxworld > max_valid2
cause = callee
maxworld = max_valid2
end
if max_valid2 == 0
# found what we were looking for, so terminate early
break
elseif child_cycle 0 && child_cycle < cycle
# record the cycle will resolve at depth "cycle"
cycle = child_cycle;
end
end
end
if maxworld 0 && cycle depth
return cycle, minworld, maxworld
end
# If we are the top of the current cycle, now mark all other parts of
# our cycle with what we found.
# Or if we found a failed edge, also mark all of the other parts of the
# cycle as also having a failed edge.
while length(stack) depth
child = pop!(stack)
if Threads.maxthreadid() == 1 # a different thread might simultaneously come to a different, but equally valid, alternative result
@assert child.max_world == WORLD_AGE_REVALIDATION_SENTINEL
@assert minworld child.min_world
end
if maxworld 0
@atomic child.min_world = minworld
end
@atomic child.max_world = maxworld
@assert visiting[codeinst] == length(stack) + 1
delete!(visiting, codeinst)
invalidations = _jl_debug_method_invalidation[]
if invalidations !== nothing && maxworld < current_world
push!(invalidations, child, "verify_methods", cause)
end
end
return 0, minworld, maxworld
end

function verify_call(@nospecialize(sig), expecteds::SimpleVector, i::Int, n::Int, world::UInt)
# verify that these edges intersect with the same methods as before
lim = _jl_debug_method_invalidation[] !== nothing ? Int(typemax(Int32)) : n
minworld = RefValue{UInt}(1)
maxworld = RefValue{UInt}(typemax(UInt))
has_ambig = RefValue{Int32}(0)
result = _methods_by_ftype(sig, nothing, lim, world, #=ambig=#false, minworld, maxworld, has_ambig)
if result === nothing
maxworld[] = 0
else
# setdiff!(result, expected)
if length(result) n
maxworld[] = 0
end
ins = 0
for k = 1:length(result)
match = result[k]::Core.MethodMatch
local found = false
for j = 1:n
t = expecteds[i+j-1]
if t isa CodeInstance
t = t.def
end
if t isa Method
meth = t
else
meth = (t::MethodInstance).def::Method
end
if match.method == meth
found = true
break
end
end
if !found
# intersection has a new method or a method was
# deleted--this is now probably no good, just invalidate
# everything about it now
maxworld[] = 0
if _jl_debug_method_invalidation[] === nothing
break
end
ins += 1
result[ins] = match
end
end
if maxworld[] typemax(UInt) && _jl_debug_method_invalidation[] !== nothing
resize!(result, ins)
end
end
return minworld[], maxworld[], result
end

function verify_invokesig(@nospecialize(invokesig), expected::Method, world::UInt)
@assert invokesig isa Type
local minworld::UInt, maxworld::UInt
if invokesig === expected.sig
# the invoke match is `expected` for `expected->sig`, unless `expected` is invalid
minworld = expected.primary_world
maxworld = expected.deleted_world
@assert minworld world
if maxworld < world
maxworld = 0
end
else
minworld = 1
maxworld = typemax(UInt)
mt = get_methodtable(expected)
if mt === nothing
maxworld = 0
else
min_valid = RefValue{UInt}(minworld)
max_valid = RefValue{UInt}(maxworld)
matches = @ccall jl_gf_invoke_lookup_worlds(
invokesig::Any, mt::Any, world::UInt, min_valid::Ptr{Csize_t}, max_valid::Ptr{Csize_t})::Any
minworld, maxworld = min_valid[], max_valid[]
if matches === nothing
maxworld = 0
else
matches = matches::Core.MethodMatch
if matches.method != expected
maxworld = 0
end
end
end
end
return minworld, maxworld
end

function run_module_init(mod::Module, i::Int=1)
# `i` informs ordering for the `@time_imports` report formatting
if TIMING_IMPORTS[] == 0
Expand Down Expand Up @@ -4198,7 +4493,7 @@ function precompile(@nospecialize(argt::Type))
end

# Variants that work for `invoke`d calls for which the signature may not be sufficient
precompile(mi::Core.MethodInstance, world::UInt=get_world_counter()) =
precompile(mi::MethodInstance, world::UInt=get_world_counter()) =
(ccall(:jl_compile_method_instance, Cvoid, (Any, Ptr{Cvoid}, UInt), mi, C_NULL, world); return true)

"""
Expand Down
2 changes: 1 addition & 1 deletion base/pcre.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ THREAD_MATCH_CONTEXTS::Vector{Ptr{Cvoid}} = [C_NULL]
PCRE_COMPILE_LOCK = nothing

_tid() = Int(ccall(:jl_threadid, Int16, ())) + 1
_mth() = Int(Core.Intrinsics.atomic_pointerref(cglobal(:jl_n_threads, Cint), :acquire))
_mth() = Base.Threads.maxthreadid()

function get_local_match_context()
tid = _tid()
Expand Down
13 changes: 6 additions & 7 deletions src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ External links:
#include "julia_assert.h"

static const size_t WORLD_AGE_REVALIDATION_SENTINEL = 0x1;
size_t jl_require_world = ~(size_t)0;
JL_DLLEXPORT size_t jl_require_world = ~(size_t)0;

#include "staticdata_utils.c"
#include "precompile_utils.c"
Expand Down Expand Up @@ -4104,12 +4104,12 @@ static jl_value_t *jl_restore_package_image_from_stream(void* pkgimage_handle, i
jl_atomic_store_release(&jl_world_counter, world);
// now permit more methods to be added again
JL_UNLOCK(&world_counter_lock);
// but one of those immediate users is going to be our cache insertions
jl_insert_backedges((jl_array_t*)edges, (jl_array_t*)new_ext_cis); // restore existing caches (needs to be last)
// reinit ccallables
jl_reinit_ccallable(&ccallable_list, base, pkgimage_handle);
arraylist_free(&ccallable_list);

jl_value_t *ext_edges = new_ext_cis ? (jl_value_t*)new_ext_cis : jl_nothing;

if (completeinfo) {
cachesizes_sv = jl_alloc_svec(7);
jl_svecset(cachesizes_sv, 0, jl_box_long(cachesizes.sysdata));
Expand All @@ -4119,12 +4119,11 @@ static jl_value_t *jl_restore_package_image_from_stream(void* pkgimage_handle, i
jl_svecset(cachesizes_sv, 4, jl_box_long(cachesizes.reloclist));
jl_svecset(cachesizes_sv, 5, jl_box_long(cachesizes.gvarlist));
jl_svecset(cachesizes_sv, 6, jl_box_long(cachesizes.fptrlist));
restored = (jl_value_t*)jl_svec(7, restored, init_order, extext_methods,
new_ext_cis ? (jl_value_t*)new_ext_cis : jl_nothing,
method_roots_list, edges, cachesizes_sv);
restored = (jl_value_t*)jl_svec(7, restored, init_order, edges, ext_edges,
extext_methods, method_roots_list, cachesizes_sv);
}
else {
restored = (jl_value_t*)jl_svec(2, restored, init_order);
restored = (jl_value_t*)jl_svec(4, restored, init_order, edges, ext_edges);
}
}
}
Expand Down
Loading

0 comments on commit 1d190d1

Please sign in to comment.