Skip to content

Commit

Permalink
irinterp: improve semi-concrete interpretation accuracy
Browse files Browse the repository at this point in the history
By enforcing re-inference on calls with all constant arguments.

While it's debatable whether this approach is the most efficient, it was
the easiest choice given that `used_ssas` based on `IncrementaCompact`
wasn't an option for irinterp.

- fixes #52202
- fixes #50037
  • Loading branch information
aviatesk committed Nov 23, 2023
1 parent 10d58eb commit 6680dd3
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 2 deletions.
6 changes: 4 additions & 2 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -903,12 +903,14 @@ end
is_all_const_arg(arginfo::ArgInfo, start::Int) = is_all_const_arg(arginfo.argtypes, start::Int)
function is_all_const_arg(argtypes::Vector{Any}, start::Int)
for i = start:length(argtypes)
a = widenslotwrapper(argtypes[i])
isa(a, Const) || isconstType(a) || issingletontype(a) || return false
argtype = widenslotwrapper(argtypes[i])
is_const_argtype(argtype) || return false
end
return true
end

is_const_argtype(@nospecialize argtype) = isa(argtype, Const) || isconstType(argtype) || issingletontype(argtype)

any_conditional(argtypes::Vector{Any}) = any(@nospecialize(x)->isa(x, Conditional), argtypes)
any_conditional(arginfo::ArgInfo) = any_conditional(arginfo.argtypes)

Expand Down
12 changes: 12 additions & 0 deletions base/compiler/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,15 @@ end
populate_def_use_map!(tpdum::TwoPhaseDefUseMap, ir::IRCode) =
populate_def_use_map!(tpdum, BBScanner(ir))

function is_all_const_call(@nospecialize(stmt), interp::AbstractInterpreter, irsv::IRInterpretationState)
isexpr(stmt, :call) || return false
@inbounds for i = 2:length(stmt.args)
argtype = abstract_eval_value(interp, stmt.args[i], nothing, irsv)
is_const_argtype(argtype) || return false
end
return true
end

function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IRInterpretationState;
externally_refined::Union{Nothing,BitSet} = nothing)
(; ir, tpdum, ssa_refined) = irsv
Expand All @@ -302,6 +311,9 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
if has_flag(flag, IR_FLAG_REFINED)
any_refined = true
sub_flag!(inst, IR_FLAG_REFINED)
elseif is_all_const_call(stmt, interp, irsv)
# force reinference on calls with all constant arguments
any_refined = true
end
for ur in userefs(stmt)
val = ur[]
Expand Down
4 changes: 4 additions & 0 deletions stdlib/LinearAlgebra/test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ mul_wrappers = [
@test @inferred(f(A)) === A
g(A) = LinearAlgebra.wrap(A, 'T')
@test @inferred(g(A)) === transpose(A)
# https://github.com/JuliaLang/julia/issues/52202
@test Base.infer_return_type((Vector{Float64},)) do v
LinearAlgebra.wrap(v, 'N')
end == Vector{Float64}
end

@testset "matrices with zero dimensions" begin
Expand Down
7 changes: 7 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5593,3 +5593,10 @@ end |> only === Float64
@test Base.infer_exception_type(c::Bool -> c ? 1 : 2) == Union{}
@test Base.infer_exception_type(c::Missing -> c ? 1 : 2) == TypeError
@test Base.infer_exception_type(c::Any -> c ? 1 : 2) == TypeError

# semi-concrete interpretation accuracy
# https://github.com/JuliaLang/julia/issues/50037
@inline countvars50037(bitflags::Int, var::Int) = bitflags >> 0
@test Base.infer_return_type() do var::Int
Val(countvars50037(1, var))
end == Val{1}

0 comments on commit 6680dd3

Please sign in to comment.