Skip to content

Commit

Permalink
Fix methodinstance usage and backedges
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 14, 2024
1 parent 7c0823f commit 9e50cc5
Show file tree
Hide file tree
Showing 11 changed files with 186 additions and 189 deletions.
16 changes: 8 additions & 8 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0))
end

opt_mi = if RABI <: NonGenABI
Compiler.fspec(eltype(FA), tt)
my_methodinstance(Reverse, eltype(FA), tt)
else
Val(0)
end
Expand Down Expand Up @@ -536,7 +536,7 @@ Like [`autodiff`](@ref) but will try to guess the activity of the return value.
) where {FA<:Annotation,CMode<:Mode,Nargs}
tt = vaEltypeof(args...)
rt = Compiler.primal_return_type(
mode,
mode isa ForwardMode ? Forward : Reverse,
eltype(FA),
tt,
)
Expand Down Expand Up @@ -632,7 +632,7 @@ f(x) = x*x
tt = vaEltypeof(args...)

opt_mi = if RABI <: NonGenABI
Compiler.fspec(eltype(FA), tt)
my_methodinstance(Forward, eltype(FA), tt)
else
Val(0)
end
Expand Down Expand Up @@ -968,7 +968,7 @@ result, ∂v, ∂A

tt′ = Tuple{args...}
opt_mi = if RABI <: NonGenABI
Compiler.fspec(eltype(FA), tt)
my_methodinstance(Reverse, eltype(FA), tt)
else
Val(0)
end
Expand Down Expand Up @@ -1098,7 +1098,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, Duplicated, Duplicated{Float

tt′ = Tuple{args...}
opt_mi = if RABI <: NonGenABI
Compiler.fspec(eltype(FA), tt)
my_methodinstance(Forward, eltype(FA), tt)
else
Val(0)
end
Expand Down Expand Up @@ -1166,7 +1166,7 @@ end

primal_tt = Tuple{map(eltype, args)...}
opt_mi = if RABI <: NonGenABI
Compiler.fspec(eltype(FA), TT)
my_methodinstance(Forward, eltype(FA), primal_tt)
else
Val(0)
end
Expand Down Expand Up @@ -1196,7 +1196,7 @@ const tape_cache = Dict{UInt,Type}()

const tape_cache_lock = ReentrantLock()

import .Compiler: fspec, remove_innerty, UnknownTapeType
import .Compiler: remove_innerty, UnknownTapeType

@inline function tape_type(
parent_job::Union{GPUCompiler.CompilerJob,Nothing},
Expand Down Expand Up @@ -1246,7 +1246,7 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType

primal_tt = Tuple{map(eltype, args)...}

mi = Compiler.fspec(eltype(FA), TT)
mi = my_methodinstance(parent_job === nothing ? Reverse : GPUCompiler.get_interpreter(parent_job), eltype(FA), primal_tt)

target = Compiler.EnzymeTarget()
params = Compiler.EnzymeCompilerParams(
Expand Down
1 change: 1 addition & 0 deletions src/analyses/activity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ end
EnzymeCore.EnzymeRules.inactive_type(T)
else
inmi = my_methodinstance(
nothing,
typeof(EnzymeCore.EnzymeRules.inactive_type),
Tuple{Type{T}},
world,
Expand Down
40 changes: 15 additions & 25 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ include("llvm/passes.jl")
include("typeutils/make_zero.jl")

function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt)
funcspec = my_methodinstance(typeof(f), tt, world)
funcspec = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, typeof(f), tt, world)
nested_codegen!(mode, mod, funcspec, world)
end

Expand Down Expand Up @@ -1124,10 +1124,6 @@ function __init__()
)
register_alloc_rules()
register_llvm_rules()

# Force compilation of AD stack
# thunk = Enzyme.Compiler.thunk(Enzyme.Compiler.fspec(typeof(Base.identity), Tuple{Active{Float64}}), Const{typeof(Base.identity)}, Active, Tuple{Active{Float64}}, #=Split=# Val(Enzyme.API.DEM_ReverseModeCombined), #=width=#Val(1), #=ModifiedBetween=#Val((false,false)), Val(#=ReturnPrimal=#false), #=ShadowInit=#Val(false), NonGenABI)
# thunk(Const(Base.identity), Active(1.0), 1.0)
end

# Define EnzymeTarget
Expand Down Expand Up @@ -1258,6 +1254,8 @@ Create the methodinstance pair, and lookup the primal return type.
@nospecialize(TT::Type),
world::Union{UInt,Nothing} = nothing,
)

fdsafdsafsa
# primal function. Inferred here to get return type
_tt = (TT.parameters...,)

