From ca722d7a1f9721cbed697f7a70e7c05171c18a9f Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Fri, 6 Dec 2024 23:20:55 -0600 Subject: [PATCH] more --- src/compiler.jl | 3 ++- src/compiler/interpreter.jl | 16 +++++++++++----- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 243d2cad1a..f561f5e8c0 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5561,9 +5561,10 @@ function thunk_generator(world::UInt, source::LineNumberNode, @nospecialize(FA:: push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.forward)}, world)) Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(EnzymeRules.forward)) else - push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_forward)}, world)) + push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_primal)}, world)) end + push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.inactive)}, world)) push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{Val{0}}, world)) Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(Val(0))) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index ffcf6b0f0d..fb24f38cb0 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -130,8 +130,9 @@ end begin fwd_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{typeof(EnzymeRules.forward)}) - rev_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_forward)}) - gen_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_forward)}) + rev_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_primal)}) + ina_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{typeof(EnzymeRules.inactive)}) + gen_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{Val{0}}) fwd_sig = Tuple{typeof(EnzymeRules.forward), <:FwdConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}} @@ -141,8 +142,13 @@ begin EnzymeRules.add_mt_backedge!(rev_rule_be, ccall(:jl_method_table_for, Any, (Any,), rev_sig)::Core.MethodTable, rev_sig) - for gen_sig in ( + for ina_sig in ( Tuple{typeof(EnzymeRules.inactive), Vararg{Annotation}}, + ) + EnzymeRules.add_mt_backedge!(ina_rule_be, ccall(:jl_method_table_for, Any, (Any,), ina_sig)::Core.MethodTable, ina_sig) + end + + for gen_sig in ( Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{Annotation}}, Tuple{typeof(EnzymeRules.noalias), Vararg{Any}}, Tuple{typeof(EnzymeRules.inactive_type), Type}, @@ -376,8 +382,8 @@ function Core.Compiler.abstract_call_gf_by_type( Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_primal)}, interp.world)::Core.MethodInstance) Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(EnzymeRules.augmented_primal)) end - Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{Val{0}}, interp.world)::Core.MethodInstance) - Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(Val(0))) + Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.inactive)}, interp.world)::Core.MethodInstance) + Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(typeof(EnzymeRules.inactive))) end @static if VERSION ≥ v"1.11-"