Skip to content

Commit

Permalink
Use right caller for isapplicable usage
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Dec 1, 2024
1 parent 2839d3f commit faa4ef3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 38 deletions.
32 changes: 6 additions & 26 deletions lib/EnzymeCore/src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ end
function isapplicable(@nospecialize(f), @nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
caller::Union{Nothing,Core.MethodInstance}=nothing)
tt = Base.to_tuple_type(TT)
sig = Base.signature_type(f, tt)
mt = ccall(:jl_method_table_for, Any, (Any,), sig)
Expand All @@ -209,35 +209,15 @@ function isapplicable(@nospecialize(f), @nospecialize(TT);
end
fullmatch = Core.Compiler._any(match::Core.MethodMatch->match.fully_covers, matches)
if !fullmatch
if caller isa Core.MethodInstance
add_mt_backedge!(caller, mt, sig)
elseif caller isa Core.Compiler.MethodLookupResult
for j = 1:Core.Compiler.length(caller)
cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch
cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance
add_mt_backedge!(cspec, mt, sig)
end
end
add_mt_backedge!(caller, mt, sig)
end
if Core.Compiler.isempty(matches)
return false
else
if caller isa Core.MethodInstance
for i = 1:Core.Compiler.length(matches)
match = Core.Compiler.getindex(matches, i)::Core.MethodMatch
edge = Core.Compiler.specialize_method(match)::Core.MethodInstance
add_backedge!(caller, edge, sig)
end
elseif caller isa Core.Compiler.MethodLookupResult
for j = 1:Core.Compiler.length(caller)
cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch
cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance
for i = 1:Core.Compiler.length(matches)
match = Core.Compiler.getindex(matches, i)::Core.MethodMatch
edge = Core.Compiler.specialize_method(match)::Core.MethodInstance
add_backedge!(cspec, edge, sig)
end
end
for i = 1:Core.Compiler.length(matches)
match = Core.Compiler.getindex(matches, i)::Core.MethodMatch
edge = Core.Compiler.specialize_method(match)::Core.MethodInstance
add_backedge!(caller, edge, sig)
end
return true
end
Expand Down
19 changes: 7 additions & 12 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,28 +214,23 @@ function Core.Compiler.abstract_call_gf_by_type(
callinfo = ret.info
method_table = Core.Compiler.method_table(interp)
specTypes = simplify_kw(atype)
caller = if callinfo isa Core.Compiler.MethodMatchInfo && callinfo.results isa Core.Compiler.MethodLookupResult
callinfo.results
else
nothing
end

if is_primitive_func(specTypes)
callinfo = NoInlineCallInfo(callinfo, atype, :primitive)
elseif is_alwaysinline_func(specTypes)
callinfo = AlwaysInlineCallInfo(callinfo, atype)
elseif EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table, caller)
callinfo = NoInlineCallInfo(callinfo, atype, :inactive)
elseif EnzymeRules.is_inactive_from_sig(specTypes; world=interp.world, method_table, caller=sv.linfo)
callinfo = NoInlineCallInfo(callinfo, atype, :inactive)
else
if interp.forward_rules
if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table, caller)
if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table, caller=sv.linfo)
callinfo = NoInlineCallInfo(callinfo, atype, :frule)
end
end
end

if interp.reverse_rules
if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table, caller)
callinfo = NoInlineCallInfo(callinfo, atype, :rrule)
if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table, caller=sv.linfo)
callinfo = NoInlineCallInfo(callinfo, atype, :rrule)
end
end
end
Expand Down

0 comments on commit faa4ef3

Please sign in to comment.