From bf6c435f13ceb26f6b99b3d13c8cf1f54c04f08d Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 4 Dec 2024 15:32:29 +0100 Subject: [PATCH] add ephermal cache --- src/compiler/interpreter.jl | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index dcec119f17..2d02604eda 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -40,6 +40,8 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter inf_params::InferenceParams opt_params::OptimizationParams + rules_cache::IdDict{Any, Bool} + forward_rules::Bool reverse_rules::Bool deferred_lower::Bool @@ -78,6 +80,7 @@ function EnzymeInterpreter( # parameters for inference and optimization parms, OptimizationParams(), + IdDict{Any, Bool}(), forward_rules, reverse_rules, deferred_lower, @@ -227,11 +230,18 @@ function Core.Compiler.abstract_call_gf_by_type( callinfo = NoInlineCallInfo(callinfo, atype, :inactive) else # 2. Check if rule is defined - if interp.forward_rules && has_frule_from_sig(interp, specTypes, sv) - callinfo = NoInlineCallInfo(callinfo, atype, :frule) - elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, sv) - callinfo = NoInlineCallInfo(callinfo, atype, :rrule) + has_rule = get!(interp.rules_cache, specTypes) do + if interp.forward_rules && has_frule_from_sig(interp, specTypes, sv) + return true + elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, sv) + return true + else + return false + end end + if has_rule + callinfo = NoInlineCallInfo(callinfo, atype, interp.forward_rules ? :frule : :rrule) + end end end @static if VERSION ≥ v"1.11-"