diff --git a/Project.toml b/Project.toml index d86475034a..46b37ccb12 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ EnzymeStaticArraysExt = "StaticArrays" BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" -EnzymeCore = "0.8.7" +EnzymeCore = "0.8.8" Enzyme_jll = "0.0.167" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" diff --git a/lib/EnzymeCore/Project.toml b/lib/EnzymeCore/Project.toml index 28f92d9055..da662c545a 100644 --- a/lib/EnzymeCore/Project.toml +++ b/lib/EnzymeCore/Project.toml @@ -1,7 +1,7 @@ name = "EnzymeCore" uuid = "f151be2c-9106-41f4-ab19-57ee4f262869" authors = ["William Moses ", "Valentin Churavy "] -version = "0.8.7" +version = "0.8.8" [compat] Adapt = "3, 4" diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index 945951b216..dc33c11110 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -171,7 +171,7 @@ end function has_frule_from_sig(@nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) + caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool ft, tt = _annotate_tt(TT) TT = Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} return isapplicable(forward, TT; world, method_table, caller) @@ -180,7 +180,7 @@ end function has_rrule_from_sig(@nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) + caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool ft, tt = _annotate_tt(TT) TT = Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} return isapplicable(augmented_primal, TT; world, method_table, caller) @@ -192,7 +192,7 @@ end function isapplicable(@nospecialize(f), @nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) + caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool tt = Base.to_tuple_type(TT) sig = Base.signature_type(f, tt) mt = ccall(:jl_method_table_for, Any, (Any,), sig) @@ -211,12 +211,6 @@ function isapplicable(@nospecialize(f), @nospecialize(TT); if !fullmatch if caller isa Core.MethodInstance add_mt_backedge!(caller, mt, sig) - elseif caller isa Core.Compiler.MethodLookupResult - for j = 1:Core.Compiler.length(caller) - cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch - cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance - add_mt_backedge!(cspec, mt, sig) - end end end if Core.Compiler.isempty(matches) @@ -228,16 +222,6 @@ function isapplicable(@nospecialize(f), @nospecialize(TT); edge = Core.Compiler.specialize_method(match)::Core.MethodInstance add_backedge!(caller, edge, sig) end - elseif caller isa Core.Compiler.MethodLookupResult - for j = 1:Core.Compiler.length(caller) - cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch - cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance - for i = 1:Core.Compiler.length(matches) - match = Core.Compiler.getindex(matches, i)::Core.MethodMatch - edge = Core.Compiler.specialize_method(match)::Core.MethodInstance - add_backedge!(cspec, edge, sig) - end - end end return true end diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index ff4c6c991b..2d02604eda 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -40,6 +40,8 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter inf_params::InferenceParams opt_params::OptimizationParams + rules_cache::IdDict{Any, Bool} + forward_rules::Bool reverse_rules::Bool deferred_lower::Bool @@ -78,6 +80,7 @@ function EnzymeInterpreter( # parameters for inference and optimization parms, OptimizationParams(), + IdDict{Any, Bool}(), forward_rules, reverse_rules, deferred_lower, @@ -168,6 +171,8 @@ function simplify_kw(@nospecialize(specTypes)) end end +include("tfunc.jl") + import Core.Compiler: CallInfo struct NoInlineCallInfo <: CallInfo info::CallInfo # wrapped call @@ -192,6 +197,7 @@ Core.Compiler.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getresult(info.info, idx) +import .EnzymeRules: FwdConfig, RevConfig, Annotation using Core.Compiler: ArgInfo, StmtInfo, AbsIntState function Core.Compiler.abstract_call_gf_by_type( @nospecialize(interp::EnzymeInterpreter), @@ -212,31 +218,30 @@ function Core.Compiler.abstract_call_gf_by_type( max_methods::Int, ) callinfo = ret.info - method_table = Core.Compiler.method_table(interp) specTypes = simplify_kw(atype) - caller = if callinfo isa Core.Compiler.MethodMatchInfo && callinfo.results isa Core.Compiler.MethodLookupResult - callinfo.results - else - nothing - end if is_primitive_func(specTypes) callinfo = NoInlineCallInfo(callinfo, atype, :primitive) elseif is_alwaysinline_func(specTypes) callinfo = AlwaysInlineCallInfo(callinfo, atype) - elseif EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table, caller) - callinfo = NoInlineCallInfo(callinfo, atype, :inactive) else - if interp.forward_rules - if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table, caller) - callinfo = NoInlineCallInfo(callinfo, atype, :frule) - end - end - - if interp.reverse_rules - if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table, caller) - callinfo = NoInlineCallInfo(callinfo, atype, :rrule) + # 1. Check if function is inactive + if is_inactive_from_sig(interp, specTypes, sv) + callinfo = NoInlineCallInfo(callinfo, atype, :inactive) + else + # 2. Check if rule is defined + has_rule = get!(interp.rules_cache, specTypes) do + if interp.forward_rules && has_frule_from_sig(interp, specTypes, sv) + return true + elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, sv) + return true + else + return false + end end + if has_rule + callinfo = NoInlineCallInfo(callinfo, atype, interp.forward_rules ? :frule : :rrule) + end end end @static if VERSION ≥ v"1.11-" diff --git a/src/compiler/tfunc.jl b/src/compiler/tfunc.jl new file mode 100644 index 0000000000..701cfa8107 --- /dev/null +++ b/src/compiler/tfunc.jl @@ -0,0 +1,56 @@ +import EnzymeCore: Annotation +import EnzymeCore.EnzymeRules: FwdConfig, RevConfig, forward, augmented_primal, inactive, _annotate_tt + +function has_frule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), + @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + ft, tt = _annotate_tt(TT) + TT = Tuple{<:FwdConfig,<:Annotation{ft},Type{<:Annotation},tt...} + return isapplicable(interp, forward, TT, sv) +end + +function has_rrule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), + @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + ft, tt = _annotate_tt(TT) + TT = Tuple{<:RevConfig,<:Annotation{ft},Type{<:Annotation},tt...} + return isapplicable(interp, augmented_primal, TT, sv) +end + + +function is_inactive_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), + @nospecialize(TT), sv::Core.Compiler.AbsIntState) + return isapplicable(interp, inactive, TT, sv) +end + +# `hasmethod` is a precise match using `Core.Compiler.findsup`, +# but here we want the broader query using `Core.Compiler.findall`. +# Also add appropriate backedges to the caller `MethodInstance` if given. +function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter), + @nospecialize(f), @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + tt = Base.to_tuple_type(TT) + sig = Base.signature_type(f, tt) + mt = ccall(:jl_method_table_for, Any, (Any,), sig) + mt isa Core.MethodTable || return false + result = Core.Compiler.findall(sig, Core.Compiler.method_table(interp); limit=-1) + (result === nothing || result === missing) && return false + @static if isdefined(Core.Compiler, :MethodMatchResult) + (; matches) = result + else + matches = result + end + # also need an edge to the method table in case something gets + # added that did not intersect with any existing method + fullmatch = Core.Compiler._any(match::Core.MethodMatch -> match.fully_covers, matches) + if !fullmatch + Core.Compiler.add_mt_backedge!(sv, mt, sig) + end + if Core.Compiler.isempty(matches) + return false + else + for i = 1:Core.Compiler.length(matches) + match = Core.Compiler.getindex(matches, i)::Core.MethodMatch + edge = Core.Compiler.specialize_method(match)::Core.MethodInstance + Core.Compiler.add_backedge!(sv, edge) + end + return true + end +end \ No newline at end of file diff --git a/test/ruleinvalidation.jl b/test/ruleinvalidation.jl index 704ada2b6e..37cb21b08f 100644 --- a/test/ruleinvalidation.jl +++ b/test/ruleinvalidation.jl @@ -34,18 +34,14 @@ for m in methods(forward, Tuple{Any,Const{typeof(issue696)},Vararg{Any}}) end @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 @static if VERSION < v"1.11-" -@test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 + @test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 else -@test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 + @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 end # now test invalidation for `inactive` inactive(::typeof(issue696), args...) = nothing @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 -@static if VERSION < v"1.11-" -@test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 -else @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 -end end # module