Skip to content

Commit

Permalink
compiler: Strengthen some assertions and fix a couple small bugs (#56449
Browse files Browse the repository at this point in the history
)
  • Loading branch information
Keno authored Nov 5, 2024
1 parent 8e0a171 commit 567a7ca
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 39 deletions.
25 changes: 11 additions & 14 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3315,16 +3315,13 @@ function abstract_eval_binding_partition!(interp::AbstractInterpreter, g::Global
end

function abstract_eval_partition_load(interp::AbstractInterpreter, partition::Core.BindingPartition)
consistent = inaccessiblememonly = ALWAYS_FALSE
nothrow = false
generic_effects = Effects(EFFECTS_TOTAL; consistent, nothrow, inaccessiblememonly)
if is_some_guard(binding_kind(partition))
if InferenceParams(interp).assume_bindings_static
return RTEffects(Union{}, UndefVarError, EFFECTS_THROWS)
else
# We do not currently assume an invalidation for guard -> defined transitions
# return RTEffects(Union{}, UndefVarError, EFFECTS_THROWS)
return RTEffects(Any, UndefVarError, generic_effects)
return RTEffects(Any, UndefVarError, generic_getglobal_effects)
end
end

Expand All @@ -3335,20 +3332,20 @@ function abstract_eval_partition_load(interp::AbstractInterpreter, partition::Co

rt = partition_restriction(partition)

if InferenceParams(interp).assume_bindings_static
return RTEffects(rt, UndefVarError, generic_getglobal_effects)
end

function abstract_eval_globalref(interp::AbstractInterpreter, g::GlobalRef, sv::AbsIntState)
partition = abstract_eval_binding_partition!(interp, g, sv)
ret = abstract_eval_partition_load(interp, partition)
if ret.rt !== Union{} && ret.exct === UndefVarError && InferenceParams(interp).assume_bindings_static
if isdefined(g, :binding) && isdefined(g.binding, :value)
return RTEffects(rt, Union{}, Effecst(generic_effects, nothrow=true))
return RTEffects(ret.rt, Union{}, Effects(generic_getglobal_effects, nothrow=true))
end
# We do not assume in general that assigned global bindings remain assigned.
# The existence of pkgimages allows them to revert in practice.
end

return RTEffects(rt, UndefVarError, generic_effects)
end

function abstract_eval_globalref(interp::AbstractInterpreter, g::GlobalRef, sv::AbsIntState)
partition = abstract_eval_binding_partition!(interp, g, sv)
return abstract_eval_partition_load(interp, partition)
return ret
end

function global_assignment_exct(interp::AbstractInterpreter, sv::AbsIntState, g::GlobalRef, @nospecialize(newty))
Expand Down Expand Up @@ -4045,7 +4042,6 @@ function typeinf(interp::AbstractInterpreter, frame::InferenceState)
takeprev = 0
while takenext >= frame.frameid
callee = takenext == 0 ? frame : callstack[takenext]::InferenceState
interp = callee.interp
if !isempty(callstack)
if length(callstack) - frame.frameid >= minwarn
topmethod = callstack[1].linfo
Expand All @@ -4059,6 +4055,7 @@ function typeinf(interp::AbstractInterpreter, frame::InferenceState)
takenext = length(callstack)
end
end
interp = callee.interp
nextstateid = takenext + 1 - frame.frameid
while length(nextstates) < nextstateid
push!(nextstates, CurrentState())
Expand Down
40 changes: 26 additions & 14 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,24 @@ struct HandlerInfo
handler_at::Vector{Tuple{Int,Int}} # tuple of current (handler, exception stack) value at the pc
end

struct WorldWithRange
this::UInt
valid_worlds::WorldRange
function WorldWithRange(world::UInt, valid_worlds::WorldRange)
if !(world in valid_worlds)
error("invalid age range update")
end
return new(world, valid_worlds)
end
end

intersect(world::WorldWithRange, valid_worlds::WorldRange) =
WorldWithRange(world.this, intersect(world.valid_worlds, valid_worlds))

mutable struct InferenceState
#= information about this method instance =#
linfo::MethodInstance
world::UInt
world::WorldWithRange
mod::Module
sptypes::Vector{VarState}
slottypes::Vector{Any}
Expand Down Expand Up @@ -265,7 +279,6 @@ mutable struct InferenceState
#= results =#
result::InferenceResult # remember where to put the result
unreachable::BitSet # statements that were found to be statically unreachable
valid_worlds::WorldRange
bestguess #::Type
exc_bestguess
ipo_effects::Effects
Expand Down Expand Up @@ -353,10 +366,10 @@ mutable struct InferenceState
parentid = frameid = cycleid = 0

this = new(
mi, world, mod, sptypes, slottypes, src, cfg, spec_info,
mi, WorldWithRange(world, valid_worlds), mod, sptypes, slottypes, src, cfg, spec_info,
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes, edges, stmt_info,
tasks, pclimitations, limitations, cycle_backedges, callstack, parentid, frameid, cycleid,
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
result, unreachable, bestguess, exc_bestguess, ipo_effects,
restrict_abstract_call_sites, cache_mode, insert_coverage,
interp)

Expand All @@ -372,7 +385,7 @@ mutable struct InferenceState
# Apply generated function restrictions
if src.min_world != 1 || src.max_world != typemax(UInt)
# From generated functions
this.valid_worlds = WorldRange(src.min_world, src.max_world)
update_valid_age!(this, WorldRange(src.min_world, src.max_world))
end

return this
Expand Down Expand Up @@ -772,14 +785,13 @@ mutable struct IRInterpretationState
const spec_info::SpecInfo
const ir::IRCode
const mi::MethodInstance
const world::UInt
world::WorldWithRange
curridx::Int
const argtypes_refined::Vector{Bool}
const sptypes::Vector{VarState}
const tpdum::TwoPhaseDefUseMap
const ssa_refined::BitSet
const lazyreachability::LazyCFGReachability
valid_worlds::WorldRange
const tasks::Vector{WorkThunk}
const edges::Vector{Any}
callstack #::Vector{AbsIntState}
Expand Down Expand Up @@ -809,8 +821,8 @@ mutable struct IRInterpretationState
tasks = WorkThunk[]
edges = Any[]
callstack = AbsIntState[]
return new(spec_info, ir, mi, world, curridx, argtypes_refined, ir.sptypes, tpdum,
ssa_refined, lazyreachability, valid_worlds, tasks, edges, callstack, 0, 0)
return new(spec_info, ir, mi, WorldWithRange(world, valid_worlds), curridx, argtypes_refined, ir.sptypes, tpdum,
ssa_refined, lazyreachability, tasks, edges, callstack, 0, 0)
end
end

Expand Down Expand Up @@ -910,8 +922,8 @@ spec_info(sv::IRInterpretationState) = sv.spec_info
propagate_inbounds(sv::AbsIntState) = spec_info(sv).propagate_inbounds
method_for_inference_limit_heuristics(sv::AbsIntState) = spec_info(sv).method_for_inference_limit_heuristics

frame_world(sv::InferenceState) = sv.world
frame_world(sv::IRInterpretationState) = sv.world
frame_world(sv::InferenceState) = sv.world.this
frame_world(sv::IRInterpretationState) = sv.world.this

function is_effect_overridden(sv::AbsIntState, effect::Symbol)
if is_effect_overridden(frame_instance(sv), effect)
Expand All @@ -933,9 +945,8 @@ has_conditional(::AbstractLattice, ::IRInterpretationState) = false

# work towards converging the valid age range for sv
function update_valid_age!(sv::AbsIntState, valid_worlds::WorldRange)
valid_worlds = sv.valid_worlds = intersect(valid_worlds, sv.valid_worlds)
@assert sv.world in valid_worlds "invalid age range update"
return valid_worlds
sv.world = intersect(sv.world, valid_worlds)
return sv.world.valid_worlds
end

"""
Expand Down Expand Up @@ -1131,6 +1142,7 @@ function Future{T}(f, prev::Future{S}, interp::AbstractInterpreter, sv::AbsIntSt
else
@assert Core._hasmethod(Tuple{Core.Typeof(f), S, typeof(interp), typeof(sv)})
result = Future{T}()
@assert !isa(sv, InferenceState) || interp === sv.interp
push!(sv.tasks, function (interp, sv)
result[] = f(later[], interp, sv) # capture just later, instead of all of prev
return true
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ struct InliningState{Interp<:AbstractInterpreter}
interp::Interp
end
function InliningState(sv::InferenceState, interp::AbstractInterpreter)
return InliningState(sv.edges, sv.world, interp)
return InliningState(sv.edges, frame_world(sv), interp)
end
function InliningState(interp::AbstractInterpreter)
return InliningState(Any[], get_inference_world(interp), interp)
Expand Down
20 changes: 10 additions & 10 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ end
function _typeinf_identifier(frame::Core.Compiler.InferenceState)
mi_info = InferenceFrameInfo(
frame.linfo,
frame.world,
frame_world(sv),
copy(frame.sptypes),
copy(frame.slottypes),
length(frame.result.argtypes),
Expand Down Expand Up @@ -173,7 +173,7 @@ function finish_cycle(::AbstractInterpreter, frames::Vector{AbsIntState}, cyclei
# all frames in the cycle should have the same bits of `valid_worlds` and `effects`
# that are simply the intersection of each partial computation, without having
# dependencies on each other (unlike rt and exct)
cycle_valid_worlds = intersect(cycle_valid_worlds, caller.valid_worlds)
cycle_valid_worlds = intersect(cycle_valid_worlds, caller.world.valid_worlds)
cycle_valid_effects = merge_effects(cycle_valid_effects, caller.ipo_effects)
end
for frameid = cycleid:length(frames)
Expand All @@ -197,7 +197,7 @@ function finish_cycle(::AbstractInterpreter, frames::Vector{AbsIntState}, cyclei
end

function adjust_cycle_frame!(sv::InferenceState, cycle_valid_worlds::WorldRange, cycle_valid_effects::Effects)
sv.valid_worlds = cycle_valid_worlds
update_valid_age!(sv, cycle_valid_worlds)
sv.ipo_effects = cycle_valid_effects
# traverse the callees of this cycle that are tracked within `sv.cycle_backedges`
# and adjust their statements so that they are consistent with the new `cycle_valid_effects`
Expand Down Expand Up @@ -403,13 +403,13 @@ function finishinfer!(me::InferenceState, interp::AbstractInterpreter)
end
end
result = me.result
result.valid_worlds = me.valid_worlds
result.valid_worlds = me.world.valid_worlds
result.result = bestguess
ipo_effects = result.ipo_effects = me.ipo_effects = adjust_effects(me)
result.exc_result = me.exc_bestguess = refine_exception_type(me.exc_bestguess, ipo_effects)
me.src.rettype = widenconst(ignorelimited(bestguess))
me.src.min_world = first(me.valid_worlds)
me.src.max_world = last(me.valid_worlds)
me.src.min_world = first(me.world.valid_worlds)
me.src.max_world = last(me.world.valid_worlds)
istoplevel = !(me.linfo.def isa Method)
istoplevel || compute_edges!(me) # don't add backedges to toplevel method instance

Expand Down Expand Up @@ -637,7 +637,7 @@ function merge_call_chain!(::AbstractInterpreter, parent::InferenceState, child:
end

function add_cycle_backedge!(caller::InferenceState, frame::InferenceState)
update_valid_age!(caller, frame.valid_worlds)
update_valid_age!(caller, frame.world.valid_worlds)
backedge = (caller, caller.currpc)
contains_is(frame.cycle_backedges, backedge) || push!(frame.cycle_backedges, backedge)
return frame
Expand Down Expand Up @@ -730,7 +730,7 @@ end
function codeinst_as_edge(interp::AbstractInterpreter, sv::InferenceState)
mi = sv.linfo
owner = cache_owner(interp)
min_world, max_world = first(sv.valid_worlds), last(sv.valid_worlds)
min_world, max_world = first(sv.world.valid_worlds), last(sv.world.valid_worlds)
if max_world >= get_world_counter()
max_world = typemax(UInt)
end
Expand Down Expand Up @@ -816,7 +816,7 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
# while splitting off the rest of the work for this caller into a separate workq thunk
let mresult = Future{MethodCallResult}()
push!(caller.tasks, function get_infer_result(interp, caller)
update_valid_age!(caller, frame.valid_worlds)
update_valid_age!(caller, frame.world.valid_worlds)
local isinferred = is_inferred(frame)
local edge = isinferred ? edge_ci : nothing
local effects = isinferred ? frame.result.ipo_effects : # effects are adjusted already within `finish` for ipo_effects
Expand All @@ -842,7 +842,7 @@ function typeinf_edge(interp::AbstractInterpreter, method::Method, @nospecialize
end
# return the current knowledge about this cycle
frame = frame::InferenceState
update_valid_age!(caller, frame.valid_worlds)
update_valid_age!(caller, frame.world.valid_worlds)
effects = adjust_effects(effects_for_cycle(frame.ipo_effects), method)
bestguess = frame.bestguess
exc_bestguess = refine_exception_type(frame.exc_bestguess, effects)
Expand Down

0 comments on commit 567a7ca

Please sign in to comment.