Skip to content

Commit

Permalink
World backedge holder
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 7, 2024
1 parent 4f160a0 commit c568216
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 16 deletions.
17 changes: 16 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5554,7 +5554,22 @@ function thunk_generator(world::UInt, source::LineNumberNode, @nospecialize(FA::
# new_ci.min_world = min_world[]
new_ci.min_world = world
new_ci.max_world = max_world[]
new_ci.edges = Core.MethodInstance[mi]

edges = Any[mi]

if Mode == API.DEM_Forward
sig = Tuple{typeof(Compiler.rule_backedge_holder), typeof(EnzymeRules.forward)}
push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig))
else
sig = Tuple{typeof(Compiler.rule_backedge_holder), typeof(EnzymeRules.augmented_primal)}
push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig))
end

sig = Tuple{typeof(Compiler.rule_backedge_holder), Val{0}}
push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig))

new_ci.edges = edges

# XXX: setting this edge does not give us proper method invalidation, see
# JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel.
# invoking `code_llvm` also does the necessary codegen, as does calling the
Expand Down
36 changes: 21 additions & 15 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter
inf_params::InferenceParams
opt_params::OptimizationParams

rules_cache::IdDict{Any, Bool}

forward_rules::Bool
reverse_rules::Bool
deferred_lower::Bool
Expand Down Expand Up @@ -103,6 +101,7 @@ Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp
Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params
get_inference_world(@nospecialize(interp::EnzymeInterpreter)) = interp.world
Core.Compiler.get_inference_cache(@nospecialize(interp::EnzymeInterpreter)) = interp.local_cache

@static if HAS_INTEGRATED_CACHE
Core.Compiler.cache_owner(@nospecialize(interp::EnzymeInterpreter)) = interp.token
else
Expand Down Expand Up @@ -225,25 +224,32 @@ function Core.Compiler.abstract_call_gf_by_type(
elseif is_alwaysinline_func(specTypes)
callinfo = AlwaysInlineCallInfo(callinfo, atype)
else
# 1. Check if function is inactive
if is_inactive_from_sig(interp, specTypes, sv)

if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table)
callinfo = NoInlineCallInfo(callinfo, atype, :inactive)
else
# 2. Check if rule is defined
has_rule = get!(interp.rules_cache, specTypes) do
if interp.forward_rules && has_frule_from_sig(interp, specTypes, sv)
return true
elseif interp.reverse_rules && has_rrule_from_sig(interp, specTypes, sv)
return true
else
return false
if interp.forward_rules
if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table)
callinfo = NoInlineCallInfo(callinfo, atype, :frule)
end
end

if interp.reverse_rules
if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table)
callinfo = NoInlineCallInfo(callinfo, atype, :rrule)
end
end
if has_rule
callinfo = NoInlineCallInfo(callinfo, atype, interp.forward_rules ? :frule : :rrule)
end
end

if interp.forward_rules
Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Compiler.rule_backedge_holder), Tuple{typeof(EnzymeRules.forward)}, interp.world)::Core.MethodInstance)
end
if interp.reverse_rules
Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Compiler.rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_primal)}, interp.world)::Core.MethodInstance)
end
Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Compiler.rule_backedge_holder), Tuple{Val{0}}, interp.world)::Core.MethodInstance)
end

@static if VERSION v"1.11-"
return Core.Compiler.CallMeta(ret.rt, ret.exct, ret.effects, callinfo)
else
Expand Down
91 changes: 91 additions & 0 deletions src/compiler/tfunc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,95 @@ function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter),
end
return true
end
end

function rule_backedge_holder_generator(world::UInt, source, self, ft::Type)
@nospecialize
ft = functy.parameters[1]
sig = Tuple{typeof(Base.identity)}
min_world = Ref{UInt}(typemin(UInt))
max_world = Ref{UInt}(typemax(UInt))
has_ambig = Ptr{Int32}(C_NULL)
mthds = Base._methods_by_ftype(
sig,
method_table,
-1, #=lim=#
world,
false, #=ambig=#
min_world,
max_world,
has_ambig,
)
mtypes, msp, m = mthds[1]
mi = ccall(
:jl_specializations_get_linfo,
Ref{Core.MethodInstance},
(Any, Any, Any),
m,
mtypes,
msp,
)
ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo

# prepare a new code info
new_ci = copy(ci)
empty!(new_ci.code)
@static if isdefined(Core, :DebugInfo)
new_ci.debuginfo = Core.DebugInfo(:none)
else
empty!(new_ci.codelocs)
resize!(new_ci.linetable, 1) # see note below
end
empty!(new_ci.ssaflags)
new_ci.ssavaluetypes = 0
new_ci.min_world = min_world[]
new_ci.max_world = max_world[]

### TODO: backedge from inactive, augmented_primal, forward, reverse
@show ft
edges = Any[]

if ft == typeof(EnzymeRules.augmented_primal)
sig = Tuple{typeof(EnzymeRules.augmented_primal), <:RevConfig, <:Annotation, Type{<:Annotation},Vararg{<:Annotation}}
push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig))
elseif ft == typeof(EnzymeRules.forward)
sig = Tuple{typeof(EnzymeRules.forward), <:FwdConfig, <:Annotation, Type{<:Annotation},Vararg{<:Annotation}}
push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig))
else
sig = Tuple{typeof(EnzymeRules.inactive), Vararg{<:Annotation}}
push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig))

sig = Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{<:Annotation}}
push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig))

sig = Tuple{typeof(EnzymeRules.noalias), Vararg{<:Annotation}}
push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig))
end
@show edges
new_ci.edges = edges

# XXX: setting this edge does not give us proper method invalidation, see
# JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel.
# invoking `code_llvm` also does the necessary codegen, as does calling the
# underlying C methods -- which GPUCompiler does, so everything Just Works.

# prepare the slots
new_ci.slotnames = Symbol[Symbol("#self#"), :ft]
new_ci.slotflags = UInt8[0x00 for i = 1:2]

# return the codegen world age
push!(new_ci.code, Core.Compiler.ReturnNode(0))
push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code`
@static if isdefined(Core, :DebugInfo)
else
push!(new_ci.codelocs, 1) # see note below
end
new_ci.ssavaluetypes += 1

return new_ci
end

@eval Base.@assume_effects :removable :foldable :nothrow @inline function rule_backedge_holder(ft::Type)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, rule_backedge_holder_generator))
end

0 comments on commit c568216

Please sign in to comment.