diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index 945951b216..433417bd10 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -192,7 +192,7 @@ end function isapplicable(@nospecialize(f), @nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) + caller::Union{Nothing,Core.MethodInstance}=nothing) tt = Base.to_tuple_type(TT) sig = Base.signature_type(f, tt) mt = ccall(:jl_method_table_for, Any, (Any,), sig) @@ -209,35 +209,15 @@ function isapplicable(@nospecialize(f), @nospecialize(TT); end fullmatch = Core.Compiler._any(match::Core.MethodMatch->match.fully_covers, matches) if !fullmatch - if caller isa Core.MethodInstance - add_mt_backedge!(caller, mt, sig) - elseif caller isa Core.Compiler.MethodLookupResult - for j = 1:Core.Compiler.length(caller) - cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch - cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance - add_mt_backedge!(cspec, mt, sig) - end - end + add_mt_backedge!(caller, mt, sig) end if Core.Compiler.isempty(matches) return false else - if caller isa Core.MethodInstance - for i = 1:Core.Compiler.length(matches) - match = Core.Compiler.getindex(matches, i)::Core.MethodMatch - edge = Core.Compiler.specialize_method(match)::Core.MethodInstance - add_backedge!(caller, edge, sig) - end - elseif caller isa Core.Compiler.MethodLookupResult - for j = 1:Core.Compiler.length(caller) - cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch - cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance - for i = 1:Core.Compiler.length(matches) - match = Core.Compiler.getindex(matches, i)::Core.MethodMatch - edge = Core.Compiler.specialize_method(match)::Core.MethodInstance - add_backedge!(cspec, edge, sig) - end - end + for i = 1:Core.Compiler.length(matches) + match = Core.Compiler.getindex(matches, i)::Core.MethodMatch + edge = Core.Compiler.specialize_method(match)::Core.MethodInstance + add_backedge!(caller, edge, sig) end return true end diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 1e442482be..6979f77fb6 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -214,28 +214,23 @@ function Core.Compiler.abstract_call_gf_by_type( callinfo = ret.info method_table = Core.Compiler.method_table(interp) specTypes = simplify_kw(atype) - caller = if callinfo isa Core.Compiler.MethodMatchInfo && callinfo.results isa Core.Compiler.MethodLookupResult - callinfo.results - else - nothing - end if is_primitive_func(specTypes) callinfo = NoInlineCallInfo(callinfo, atype, :primitive) elseif is_alwaysinline_func(specTypes) callinfo = AlwaysInlineCallInfo(callinfo, atype) - elseif EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table, caller) - callinfo = NoInlineCallInfo(callinfo, atype, :inactive) + elseif EnzymeRules.is_inactive_from_sig(specTypes; world=interp.world, method_table, caller=sv.linfo) + callinfo = NoInlineCallInfo(callinfo, atype, :inactive) else if interp.forward_rules - if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table, caller) + if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table, caller=sv.linfo) callinfo = NoInlineCallInfo(callinfo, atype, :frule) - end + end end - + if interp.reverse_rules - if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table, caller) - callinfo = NoInlineCallInfo(callinfo, atype, :rrule) + if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table, caller=sv.linfo) + callinfo = NoInlineCallInfo(callinfo, atype, :rrule) end end end