diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 2d02604eda..3ba546a29a 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -173,6 +173,59 @@ end include("tfunc.jl") +struct EnzymeCache + inactive::Bool + has_rule::Bool +end + +if VERSION >= v"1.11.0-" +function CC.ipo_dataflow_analysis!(interp::EnzymeInterpreter, ir::Core.Compiler.IRCode, + caller::Core.Compiler.InferenceResult) + mi = caller.linfo + specTypes = simplify_kw(mi.spectypes) + inactive = false + has_rule = false + if is_inactive_from_sig(interp, specTypes, mi) + inactive = true + else + # 2. Check if rule is defined + if interp.forward_rules && has_frule_from_sig(interp, specTypes, mi) + has_rule = true + elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, mi) + has_rule = true + end + end + CC.stack_analysis_result!(caller, EnzymeCache(inactive, has_rule)) + @invoke CC.ipo_dataflow_analysis!(interp::Core.Compiler.AbstractInterpreter, ir::Core.Compiler.IRCode, + caller::Core.Compiler.InferenceResult) +end + +else # v1.10 +# 1.10 doesn't have stack_analysis_result or ipo_dataflow_analysis +function Core.Compiler.finish(interp::EnzymeInterpreter, opt::Core.Compiler.OptimizationState, ir::Core.Compiler.IRCode, + caller::Core.Compiler.InferenceResult) + mi = caller.linfo + specTypes = simplify_kw(mi.spectypes) + inactive = false + has_rule = false + if is_inactive_from_sig(interp, specTypes, mi) + inactive = true + else + # 2. Check if rule is defined + if interp.forward_rules && has_frule_from_sig(interp, specTypes, mi) + has_rule = true + elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, mi) + has_rule = true + end + end + # HACK: we store the deferred edges in the argescapes field, which is invalid, + # but nobody should be running EA on our results. + caller.argescapes = EnzymeCache(inactive, has_rule) + @invoke CC.finish(interp::Core.Compiler.AbstractInterpreter, opt::Core.Compiler.OptimizationState, + ir::Core.Compiler.IRCode, caller::Core.Compiler.InferenceResult) +end +end + import Core.Compiler: CallInfo struct NoInlineCallInfo <: CallInfo info::CallInfo # wrapped call @@ -219,31 +272,33 @@ function Core.Compiler.abstract_call_gf_by_type( ) callinfo = ret.info specTypes = simplify_kw(atype) + @show ret if is_primitive_func(specTypes) callinfo = NoInlineCallInfo(callinfo, atype, :primitive) elseif is_alwaysinline_func(specTypes) callinfo = AlwaysInlineCallInfo(callinfo, atype) - else - # 1. Check if function is inactive - if is_inactive_from_sig(interp, specTypes, sv) - callinfo = NoInlineCallInfo(callinfo, atype, :inactive) - else - # 2. Check if rule is defined - has_rule = get!(interp.rules_cache, specTypes) do - if interp.forward_rules && has_frule_from_sig(interp, specTypes, sv) - return true - elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, sv) - return true - else - return false - end - end - if has_rule - callinfo = NoInlineCallInfo(callinfo, atype, interp.forward_rules ? :frule : :rrule) - end - end end + # else + # # 1. Check if function is inactive + # if is_inactive_from_sig(interp, specTypes, sv) + # callinfo = NoInlineCallInfo(callinfo, atype, :inactive) + # else + # # 2. Check if rule is defined + # has_rule = get!(interp.rules_cache, specTypes) do + # if interp.forward_rules && has_frule_from_sig(interp, specTypes, sv) + # return true + # elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, sv) + # return true + # else + # return false + # end + # end + # if has_rule + # callinfo = NoInlineCallInfo(callinfo, atype, interp.forward_rules ? :frule : :rrule) + # end + # end + # end @static if VERSION ≥ v"1.11-" return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo) else