Skip to content

Commit

Permalink
add ephermal cache
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Dec 4, 2024
1 parent 0039b1c commit bf6c435
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter
inf_params::InferenceParams
opt_params::OptimizationParams

rules_cache::IdDict{Any, Bool}

forward_rules::Bool
reverse_rules::Bool
deferred_lower::Bool
Expand Down Expand Up @@ -78,6 +80,7 @@ function EnzymeInterpreter(
# parameters for inference and optimization
parms,
OptimizationParams(),
IdDict{Any, Bool}(),
forward_rules,
reverse_rules,
deferred_lower,
Expand Down Expand Up @@ -227,11 +230,18 @@ function Core.Compiler.abstract_call_gf_by_type(
callinfo = NoInlineCallInfo(callinfo, atype, :inactive)
else
# 2. Check if rule is defined
if interp.forward_rules && has_frule_from_sig(interp, specTypes, sv)
callinfo = NoInlineCallInfo(callinfo, atype, :frule)
elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, sv)
callinfo = NoInlineCallInfo(callinfo, atype, :rrule)
has_rule = get!(interp.rules_cache, specTypes) do
if interp.forward_rules && has_frule_from_sig(interp, specTypes, sv)
return true
elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, sv)
return true
else
return false
end
end
if has_rule
callinfo = NoInlineCallInfo(callinfo, atype, interp.forward_rules ? :frule : :rrule)
end
end
end
@static if VERSION v"1.11-"
Expand Down

0 comments on commit bf6c435

Please sign in to comment.