Skip to content

Commit

Permalink
aot: move jl_insert_backedges to Julia side
Browse files Browse the repository at this point in the history
With #56447, the dependency between `jl_insert_backedges`
and method insertion has been eliminated, allowing `jl_insert_backedges`
to be performed after loading. As a result, it is now possible to move
`jl_insert_backedges` to the Julia side.

Currently this commit simply moves the implementation without adding
any new features.
  • Loading branch information
aviatesk committed Jan 14, 2025
1 parent 9b1ea1a commit 0b2575d
Show file tree
Hide file tree
Showing 7 changed files with 330 additions and 384 deletions.
5 changes: 2 additions & 3 deletions Compiler/src/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -560,12 +560,11 @@ function store_backedges(caller::CodeInstance, edges::SimpleVector)
ccall(:jl_method_table_add_backedge, Cvoid, (Any, Any, Any), callee, item, caller)
i += 2
continue
end
# `invoke` edge
if isa(callee, Method)
elseif isa(callee, Method)
# ignore `Method`-edges (from e.g. failed `abstract_call_method`)
i += 2
continue
# `invoke` edge
elseif isa(callee, CodeInstance)
callee = get_ci_mi(callee)
end
Expand Down
1 change: 1 addition & 0 deletions base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ include("uuid.jl")
include("pkgid.jl")
include("toml_parser.jl")
include("linking.jl")
include("staticdata.jl")
include("loading.jl")

