diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 2d02604eda..38aeaf9cc4 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -173,6 +173,60 @@ end include("tfunc.jl") +struct EnzymeCache + inactive::Bool + has_rule::Bool +end + +if VERSION >= v"1.11.0-" +function CC.ipo_dataflow_analysis!(interp::EnzymeInterpreter, ir::Core.Compiler.IRCode, + caller::Core.Compiler.InferenceResult) + mi = caller.linfo + specTypes = simplify_kw(mi.specTypes) + inactive = false + has_rule = false + if is_inactive_from_sig(interp, specTypes, mi) + inactive = true + else + # 2. Check if rule is defined + if interp.forward_rules && has_frule_from_sig(interp, specTypes, mi) + has_rule = true + elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, mi) + has_rule = true + end + end + CC.stack_analysis_result!(caller, EnzymeCache(inactive, has_rule)) + @invoke CC.ipo_dataflow_analysis!(interp::Core.Compiler.AbstractInterpreter, ir::Core.Compiler.IRCode, + caller::Core.Compiler.InferenceResult) +end + +else # v1.10 +# 1.10 doesn't have stack_analysis_result or ipo_dataflow_analysis +function Core.Compiler.finish(interp::EnzymeInterpreter, opt::Core.Compiler.OptimizationState, ir::Core.Compiler.IRCode, + caller::Core.Compiler.InferenceResult) + (; src, linfo) = opt + specTypes = simplify_kw(linfo.specTypes) + inactive = false + has_rule = false + if is_inactive_from_sig(interp, specTypes, linfo) + inactive = true + else + # 2. Check if rule is defined + if interp.forward_rules && has_frule_from_sig(interp, specTypes, linfo) + has_rule = true + elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, linfo) + has_rule = true + end + end + @invoke Core.Compiler.finish(interp::Core.Compiler.AbstractInterpreter, opt::Core.Compiler.OptimizationState, + ir::Core.Compiler.IRCode, caller::Core.Compiler.InferenceResult) + # Must happen afterwards + if inactive || has_rule + Core.Compiler.set_inlineable!(src, false) + end +end +end + import Core.Compiler: CallInfo struct NoInlineCallInfo <: CallInfo info::CallInfo # wrapped call @@ -224,25 +278,6 @@ function Core.Compiler.abstract_call_gf_by_type( callinfo = NoInlineCallInfo(callinfo, atype, :primitive) elseif is_alwaysinline_func(specTypes) callinfo = AlwaysInlineCallInfo(callinfo, atype) - else - # 1. Check if function is inactive - if is_inactive_from_sig(interp, specTypes, sv) - callinfo = NoInlineCallInfo(callinfo, atype, :inactive) - else - # 2. Check if rule is defined - has_rule = get!(interp.rules_cache, specTypes) do - if interp.forward_rules && has_frule_from_sig(interp, specTypes, sv) - return true - elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, sv) - return true - else - return false - end - end - if has_rule - callinfo = NoInlineCallInfo(callinfo, atype, interp.forward_rules ? :frule : :rrule) - end - end end @static if VERSION ≥ v"1.11-" return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo) diff --git a/src/compiler/tfunc.jl b/src/compiler/tfunc.jl index 701cfa8107..b8bd262dc6 100644 --- a/src/compiler/tfunc.jl +++ b/src/compiler/tfunc.jl @@ -1,31 +1,42 @@ import EnzymeCore: Annotation import EnzymeCore.EnzymeRules: FwdConfig, RevConfig, forward, augmented_primal, inactive, _annotate_tt +function add_backedge!(caller::Core.MethodInstance, callee::Core.MethodInstance, @nospecialize(sig)) + ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), callee, sig, caller) + return nothing +end + +function add_mt_backedge!(caller::Core.MethodInstance, mt::Core.MethodTable, @nospecialize(sig)) + ccall(:jl_method_table_add_backedge, Cvoid, (Any, Any, Any), mt, sig, caller) + return nothing +end + + function has_frule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), - @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + @nospecialize(TT), caller::MethodInstance)::Bool ft, tt = _annotate_tt(TT) TT = Tuple{<:FwdConfig,<:Annotation{ft},Type{<:Annotation},tt...} - return isapplicable(interp, forward, TT, sv) + return isapplicable(interp, forward, TT, caller) end function has_rrule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), - @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + @nospecialize(TT), caller::MethodInstance)::Bool ft, tt = _annotate_tt(TT) TT = Tuple{<:RevConfig,<:Annotation{ft},Type{<:Annotation},tt...} - return isapplicable(interp, augmented_primal, TT, sv) + return isapplicable(interp, augmented_primal, TT, caller) end function is_inactive_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), - @nospecialize(TT), sv::Core.Compiler.AbsIntState) - return isapplicable(interp, inactive, TT, sv) + @nospecialize(TT), caller::MethodInstance) + return isapplicable(interp, inactive, TT, caller) end # `hasmethod` is a precise match using `Core.Compiler.findsup`, # but here we want the broader query using `Core.Compiler.findall`. # Also add appropriate backedges to the caller `MethodInstance` if given. function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter), - @nospecialize(f), @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + @nospecialize(f), @nospecialize(TT), caller::MethodInstance)::Bool tt = Base.to_tuple_type(TT) sig = Base.signature_type(f, tt) mt = ccall(:jl_method_table_for, Any, (Any,), sig) @@ -41,7 +52,7 @@ function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter), # added that did not intersect with any existing method fullmatch = Core.Compiler._any(match::Core.MethodMatch -> match.fully_covers, matches) if !fullmatch - Core.Compiler.add_mt_backedge!(sv, mt, sig) + add_mt_backedge!(caller, mt, sig) end if Core.Compiler.isempty(matches) return false @@ -49,7 +60,7 @@ function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter), for i = 1:Core.Compiler.length(matches) match = Core.Compiler.getindex(matches, i)::Core.MethodMatch edge = Core.Compiler.specialize_method(match)::Core.MethodInstance - Core.Compiler.add_backedge!(sv, edge) + add_backedge!(caller, edge, sig) end return true end