Expand Down Expand Up @@ -2123,7 +2121,7 @@ function create_abi_wrapper(
push!(realparms, val)
elseif T <: BatchDuplicatedFunc
Func = get_func(T)
funcspec = my_methodinstance(Func, Tuple{}, world)
funcspec = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, Tuple{}, world)
llvmf = nested_codegen!(Mode, mod, funcspec, world)
push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))
Func_RT = return_type(interp, funcspec)
Expand Down Expand Up @@ -4520,7 +4518,7 @@ end
((LLVM.DoubleType(), Float64, ""), (LLVM.FloatType(), Float32, "f"))
fname = String(name) * pf
if haskey(functions(mod), fname)
funcspec = my_methodinstance(fnty, Tuple{JT}, world)
funcspec = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, fnty, Tuple{JT}, world)
llvmf = nested_codegen!(mode, mod, funcspec, world)
push!(function_attributes(llvmf), StringAttribute("implements", fname))
end
Expand Down Expand Up @@ -5545,17 +5543,13 @@ function thunk_generator(world::UInt, source::LineNumberNode, @nospecialize(FA::
primal_tt = Tuple{map(eltype, TT.parameters)...}
# look up the method match
method_error = :(throw(MethodError($ft, $primal_tt, $world)))
sig = Tuple{ft, primal_tt.parameters...}

min_world = Ref{UInt}(typemin(UInt))
max_world = Ref{UInt}(typemax(UInt))
match = ccall(:jl_gf_invoke_lookup_worlds, Any,
(Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}),
sig, #=mt=# nothing, world, min_world, max_world)
match === nothing && return stub(world, source, method_error)

# look up the method and code instance
mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
(Any, Any, Any), match.method, match.spec_types, match.sparams)

mi = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, ft, primal_tt, world, min_world, max_world)

mi === nothing && return stub(world, source, method_error)

ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo

Expand Down Expand Up @@ -5681,18 +5675,14 @@ function deferred_id_generator(world::UInt, source::LineNumberNode, @nospecializ
primal_tt = Tuple{map(eltype, TT.parameters)...}
# look up the method match
method_error = :(throw(MethodError($ft, $primal_tt, $world)))
sig = Tuple{ft, primal_tt.parameters...}

min_world = Ref{UInt}(typemin(UInt))
max_world = Ref{UInt}(typemax(UInt))
match = ccall(:jl_gf_invoke_lookup_worlds, Any,
(Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}),
sig, #=mt=# nothing, world, min_world, max_world)
match === nothing && return stub(world, source, method_error)

# look up the method and code instance
mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
(Any, Any, Any), match.method, match.spec_types, match.sparams)

mi = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, ft, primal_tt, world, min_world, max_world)

mi === nothing && return stub(world, source, method_error)

ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo

# prepare a new code info
Expand Down
80 changes: 22 additions & 58 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,34 +68,31 @@ function rule_backedge_holder_generator(world::UInt, source, self, ft::Type)
### TODO: backedge from inactive, augmented_primal, forward, reverse
edges = Any[]

@static if false
if ft == typeof(EnzymeRules.augmented_primal)
# this is illegal
# sig = Tuple{typeof(EnzymeRules.augmented_primal), <:RevConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}}
# push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig))
push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.augmented_primal), Tuple{<:RevConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}}, world))
rev_sig = Tuple{typeof(EnzymeRules.augmented_primal), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}}
push!(edges, ccall(:jl_method_table_for, Any, (Any,), rev_sig)::Core.MethodTable)
push!(edges, rev_sig)

rev_sig = Tuple{typeof(EnzymeRules.reverse), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Union{Type{<:Enzyme.EnzymeCore.Annotation}, Enzyme.EnzymeCore.Active}, Any, Vararg{Enzyme.EnzymeCore.Annotation}}
push!(edges, ccall(:jl_method_table_for, Any, (Any,), rev_sig)::Core.MethodTable)
push!(edges, rev_sig)
elseif ft == typeof(EnzymeRules.forward)
# this is illegal
# sig = Tuple{typeof(EnzymeRules.forward), <:FwdConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}}
# push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig))
push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.forward), Tuple{<:FwdConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}}, world))
fwd_sig = Tuple{typeof(EnzymeRules.forward), <:EnzymeRules.FwdConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}}
push!(edges, ccall(:jl_method_table_for, Any, (Any,), fwd_sig)::Core.MethodTable)
push!(edges, fwd_sig)
elseif ft == typeof(EnzymeRules.inactive)
ina_sig = Tuple{typeof(EnzymeRules.inactive), Vararg{Any}}
push!(edges, ccall(:jl_method_table_for, Any, (Any,), ina_sig)::Core.MethodTable)
push!(edges, ina_sig)
else
# sig = Tuple{typeof(EnzymeRules.inactive), Vararg{Annotation}}
# push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig))
push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.inactive), Tuple{Vararg{Annotation}}, world))

# sig = Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{Annotation}}
# push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig))
push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.inactive_noinl), Tuple{Vararg{Annotation}}, world))

# sig = Tuple{typeof(EnzymeRules.noalias), Vararg{Any}}
# push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig))
push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.noalias), Tuple{Vararg{Any}}, world))

# sig = Tuple{typeof(EnzymeRules.inactive_type), Type}
# push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig))
push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.inactive_type), Tuple{Type}, world))
end
for gen_sig in (
Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{Any}},
Tuple{typeof(EnzymeRules.noalias), Vararg{Any}},
Tuple{typeof(EnzymeRules.inactive_type), Type},
)
push!(edges, ccall(:jl_method_table_for, Any, (Any,), gen_sig)::Core.MethodTable)
push!(edges, gen_sig)
end
end