# misc useful functions & macros
Expand Down
17 changes: 10 additions & 7 deletions base/loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1280,17 +1280,20 @@ function _include_from_serialized(pkg::PkgId, path::String, ocachepath::Union{No
sv = try
if ocachepath !== nothing
@debug "Loading object cache file $ocachepath for $(repr("text/plain", pkg))"
ccall(:jl_restore_package_image_from_file, Any, (Cstring, Any, Cint, Cstring, Cint), ocachepath, depmods, false, pkg.name, ignore_native)
ccall(:jl_restore_package_image_from_file, Ref{SimpleVector}, (Cstring, Any, Cint, Cstring, Cint),
ocachepath, depmods, #=completeinfo=#false, pkg.name, ignore_native)
else
@debug "Loading cache file $path for $(repr("text/plain", pkg))"
ccall(:jl_restore_incremental, Any, (Cstring, Any, Cint, Cstring), path, depmods, false, pkg.name)
ccall(:jl_restore_incremental, Ref{SimpleVector}, (Cstring, Any, Cint, Cstring),
path, depmods, #=completeinfo=#false, pkg.name)
end
finally
lock(require_lock)
end
if isa(sv, Exception)
return sv
end

edges = sv[3]::Vector{Any}
ext_edges = sv[4]::Union{Nothing,Vector{Any}}
StaticData.insert_backedges(edges, ext_edges)

restored = register_restored_modules(sv, pkg, path)

Expand Down Expand Up @@ -4198,7 +4201,7 @@ function precompile(@nospecialize(argt::Type))
end

# Variants that work for `invoke`d calls for which the signature may not be sufficient
precompile(mi::Core.MethodInstance, world::UInt=get_world_counter()) =
precompile(mi::MethodInstance, world::UInt=get_world_counter()) =
(ccall(:jl_compile_method_instance, Cvoid, (Any, Ptr{Cvoid}, UInt), mi, C_NULL, world); return true)

"""
Expand All @@ -4214,7 +4217,7 @@ end

function precompile(@nospecialize(argt::Type), m::Method)
atype, sparams = ccall(:jl_type_intersection_with_env, Any, (Any, Any), argt, m.sig)::SimpleVector
mi = Core.Compiler.specialize_method(m, atype, sparams)
mi = Base.Compiler.specialize_method(m, atype, sparams)
return precompile(mi)
end

Expand Down
303 changes: 303 additions & 0 deletions base/staticdata.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

module StaticData

using Core: CodeInstance, MethodInstance
using Base: get_world_counter

const WORLD_AGE_REVALIDATION_SENTINEL::UInt = 1 # needs to sync with staticdata.c
const _jl_debug_method_invalidation = Ref{Union{Nothing,Vector{Any}}}(nothing)
debug_method_invalidation(onoff::Bool) =
_jl_debug_method_invalidation[] = onoff ? Any[] : nothing

function get_ci_mi(codeinst::CodeInstance)
def = codeinst.def
if def isa Core.ABIOverride
return def.def
else
return def::MethodInstance
end
end

# Restore backedges to external targets
# `edges` = [caller1, ...], the list of worklist-owned code instances internally
# `ext_ci_list` = [caller1, ...], the list of worklist-owned code instances externally
function insert_backedges(edges::Vector{Any}, ext_ci_list::Union{Nothing,Vector{Any}})
# determine which CodeInstance objects are still valid in our image
# to enable any applicable new codes
stack = CodeInstance[]
visiting = IdDict{CodeInstance,Int}()
_insert_backedges(edges, stack, visiting)
if ext_ci_list !== nothing
_insert_backedges(ext_ci_list, stack, visiting, #=external=#true)
end
end

function _insert_backedges(edges::Vector{Any}, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}, external::Bool=false)
for i = 1:length(edges)
codeinst = edges[i]::CodeInstance
verify_method_graph(codeinst, stack, visiting)
minvalid = codeinst.min_world
maxvalid = codeinst.max_world
if maxvalid minvalid
if get_world_counter() == maxvalid
# if this callee is still valid, add all the backedges
Base.Compiler.store_backedges(codeinst, codeinst.edges)
end
if get_world_counter() == maxvalid
maxvalid = typemax(UInt)
@atomic codeinst.max_world = maxvalid
end
if external
caller = get_ci_mi(codeinst)
@assert isdefined(codeinst, :inferred) # See #53586, #53109
inferred = @ccall jl_rettype_inferred(
codeinst.owner::Any, caller::Any, minvalid::UInt, maxvalid::UInt)::Any
if inferred !== nothing
# We already got a code instance for this world age range from
# somewhere else - we don't need this one.
else
@ccall jl_mi_cache_insert(caller::Any, codeinst::Any)::Cvoid
end
end
end
end
end

function verify_method_graph(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int})
@assert isempty(stack); @assert isempty(visiting);
child_cycle, minworld, maxworld = verify_method(codeinst, stack, visiting)
@assert child_cycle == 0
@assert isempty(stack); @assert isempty(visiting);
if Threads.maxthreadid() == 1 # a different thread might simultaneously come to a different, but equally valid, alternative result
@assert maxworld == 0 || codeinst.min_world == minworld
@assert codeinst.max_world == maxworld
end
end

# Test all edges relevant to a method:
# - Visit the entire call graph, starting from edges[idx] to determine if that method is valid
# - Implements Tarjan's SCC (strongly connected components) algorithm, simplified to remove the count variable
# and slightly modified with an early termination option once the computation reaches its minimum
function verify_method(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int})
world = codeinst.min_world
let max_valid2 = codeinst.max_world
if max_valid2 WORLD_AGE_REVALIDATION_SENTINEL
return 0, world, max_valid2
end
end
current_world = get_world_counter()
local minworld::UInt, maxworld::UInt = 1, current_world
@assert get_ci_mi(codeinst).def isa Method
if haskey(visiting, codeinst)
return visiting[codeinst], minworld, maxworld
end
push!(stack, codeinst)
depth = length(stack)
visiting[codeinst] = depth
# TODO JL_TIMING(VERIFY_IMAGE, VERIFY_Methods)
callees = codeinst.edges
# verify current edges
if isempty(callees)
# quick return: no edges to verify (though we probably shouldn't have gotten here from WORLD_AGE_REVALIDATION_SENTINEL)
elseif maxworld == unsafe_load(cglobal(:jl_require_world, UInt))
# if no new worlds were allocated since serializing the base module, then no new validation is worth doing right now either
minworld = maxworld
else
j = 1
while j length(callees)
local min_valid2::UInt, max_valid2::UInt
edge = callees[j]
@assert !(edge isa Method) # `Method`-edge isn't allowed for the optimized one-edge format
if edge isa Core.BindingPartition
j += 1
continue
end
if edge isa CodeInstance
edge = get_ci_mi(edge)
end
if edge isa MethodInstance
sig = typeintersect((edge.def::Method).sig, edge.specTypes) # TODO??
min_valid2, max_valid2, matches = verify_call(sig, callees, j, 1, world)
j += 1
elseif edge isa Int
sig = callees[j+1]
min_valid2, max_valid2, matches = verify_call(sig, callees, j+2, edge, world)
j += 2 + edge
edge = sig
else
callee = callees[j+1]
if callee isa Core.MethodTable # skip the legacy edge (missing backedge)
j += 2
continue
end
if callee isa CodeInstance
callee = get_ci_mi(callee)
end
if callee isa MethodInstance
meth = callee.def::Method
else
meth = callee::Method
end
min_valid2, max_valid2 = verify_invokesig(edge, meth, world)
matches = nothing
j += 2
end
if minworld < min_valid2
minworld = min_valid2
end
if maxworld > max_valid2
maxworld = max_valid2
end
invalidations = _jl_debug_method_invalidation[]
if max_valid2 typemax(UInt) && invalidations !== nothing
push!(invalidations, edge, "insert_backedges_callee", codeinst, matches)
end
if max_valid2 == 0 && invalidations === nothing
break
end
end
end
# verify recursive edges (if valid, or debugging)
cycle = depth
cause = codeinst
if maxworld 0 || _jl_debug_method_invalidation[] !== nothing
for j = 1:length(callees)
edge = callees[j]
if !(edge isa CodeInstance)
continue
end
callee = edge
local min_valid2::UInt, max_valid2::UInt
child_cycle, min_valid2, max_valid2 = verify_method(callee, stack, visiting)
if minworld < min_valid2
minworld = min_valid2
end
if minworld > max_valid2
max_valid2 = 0
end
if maxworld > max_valid2
cause = callee
maxworld = max_valid2
end
if max_valid2 == 0
# found what we were looking for, so terminate early
break
elseif child_cycle 0 && child_cycle < cycle
# record the cycle will resolve at depth "cycle"
cycle = child_cycle
end
end
end
if maxworld 0 && cycle depth
return cycle, minworld, maxworld
end
# If we are the top of the current cycle, now mark all other parts of
# our cycle with what we found.
# Or if we found a failed edge, also mark all of the other parts of the
# cycle as also having a failed edge.
while length(stack) depth
child = pop!(stack)
if Threads.maxthreadid() == 1 # a different thread might simultaneously come to a different, but equally valid, alternative result
@assert child.max_world == WORLD_AGE_REVALIDATION_SENTINEL
@assert minworld child.min_world
end
if maxworld 0
@atomic child.min_world = minworld
end
@atomic child.max_world = maxworld
@assert visiting[child] == length(stack) + 1
delete!(visiting, child)
invalidations = _jl_debug_method_invalidation[]
if invalidations !== nothing && maxworld < current_world
push!(invalidations, child, "verify_methods", cause)
end
end
return 0, minworld, maxworld
end

function verify_call(@nospecialize(sig), expecteds::Core.SimpleVector, i::Int, n::Int, world::UInt)
# verify that these edges intersect with the same methods as before
lim = _jl_debug_method_invalidation[] !== nothing ? Int(typemax(Int32)) : n
minworld = Ref{UInt}(1)
maxworld = Ref{UInt}(typemax(UInt))
has_ambig = Ref{Int32}(0)
result = Base._methods_by_ftype(sig, nothing, lim, world, #=ambig=#false, minworld, maxworld, has_ambig)
if result === nothing
maxworld[] = 0
else
# setdiff!(result, expected)
if length(result) n
maxworld[] = 0
end
ins = 0
for k = 1:length(result)
match = result[k]::Core.MethodMatch
local found = false
for j = 1:n
t = expecteds[i+j-1]
if t isa Method
meth = t
else
if t isa CodeInstance
t = get_ci_mi(t)
else
t = t::MethodInstance
end
meth = t.def::Method
end
if match.method == meth
found = true
break
end
end
if !found
# intersection has a new method or a method was
# deleted--this is now probably no good, just invalidate
# everything about it now
maxworld[] = 0
if _jl_debug_method_invalidation[] === nothing
break
end
ins += 1
result[ins] = match.method
end
end
if maxworld[] typemax(UInt) && _jl_debug_method_invalidation[] !== nothing
resize!(result, ins)
end
end
return minworld[], maxworld[], result
end

function verify_invokesig(@nospecialize(invokesig), expected::Method, world::UInt)
@assert invokesig isa Type
local minworld::UInt, maxworld::UInt
if invokesig === expected.sig
# the invoke match is `expected` for `expected->sig`, unless `expected` is invalid
minworld = expected.primary_world
maxworld = expected.deleted_world
@assert minworld world
if maxworld < world
maxworld = 0
end
else
minworld = 1
maxworld = typemax(UInt)
mt = Base.get_methodtable(expected)
if mt === nothing
maxworld = 0
else
matched, valid_worlds = Base.Compiler._findsup(invokesig, mt, world)
minworld, maxworld = valid_worlds.min_world, valid_worlds.max_world
if matched === nothing
maxworld = 0
elseif matched.method != expected
maxworld = 0
end
end
end
return minworld, maxworld
end

end # module StaticData
Loading

0 comments on commit 0b2575d

Please sign in to comment.