Skip to content

Commit

Permalink
vc/fixup isapplicable use v2 (#2158)
Browse files Browse the repository at this point in the history
* Forward interp and sv to isapplicable

* invalidation for inactive now works

* move tfunc to compilers

* fixup! move tfunc to compilers

* fixup! fixup! move tfunc to compilers

* add ephermal cache

* bump versions

---------

Co-authored-by: William S. Moses <[email protected]>
  • Loading branch information
vchuravy and wsmoses authored Dec 5, 2024
1 parent 5b85862 commit 6606cd9
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 44 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ EnzymeStaticArraysExt = "StaticArrays"
BFloat16s = "0.2, 0.3, 0.4, 0.5"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.8.7"
EnzymeCore = "0.8.8"
Enzyme_jll = "0.0.167"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1"
LLVM = "6.1, 7, 8, 9"
Expand Down
2 changes: 1 addition & 1 deletion lib/EnzymeCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeCore"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>"]
version = "0.8.7"
version = "0.8.8"

[compat]
Adapt = "3, 4"
Expand Down
22 changes: 3 additions & 19 deletions lib/EnzymeCore/src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ end
function has_frule_from_sig(@nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool
ft, tt = _annotate_tt(TT)
TT = Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...}
return isapplicable(forward, TT; world, method_table, caller)
Expand All @@ -180,7 +180,7 @@ end
function has_rrule_from_sig(@nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool
ft, tt = _annotate_tt(TT)
TT = Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...}
return isapplicable(augmented_primal, TT; world, method_table, caller)
Expand All @@ -192,7 +192,7 @@ end
function isapplicable(@nospecialize(f), @nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool
tt = Base.to_tuple_type(TT)
sig = Base.signature_type(f, tt)
mt = ccall(:jl_method_table_for, Any, (Any,), sig)
Expand All @@ -211,12 +211,6 @@ function isapplicable(@nospecialize(f), @nospecialize(TT);
if !fullmatch
if caller isa Core.MethodInstance
add_mt_backedge!(caller, mt, sig)
elseif caller isa Core.Compiler.MethodLookupResult
for j = 1:Core.Compiler.length(caller)
cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch
cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance
add_mt_backedge!(cspec, mt, sig)
end
end
end
if Core.Compiler.isempty(matches)
Expand All @@ -228,16 +222,6 @@ function isapplicable(@nospecialize(f), @nospecialize(TT);
edge = Core.Compiler.specialize_method(match)::Core.MethodInstance
add_backedge!(caller, edge, sig)
end
elseif caller isa Core.Compiler.MethodLookupResult
for j = 1:Core.Compiler.length(caller)
cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch
cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance
for i = 1:Core.Compiler.length(matches)
match = Core.Compiler.getindex(matches, i)::Core.MethodMatch
edge = Core.Compiler.specialize_method(match)::Core.MethodInstance
add_backedge!(cspec, edge, sig)
end
end
end
return true
end
Expand Down
39 changes: 22 additions & 17 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter
inf_params::InferenceParams
opt_params::OptimizationParams

rules_cache::IdDict{Any, Bool}

forward_rules::Bool
reverse_rules::Bool
deferred_lower::Bool
Expand Down Expand Up @@ -78,6 +80,7 @@ function EnzymeInterpreter(
# parameters for inference and optimization
parms,
OptimizationParams(),
IdDict{Any, Bool}(),
forward_rules,
reverse_rules,
deferred_lower,
Expand Down Expand Up @@ -168,6 +171,8 @@ function simplify_kw(@nospecialize(specTypes))
end
end

include("tfunc.jl")

import Core.Compiler: CallInfo
struct NoInlineCallInfo <: CallInfo
info::CallInfo # wrapped call
Expand All @@ -192,6 +197,7 @@ Core.Compiler.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) =
Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) =
Core.Compiler.getresult(info.info, idx)

import .EnzymeRules: FwdConfig, RevConfig, Annotation
using Core.Compiler: ArgInfo, StmtInfo, AbsIntState
function Core.Compiler.abstract_call_gf_by_type(
@nospecialize(interp::EnzymeInterpreter),
Expand All @@ -212,31 +218,30 @@ function Core.Compiler.abstract_call_gf_by_type(
max_methods::Int,
)
callinfo = ret.info
method_table = Core.Compiler.method_table(interp)
specTypes = simplify_kw(atype)
caller = if callinfo isa Core.Compiler.MethodMatchInfo && callinfo.results isa Core.Compiler.MethodLookupResult
callinfo.results
else
nothing
end

if is_primitive_func(specTypes)
callinfo = NoInlineCallInfo(callinfo, atype, :primitive)
elseif is_alwaysinline_func(specTypes)
callinfo = AlwaysInlineCallInfo(callinfo, atype)
elseif EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table, caller)
callinfo = NoInlineCallInfo(callinfo, atype, :inactive)
else
if interp.forward_rules
if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table, caller)
callinfo = NoInlineCallInfo(callinfo, atype, :frule)
end
end

if interp.reverse_rules
if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table, caller)
callinfo = NoInlineCallInfo(callinfo, atype, :rrule)
# 1. Check if function is inactive
if is_inactive_from_sig(interp, specTypes, sv)
callinfo = NoInlineCallInfo(callinfo, atype, :inactive)
else
# 2. Check if rule is defined
has_rule = get!(interp.rules_cache, specTypes) do
if interp.forward_rules && has_frule_from_sig(interp, specTypes, sv)
return true
elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, sv)
return true
else
return false
end
end
if has_rule
callinfo = NoInlineCallInfo(callinfo, atype, interp.forward_rules ? :frule : :rrule)
end
end
end
@static if VERSION v"1.11-"
Expand Down
56 changes: 56 additions & 0 deletions src/compiler/tfunc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import EnzymeCore: Annotation
import EnzymeCore.EnzymeRules: FwdConfig, RevConfig, forward, augmented_primal, inactive, _annotate_tt

