From 49c9e563859777939f9005a1af5bb943b3f6ad4c Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 6 Dec 2024 11:41:31 +0100 Subject: [PATCH] swap back to the old form --- src/compiler/tfunc.jl | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/src/compiler/tfunc.jl b/src/compiler/tfunc.jl index 701cfa8107..b8bd262dc6 100644 --- a/src/compiler/tfunc.jl +++ b/src/compiler/tfunc.jl @@ -1,31 +1,42 @@ import EnzymeCore: Annotation import EnzymeCore.EnzymeRules: FwdConfig, RevConfig, forward, augmented_primal, inactive, _annotate_tt +function add_backedge!(caller::Core.MethodInstance, callee::Core.MethodInstance, @nospecialize(sig)) + ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), callee, sig, caller) + return nothing +end + +function add_mt_backedge!(caller::Core.MethodInstance, mt::Core.MethodTable, @nospecialize(sig)) + ccall(:jl_method_table_add_backedge, Cvoid, (Any, Any, Any), mt, sig, caller) + return nothing +end + + function has_frule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), - @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + @nospecialize(TT), caller::MethodInstance)::Bool ft, tt = _annotate_tt(TT) TT = Tuple{<:FwdConfig,<:Annotation{ft},Type{<:Annotation},tt...} - return isapplicable(interp, forward, TT, sv) + return isapplicable(interp, forward, TT, caller) end function has_rrule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), - @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + @nospecialize(TT), caller::MethodInstance)::Bool ft, tt = _annotate_tt(TT) TT = Tuple{<:RevConfig,<:Annotation{ft},Type{<:Annotation},tt...} - return isapplicable(interp, augmented_primal, TT, sv) + return isapplicable(interp, augmented_primal, TT, caller) end function is_inactive_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), - @nospecialize(TT), sv::Core.Compiler.AbsIntState) - return isapplicable(interp, inactive, TT, sv) + @nospecialize(TT), caller::MethodInstance) + return isapplicable(interp, inactive, TT, caller) end # `hasmethod` is a precise match using `Core.Compiler.findsup`, # but here we want the broader query using `Core.Compiler.findall`. # Also add appropriate backedges to the caller `MethodInstance` if given. function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter), - @nospecialize(f), @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + @nospecialize(f), @nospecialize(TT), caller::MethodInstance)::Bool tt = Base.to_tuple_type(TT) sig = Base.signature_type(f, tt) mt = ccall(:jl_method_table_for, Any, (Any,), sig) @@ -41,7 +52,7 @@ function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter), # added that did not intersect with any existing method fullmatch = Core.Compiler._any(match::Core.MethodMatch -> match.fully_covers, matches) if !fullmatch - Core.Compiler.add_mt_backedge!(sv, mt, sig) + add_mt_backedge!(caller, mt, sig) end if Core.Compiler.isempty(matches) return false @@ -49,7 +60,7 @@ function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter), for i = 1:Core.Compiler.length(matches) match = Core.Compiler.getindex(matches, i)::Core.MethodMatch edge = Core.Compiler.specialize_method(match)::Core.MethodInstance - Core.Compiler.add_backedge!(sv, edge) + add_backedge!(caller, edge, sig) end return true end