diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index 945951b216..dc33c11110 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -171,7 +171,7 @@ end function has_frule_from_sig(@nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) + caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool ft, tt = _annotate_tt(TT) TT = Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} return isapplicable(forward, TT; world, method_table, caller) @@ -180,7 +180,7 @@ end function has_rrule_from_sig(@nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) + caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool ft, tt = _annotate_tt(TT) TT = Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} return isapplicable(augmented_primal, TT; world, method_table, caller) @@ -192,7 +192,7 @@ end function isapplicable(@nospecialize(f), @nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) + caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool tt = Base.to_tuple_type(TT) sig = Base.signature_type(f, tt) mt = ccall(:jl_method_table_for, Any, (Any,), sig) @@ -211,12 +211,6 @@ function isapplicable(@nospecialize(f), @nospecialize(TT); if !fullmatch if caller isa Core.MethodInstance add_mt_backedge!(caller, mt, sig) - elseif caller isa Core.Compiler.MethodLookupResult - for j = 1:Core.Compiler.length(caller) - cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch - cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance - add_mt_backedge!(cspec, mt, sig) - end end end if Core.Compiler.isempty(matches) @@ -228,16 +222,6 @@ function isapplicable(@nospecialize(f), @nospecialize(TT); edge = Core.Compiler.specialize_method(match)::Core.MethodInstance add_backedge!(caller, edge, sig) end - elseif caller isa Core.Compiler.MethodLookupResult - for j = 1:Core.Compiler.length(caller) - cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch - cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance - for i = 1:Core.Compiler.length(matches) - match = Core.Compiler.getindex(matches, i)::Core.MethodMatch - edge = Core.Compiler.specialize_method(match)::Core.MethodInstance - add_backedge!(cspec, edge, sig) - end - end end return true end diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index ff4c6c991b..7afb49c071 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -192,6 +192,22 @@ Core.Compiler.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getresult(info.info, idx) +function annotate(@nospecialize(T)) + T = widenconst(T) + if Core.Compiler.isvarargtype(T) + VA = T + T = annotate(Core.Compiler.unwrapva(VA)) + if isdefined(VA, :N) + return Vararg{T, VA.N} + else + return Vararg{T} + end + else + return Annotation{T} + end +end + +import .EnzymeRules: FwdConfig, RevConfig, Annotation using Core.Compiler: ArgInfo, StmtInfo, AbsIntState function Core.Compiler.abstract_call_gf_by_type( @nospecialize(interp::EnzymeInterpreter), @@ -212,30 +228,37 @@ function Core.Compiler.abstract_call_gf_by_type( max_methods::Int, ) callinfo = ret.info - method_table = Core.Compiler.method_table(interp) specTypes = simplify_kw(atype) - caller = if callinfo isa Core.Compiler.MethodMatchInfo && callinfo.results isa Core.Compiler.MethodLookupResult - callinfo.results - else - nothing - end if is_primitive_func(specTypes) callinfo = NoInlineCallInfo(callinfo, atype, :primitive) elseif is_alwaysinline_func(specTypes) callinfo = AlwaysInlineCallInfo(callinfo, atype) - elseif EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table, caller) - callinfo = NoInlineCallInfo(callinfo, atype, :inactive) else - if interp.forward_rules - if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table, caller) - callinfo = NoInlineCallInfo(callinfo, atype, :frule) - end - end - - if interp.reverse_rules - if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table, caller) - callinfo = NoInlineCallInfo(callinfo, atype, :rrule) + (;fargs, argtypes) = arginfo + # 1. Check if function is inactive + inactive_argtypes = Any[Core.Const(Core.applicable), Core.Const(EnzymeRules.inactive)] + append!(inactive_argtypes, argtypes) + + inactive_meta = Core.Compiler.abstract_applicable(interp, inactive_argtypes, sv, #=max_methods=# -1) # Does backedge handling internally + if inactive_meta.rt != Core.Const(false) # It may be Const(true), Const(false), Bool + callinfo = NoInlineCallInfo(callinfo, atype, :inactive) + else + # 2. Check if rule is defined + tt = Core.Compiler.anymap(annotate, argtypes) + if interp.forward_rules + rulef = EnzymeRules.forward + rule_argtypes = Any[Core.Const(Core.applicable), Core.Const(EnzymeRules.forward), FwdConfig, tt[1], Type{<:Annotation}, tt[2:end]...] + else + rulef = EnzymeRules.reverse + rule_argtypes = Any[Core.Const(Core.applicable), Core.Const(EnzymeRules.reverse), RevConfig, tt[1], Type{<:Annotation}, tt[2:end]...] + end + Base.@show rule_argtypes + Base.@show Core.Compiler.argtypes_to_type(rule_argtypes) + rule_meta = Core.Compiler.abstract_applicable(interp, rule_argtypes, sv, #=max_methods=# -1) # Does backedge handling internally + @show rule_meta.rt + if rule_meta.rt != Core.Const(false) # It may be Const(true), Const(false), Bool + callinfo = NoInlineCallInfo(callinfo, atype, interp.forward_rules ? :frule : :rrule) end end end diff --git a/test/ruleinvalidation.jl b/test/ruleinvalidation.jl index 704ada2b6e..37cb21b08f 100644 --- a/test/ruleinvalidation.jl +++ b/test/ruleinvalidation.jl @@ -34,18 +34,14 @@ for m in methods(forward, Tuple{Any,Const{typeof(issue696)},Vararg{Any}}) end @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 @static if VERSION < v"1.11-" -@test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 + @test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 else -@test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 + @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 end # now test invalidation for `inactive` inactive(::typeof(issue696), args...) = nothing @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 -@static if VERSION < v"1.11-" -@test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 -else @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 -end end # module