diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index 945951b216..f34fd1de13 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -171,7 +171,7 @@ end function has_frule_from_sig(@nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) + caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool ft, tt = _annotate_tt(TT) TT = Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} return isapplicable(forward, TT; world, method_table, caller) @@ -180,7 +180,7 @@ end function has_rrule_from_sig(@nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) + caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool ft, tt = _annotate_tt(TT) TT = Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} return isapplicable(augmented_primal, TT; world, method_table, caller) @@ -192,7 +192,7 @@ end function isapplicable(@nospecialize(f), @nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) + caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool tt = Base.to_tuple_type(TT) sig = Base.signature_type(f, tt) mt = ccall(:jl_method_table_for, Any, (Any,), sig) @@ -211,12 +211,6 @@ function isapplicable(@nospecialize(f), @nospecialize(TT); if !fullmatch if caller isa Core.MethodInstance add_mt_backedge!(caller, mt, sig) - elseif caller isa Core.Compiler.MethodLookupResult - for j = 1:Core.Compiler.length(caller) - cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch - cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance - add_mt_backedge!(cspec, mt, sig) - end end end if Core.Compiler.isempty(matches) @@ -228,16 +222,54 @@ function isapplicable(@nospecialize(f), @nospecialize(TT); edge = Core.Compiler.specialize_method(match)::Core.MethodInstance add_backedge!(caller, edge, sig) end - elseif caller isa Core.Compiler.MethodLookupResult - for j = 1:Core.Compiler.length(caller) - cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch - cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance - for i = 1:Core.Compiler.length(matches) - match = Core.Compiler.getindex(matches, i)::Core.MethodMatch - edge = Core.Compiler.specialize_method(match)::Core.MethodInstance - add_backedge!(cspec, edge, sig) - end - end + end + return true + end +end + +function has_frule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), + @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + ft, tt = _annotate_tt(TT) + TT = Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} + return isapplicable(interp, forward, TT, sv) +end + +function has_rrule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), + @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + ft, tt = _annotate_tt(TT) + TT = Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} + return isapplicable(interp, augmented_primal, TT, sv) +end + +# `hasmethod` is a precise match using `Core.Compiler.findsup`, +# but here we want the broader query using `Core.Compiler.findall`. +# Also add appropriate backedges to the caller `MethodInstance` if given. +function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter), + @nospecialize(f), @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + tt = Base.to_tuple_type(TT) + sig = Base.signature_type(f, tt) + mt = ccall(:jl_method_table_for, Any, (Any,), sig) + mt isa Core.MethodTable || return false + result = Core.Compiler.findall(sig, Core.Compiler.method_table(interp); limit=-1) + (result === nothing || result === missing) && return false + @static if isdefined(Core.Compiler, :MethodMatchResult) + (; matches) = result + else + matches = result + end + # also need an edge to the method table in case something gets + # added that did not intersect with any existing method + fullmatch = Core.Compiler._any(match::Core.MethodMatch->match.fully_covers, matches) + if !fullmatch + Core.Compiler.add_mt_backedge!(sv, mt, sig) + end + if Core.Compiler.isempty(matches) + return false + else + for i = 1:Core.Compiler.length(matches) + match = Core.Compiler.getindex(matches, i)::Core.MethodMatch + edge = Core.Compiler.specialize_method(match)::Core.MethodInstance + Core.Compiler.add_backedge!(sv, edge) end return true end @@ -267,6 +299,11 @@ function is_inactive_from_sig(@nospecialize(TT); return isapplicable(inactive, TT; world, method_table, caller) end + +function is_inactive_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), + @nospecialize(TT), sv::Core.Compiler.AbsIntState) + return isapplicable(interp, inactive, TT, sv) +end """ inactive_noinl(func::typeof(f), args...) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index ff4c6c991b..850a17195d 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -192,6 +192,7 @@ Core.Compiler.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getresult(info.info, idx) +import .EnzymeRules: FwdConfig, RevConfig, Annotation using Core.Compiler: ArgInfo, StmtInfo, AbsIntState function Core.Compiler.abstract_call_gf_by_type( @nospecialize(interp::EnzymeInterpreter), @@ -212,30 +213,22 @@ function Core.Compiler.abstract_call_gf_by_type( max_methods::Int, ) callinfo = ret.info - method_table = Core.Compiler.method_table(interp) specTypes = simplify_kw(atype) - caller = if callinfo isa Core.Compiler.MethodMatchInfo && callinfo.results isa Core.Compiler.MethodLookupResult - callinfo.results - else - nothing - end if is_primitive_func(specTypes) callinfo = NoInlineCallInfo(callinfo, atype, :primitive) elseif is_alwaysinline_func(specTypes) callinfo = AlwaysInlineCallInfo(callinfo, atype) - elseif EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table, caller) - callinfo = NoInlineCallInfo(callinfo, atype, :inactive) else - if interp.forward_rules - if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table, caller) - callinfo = NoInlineCallInfo(callinfo, atype, :frule) - end - end - - if interp.reverse_rules - if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table, caller) - callinfo = NoInlineCallInfo(callinfo, atype, :rrule) + # 1. Check if function is inactive + if EnzymeRules.is_inactive_from_sig(interp, specTypes, sv) + callinfo = NoInlineCallInfo(callinfo, atype, :inactive) + else + # 2. Check if rule is defined + if interp.forward_rules && EnzymeRules.has_frule_from_sig(interp, specTypes, sv) + callinfo = NoInlineCallInfo(callinfo, atype, :frule) + elseif interp.reverse_rules && EnzymeRules.has_rrule_from_sig(interp, specTypes, sv) + callinfo = NoInlineCallInfo(callinfo, atype, :rrule) end end end