Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vc/fixup isapplicable use v2 #2158

Merged
merged 8 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ EnzymeStaticArraysExt = "StaticArrays"
BFloat16s = "0.2, 0.3, 0.4, 0.5"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.8.7"
EnzymeCore = "0.8.8"
Enzyme_jll = "0.0.167"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1"
LLVM = "6.1, 7, 8, 9"
Expand Down
2 changes: 1 addition & 1 deletion lib/EnzymeCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeCore"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.8.7"
version = "0.8.8"

[compat]
Adapt = "3, 4"
Expand Down
22 changes: 3 additions & 19 deletions lib/EnzymeCore/src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ end
function has_frule_from_sig(@nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool
ft, tt = _annotate_tt(TT)
TT = Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...}
return isapplicable(forward, TT; world, method_table, caller)
Expand All @@ -180,7 +180,7 @@ end
function has_rrule_from_sig(@nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool
ft, tt = _annotate_tt(TT)
TT = Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...}
return isapplicable(augmented_primal, TT; world, method_table, caller)
Expand All @@ -192,7 +192,7 @@ end
function isapplicable(@nospecialize(f), @nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool
tt = Base.to_tuple_type(TT)
sig = Base.signature_type(f, tt)
mt = ccall(:jl_method_table_for, Any, (Any,), sig)
Expand All @@ -211,12 +211,6 @@ function isapplicable(@nospecialize(f), @nospecialize(TT);
if !fullmatch
if caller isa Core.MethodInstance
add_mt_backedge!(caller, mt, sig)
elseif caller isa Core.Compiler.MethodLookupResult
for j = 1:Core.Compiler.length(caller)
cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch
cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance
add_mt_backedge!(cspec, mt, sig)
end
end
end
if Core.Compiler.isempty(matches)
Expand All @@ -228,16 +222,6 @@ function isapplicable(@nospecialize(f), @nospecialize(TT);
edge = Core.Compiler.specialize_method(match)::Core.MethodInstance
add_backedge!(caller, edge, sig)
end
elseif caller isa Core.Compiler.MethodLookupResult
for j = 1:Core.Compiler.length(caller)
cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch
cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance
for i = 1:Core.Compiler.length(matches)
match = Core.Compiler.getindex(matches, i)::Core.MethodMatch
edge = Core.Compiler.specialize_method(match)::Core.MethodInstance
add_backedge!(cspec, edge, sig)
end
end
end
return true
end
Expand Down
39 changes: 22 additions & 17 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter
inf_params::InferenceParams
opt_params::OptimizationParams

rules_cache::IdDict{Any, Bool}

forward_rules::Bool
reverse_rules::Bool
deferred_lower::Bool
Expand Down Expand Up @@ -78,6 +80,7 @@ function EnzymeInterpreter(
# parameters for inference and optimization
parms,
OptimizationParams(),
IdDict{Any, Bool}(),
forward_rules,
reverse_rules,
deferred_lower,
Expand Down Expand Up @@ -168,6 +171,8 @@ function simplify_kw(@nospecialize(specTypes))
end
end

include("tfunc.jl")

import Core.Compiler: CallInfo
struct NoInlineCallInfo <: CallInfo
info::CallInfo # wrapped call
Expand All @@ -192,6 +197,7 @@ Core.Compiler.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) =
Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) =
Core.Compiler.getresult(info.info, idx)

import .EnzymeRules: FwdConfig, RevConfig, Annotation
using Core.Compiler: ArgInfo, StmtInfo, AbsIntState
function Core.Compiler.abstract_call_gf_by_type(
@nospecialize(interp::EnzymeInterpreter),
Expand All @@ -212,31 +218,30 @@ function Core.Compiler.abstract_call_gf_by_type(
max_methods::Int,
)
callinfo = ret.info
method_table = Core.Compiler.method_table(interp)
specTypes = simplify_kw(atype)
caller = if callinfo isa Core.Compiler.MethodMatchInfo && callinfo.results isa Core.Compiler.MethodLookupResult
callinfo.results
else
nothing
end

if is_primitive_func(specTypes)
callinfo = NoInlineCallInfo(callinfo, atype, :primitive)
elseif is_alwaysinline_func(specTypes)
callinfo = AlwaysInlineCallInfo(callinfo, atype)
elseif EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table, caller)
callinfo = NoInlineCallInfo(callinfo, atype, :inactive)
else
if interp.forward_rules
if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table, caller)
callinfo = NoInlineCallInfo(callinfo, atype, :frule)
end
end

if interp.reverse_rules
if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table, caller)
callinfo = NoInlineCallInfo(callinfo, atype, :rrule)
# 1. Check if function is inactive
if is_inactive_from_sig(interp, specTypes, sv)
callinfo = NoInlineCallInfo(callinfo, atype, :inactive)
else
# 2. Check if rule is defined
has_rule = get!(interp.rules_cache, specTypes) do
if interp.forward_rules && has_frule_from_sig(interp, specTypes, sv)
return true
elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, sv)
return true
else
return false
end
end
if has_rule
callinfo = NoInlineCallInfo(callinfo, atype, interp.forward_rules ? :frule : :rrule)
end
end
end
@static if VERSION ≥ v"1.11-"
Expand Down
56 changes: 56 additions & 0 deletions src/compiler/tfunc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import EnzymeCore: Annotation
import EnzymeCore.EnzymeRules: FwdConfig, RevConfig, forward, augmented_primal, inactive, _annotate_tt

