From 873a1e41a2dd84276b6ef6a3a47c452850808c50 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Date: Fri, 22 Nov 2024 01:26:08 +0900 Subject: [PATCH] inference: add missing modeling for `swapglobal!` (#56623) This was missed from JuliaLang/julia#56299. --- Compiler/src/abstractinterpretation.jl | 51 ++++++++++++++++++++------ Compiler/test/inference.jl | 19 ++++++++++ 2 files changed, 59 insertions(+), 11 deletions(-) diff --git a/Compiler/src/abstractinterpretation.jl b/Compiler/src/abstractinterpretation.jl index 64181f685e665..a3abbf814165a 100644 --- a/Compiler/src/abstractinterpretation.jl +++ b/Compiler/src/abstractinterpretation.jl @@ -2329,13 +2329,13 @@ function abstract_throw_methoderror(interp::AbstractInterpreter, argtypes::Vecto elseif !isvarargtype(argtypes[2]) MethodError else - ⊔ = join(typeinf_lattice(interp)) - MethodError ⊔ ArgumentError + Union{MethodError, ArgumentError} end return Future(CallMeta(Union{}, exct, EFFECTS_THROWS, NoCallInfo())) end const generic_getglobal_effects = Effects(EFFECTS_THROWS, consistent=ALWAYS_FALSE, inaccessiblememonly=ALWAYS_FALSE) +const generic_getglobal_exct = Union{ArgumentError, TypeError, ConcurrencyViolationError, UndefVarError} function abstract_eval_getglobal(interp::AbstractInterpreter, sv::AbsIntState, saw_latestworld::Bool, @nospecialize(M), @nospecialize(s)) ⊑ = partialorder(typeinf_lattice(interp)) if M isa Const && s isa Const @@ -2373,8 +2373,7 @@ function abstract_eval_getglobal(interp::AbstractInterpreter, sv::AbsIntState, s elseif !isvarargtype(argtypes[end]) || length(argtypes) > 5 return CallMeta(Union{}, ArgumentError, EFFECTS_THROWS, NoCallInfo()) else - return CallMeta(Any, Union{ArgumentError, UndefVarError, TypeError, ConcurrencyViolationError}, - generic_getglobal_effects, NoCallInfo()) + return CallMeta(Any, generic_getglobal_exct, generic_getglobal_effects, NoCallInfo()) end end @@ -2440,6 +2439,8 @@ function abstract_eval_setglobal!(interp::AbstractInterpreter, sv::AbsIntState, return merge_exct(cm, goe) end +const generic_setglobal!_exct = Union{ArgumentError, TypeError, ErrorException, ConcurrencyViolationError} + function abstract_eval_setglobal!(interp::AbstractInterpreter, sv::AbsIntState, saw_latestworld::Bool, argtypes::Vector{Any}) if length(argtypes) == 4 return abstract_eval_setglobal!(interp, sv, saw_latestworld, argtypes[2], argtypes[3], argtypes[4]) @@ -2448,7 +2449,35 @@ function abstract_eval_setglobal!(interp::AbstractInterpreter, sv::AbsIntState, elseif !isvarargtype(argtypes[end]) || length(argtypes) > 6 return CallMeta(Union{}, ArgumentError, EFFECTS_THROWS, NoCallInfo()) else - return CallMeta(Any, Union{ArgumentError, TypeError, ErrorException, ConcurrencyViolationError}, setglobal!_effects, NoCallInfo()) + return CallMeta(Any, generic_setglobal!_exct, setglobal!_effects, NoCallInfo()) + end +end + +function abstract_eval_swapglobal!(interp::AbstractInterpreter, sv::AbsIntState, saw_latestworld::Bool, + @nospecialize(M), @nospecialize(s), @nospecialize(v)) + scm = abstract_eval_setglobal!(interp, sv, saw_latestworld, M, s, v) + scm.rt === Bottom && return scm + gcm = abstract_eval_getglobal(interp, sv, saw_latestworld, M, s) + return CallMeta(gcm.rt, Union{scm.exct,gcm.exct}, merge_effects(scm.effects, gcm.effects), NoCallInfo()) +end + +function abstract_eval_swapglobal!(interp::AbstractInterpreter, sv::AbsIntState, saw_latestworld::Bool, + @nospecialize(M), @nospecialize(s), @nospecialize(v), @nospecialize(order)) + scm = abstract_eval_setglobal!(interp, sv, saw_latestworld, M, s, v, order) + scm.rt === Bottom && return scm + gcm = abstract_eval_getglobal(interp, sv, saw_latestworld, M, s, order) + return CallMeta(gcm.rt, Union{scm.exct,gcm.exct}, merge_effects(scm.effects, gcm.effects), NoCallInfo()) +end + +function abstract_eval_swapglobal!(interp::AbstractInterpreter, sv::AbsIntState, saw_latestworld::Bool, argtypes::Vector{Any}) + if length(argtypes) == 4 + return abstract_eval_swapglobal!(interp, sv, saw_latestworld, argtypes[2], argtypes[3], argtypes[4]) + elseif length(argtypes) == 5 + return abstract_eval_swapglobal!(interp, sv, saw_latestworld, argtypes[2], argtypes[3], argtypes[4], argtypes[5]) + elseif !isvarargtype(argtypes[end]) || length(argtypes) > 6 + return CallMeta(Union{}, ArgumentError, EFFECTS_THROWS, NoCallInfo()) + else + return CallMeta(Any, Union{generic_getglobal_exct,generic_setglobal!_exct}, setglobal!_effects, NoCallInfo()) end end @@ -2467,20 +2496,18 @@ function abstract_eval_setglobalonce!(interp::AbstractInterpreter, sv::AbsIntSta elseif !isvarargtype(argtypes[end]) || length(argtypes) > 6 return CallMeta(Union{}, ArgumentError, EFFECTS_THROWS, NoCallInfo()) else - return CallMeta(Bool, Union{ArgumentError, TypeError, ErrorException, ConcurrencyViolationError}, setglobal!_effects, NoCallInfo()) + return CallMeta(Bool, generic_setglobal!_exct, setglobal!_effects, NoCallInfo()) end end function abstract_eval_replaceglobal!(interp::AbstractInterpreter, sv::AbsIntState, saw_latestworld::Bool, argtypes::Vector{Any}) if length(argtypes) in (5, 6, 7) (M, s, x, v) = argtypes[2], argtypes[3], argtypes[4], argtypes[5] - T = nothing if isa(M, Const) && isa(s, Const) M, s = M.val, s.val - if !(M isa Module && s isa Symbol) - return CallMeta(Union{}, TypeError, EFFECTS_THROWS, NoCallInfo()) - end + M isa Module || return CallMeta(Union{}, TypeError, EFFECTS_THROWS, NoCallInfo()) + s isa Symbol || return CallMeta(Union{}, TypeError, EFFECTS_THROWS, NoCallInfo()) partition = abstract_eval_binding_partition!(interp, GlobalRef(M, s), sv) rte = abstract_eval_partition_load(interp, partition) if binding_kind(partition) == BINDING_KIND_GLOBAL @@ -2507,7 +2534,7 @@ function abstract_eval_replaceglobal!(interp::AbstractInterpreter, sv::AbsIntSta elseif !isvarargtype(argtypes[end]) || length(argtypes) > 8 return CallMeta(Union{}, ArgumentError, EFFECTS_THROWS, NoCallInfo()) else - return CallMeta(Any, Union{ArgumentError, TypeError, ErrorException, ConcurrencyViolationError}, setglobal!_effects, NoCallInfo()) + return CallMeta(Any, Union{generic_getglobal_exct,generic_setglobal!_exct}, setglobal!_effects, NoCallInfo()) end end @@ -2547,6 +2574,8 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f), return Future(abstract_eval_getglobal(interp, sv, si.saw_latestworld, argtypes)) elseif f === Core.setglobal! return Future(abstract_eval_setglobal!(interp, sv, si.saw_latestworld, argtypes)) + elseif f === Core.swapglobal! + return Future(abstract_eval_swapglobal!(interp, sv, si.saw_latestworld, argtypes)) elseif f === Core.setglobalonce! return Future(abstract_eval_setglobalonce!(interp, sv, si.saw_latestworld, argtypes)) elseif f === Core.replaceglobal! diff --git a/Compiler/test/inference.jl b/Compiler/test/inference.jl index b8c869d737510..560b9da02e643 100644 --- a/Compiler/test/inference.jl +++ b/Compiler/test/inference.jl @@ -6097,3 +6097,22 @@ global setglobal!_must_throw::Int = 42 @test Base.infer_return_type((String,)) do x setglobal!(@__MODULE__, :setglobal!_must_throw, x) end === Union{} + +global swapglobal!_xxx::Int = 42 +@test Base.infer_return_type((Int,)) do x + swapglobal!(@__MODULE__, :swapglobal!_xxx, x) +end === Int +@test Base.infer_return_type((String,)) do x + swapglobal!(@__MODULE__, :swapglobal!_xxx, x) +end === Union{} + +global swapglobal!_must_throw +@newinterp SwapGlobalInterp +let CC = Base.Compiler + CC.InferenceParams(::SwapGlobalInterp) = CC.InferenceParams(; assume_bindings_static=true) +end +function func_swapglobal!_must_throw(x) + swapglobal!(@__MODULE__, :swapglobal!_must_throw, x) +end +@test Base.infer_return_type(func_swapglobal!_must_throw, (Int,); interp=SwapGlobalInterp()) === Union{} +@test !Base.Compiler.is_effect_free(Base.infer_effects(func_swapglobal!_must_throw, (Int,); interp=SwapGlobalInterp()) )