diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 92ec9a623f..9258079ba6 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -370,7 +370,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) FTy = Core.Typeof(f.val) rt = if A isa UnionAll - Compiler.primal_return_type(mode, FTy, tt) + Compiler.primal_return_type(Reverse, FTy, tt) else eltype(A) end @@ -410,7 +410,7 @@ Enzyme.autodiff(ReverseWithPrimal, x->x*x, Active(3.0)) end opt_mi = if RABI <: NonGenABI - Compiler.fspec(eltype(FA), tt′) + my_methodinstance(Reverse, eltype(FA), tt) else Val(0) end @@ -536,7 +536,7 @@ Like [`autodiff`](@ref) but will try to guess the activity of the return value. ) where {FA<:Annotation,CMode<:Mode,Nargs} tt = vaEltypeof(args...) rt = Compiler.primal_return_type( - mode, + mode isa ForwardMode ? Forward : Reverse, eltype(FA), tt, ) @@ -632,7 +632,7 @@ f(x) = x*x tt = vaEltypeof(args...) opt_mi = if RABI <: NonGenABI - Compiler.fspec(eltype(FA), tt′) + my_methodinstance(Forward, eltype(FA), tt) else Val(0) end @@ -687,7 +687,7 @@ code, as well as high-order differentiation. A2 = A if A isa UnionAll - rt = Compiler.primal_return_type(mode, FTy, tt) + rt = Compiler.primal_return_type(Reverse, FTy, tt) A2 = A{rt} if rt == Union{} rt = Nothing @@ -840,7 +840,7 @@ code, as well as high-order differentiation. FT = Core.Typeof(f.val) if RT isa UnionAll - rt = Compiler.primal_return_type(mode, FT, tt) + rt = Compiler.primal_return_type(Forward, FT, tt) if rt == Union{} rt = Nothing end @@ -968,7 +968,7 @@ result, ∂v, ∂A tt′ = Tuple{args...} opt_mi = if RABI <: NonGenABI - Compiler.fspec(eltype(FA), tt′) + my_methodinstance(Reverse, eltype(FA), tt) else Val(0) end @@ -1098,7 +1098,7 @@ forward = autodiff_thunk(Forward, Const{typeof(f)}, Duplicated, Duplicated{Float tt′ = Tuple{args...} opt_mi = if RABI <: NonGenABI - Compiler.fspec(eltype(FA), tt′) + my_methodinstance(Forward, eltype(FA), tt) else Val(0) end @@ -1166,7 +1166,7 @@ end primal_tt = Tuple{map(eltype, args)...} opt_mi = if RABI <: NonGenABI - Compiler.fspec(eltype(FA), TT) + my_methodinstance(Forward, eltype(FA), primal_tt) else Val(0) end @@ -1196,7 +1196,7 @@ const tape_cache = Dict{UInt,Type}() const tape_cache_lock = ReentrantLock() -import .Compiler: fspec, remove_innerty, UnknownTapeType +import .Compiler: remove_innerty, UnknownTapeType @inline function tape_type( parent_job::Union{GPUCompiler.CompilerJob,Nothing}, @@ -1246,7 +1246,7 @@ import .Compiler: fspec, remove_innerty, UnknownTapeType primal_tt = Tuple{map(eltype, args)...} - mi = Compiler.fspec(eltype(FA), TT) + mi = my_methodinstance(parent_job === nothing ? Reverse : GPUCompiler.get_interpreter(parent_job), eltype(FA), primal_tt) target = Compiler.EnzymeTarget() params = Compiler.EnzymeCompilerParams( @@ -1382,7 +1382,7 @@ result, ∂v, ∂A TT = Tuple{args...} primal_tt = Tuple{map(eltype, args)...} - rt0 = Compiler.primal_return_type(mode, eltype(FA), primal_tt) + rt0 = Compiler.primal_return_type(Reverse, eltype(FA), primal_tt) rt = Compiler.remove_innerty(A2){rt0} diff --git a/src/analyses/activity.jl b/src/analyses/activity.jl index 3c29838e70..61d2f35ab7 100644 --- a/src/analyses/activity.jl +++ b/src/analyses/activity.jl @@ -249,6 +249,7 @@ end EnzymeCore.EnzymeRules.inactive_type(T) else inmi = my_methodinstance( + nothing, typeof(EnzymeCore.EnzymeRules.inactive_type), Tuple{Type{T}}, world, diff --git a/src/compiler.jl b/src/compiler.jl index ddccde1a24..61eeb81105 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -318,7 +318,7 @@ include("llvm/passes.jl") include("typeutils/make_zero.jl") function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, @nospecialize(f), @nospecialize(tt::Type), world::UInt) - funcspec = my_methodinstance(typeof(f), tt, world) + funcspec = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, typeof(f), tt, world) nested_codegen!(mode, mod, funcspec, world) end @@ -361,6 +361,8 @@ function prepare_llvm(mod::LLVM.Module, job, meta) end end +const mod_to_edges = Dict{LLVM.Module, Vector{Any}}() + function nested_codegen!( mode::API.CDerivativeMode, mod::LLVM.Module, @@ -390,6 +392,11 @@ function nested_codegen!( permit_inlining!(f) end + edges = get(mod_to_edges, mod, nothing) + @assert edges !== nothing + edges = edges::Vector{Any} + push!(edges, funcspec) + # Apply first stage of optimization's so that this module is at the same stage as `mod` optimize!(otherMod, JIT.get_tm()) # 4) Link the corresponding module @@ -1124,10 +1131,6 @@ function __init__() ) register_alloc_rules() register_llvm_rules() - - # Force compilation of AD stack - # thunk = Enzyme.Compiler.thunk(Enzyme.Compiler.fspec(typeof(Base.identity), Tuple{Active{Float64}}), Const{typeof(Base.identity)}, Active, Tuple{Active{Float64}}, #=Split=# Val(Enzyme.API.DEM_ReverseModeCombined), #=width=#Val(1), #=ModifiedBetween=#Val((false,false)), Val(#=ReturnPrimal=#false), #=ShadowInit=#Val(false), NonGenABI) - # thunk(Const(Base.identity), Active(1.0), 1.0) end # Define EnzymeTarget @@ -1197,16 +1200,27 @@ if VERSION >= v"1.11.0-DEV.1552" always_inline::Any method_table::Core.MethodTable param_type::Type - is_fwd::Bool + last_fwd_rule_world::Union{Nothing, Tuple} + last_rev_rule_world::Union{Nothing, Tuple} + last_ina_rule_world::Tuple end + @inline EnzymeCacheToken(target_type::Type, always_inline::Any, method_table::Core.MethodTable, param_type::Type, world::UInt, is_forward::Bool, is_reverse::Bool) = + EnzymeCacheToken(target_type, always_inline, method_table, param_type, + is_forward ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.forward, Tuple{<:EnzymeCore.EnzymeRules.FwdConfig, <:Annotation, Type{<:Annotation}, Vararg{Annotation}}, world)...,) : nothing, + is_reverse ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.augmented_primal, Tuple{<:EnzymeCore.EnzymeRules.RevConfig, <:Annotation, Type{<:Annotation}, Vararg{Annotation}}, world)...,) : nothing, + (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.inactive, Tuple{Vararg{Any}}, world)...,) + ) + GPUCompiler.ci_cache_token(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = EnzymeCacheToken( typeof(job.config.target), job.config.always_inline, GPUCompiler.method_table(job), typeof(job.config.params), + job.world, job.config.params.mode == API.DEM_ForwardMode, + job.config.params.mode != API.DEM_ForwardMode ) GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = @@ -1258,6 +1272,8 @@ Create the methodinstance pair, and lookup the primal return type. @nospecialize(TT::Type), world::Union{UInt,Nothing} = nothing, ) + +fdsafdsafsa # primal function. Inferred here to get return type _tt = (TT.parameters...,) @@ -2123,7 +2139,7 @@ function create_abi_wrapper( push!(realparms, val) elseif T <: BatchDuplicatedFunc Func = get_func(T) - funcspec = my_methodinstance(Func, Tuple{}, world) + funcspec = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, Func, Tuple{}, world) llvmf = nested_codegen!(Mode, mod, funcspec, world) push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) Func_RT = return_type(interp, funcspec) @@ -3236,6 +3252,7 @@ function GPUCompiler.codegen( if params.run_enzyme # @assert eltype(params.rt) != Union{} end + expectedTapeType = params.expectedTapeType mode = params.mode TT = params.TT @@ -3277,6 +3294,8 @@ function GPUCompiler.codegen( GPUCompiler.prepare_job!(primal_job) mod, meta = GPUCompiler.emit_llvm(primal_job; libraries=true, toplevel=toplevel, optimize=false, cleanup=false, only_entry=false, validate=false) + edges = Any[] + mod_to_edges[mod] = edges prepare_llvm(mod, primal_job, meta) for f in functions(mod) @@ -3555,16 +3574,15 @@ function GPUCompiler.codegen( specTypes = Interpreter.simplify_kw(mi.specTypes) - caller = mi if mode == API.DEM_ForwardMode has_custom_rule = - EnzymeRules.has_frule_from_sig(specTypes; world, method_table, caller) + EnzymeRules.has_frule_from_sig(specTypes; world, method_table) if has_custom_rule @safe_debug "Found frule for" mi.specTypes end else has_custom_rule = - EnzymeRules.has_rrule_from_sig(specTypes; world, method_table, caller) + EnzymeRules.has_rrule_from_sig(specTypes; world, method_table) if has_custom_rule @safe_debug "Found rrule for" mi.specTypes end @@ -3579,7 +3597,8 @@ function GPUCompiler.codegen( actualRetType = k.ci.rettype end - if EnzymeRules.noalias_from_sig(mi.specTypes; world, method_table, caller) + if EnzymeRules.noalias_from_sig(mi.specTypes; world, method_table) + push!(edges, mi) push!(return_attributes(llvmfn), EnumAttribute("noalias")) for u in LLVM.uses(llvmfn) c = LLVM.user(u) @@ -3803,12 +3822,8 @@ end end continue end - if EnzymeRules.is_inactive_from_sig(specTypes; world, method_table, caller) && - has_method( - Tuple{typeof(EnzymeRules.inactive),specTypes.parameters...}, - world, - method_table, - ) + if EnzymeRules.is_inactive_from_sig(specTypes; world, method_table) + push!(edges, mi) handleCustom( llvmfn, "enz_noop", @@ -3821,12 +3836,8 @@ end ) continue end - if EnzymeRules.is_inactive_noinl_from_sig(specTypes; world, method_table, caller) && - has_method( - Tuple{typeof(EnzymeRules.inactive_noinl),specTypes.parameters...}, - world, - method_table, - ) + if EnzymeRules.is_inactive_noinl_from_sig(specTypes; world, method_table) + push!(edges, mi) handleCustom( llvmfn, "enz_noop", @@ -4520,7 +4531,7 @@ end ((LLVM.DoubleType(), Float64, ""), (LLVM.FloatType(), Float32, "f")) fname = String(name) * pf if haskey(functions(mod), fname) - funcspec = my_methodinstance(fnty, Tuple{JT}, world) + funcspec = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, fnty, Tuple{JT}, world) llvmf = nested_codegen!(mode, mod, funcspec, world) push!(function_attributes(llvmf), StringAttribute("implements", fname)) end @@ -4590,10 +4601,12 @@ end isempty(LLVM.blocks(fn)) && continue linkage!(fn, LLVM.API.LLVMLinkerPrivateLinkage) end + + delete!(mod_to_edges, mod) use_primal = mode == API.DEM_ReverseModePrimal entry = use_primal ? augmented_primalf : adjointf - return mod, (; adjointf, augmented_primalf, entry, compiled = meta.compiled, TapeType) + return mod, (; adjointf, augmented_primalf, entry, compiled = meta.compiled, TapeType, edges) end # Compiler result @@ -4601,6 +4614,7 @@ struct CompileResult{AT,PT} adjoint::AT primal::PT TapeType::Type + edges::Vector{Any} end @inline (thunk::PrimalErrorThunk{PT,FA,RT,TT,Width,ReturnPrimal})( @@ -5226,12 +5240,13 @@ end # JIT ## -function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType), prepost::String) +function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, edges::Vector{Any}, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType), prepost::String) if job.config.params.ABI <: InlineABI return CompileResult( Val((Symbol(mod), Symbol(adjoint_name))), Val((Symbol(mod), Symbol(primal_name))), - TapeType + TapeType, + edges ) end @@ -5263,16 +5278,17 @@ function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module end end - return CompileResult(adjoint_ptr, primal_ptr, TapeType) + return CompileResult(adjoint_ptr, primal_ptr, TapeType, edges) end const DumpPostOpt = Ref(false) # actual compilation -function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, String, Union{String, Nothing}, Type, String} +function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, String, Union{String, Nothing}, Type, String} mod, meta = codegen(:llvm, job; optimize = false) adjointf, augmented_primalf = meta.adjointf, meta.augmented_primalf + adjoint_name = name(adjointf) if augmented_primalf !== nothing @@ -5305,7 +5321,7 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, String, Union{Str else "" end - return (mod, adjoint_name, primal_name, meta.TapeType, prepost) + return (mod, meta.edges, adjoint_name, primal_name, meta.TapeType, prepost) end const cache = Dict{UInt,CompileResult}() @@ -5324,10 +5340,10 @@ const cache_lock = ReentrantLock() asm = _thunk(job) obj = _link(job, asm...) if obj.adjoint isa Ptr{Nothing} - autodiff_cache[obj.adjoint] = (asm[2], asm[5]) + autodiff_cache[obj.adjoint] = (asm[3], asm[6]) end - if obj.primal isa Ptr{Nothing} && asm[3] isa String - autodiff_cache[obj.primal] = (asm[3], asm[5]) + if obj.primal isa Ptr{Nothing} && asm[4] isa String + autodiff_cache[obj.primal] = (asm[4], asm[6]) end cache[key] = obj end @@ -5351,7 +5367,8 @@ end @nospecialize(ABI::Type), ErrIfFuncWritten::Bool, RuntimeActivity::Bool, -) + edges::Union{Nothing, Vector{Any}} +) target = Compiler.EnzymeTarget() params = Compiler.EnzymeCompilerParams( Tuple{FA,TT.parameters...}, @@ -5432,6 +5449,11 @@ end compile_result = cached_compilation(job) + if edges !== nothing + for e in compile_result.edges + push!(edges, e) + end + end if !run_enzyme ErrT = PrimalErrorThunk{typeof(compile_result.adjoint),FA,rt2,TT,width,ReturnPrimal} if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient @@ -5528,6 +5550,7 @@ end ABI, ErrIfFuncWritten, RuntimeActivity, + nothing ) finally deactivate(ctx) @@ -5545,17 +5568,13 @@ function thunk_generator(world::UInt, source::LineNumberNode, @nospecialize(FA:: primal_tt = Tuple{map(eltype, TT.parameters)...} # look up the method match method_error = :(throw(MethodError($ft, $primal_tt, $world))) - sig = Tuple{ft, primal_tt.parameters...} + min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) - match = ccall(:jl_gf_invoke_lookup_worlds, Any, - (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), - sig, #=mt=# nothing, world, min_world, max_world) - match === nothing && return stub(world, source, method_error) - - # look up the method and code instance - mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, - (Any, Any, Any), match.method, match.spec_types, match.sparams) + + mi = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, ft, primal_tt, world, min_world, max_world) + + mi === nothing && return stub(world, source, method_error) ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo @@ -5574,18 +5593,34 @@ function thunk_generator(world::UInt, source::LineNumberNode, @nospecialize(FA:: new_ci.min_world = world new_ci.max_world = max_world[] - edges = Core.MethodInstance[mi] + edges = Any[mi] if Mode == API.DEM_ForwardMode - push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.forward)}, world)) - Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(EnzymeRules.forward)) + fwd_sig = Tuple{typeof(EnzymeRules.forward), <:EnzymeRules.FwdConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} + push!(edges, ccall(:jl_method_table_for, Any, (Any,), fwd_sig)::Core.MethodTable) + push!(edges, fwd_sig) else - push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_primal)}, world)) + rev_sig = Tuple{typeof(EnzymeRules.augmented_primal), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} + push!(edges, ccall(:jl_method_table_for, Any, (Any,), rev_sig)::Core.MethodTable) + push!(edges, rev_sig) + + rev_sig = Tuple{typeof(EnzymeRules.reverse), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Union{Type{<:Enzyme.EnzymeCore.Annotation}, Enzyme.EnzymeCore.Active}, Any, Vararg{Enzyme.EnzymeCore.Annotation}} + push!(edges, ccall(:jl_method_table_for, Any, (Any,), rev_sig)::Core.MethodTable) + push!(edges, rev_sig) + end + + ina_sig = Tuple{typeof(EnzymeRules.inactive), Vararg{Any}} + push!(edges, ccall(:jl_method_table_for, Any, (Any,), ina_sig)::Core.MethodTable) + push!(edges, ina_sig) + + for gen_sig in ( + Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{Any}}, + Tuple{typeof(EnzymeRules.noalias), Vararg{Any}}, + Tuple{typeof(EnzymeRules.inactive_type), Type}, + ) + push!(edges, ccall(:jl_method_table_for, Any, (Any,), gen_sig)::Core.MethodTable) + push!(edges, gen_sig) end - - push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.inactive)}, world)) - push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{Val{0}}, world)) - Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(Val(0))) new_ci.edges = edges @@ -5612,6 +5647,7 @@ function thunk_generator(world::UInt, source::LineNumberNode, @nospecialize(FA:: ABI, ErrIfFuncWritten, RuntimeActivity, + edges ) finally deactivate(ctx) @@ -5681,18 +5717,14 @@ function deferred_id_generator(world::UInt, source::LineNumberNode, @nospecializ primal_tt = Tuple{map(eltype, TT.parameters)...} # look up the method match method_error = :(throw(MethodError($ft, $primal_tt, $world))) - sig = Tuple{ft, primal_tt.parameters...} + min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) - match = ccall(:jl_gf_invoke_lookup_worlds, Any, - (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), - sig, #=mt=# nothing, world, min_world, max_world) - match === nothing && return stub(world, source, method_error) - - # look up the method and code instance - mi = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, - (Any, Any, Any), match.method, match.spec_types, match.sparams) + mi = my_methodinstance(Mode == API.DEM_ForwardMode ? Forward : Reverse, ft, primal_tt, world, min_world, max_world) + + mi === nothing && return stub(world, source, method_error) + ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo # prepare a new code info diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 60751a1004..a45ab24d1a 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -29,7 +29,7 @@ function rule_backedge_holder_generator(world::UInt, source, self, ft::Type) sig = Tuple{typeof(Base.identity), Int} min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) - has_ambig = Ptr{Int32}(C_NULL) + has_ambig = Ptr{Int32}(C_NULL) mthds = Base._methods_by_ftype( sig, nothing, @@ -68,34 +68,18 @@ function rule_backedge_holder_generator(world::UInt, source, self, ft::Type) ### TODO: backedge from inactive, augmented_primal, forward, reverse edges = Any[] - @static if false if ft == typeof(EnzymeRules.augmented_primal) - # this is illegal - # sig = Tuple{typeof(EnzymeRules.augmented_primal), <:RevConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}} - # push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig)) - push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.augmented_primal), Tuple{<:RevConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}}, world)) + sig = Tuple{typeof(EnzymeRules.augmented_primal), <:RevConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}} + push!(edges, ccall(:jl_method_table_for, Any, (Any,), sig)) + push!(edges, sig) elseif ft == typeof(EnzymeRules.forward) - # this is illegal - # sig = Tuple{typeof(EnzymeRules.forward), <:FwdConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}} - # push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig)) - push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.forward), Tuple{<:FwdConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}}, world)) + sig = Tuple{typeof(EnzymeRules.forward), <:FwdConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}} + push!(edges, ccall(:jl_method_table_for, Any, (Any,), sig)) + push!(edges, sig) else - # sig = Tuple{typeof(EnzymeRules.inactive), Vararg{Annotation}} - # push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig)) - push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.inactive), Tuple{Vararg{Annotation}}, world)) - - # sig = Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{Annotation}} - # push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig)) - push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.inactive_noinl), Tuple{Vararg{Annotation}}, world)) - - # sig = Tuple{typeof(EnzymeRules.noalias), Vararg{Any}} - # push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig)) - push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.noalias), Tuple{Vararg{Any}}, world)) - - # sig = Tuple{typeof(EnzymeRules.inactive_type), Type} - # push!(edges, (ccall(:jl_method_table_for, Any, (Any,), sig), sig)) - push!(edges, GPUCompiler.generic_methodinstance(typeof(EnzymeRules.inactive_type), Tuple{Type}, world)) - end + sig = Tuple{typeof(EnzymeRules.inactive), Vararg{Annotation}} + push!(edges, ccall(:jl_method_table_for, Any, (Any,), sig)) + push!(edges, sig) end new_ci.edges = edges @@ -110,7 +94,7 @@ function rule_backedge_holder_generator(world::UInt, source, self, ft::Type) new_ci.slotflags = UInt8[0x00 for i = 1:2] # return the codegen world age - push!(new_ci.code, Core.Compiler.ReturnNode(0)) + push!(new_ci.code, Core.Compiler.ReturnNode(world)) push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code` @static if isdefined(Core, :DebugInfo) else @@ -126,39 +110,6 @@ end $(Expr(:meta, :generated, rule_backedge_holder_generator)) end -begin - # Forward-rule catch all - fwd_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{typeof(EnzymeRules.forward)}) - # Reverse-rule catch all - rev_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_primal)}) - # Inactive-rule catch all - ina_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{typeof(EnzymeRules.inactive)}) - # All other derivative-related catch all (just for autodiff, not inference), including inactive_noinl, noalias, and inactive_type - gen_rule_be = GPUCompiler.methodinstance(typeof(rule_backedge_holder), Tuple{Val{0}}) - - - fwd_sig = Tuple{typeof(EnzymeRules.forward), <:EnzymeRules.FwdConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} - EnzymeRules.add_mt_backedge!(fwd_rule_be, ccall(:jl_method_table_for, Any, (Any,), fwd_sig)::Core.MethodTable, fwd_sig) - - rev_sig = Tuple{typeof(EnzymeRules.augmented_primal), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} - EnzymeRules.add_mt_backedge!(rev_rule_be, ccall(:jl_method_table_for, Any, (Any,), rev_sig)::Core.MethodTable, rev_sig) - - - for ina_sig in ( - Tuple{typeof(EnzymeRules.inactive), Vararg{Any}}, - ) - EnzymeRules.add_mt_backedge!(ina_rule_be, ccall(:jl_method_table_for, Any, (Any,), ina_sig)::Core.MethodTable, ina_sig) - end - - for gen_sig in ( - Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{Any}}, - Tuple{typeof(EnzymeRules.noalias), Vararg{Any}}, - Tuple{typeof(EnzymeRules.inactive_type), Type}, - ) - EnzymeRules.add_mt_backedge!(gen_rule_be, ccall(:jl_method_table_for, Any, (Any,), gen_sig)::Core.MethodTable, gen_sig) - end -end - struct EnzymeInterpreter{T} <: AbstractInterpreter @static if HAS_INTEGRATED_CACHE token::Any @@ -182,6 +133,33 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter handler::T end + +function get_rule_signatures(f, TT, world) + fwdrules_meths = Base._methods(f, TT, -1, world)::Vector + sigs = Type[] + for rule in fwdrules_meths + push!(sigs, (rule::Core.MethodMatch).method.sig) + end + return Base.IdSet{Type}(sigs) +end + +function rule_sigs_equal(a, b) + if length(a) != length(b) + return false + end + for v in a + if v in b + continue + end + return false + end + return true +end + +const LastFwdWorld = Ref(Base.IdSet{Type}()) +const LastRevWorld = Ref(Base.IdSet{Type}()) +const LastInaWorld = Ref(Base.IdSet{Type}()) + function EnzymeInterpreter( cache_or_token, mt::Union{Nothing,Core.MethodTable}, @@ -198,6 +176,37 @@ function EnzymeInterpreter( else InferenceParams(; unoptimize_throw_blocks=false) end + + @static if HAS_INTEGRATED_CACHE + + else + cache_or_token = cache_or_token::CodeCache + invalid = false + if forward_rules + fwdrules = get_rule_signatures(EnzymeRules.forward, Tuple{<:FwdConfig, <:Annotation, Type{<:Annotation}, Vararg{Annotation}}, world) + if !rule_sigs_equal(fwdrules, LastFwdWorld[]) + LastFwdWorld[] = fwdrules + invalid = true + end + end + if reverse_rules + revrules = get_rule_signatures(EnzymeRules.augmented_primal, Tuple{<:RevConfig, <:Annotation, Type{<:Annotation}, Vararg{Annotation}}, world) + if !rule_sigs_equal(revrules, LastRevWorld[]) + LastRevWorld[] = revrules + invalid = true + end + end + + inarules = get_rule_signatures(EnzymeRules.inactive, Tuple{Vararg{Any}}, world) + if !rule_sigs_equal(inarules, LastInaWorld[]) + LastInaWorld[] = inarules + invalid = true + end + + if invalid + Base.empty!(cache_or_token) + end + end return EnzymeInterpreter( cache_or_token, @@ -370,17 +379,6 @@ function Core.Compiler.abstract_call_gf_by_type( end end end - - if interp.forward_rules - Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.forward)}, interp.world)::Core.MethodInstance) - Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(EnzymeRules.forward)) - end - if interp.reverse_rules - Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_primal)}, interp.world)::Core.MethodInstance) - Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(EnzymeRules.augmented_primal)) - end - Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.inactive)}, interp.world)::Core.MethodInstance) - Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(typeof(EnzymeRules.inactive))) end @static if VERSION ≥ v"1.11-" diff --git a/src/compiler/orcv2.jl b/src/compiler/orcv2.jl index 1640b05db2..3e2dc8a9f3 100644 --- a/src/compiler/orcv2.jl +++ b/src/compiler/orcv2.jl @@ -208,7 +208,7 @@ function get_trampoline(job) # 2 Add a module defining "foo.rt.impl" to the JITDylib. # 2. Call MR.replace(symbolAliases({"my_deferred_decision_sym.1" -> "foo.rt.impl"})). GPUCompiler.JuliaContext() do ctx - mod, adjoint_name, primal_name = Compiler._thunk(job) + mod, edges, adjoint_name, primal_name = Compiler._thunk(job) func_name = use_primal ? primal_name : adjoint_name other_name = !use_primal ? primal_name : adjoint_name diff --git a/src/compiler/reflection.jl b/src/compiler/reflection.jl index 40b1293d8d..dcddf0acd4 100644 --- a/src/compiler/reflection.jl +++ b/src/compiler/reflection.jl @@ -19,13 +19,14 @@ function get_job( tt = Tuple{map(eltype, types.parameters)...} - - primal, rt = if world isa Nothing - fspec(Core.Typeof(func), types), Compiler.primal_return_type(mode == API.DEM_ForwardMode ? Forward : Reverse, Core.Typeof(func), tt) - else - fspec(Core.Typeof(func), types, world), Compiler.primal_return_type_world(mode == API.DEM_ForwardMode ? Forward : Reverse, world, Core.Typeof(func), tt) + if world isa Nothing + world=Base.get_world_counter() end + primal = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, Core.Typeof(func), tt, world) + rt = Compiler.primal_return_type_world(mode == API.DEM_ForwardMode ? Forward : Reverse, world, Core.Typeof(func), tt) + + @assert primal !== nothing rt = A{rt} target = Compiler.EnzymeTarget() if modifiedBetween === nothing @@ -47,18 +48,11 @@ function get_job( ErrIfFuncWritten, RuntimeActivity, ) - if world isa Nothing - return Compiler.CompilerJob( - primal, - CompilerConfig(target, params; kernel = false), - ) - else - return Compiler.CompilerJob( + return Compiler.CompilerJob( primal, CompilerConfig(target, params; kernel = false), world, ) - end end function reflect( diff --git a/src/compiler/tfunc.jl b/src/compiler/tfunc.jl index 701cfa8107..57ec0053d1 100644 --- a/src/compiler/tfunc.jl +++ b/src/compiler/tfunc.jl @@ -2,30 +2,32 @@ import EnzymeCore: Annotation import EnzymeCore.EnzymeRules: FwdConfig, RevConfig, forward, augmented_primal, inactive, _annotate_tt function has_frule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), - @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + @nospecialize(TT::Type), sv::Core.Compiler.AbsIntState, partialedge::Bool=true)::Bool ft, tt = _annotate_tt(TT) TT = Tuple{<:FwdConfig,<:Annotation{ft},Type{<:Annotation},tt...} - return isapplicable(interp, forward, TT, sv) + fwd_sig = Tuple{typeof(EnzymeRules.forward), <:EnzymeRules.FwdConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} + return isapplicable(interp, forward, TT, sv, fwd_sig) end function has_rrule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), - @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + @nospecialize(TT::Type), sv::Core.Compiler.AbsIntState, partialedge::Bool=true)::Bool ft, tt = _annotate_tt(TT) TT = Tuple{<:RevConfig,<:Annotation{ft},Type{<:Annotation},tt...} - return isapplicable(interp, augmented_primal, TT, sv) + rev_sig = Tuple{typeof(EnzymeRules.augmented_primal), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} + return isapplicable(interp, augmented_primal, TT, sv, rev_sig) end function is_inactive_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), - @nospecialize(TT), sv::Core.Compiler.AbsIntState) - return isapplicable(interp, inactive, TT, sv) + @nospecialize(TT::Type), sv::Core.Compiler.AbsIntState) + return isapplicable(interp, inactive, TT, sv, nothing) end # `hasmethod` is a precise match using `Core.Compiler.findsup`, # but here we want the broader query using `Core.Compiler.findall`. # Also add appropriate backedges to the caller `MethodInstance` if given. function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter), - @nospecialize(f), @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + @nospecialize(f), @nospecialize(TT::Type), sv::Core.Compiler.AbsIntState, @nospecialize(partialsig::Union{Type, Nothing}))::Bool tt = Base.to_tuple_type(TT) sig = Base.signature_type(f, tt) mt = ccall(:jl_method_table_for, Any, (Any,), sig) @@ -39,18 +41,24 @@ function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter), end # also need an edge to the method table in case something gets # added that did not intersect with any existing method - fullmatch = Core.Compiler._any(match::Core.MethodMatch -> match.fully_covers, matches) - if !fullmatch - Core.Compiler.add_mt_backedge!(sv, mt, sig) + # fullmatch = Core.Compiler._any(match::Core.MethodMatch -> match.fully_covers, matches) + # if !fullmatch + if true + if partialsig === nothing + Core.Compiler.add_mt_backedge!(sv, mt, sig) + else + pmt = ccall(:jl_method_table_for, Any, (Any,), partialsig) + Core.Compiler.add_mt_backedge!(sv, pmt, partialsig) + end end if Core.Compiler.isempty(matches) return false else - for i = 1:Core.Compiler.length(matches) - match = Core.Compiler.getindex(matches, i)::Core.MethodMatch - edge = Core.Compiler.specialize_method(match)::Core.MethodInstance - Core.Compiler.add_backedge!(sv, edge) - end + #for i = 1:Core.Compiler.length(matches) + # match = Core.Compiler.getindex(matches, i)::Core.MethodMatch + # edge = Core.Compiler.specialize_method(match)::Core.MethodInstance + # Core.Compiler.add_backedge!(sv, edge) + #end return true end -end \ No newline at end of file +end diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index 20be3891a0..dbd0c14773 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -474,24 +474,12 @@ const generic_method_offsets = Dict{String,Tuple{Int,Int}}(( "ijl_apply_generic" => (1, 2), )) -@inline function has_method(@nospecialize(sig::Type), world::UInt, mt::Union{Nothing,Core.MethodTable}) - return ccall(:jl_gf_invoke_lookup, Any, (Any, Any, UInt), sig, mt, world) !== nothing -end - -@inline function has_method(@nospecialize(sig::Type), world::UInt, mt::Core.Compiler.InternalMethodTable) - return has_method(sig, mt.world, nothing) -end - -@inline function has_method(@nospecialize(sig::Type), world::UInt, mt::Core.Compiler.OverlayMethodTable) - return has_method(sig, mt.world, mt.mt) || has_method(sig, mt.world, nothing) -end - @inline function is_inactive(@nospecialize(tys::Union{Vector{Union{Type,Core.TypeofVararg}}, Core.SimpleVector}), world::UInt, @nospecialize(mt)) specTypes = Interpreter.simplify_kw(Tuple{tys...}) - if has_method(Tuple{typeof(EnzymeRules.inactive),tys...}, world, mt) + if Enzyme.has_method(Tuple{typeof(EnzymeRules.inactive),tys...}, world, mt) return true end - if has_method(Tuple{typeof(EnzymeRules.inactive_noinl),tys...}, world, mt) + if Enzyme.has_method(Tuple{typeof(EnzymeRules.inactive_noinl),tys...}, world, mt) return true end return false diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 6956842462..bfde266a8d 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -523,21 +523,21 @@ end typeof(EnzymeRules.forward) end @safe_debug "Applying custom forward rule" TT = TT, functy = functy - fmi, fwd_RT = try - fmi = my_methodinstance(functy, TT, world) - fwd_RT = primal_return_type_world(Forward, world, fmi) - fmi, fwd_RT - catch e + fmi = my_methodinstance(Forward, functy, TT, world) + if fmi === nothing TT = Tuple{typeof(world),functy,TT.parameters...} - fmi = my_methodinstance(typeof(custom_rule_method_error), TT, world) + fmi = my_methodinstance(Forward, typeof(custom_rule_method_error), TT, world) pushfirst!(args, LLVM.ConstantInt(world)) fwd_RT = Union{} - fmi, fwd_RT + else + fwd_RT = primal_return_type_world(Forward, world, fmi) end + fmi = fmi::Core.MethodInstance fwd_RT = fwd_RT::Type llvmf = nested_codegen!(mode, mod, fmi, world) + push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) swiftself = has_swiftself(llvmf) @@ -802,11 +802,11 @@ end typeof(EnzymeRules.augmented_primal) end - ami = try - my_methodinstance(functy, augprimal_TT, world) - catch e + ami = my_methodinstance(Reverse, functy, augprimal_TT, world) + if ami === nothing augprimal_TT = Tuple{typeof(world),functy,augprimal_TT.parameters...} ami = my_methodinstance( + Reverse, typeof(custom_rule_method_error), augprimal_TT, world, @@ -816,6 +816,7 @@ end end ami end + ami = ami::Core.MethodInstance @safe_debug "Applying custom augmented_primal rule" TT = augprimal_TT, functy=functy return ami, @@ -984,18 +985,18 @@ function enzyme_custom_common_rev( end @safe_debug "Applying custom reverse rule" TT = rev_TT, functy=functy - rmi, rev_RT = try - rmi = my_methodinstance(functy, rev_TT, world) - rev_RT = return_type(interp, rmi) - rmi, rev_RT - catch e + rmi = my_methodinstance(Reverse, functy, rev_TT, world) + + if rmi === nothing rev_TT = Tuple{typeof(world),functy,rev_TT.parameters...} rmi = my_methodinstance(typeof(custom_rule_method_error), rev_TT, world) pushfirst!(args, LLVM.ConstantInt(world)) rev_RT = Union{} applicablefn = false - rmi, rev_RT + else + rev_RT = return_type(interp, rmi) end + rmi = rmi::Core.MethodInstance rev_RT = rev_RT::Type llvmf = nested_codegen!(mode, mod, rmi, world) diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index d4356aba61..41992a4ff8 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -238,7 +238,8 @@ end pfuncT = funcT - mi2 = fspec(funcT, e_tt, world) + mi2 = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, funcT, Tuple{map(eltype, e_tt.parameters)...}, world) + @assert mi2 !== nothing refed = false @@ -275,7 +276,7 @@ end world, ) - cmod, fwdmodenm, _, _, _ = _thunk(ejob, false) #=postopt=# + cmod, edges, fwdmodenm, _, _, _ = _thunk(ejob, false) #=postopt=# LLVM.link!(mod, cmod) @@ -306,7 +307,8 @@ end funcT = Core.Typeof(referenceCaller) dupClosure = false modifiedBetween = (false, modifiedBetween...) - mi2 = fspec(funcT, e_tt, world) + mi2 = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, funcT, Tuple{map(eltype, e_tt.parameters)...}, world) + @assert mi2 !== nothing end end @@ -334,7 +336,7 @@ end world, ) - cmod, adjointnm, augfwdnm, TapeType, _ = _thunk(ejob, false) #=postopt=# + cmod, edges, adjointnm, augfwdnm, TapeType, _ = _thunk(ejob, false) #=postopt=# LLVM.link!(mod, cmod) diff --git a/src/sugar.jl b/src/sugar.jl index b93b7fb0eb..574ffa8cbf 100644 --- a/src/sugar.jl +++ b/src/sugar.jl @@ -13,7 +13,7 @@ end target = Compiler.DefaultCompilerTarget() params = Compiler.PrimalCompilerParams(API.DEM_ForwardMode) - mi = my_methodinstance(fn, Tuple{T, Int}) + mi = my_methodinstance(nothing, fn, Tuple{T, Int}) job = GPUCompiler.CompilerJob(mi, GPUCompiler.CompilerConfig(target, params; kernel = false)) GPUCompiler.prepare_job!(job) @@ -899,7 +899,7 @@ this function will retun an AbstractArray of shape `size(output)` of values of t Core.Typeof(f) end - rt = Compiler.primal_return_type(mode, FRT, tt) + rt = Compiler.primal_return_type(Reverse, FRT, tt) ModifiedBetweenT = (false, false) FA = Const{FRT} diff --git a/src/typeutils/inference.jl b/src/typeutils/inference.jl index cd000b70c9..764f55ab23 100644 --- a/src/typeutils/inference.jl +++ b/src/typeutils/inference.jl @@ -25,7 +25,9 @@ function primal_interp_world( false, GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=# EnzymeCompilerParams, + world, false, + true ) else Enzyme.Compiler.GLOBAL_REV_CACHE @@ -46,7 +48,9 @@ function primal_interp_world( false, GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=# EnzymeCompilerParams, + world, true, + false ) else Enzyme.Compiler.GLOBAL_FWD_CACHE @@ -97,41 +101,19 @@ function primal_return_type_generator(world::UInt, source, self, @nospecialize(m # look up the method method_error = :(throw(MethodError(ft, tt, $world))) - sig = Tuple{ft,tt.parameters...} + min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) - has_ambig = Ptr{Int32}(C_NULL) # don't care about ambiguous results - #interp = primal_interp_world(mode, world) - #method_table = Core.Compiler.method_table(interp) - method_table = nothing - mthds = Base._methods_by_ftype( - sig, - method_table, - -1, #=lim=# - world, - false, #=ambig=# - min_world, - max_world, - has_ambig, - ) + + mi = my_methodinstance(mode, ft, tt, world, min_world, max_world) + stub = Core.GeneratedFunctionStub( identity, Core.svec(:methodinstance, :mode, :ft, :tt), Core.svec(), ) - mthds === nothing && return stub(world, source, method_error) - length(mthds) == 1 || return stub(world, source, method_error) - - # look up the method and code instance - mtypes, msp, m = mthds[1] - mi = ccall( - :jl_specializations_get_linfo, - Ref{Core.MethodInstance}, - (Any, Any, Any), - m, - mtypes, - msp, - ) + mi === nothing && return stub(world, source, method_error) + ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo # prepare a new code info diff --git a/src/utils.jl b/src/utils.jl index 0f92fc1f5d..aa94e766f8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -169,33 +169,120 @@ using Base: _methods_by_ftype # Julia compiler integration +@inline function has_method(@nospecialize(sig::Type), world::UInt, mt::Union{Nothing,Core.MethodTable}) + return ccall(:jl_gf_invoke_lookup, Any, (Any, Any, UInt), sig, mt, world) !== nothing +end -if VERSION >= v"1.11.0-DEV.1552" +@inline function has_method(@nospecialize(sig::Type), world::UInt, mt::Core.Compiler.InternalMethodTable) + return has_method(sig, mt.world, nothing) +end +@inline function has_method(@nospecialize(sig::Type), world::UInt, mt::Core.Compiler.OverlayMethodTable) + return has_method(sig, mt.world, mt.mt) || has_method(sig, mt.world, nothing) +end + +@inline function lookup_world( + @nospecialize(sig::Type), + world::UInt, + mt::Union{Nothing,Core.MethodTable}, + min_world::Ref{UInt}, + max_world::Ref{UInt}, +) + res = ccall( + :jl_gf_invoke_lookup_worlds, + Any, + (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), + sig, + mt, + world, + min_world, + max_world, + ) + return res +end -const prevmethodinstance = GPUCompiler.generic_methodinstance +@inline function lookup_world( + @nospecialize(sig::Type), + world::UInt, + mt::Core.Compiler.InternalMethodTable, + min_world::Ref{UInt}, + max_world::Ref{UInt}, +) + res = lookup_world(sig, mt.world, nothing, min_world, max_world) + return res +end -function methodinstance_generator(world::UInt, source, self, @nospecialize(ft::Type), @nospecialize(tt::Type)) +@inline function lookup_world( + @nospecialize(sig::Type), + world::UInt, + mt::Core.Compiler.OverlayMethodTable, + min_world::Ref{UInt}, + max_world::Ref{UInt}, +) + res = lookup_world(sig, mt.world, mt.mt, min_world, max_world) + if res !== nothing + return res + else + return lookup_world(sig, mt.world, nothing, min_world, max_world) + end +end + +@inline function my_methodinstance(@nospecialize(method_table::Union{Core.Compiler.MethodTableView, Nothing}), @nospecialize(ft::Type), @nospecialize(tt::Type), world::UInt, min_world::Union{Nothing, Base.RefValue{UInt}}=nothing, max_world::Union{Nothing, Base.RefValue{UInt}}=nothing)::Union{Core.MethodInstance, Nothing} + + if min_world === nothing + min_world = Ref{UInt}(typemin(UInt)) + end + if max_world === nothing + max_world = Ref{UInt}(typemax(UInt)) + end + + sig = Tuple{ft, tt.parameters...} + + lookup_result = lookup_world( + sig, world, method_table, min_world, max_world + ) + if lookup_result === nothing + return nothing + end + + match = lookup_result::Core.MethodMatch + + mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, + (Any, Any, Any), match.method, match.spec_types, match.sparams) + return mi::Core.MethodInstance +end + +@inline function my_methodinstance(@nospecialize(interp::Core.Compiler.AbstractInterpreter), @nospecialize(ft::Type), @nospecialize(tt::Type), min_world::Union{Nothing, Base.RefValue{UInt}}=nothing, max_world::Union{Nothing, Base.RefValue{UInt}}=nothing)::Union{Core.MethodInstance, Nothing} + my_methodinstance(Core.Compiler.method_table(interp), ft, tt, interp.world, min_world, max_world) +end + +@inline function my_methodinstance(@nospecialize(mode::Union{EnzymeCore.ForwardMode, EnzymeCore.ReverseMode}), @nospecialize(ft::Type), @nospecialize(tt::Type), world::UInt, min_world::Union{Nothing, Base.RefValue{UInt}}=nothing, max_world::Union{Nothing, Base.RefValue{UInt}}=nothing)::Union{Core.MethodInstance, Nothing} + interp = if mode === Nothing + Base.NativeInterpreter(; world) + else + @assert mode == Forward || mode == Reverse + Compiler.primal_interp_world(mode, world) + end + my_methodinstance(interp, ft, tt, min_world, max_world) +end + +function methodinstance_generator(world::UInt, source, self, @nospecialize(mode::Type), @nospecialize(ft::Type), @nospecialize(tt::Type)) @nospecialize @assert Core.Compiler.isType(ft) && Core.Compiler.isType(tt) ft = ft.parameters[1] tt = tt.parameters[1] - stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, :ft, :tt), Core.svec()) + stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, :mode, :ft, :tt), Core.svec()) # look up the method match method_error = :(throw(MethodError(ft, tt, $world))) - sig = Tuple{ft, tt.parameters...} min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) - match = ccall(:jl_gf_invoke_lookup_worlds, Any, - (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), - sig, #=mt=# nothing, world, min_world, max_world) - match === nothing && return stub(world, source, method_error) - # look up the method and code instance - mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, - (Any, Any, Any), match.method, match.spec_types, match.sparams) + mi = my_methodinstance(mode.instance, ft, tt, world, min_world, max_world) + + mi === nothing && return stub(world, source, method_error) + ci = Core.Compiler.retrieve_code_info(mi, world) # prepare a new code info @@ -212,8 +299,8 @@ function methodinstance_generator(world::UInt, source, self, @nospecialize(ft::T new_ci.edges = MethodInstance[mi] # prepare the slots - new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt] - new_ci.slotflags = UInt8[0x00 for i = 1:3] + new_ci.slotnames = Symbol[Symbol("#self#"), :mode, :ft, :tt] + new_ci.slotflags = UInt8[0x00 for i = 1:4] # return the method instance push!(new_ci.code, Core.Compiler.ReturnNode(mi)) @@ -225,23 +312,15 @@ function methodinstance_generator(world::UInt, source, self, @nospecialize(ft::T return new_ci end -@eval function prevmethodinstance(ft, tt)::Core.MethodInstance +@eval function prevmethodinstance(mode, ft, tt)::Core.MethodInstance $(Expr(:meta, :generated_only)) $(Expr(:meta, :generated, methodinstance_generator)) end # XXX: version of Base.method_instance that uses a function type -@inline function my_methodinstance(@nospecialize(ft::Type), @nospecialize(tt::Type), - world::Integer=tls_world_age())::Core.MethodInstance +@inline function my_methodinstance(@nospecialize(mode::Union{Nothing, EnzymeCore.ForwardMode, EnzymeCore.ReverseMode}), @nospecialize(ft::Type), @nospecialize(tt::Type))::Core.MethodInstance sig = GPUCompiler.signature_type_by_tt(ft, tt) - if Base.isdispatchtuple(sig) # JuliaLang/julia#52233 - return GPUCompiler.methodinstance(ft, tt, world)::Core.MethodInstance - else - return prevmethodinstance(ft, tt, world)::Core.MethodInstance - end -end -else - import GPUCompiler: methodinstance as my_methodinstance + return prevmethodinstance(mode, ft, tt)::Core.MethodInstance end export my_methodinstance diff --git a/test/ruleinvalidation.jl b/test/ruleinvalidation.jl index 501b0aac10..fff95880f5 100644 --- a/test/ruleinvalidation.jl +++ b/test/ruleinvalidation.jl @@ -33,18 +33,19 @@ for m in methods(forward, Tuple{Any,Const{typeof(issue696)},Vararg{Any}}) Base.delete_method(m) end @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 -@static if VERSION < v"1.11-" - @test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 -else - @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 -end +@test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 # now test invalidation for `inactive` inactive(::typeof(issue696), args...) = nothing @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 -@static if VERSION < v"1.11-" - @test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 -else - @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 +@test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 + +# check that `Base.delete_method` works as expected +for m in methods(inactive, Tuple{typeof(issue696),Vararg{Any}}) + Base.delete_method(m) end + +@test_broken autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 +@test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 + end # module