function has_frule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter),
@nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool
ft, tt = _annotate_tt(TT)
TT = Tuple{<:FwdConfig,<:Annotation{ft},Type{<:Annotation},tt...}
return isapplicable(interp, forward, TT, sv)
end

function has_rrule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter),
@nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool
ft, tt = _annotate_tt(TT)
TT = Tuple{<:RevConfig,<:Annotation{ft},Type{<:Annotation},tt...}
return isapplicable(interp, augmented_primal, TT, sv)
end


function is_inactive_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter),
@nospecialize(TT), sv::Core.Compiler.AbsIntState)
return isapplicable(interp, inactive, TT, sv)
end

# `hasmethod` is a precise match using `Core.Compiler.findsup`,
# but here we want the broader query using `Core.Compiler.findall`.
# Also add appropriate backedges to the caller `MethodInstance` if given.
function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter),
@nospecialize(f), @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool
tt = Base.to_tuple_type(TT)
sig = Base.signature_type(f, tt)
mt = ccall(:jl_method_table_for, Any, (Any,), sig)
mt isa Core.MethodTable || return false
result = Core.Compiler.findall(sig, Core.Compiler.method_table(interp); limit=-1)
(result === nothing || result === missing) && return false
@static if isdefined(Core.Compiler, :MethodMatchResult)
(; matches) = result
else
matches = result
end
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
fullmatch = Core.Compiler._any(match::Core.MethodMatch -> match.fully_covers, matches)
if !fullmatch
Core.Compiler.add_mt_backedge!(sv, mt, sig)
end
if Core.Compiler.isempty(matches)
return false
else
for i = 1:Core.Compiler.length(matches)
match = Core.Compiler.getindex(matches, i)::Core.MethodMatch
edge = Core.Compiler.specialize_method(match)::Core.MethodInstance
Core.Compiler.add_backedge!(sv, edge)
end
return true
end
end
8 changes: 2 additions & 6 deletions test/ruleinvalidation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,14 @@ for m in methods(forward, Tuple{Any,Const{typeof(issue696)},Vararg{Any}})
end
@test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] 2.0
@static if VERSION < v"1.11-"
@test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] 2.0
@test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] 2.0
else
@test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] 2.0
@test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] 2.0
end

# now test invalidation for `inactive`
inactive(::typeof(issue696), args...) = nothing
@test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] 0.0
@static if VERSION < v"1.11-"
@test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] 0.0
else
@test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] 0.0
end

end # module

4 comments on commit 6606cd9

@wsmoses
Copy link
Member

@wsmoses wsmoses commented on 6606cd9 Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir="lib/EnzymeCore"

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/120705

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a EnzymeCore-v0.8.8 -m "<description of version>" 6606cd96184364cb7d39ae0e75259be5066b5630
git push origin EnzymeCore-v0.8.8

@wsmoses
Copy link
Member

@wsmoses wsmoses commented on 6606cd9 Dec 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/120707

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.13.19 -m "<description of version>" 6606cd96184364cb7d39ae0e75259be5066b5630
git push origin v0.13.19

Please sign in to comment.