From 192ad1027f9c411ff0bce8d9c38a4cd5bca3a0df Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Thu, 5 Dec 2024 18:13:44 +0100 Subject: [PATCH 1/4] WIP: use IPO cache for results --- src/compiler/interpreter.jl | 93 +++++++++++++++++++++++++++++-------- 1 file changed, 74 insertions(+), 19 deletions(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 2d02604eda..3ba546a29a 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -173,6 +173,59 @@ end include("tfunc.jl") +struct EnzymeCache + inactive::Bool + has_rule::Bool +end + +if VERSION >= v"1.11.0-" +function CC.ipo_dataflow_analysis!(interp::EnzymeInterpreter, ir::Core.Compiler.IRCode, + caller::Core.Compiler.InferenceResult) + mi = caller.linfo + specTypes = simplify_kw(mi.spectypes) + inactive = false + has_rule = false + if is_inactive_from_sig(interp, specTypes, mi) + inactive = true + else + # 2. Check if rule is defined + if interp.forward_rules && has_frule_from_sig(interp, specTypes, mi) + has_rule = true + elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, mi) + has_rule = true + end + end + CC.stack_analysis_result!(caller, EnzymeCache(inactive, has_rule)) + @invoke CC.ipo_dataflow_analysis!(interp::Core.Compiler.AbstractInterpreter, ir::Core.Compiler.IRCode, + caller::Core.Compiler.InferenceResult) +end + +else # v1.10 +# 1.10 doesn't have stack_analysis_result or ipo_dataflow_analysis +function Core.Compiler.finish(interp::EnzymeInterpreter, opt::Core.Compiler.OptimizationState, ir::Core.Compiler.IRCode, + caller::Core.Compiler.InferenceResult) + mi = caller.linfo + specTypes = simplify_kw(mi.spectypes) + inactive = false + has_rule = false + if is_inactive_from_sig(interp, specTypes, mi) + inactive = true + else + # 2. Check if rule is defined + if interp.forward_rules && has_frule_from_sig(interp, specTypes, mi) + has_rule = true + elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, mi) + has_rule = true + end + end + # HACK: we store the deferred edges in the argescapes field, which is invalid, + # but nobody should be running EA on our results. + caller.argescapes = EnzymeCache(inactive, has_rule) + @invoke CC.finish(interp::Core.Compiler.AbstractInterpreter, opt::Core.Compiler.OptimizationState, + ir::Core.Compiler.IRCode, caller::Core.Compiler.InferenceResult) +end +end + import Core.Compiler: CallInfo struct NoInlineCallInfo <: CallInfo info::CallInfo # wrapped call @@ -219,31 +272,33 @@ function Core.Compiler.abstract_call_gf_by_type( ) callinfo = ret.info specTypes = simplify_kw(atype) + @show ret if is_primitive_func(specTypes) callinfo = NoInlineCallInfo(callinfo, atype, :primitive) elseif is_alwaysinline_func(specTypes) callinfo = AlwaysInlineCallInfo(callinfo, atype) - else - # 1. Check if function is inactive - if is_inactive_from_sig(interp, specTypes, sv) - callinfo = NoInlineCallInfo(callinfo, atype, :inactive) - else - # 2. Check if rule is defined - has_rule = get!(interp.rules_cache, specTypes) do - if interp.forward_rules && has_frule_from_sig(interp, specTypes, sv) - return true - elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, sv) - return true - else - return false - end - end - if has_rule - callinfo = NoInlineCallInfo(callinfo, atype, interp.forward_rules ? :frule : :rrule) - end - end end + # else + # # 1. Check if function is inactive + # if is_inactive_from_sig(interp, specTypes, sv) + # callinfo = NoInlineCallInfo(callinfo, atype, :inactive) + # else + # # 2. Check if rule is defined + # has_rule = get!(interp.rules_cache, specTypes) do + # if interp.forward_rules && has_frule_from_sig(interp, specTypes, sv) + # return true + # elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, sv) + # return true + # else + # return false + # end + # end + # if has_rule + # callinfo = NoInlineCallInfo(callinfo, atype, interp.forward_rules ? :frule : :rrule) + # end + # end + # end @static if VERSION ≥ v"1.11-" return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo) else From d121d48d3ca0651aa05d91f8fb169ea53384427b Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 5 Dec 2024 20:12:54 -0600 Subject: [PATCH 2/4] Update interpreter.jl --- src/compiler/interpreter.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 3ba546a29a..4bf8f21179 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -182,7 +182,7 @@ if VERSION >= v"1.11.0-" function CC.ipo_dataflow_analysis!(interp::EnzymeInterpreter, ir::Core.Compiler.IRCode, caller::Core.Compiler.InferenceResult) mi = caller.linfo - specTypes = simplify_kw(mi.spectypes) + specTypes = simplify_kw(mi.specTypes) inactive = false has_rule = false if is_inactive_from_sig(interp, specTypes, mi) @@ -205,7 +205,7 @@ else # v1.10 function Core.Compiler.finish(interp::EnzymeInterpreter, opt::Core.Compiler.OptimizationState, ir::Core.Compiler.IRCode, caller::Core.Compiler.InferenceResult) mi = caller.linfo - specTypes = simplify_kw(mi.spectypes) + specTypes = simplify_kw(mi.specTypes) inactive = false has_rule = false if is_inactive_from_sig(interp, specTypes, mi) From 49c9e563859777939f9005a1af5bb943b3f6ad4c Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 6 Dec 2024 11:41:31 +0100 Subject: [PATCH 3/4] swap back to the old form --- src/compiler/tfunc.jl | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/src/compiler/tfunc.jl b/src/compiler/tfunc.jl index 701cfa8107..b8bd262dc6 100644 --- a/src/compiler/tfunc.jl +++ b/src/compiler/tfunc.jl @@ -1,31 +1,42 @@ import EnzymeCore: Annotation import EnzymeCore.EnzymeRules: FwdConfig, RevConfig, forward, augmented_primal, inactive, _annotate_tt +function add_backedge!(caller::Core.MethodInstance, callee::Core.MethodInstance, @nospecialize(sig)) + ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), callee, sig, caller) + return nothing +end + +function add_mt_backedge!(caller::Core.MethodInstance, mt::Core.MethodTable, @nospecialize(sig)) + ccall(:jl_method_table_add_backedge, Cvoid, (Any, Any, Any), mt, sig, caller) + return nothing +end + + function has_frule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), - @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + @nospecialize(TT), caller::MethodInstance)::Bool ft, tt = _annotate_tt(TT) TT = Tuple{<:FwdConfig,<:Annotation{ft},Type{<:Annotation},tt...} - return isapplicable(interp, forward, TT, sv) + return isapplicable(interp, forward, TT, caller) end function has_rrule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), - @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + @nospecialize(TT), caller::MethodInstance)::Bool ft, tt = _annotate_tt(TT) TT = Tuple{<:RevConfig,<:Annotation{ft},Type{<:Annotation},tt...} - return isapplicable(interp, augmented_primal, TT, sv) + return isapplicable(interp, augmented_primal, TT, caller) end function is_inactive_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), - @nospecialize(TT), sv::Core.Compiler.AbsIntState) - return isapplicable(interp, inactive, TT, sv) + @nospecialize(TT), caller::MethodInstance) + return isapplicable(interp, inactive, TT, caller) end # `hasmethod` is a precise match using `Core.Compiler.findsup`, # but here we want the broader query using `Core.Compiler.findall`. # Also add appropriate backedges to the caller `MethodInstance` if given. function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter), - @nospecialize(f), @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + @nospecialize(f), @nospecialize(TT), caller::MethodInstance)::Bool tt = Base.to_tuple_type(TT) sig = Base.signature_type(f, tt) mt = ccall(:jl_method_table_for, Any, (Any,), sig) @@ -41,7 +52,7 @@ function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter), # added that did not intersect with any existing method fullmatch = Core.Compiler._any(match::Core.MethodMatch -> match.fully_covers, matches) if !fullmatch - Core.Compiler.add_mt_backedge!(sv, mt, sig) + add_mt_backedge!(caller, mt, sig) end if Core.Compiler.isempty(matches) return false @@ -49,7 +60,7 @@ function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter), for i = 1:Core.Compiler.length(matches) match = Core.Compiler.getindex(matches, i)::Core.MethodMatch edge = Core.Compiler.specialize_method(match)::Core.MethodInstance - Core.Compiler.add_backedge!(sv, edge) + add_backedge!(caller, edge, sig) end return true end From 114410848ee437a4cbb50385a76e68815239e513 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 6 Dec 2024 12:15:48 +0100 Subject: [PATCH 4/4] fixup impl for 1.10 --- src/compiler/interpreter.jl | 40 ++++++++++--------------------------- 1 file changed, 10 insertions(+), 30 deletions(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 4bf8f21179..38aeaf9cc4 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -204,25 +204,26 @@ else # v1.10 # 1.10 doesn't have stack_analysis_result or ipo_dataflow_analysis function Core.Compiler.finish(interp::EnzymeInterpreter, opt::Core.Compiler.OptimizationState, ir::Core.Compiler.IRCode, caller::Core.Compiler.InferenceResult) - mi = caller.linfo - specTypes = simplify_kw(mi.specTypes) + (; src, linfo) = opt + specTypes = simplify_kw(linfo.specTypes) inactive = false has_rule = false - if is_inactive_from_sig(interp, specTypes, mi) + if is_inactive_from_sig(interp, specTypes, linfo) inactive = true else # 2. Check if rule is defined - if interp.forward_rules && has_frule_from_sig(interp, specTypes, mi) + if interp.forward_rules && has_frule_from_sig(interp, specTypes, linfo) has_rule = true - elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, mi) + elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, linfo) has_rule = true end end - # HACK: we store the deferred edges in the argescapes field, which is invalid, - # but nobody should be running EA on our results. - caller.argescapes = EnzymeCache(inactive, has_rule) - @invoke CC.finish(interp::Core.Compiler.AbstractInterpreter, opt::Core.Compiler.OptimizationState, + @invoke Core.Compiler.finish(interp::Core.Compiler.AbstractInterpreter, opt::Core.Compiler.OptimizationState, ir::Core.Compiler.IRCode, caller::Core.Compiler.InferenceResult) + # Must happen afterwards + if inactive || has_rule + Core.Compiler.set_inlineable!(src, false) + end end end @@ -272,33 +273,12 @@ function Core.Compiler.abstract_call_gf_by_type( ) callinfo = ret.info specTypes = simplify_kw(atype) - @show ret if is_primitive_func(specTypes) callinfo = NoInlineCallInfo(callinfo, atype, :primitive) elseif is_alwaysinline_func(specTypes) callinfo = AlwaysInlineCallInfo(callinfo, atype) end - # else - # # 1. Check if function is inactive - # if is_inactive_from_sig(interp, specTypes, sv) - # callinfo = NoInlineCallInfo(callinfo, atype, :inactive) - # else - # # 2. Check if rule is defined - # has_rule = get!(interp.rules_cache, specTypes) do - # if interp.forward_rules && has_frule_from_sig(interp, specTypes, sv) - # return true - # elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, sv) - # return true - # else - # return false - # end - # end - # if has_rule - # callinfo = NoInlineCallInfo(callinfo, atype, interp.forward_rules ? :frule : :rrule) - # end - # end - # end @static if VERSION ≥ v"1.11-" return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo) else