function has_frule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter),
@nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool
ft, tt = _annotate_tt(TT)
TT = Tuple{<:FwdConfig,<:Annotation{ft},Type{<:Annotation},tt...}
return isapplicable(interp, forward, TT, sv)
end

function has_rrule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter),
@nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool
ft, tt = _annotate_tt(TT)
TT = Tuple{<:RevConfig,<:Annotation{ft},Type{<:Annotation},tt...}
return isapplicable(interp, augmented_primal, TT, sv)
end


function is_inactive_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter),
@nospecialize(TT), sv::Core.Compiler.AbsIntState)
return isapplicable(interp, inactive, TT, sv)
end

# `hasmethod` is a precise match using `Core.Compiler.findsup`,
# but here we want the broader query using `Core.Compiler.findall`.
# Also add appropriate backedges to the caller `MethodInstance` if given.
function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter),
@nospecialize(f), @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool
tt = Base.to_tuple_type(TT)
sig = Base.signature_type(f, tt)
mt = ccall(:jl_method_table_for, Any, (Any,), sig)
mt isa Core.MethodTable || return false
result = Core.Compiler.findall(sig, Core.Compiler.method_table(interp); limit=-1)
(result === nothing || result === missing) && return false
@static if isdefined(Core.Compiler, :MethodMatchResult)
(; matches) = result
else
matches = result
end
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
fullmatch = Core.Compiler._any(match::Core.MethodMatch -> match.fully_covers, matches)
if !fullmatch
Core.Compiler.add_mt_backedge!(sv, mt, sig)
end
if Core.Compiler.isempty(matches)
return false
else
for i = 1:Core.Compiler.length(matches)
match = Core.Compiler.getindex(matches, i)::Core.MethodMatch
edge = Core.Compiler.specialize_method(match)::Core.MethodInstance
Core.Compiler.add_backedge!(sv, edge)
end
return true
end
end
8 changes: 2 additions & 6 deletions test/ruleinvalidation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,14 @@ for m in methods(forward, Tuple{Any,Const{typeof(issue696)},Vararg{Any}})
end
@test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0
@static if VERSION < v"1.11-"
@test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0
@test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0
else
@test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0
@test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0
end

# now test invalidation for `inactive`
inactive(::typeof(issue696), args...) = nothing
@test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0
@static if VERSION < v"1.11-"
@test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0
else
@test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0
end

end # module
Loading