new_ci.edges = edges
Expand Down Expand Up @@ -126,39 +123,6 @@ end
$(Expr(:meta, :generated, rule_backedge_holder_generator))
end

begin
# Forward-rule catch all
fwd_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{typeof(EnzymeRules.forward)})
# Reverse-rule catch all
rev_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_primal)})
# Inactive-rule catch all
ina_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{typeof(EnzymeRules.inactive)})
# All other derivative-related catch all (just for autodiff, not inference), including inactive_noinl, noalias, and inactive_type
gen_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{Val{0}})


fwd_sig = Tuple{typeof(EnzymeRules.forward), <:EnzymeRules.FwdConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}}
EnzymeRules.add_mt_backedge!(fwd_rule_be, ccall(:jl_method_table_for, Any, (Any,), fwd_sig)::Core.MethodTable, fwd_sig)

rev_sig = Tuple{typeof(EnzymeRules.augmented_primal), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}}
EnzymeRules.add_mt_backedge!(rev_rule_be, ccall(:jl_method_table_for, Any, (Any,), rev_sig)::Core.MethodTable, rev_sig)


for ina_sig in (
Tuple{typeof(EnzymeRules.inactive), Vararg{Any}},
)
EnzymeRules.add_mt_backedge!(ina_rule_be, ccall(:jl_method_table_for, Any, (Any,), ina_sig)::Core.MethodTable, ina_sig)
end

for gen_sig in (
Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{Any}},
Tuple{typeof(EnzymeRules.noalias), Vararg{Any}},
Tuple{typeof(EnzymeRules.inactive_type), Type},
)
EnzymeRules.add_mt_backedge!(gen_rule_be, ccall(:jl_method_table_for, Any, (Any,), gen_sig)::Core.MethodTable, gen_sig)
end
end

struct EnzymeInterpreter{T} <: AbstractInterpreter
@static if HAS_INTEGRATED_CACHE
token::Any
Expand Down
19 changes: 7 additions & 12 deletions src/compiler/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ function get_job(
tt = Tuple{map(eltype, types.parameters)...}


primal, rt = if world isa Nothing
fspec(Core.Typeof(func), types), Compiler.primal_return_type(mode == API.DEM_ForwardMode ? Forward : Reverse, Core.Typeof(func), tt)
else
fspec(Core.Typeof(func), types, world), Compiler.primal_return_type_world(mode == API.DEM_ForwardMode ? Forward : Reverse, world, Core.Typeof(func), tt)
primal, rt =
if world isa Nothing
world=Base.get_world_counter()
end

primal = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, eltype(Core.Typeof(func)), Tuple{map(eltype, types.parameters)...}, world)
rt = Compiler.primal_return_type_world(mode == API.DEM_ForwardMode ? Forward : Reverse, world, Core.Typeof(func), tt)

rt = A{rt}
target = Compiler.EnzymeTarget()
if modifiedBetween === nothing
Expand All @@ -47,18 +49,11 @@ function get_job(
ErrIfFuncWritten,
RuntimeActivity,
)
if world isa Nothing
return Compiler.CompilerJob(
primal,
CompilerConfig(target, params; kernel = false),
)
else
return Compiler.CompilerJob(
return Compiler.CompilerJob(
primal,
CompilerConfig(target, params; kernel = false),
world,
)
end
end

function reflect(
Expand Down
16 changes: 2 additions & 14 deletions src/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -474,24 +474,12 @@ const generic_method_offsets = Dict{String,Tuple{Int,Int}}((
"ijl_apply_generic" => (1, 2),
))

@inline function has_method(@nospecialize(sig::Type), world::UInt, mt::Union{Nothing,Core.MethodTable})
return ccall(:jl_gf_invoke_lookup, Any, (Any, Any, UInt), sig, mt, world) !== nothing
end

@inline function has_method(@nospecialize(sig::Type), world::UInt, mt::Core.Compiler.InternalMethodTable)
return has_method(sig, mt.world, nothing)
end

@inline function has_method(@nospecialize(sig::Type), world::UInt, mt::Core.Compiler.OverlayMethodTable)
return has_method(sig, mt.world, mt.mt) || has_method(sig, mt.world, nothing)
end

@inline function is_inactive(@nospecialize(tys::Union{Vector{Union{Type,Core.TypeofVararg}}, Core.SimpleVector}), world::UInt, @nospecialize(mt))
specTypes = Interpreter.simplify_kw(Tuple{tys...})
if has_method(Tuple{typeof(EnzymeRules.inactive),tys...}, world, mt)
if Enzyme.has_method(Tuple{typeof(EnzymeRules.inactive),tys...}, world, mt)
return true
end
if has_method(Tuple{typeof(EnzymeRules.inactive_noinl),tys...}, world, mt)
if Enzyme.has_method(Tuple{typeof(EnzymeRules.inactive_noinl),tys...}, world, mt)
return true
end
return false
Expand Down
Loading

0 comments on commit 9e50cc5

Please sign in to comment.