Skip to content

Commit

Permalink
Forward interp and sv to isapplicable
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Dec 4, 2024
1 parent b046156 commit 6ffb629
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 36 deletions.
75 changes: 56 additions & 19 deletions lib/EnzymeCore/src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ end
function has_frule_from_sig(@nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool
ft, tt = _annotate_tt(TT)
TT = Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...}
return isapplicable(forward, TT; world, method_table, caller)
Expand All @@ -180,7 +180,7 @@ end
function has_rrule_from_sig(@nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool
ft, tt = _annotate_tt(TT)
TT = Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...}
return isapplicable(augmented_primal, TT; world, method_table, caller)
Expand All @@ -192,7 +192,7 @@ end
function isapplicable(@nospecialize(f), @nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool
tt = Base.to_tuple_type(TT)
sig = Base.signature_type(f, tt)
mt = ccall(:jl_method_table_for, Any, (Any,), sig)
Expand All @@ -211,12 +211,6 @@ function isapplicable(@nospecialize(f), @nospecialize(TT);
if !fullmatch
if caller isa Core.MethodInstance
add_mt_backedge!(caller, mt, sig)
elseif caller isa Core.Compiler.MethodLookupResult
for j = 1:Core.Compiler.length(caller)
cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch
cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance
add_mt_backedge!(cspec, mt, sig)
end
end
end
if Core.Compiler.isempty(matches)
Expand All @@ -228,16 +222,54 @@ function isapplicable(@nospecialize(f), @nospecialize(TT);
edge = Core.Compiler.specialize_method(match)::Core.MethodInstance
add_backedge!(caller, edge, sig)
end
elseif caller isa Core.Compiler.MethodLookupResult
for j = 1:Core.Compiler.length(caller)
cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch
cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance
for i = 1:Core.Compiler.length(matches)
match = Core.Compiler.getindex(matches, i)::Core.MethodMatch
edge = Core.Compiler.specialize_method(match)::Core.MethodInstance
add_backedge!(cspec, edge, sig)
end
end
end
return true
end
end

function has_frule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter),
@nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool
ft, tt = _annotate_tt(TT)
TT = Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...}
return isapplicable(interp, forward, TT, sv)
end

function has_rrule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter),
@nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool
ft, tt = _annotate_tt(TT)
TT = Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...}
return isapplicable(interp, augmented_primal, TT, sv)
end

# `hasmethod` is a precise match using `Core.Compiler.findsup`,
# but here we want the broader query using `Core.Compiler.findall`.
# Also add appropriate backedges to the caller `MethodInstance` if given.
function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter),
@nospecialize(f), @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool
tt = Base.to_tuple_type(TT)
sig = Base.signature_type(f, tt)
mt = ccall(:jl_method_table_for, Any, (Any,), sig)
mt isa Core.MethodTable || return false
result = Core.Compiler.findall(sig, Core.Compiler.method_table(interp); limit=-1)
(result === nothing || result === missing) && return false
@static if isdefined(Core.Compiler, :MethodMatchResult)
(; matches) = result
else
matches = result
end
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
fullmatch = Core.Compiler._any(match::Core.MethodMatch->match.fully_covers, matches)
if !fullmatch
Core.Compiler.add_mt_backedge!(sv, mt, sig)
end
if Core.Compiler.isempty(matches)
return false
else
for i = 1:Core.Compiler.length(matches)
match = Core.Compiler.getindex(matches, i)::Core.MethodMatch
edge = Core.Compiler.specialize_method(match)::Core.MethodInstance
Core.Compiler.add_backedge!(sv, edge)
end
return true
end
Expand Down Expand Up @@ -267,6 +299,11 @@ function is_inactive_from_sig(@nospecialize(TT);
return isapplicable(inactive, TT; world, method_table, caller)
end


function is_inactive_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter),
@nospecialize(TT), sv::Core.Compiler.AbsIntState)
return isapplicable(interp, inactive, TT, sv)
end
"""
inactive_noinl(func::typeof(f), args...)
Expand Down
27 changes: 10 additions & 17 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ Core.Compiler.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) =
Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) =
Core.Compiler.getresult(info.info, idx)

import .EnzymeRules: FwdConfig, RevConfig, Annotation
using Core.Compiler: ArgInfo, StmtInfo, AbsIntState
function Core.Compiler.abstract_call_gf_by_type(
@nospecialize(interp::EnzymeInterpreter),
Expand All @@ -212,30 +213,22 @@ function Core.Compiler.abstract_call_gf_by_type(
max_methods::Int,
)
callinfo = ret.info
method_table = Core.Compiler.method_table(interp)
specTypes = simplify_kw(atype)
caller = if callinfo isa Core.Compiler.MethodMatchInfo && callinfo.results isa Core.Compiler.MethodLookupResult
callinfo.results
else
nothing
end

if is_primitive_func(specTypes)
callinfo = NoInlineCallInfo(callinfo, atype, :primitive)
elseif is_alwaysinline_func(specTypes)
callinfo = AlwaysInlineCallInfo(callinfo, atype)
elseif EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table, caller)
callinfo = NoInlineCallInfo(callinfo, atype, :inactive)
else
if interp.forward_rules
if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table, caller)
callinfo = NoInlineCallInfo(callinfo, atype, :frule)
end
end

if interp.reverse_rules
if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table, caller)
callinfo = NoInlineCallInfo(callinfo, atype, :rrule)
# 1. Check if function is inactive
if EnzymeRules.is_inactive_from_sig(interp, specTypes, sv)
callinfo = NoInlineCallInfo(callinfo, atype, :inactive)
else
# 2. Check if rule is defined
if interp.forward_rules && EnzymeRules.has_frule_from_sig(interp, specTypes, sv)
callinfo = NoInlineCallInfo(callinfo, atype, :frule)
elseif interp.reverse_rules && EnzymeRules.has_rrule_from_sig(interp, specTypes, sv)
callinfo = NoInlineCallInfo(callinfo, atype, :rrule)
end
end
end
Expand Down

0 comments on commit 6ffb629

Please sign in to comment.