From 258c5eddc7f7cea15bb43a3618b691fe60f63c69 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 14 Dec 2024 22:14:23 -0500 Subject: [PATCH] final fix --- src/compiler.jl | 64 ++++++++++---- src/compiler/interpreter.jl | 161 ++++++++++++++++++++++++++++++++---- src/compiler/tfunc.jl | 15 ++-- src/rules/customrules.jl | 6 +- src/typeutils/inference.jl | 2 + test/ruleinvalidation.jl | 6 +- 6 files changed, 211 insertions(+), 43 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 97cd22031b..e01f3e2d08 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -361,6 +361,8 @@ function prepare_llvm(mod::LLVM.Module, job, meta) end end +const mod_to_edges = Dict{LLVM.Module, Vector{Any}}() + function nested_codegen!( mode::API.CDerivativeMode, mod::LLVM.Module, @@ -390,6 +392,11 @@ function nested_codegen!( permit_inlining!(f) end + edges = get(mod_to_edges, mod, nothing) + @assert edges !== nothing + edges = edges::Vector{Any} + push!(edges, funcspec) + # Apply first stage of optimization's so that this module is at the same stage as `mod` optimize!(otherMod, JIT.get_tm()) # 4) Link the corresponding module @@ -1193,9 +1200,18 @@ if VERSION >= v"1.11.0-DEV.1552" always_inline::Any method_table::Core.MethodTable param_type::Type - is_fwd::Bool + last_fwd_rule_world::Union{Nothing, Tuple} + last_rev_rule_world::Union{Nothing, Tuple} + last_ina_rule_world::Tuple end + @inline EnzymeCacheToken(target_type, always_inline, method_table, param_type, is_forward, is_reverse) = + EnzymeCacheToken(target_type, always_inline, method_table, param_type, + is_forward ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.forward, Tuple{<:FwdConfig, <:Annotation, Type{<:Annotation}, Vararg{Annotation}}, world)...,) : nothing, + is_reverse ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.augmented_primal, Tuple{<:RevConfig, <:Annotation, Type{<:Annotation}, Vararg{Annotation}}, world)...,) : nothing, + (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.inactive, Tuple{Vararg{Any}}, world)...,) + ) + GPUCompiler.ci_cache_token(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = EnzymeCacheToken( typeof(job.config.target), @@ -1203,6 +1219,7 @@ if VERSION >= v"1.11.0-DEV.1552" GPUCompiler.method_table(job), typeof(job.config.params), job.config.params.mode == API.DEM_ForwardMode, + job.config.params.mode != API.DEM_ForwardMode ) GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) = @@ -3234,6 +3251,7 @@ function GPUCompiler.codegen( if params.run_enzyme # @assert eltype(params.rt) != Union{} end + expectedTapeType = params.expectedTapeType mode = params.mode TT = params.TT @@ -3275,6 +3293,8 @@ function GPUCompiler.codegen( GPUCompiler.prepare_job!(primal_job) mod, meta = GPUCompiler.emit_llvm(primal_job; libraries=true, toplevel=toplevel, optimize=false, cleanup=false, only_entry=false, validate=false) + edges = Any[] + mod_to_edges[mod] = edges prepare_llvm(mod, primal_job, meta) for f in functions(mod) @@ -3556,13 +3576,13 @@ function GPUCompiler.codegen( caller = mi if mode == API.DEM_ForwardMode has_custom_rule = - EnzymeRules.has_frule_from_sig(specTypes; world, method_table, caller) + EnzymeRules.has_frule_from_sig(specTypes; world, method_table) #, caller) if has_custom_rule @safe_debug "Found frule for" mi.specTypes end else has_custom_rule = - EnzymeRules.has_rrule_from_sig(specTypes; world, method_table, caller) + EnzymeRules.has_rrule_from_sig(specTypes; world, method_table) # , caller) if has_custom_rule @safe_debug "Found rrule for" mi.specTypes end @@ -3577,7 +3597,7 @@ function GPUCompiler.codegen( actualRetType = k.ci.rettype end - if EnzymeRules.noalias_from_sig(mi.specTypes; world, method_table, caller) + if EnzymeRules.noalias_from_sig(mi.specTypes; world, method_table) #, caller) push!(return_attributes(llvmfn), EnumAttribute("noalias")) for u in LLVM.uses(llvmfn) c = LLVM.user(u) @@ -3801,7 +3821,7 @@ end end continue end - if EnzymeRules.is_inactive_from_sig(specTypes; world, method_table, caller) && + if EnzymeRules.is_inactive_from_sig(specTypes; world, method_table) && #, caller) && Enzyme.has_method( Tuple{typeof(EnzymeRules.inactive),specTypes.parameters...}, world, @@ -3819,7 +3839,7 @@ end ) continue end - if EnzymeRules.is_inactive_noinl_from_sig(specTypes; world, method_table, caller) && + if EnzymeRules.is_inactive_noinl_from_sig(specTypes; world, method_table) # , caller) && has_method( Tuple{typeof(EnzymeRules.inactive_noinl),specTypes.parameters...}, world, @@ -4588,10 +4608,12 @@ end isempty(LLVM.blocks(fn)) && continue linkage!(fn, LLVM.API.LLVMLinkerPrivateLinkage) end + + delete!(mod_to_edges, mod) use_primal = mode == API.DEM_ReverseModePrimal entry = use_primal ? augmented_primalf : adjointf - return mod, (; adjointf, augmented_primalf, entry, compiled = meta.compiled, TapeType) + return mod, (; adjointf, augmented_primalf, entry, compiled = meta.compiled, TapeType, edges) end # Compiler result @@ -4599,6 +4621,7 @@ struct CompileResult{AT,PT} adjoint::AT primal::PT TapeType::Type + edges::Vector{Any} end @inline (thunk::PrimalErrorThunk{PT,FA,RT,TT,Width,ReturnPrimal})( @@ -5224,12 +5247,13 @@ end # JIT ## -function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType), prepost::String) +function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, edges::Vector{Any}, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType), prepost::String) if job.config.params.ABI <: InlineABI return CompileResult( Val((Symbol(mod), Symbol(adjoint_name))), Val((Symbol(mod), Symbol(primal_name))), - TapeType + TapeType, + edges ) end @@ -5261,16 +5285,17 @@ function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module end end - return CompileResult(adjoint_ptr, primal_ptr, TapeType) + return CompileResult(adjoint_ptr, primal_ptr, TapeType, edges) end const DumpPostOpt = Ref(false) # actual compilation -function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, String, Union{String, Nothing}, Type, String} +function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, String, Union{String, Nothing}, Type, String} mod, meta = codegen(:llvm, job; optimize = false) adjointf, augmented_primalf = meta.adjointf, meta.augmented_primalf + adjoint_name = name(adjointf) if augmented_primalf !== nothing @@ -5303,7 +5328,7 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, String, Union{Str else "" end - return (mod, adjoint_name, primal_name, meta.TapeType, prepost) + return (mod, meta.edges, adjoint_name, primal_name, meta.TapeType, prepost) end const cache = Dict{UInt,CompileResult}() @@ -5322,10 +5347,10 @@ const cache_lock = ReentrantLock() asm = _thunk(job) obj = _link(job, asm...) if obj.adjoint isa Ptr{Nothing} - autodiff_cache[obj.adjoint] = (asm[2], asm[5]) + autodiff_cache[obj.adjoint] = (asm[3], asm[6]) end - if obj.primal isa Ptr{Nothing} && asm[3] isa String - autodiff_cache[obj.primal] = (asm[3], asm[5]) + if obj.primal isa Ptr{Nothing} && asm[4] isa String + autodiff_cache[obj.primal] = (asm[4], asm[6]) end cache[key] = obj end @@ -5349,7 +5374,8 @@ end @nospecialize(ABI::Type), ErrIfFuncWritten::Bool, RuntimeActivity::Bool, -) + edges::Union{Nothing, Vector{Any}} +) target = Compiler.EnzymeTarget() params = Compiler.EnzymeCompilerParams( Tuple{FA,TT.parameters...}, @@ -5430,6 +5456,11 @@ end compile_result = cached_compilation(job) + if edges !== nothing + for e in compile_result.edges + push!(edges, e) + end + end if !run_enzyme ErrT = PrimalErrorThunk{typeof(compile_result.adjoint),FA,rt2,TT,width,ReturnPrimal} if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient @@ -5622,6 +5653,7 @@ function thunk_generator(world::UInt, source::LineNumberNode, @nospecialize(FA:: ABI, ErrIfFuncWritten, RuntimeActivity, + edges ) finally deactivate(ctx) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 6f4cc99295..a45ab24d1a 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -24,6 +24,92 @@ else import Core.Compiler: get_world_counter, get_world_counter as get_inference_world end +function rule_backedge_holder_generator(world::UInt, source, self, ft::Type) + @nospecialize + sig = Tuple{typeof(Base.identity), Int} + min_world = Ref{UInt}(typemin(UInt)) + max_world = Ref{UInt}(typemax(UInt)) + has_ambig = Ptr{Int32}(C_NULL) + mthds = Base._methods_by_ftype( + sig, + nothing, + -1, #=lim=# + world, + false, #=ambig=# + min_world, + max_world, + has_ambig, + ) + mtypes, msp, m = mthds[1] + mi = ccall( + :jl_specializations_get_linfo, + Ref{Core.MethodInstance}, + (Any, Any, Any), + m, + mtypes, + msp, + ) + ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo + + # prepare a new code info + new_ci = copy(ci) + empty!(new_ci.code) + @static if isdefined(Core, :DebugInfo) + new_ci.debuginfo = Core.DebugInfo(:none) + else + empty!(new_ci.codelocs) + resize!(new_ci.linetable, 1) # see note below + end + empty!(new_ci.ssaflags) + new_ci.ssavaluetypes = 0 + new_ci.min_world = min_world[] + new_ci.max_world = max_world[] + + ### TODO: backedge from inactive, augmented_primal, forward, reverse + edges = Any[] + + if ft == typeof(EnzymeRules.augmented_primal) + sig = Tuple{typeof(EnzymeRules.augmented_primal), <:RevConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}} + push!(edges, ccall(:jl_method_table_for, Any, (Any,), sig)) + push!(edges, sig) + elseif ft == typeof(EnzymeRules.forward) + sig = Tuple{typeof(EnzymeRules.forward), <:FwdConfig, <:Annotation, Type{<:Annotation},Vararg{Annotation}} + push!(edges, ccall(:jl_method_table_for, Any, (Any,), sig)) + push!(edges, sig) + else + sig = Tuple{typeof(EnzymeRules.inactive), Vararg{Annotation}} + push!(edges, ccall(:jl_method_table_for, Any, (Any,), sig)) + push!(edges, sig) + end + + new_ci.edges = edges + + # XXX: setting this edge does not give us proper method invalidation, see + # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel. + # invoking `code_llvm` also does the necessary codegen, as does calling the + # underlying C methods -- which GPUCompiler does, so everything Just Works. + + # prepare the slots + new_ci.slotnames = Symbol[Symbol("#self#"), :ft] + new_ci.slotflags = UInt8[0x00 for i = 1:2] + + # return the codegen world age + push!(new_ci.code, Core.Compiler.ReturnNode(world)) + push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code` + @static if isdefined(Core, :DebugInfo) + else + push!(new_ci.codelocs, 1) # see note below + end + new_ci.ssavaluetypes += 1 + + return new_ci +end + +@eval Base.@assume_effects :removable :foldable :nothrow @inline function rule_backedge_holder(ft) + $(Expr(:meta, :generated_only)) + $(Expr(:meta, :generated, rule_backedge_holder_generator)) +end + struct EnzymeInterpreter{T} <: AbstractInterpreter @static if HAS_INTEGRATED_CACHE token::Any @@ -47,6 +133,33 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter handler::T end + +function get_rule_signatures(f, TT, world) + fwdrules_meths = Base._methods(f, TT, -1, world)::Vector + sigs = Type[] + for rule in fwdrules_meths + push!(sigs, (rule::Core.MethodMatch).method.sig) + end + return Base.IdSet{Type}(sigs) +end + +function rule_sigs_equal(a, b) + if length(a) != length(b) + return false + end + for v in a + if v in b + continue + end + return false + end + return true +end + +const LastFwdWorld = Ref(Base.IdSet{Type}()) +const LastRevWorld = Ref(Base.IdSet{Type}()) +const LastInaWorld = Ref(Base.IdSet{Type}()) + function EnzymeInterpreter( cache_or_token, mt::Union{Nothing,Core.MethodTable}, @@ -63,6 +176,37 @@ function EnzymeInterpreter( else InferenceParams(; unoptimize_throw_blocks=false) end + + @static if HAS_INTEGRATED_CACHE + + else + cache_or_token = cache_or_token::CodeCache + invalid = false + if forward_rules + fwdrules = get_rule_signatures(EnzymeRules.forward, Tuple{<:FwdConfig, <:Annotation, Type{<:Annotation}, Vararg{Annotation}}, world) + if !rule_sigs_equal(fwdrules, LastFwdWorld[]) + LastFwdWorld[] = fwdrules + invalid = true + end + end + if reverse_rules + revrules = get_rule_signatures(EnzymeRules.augmented_primal, Tuple{<:RevConfig, <:Annotation, Type{<:Annotation}, Vararg{Annotation}}, world) + if !rule_sigs_equal(revrules, LastRevWorld[]) + LastRevWorld[] = revrules + invalid = true + end + end + + inarules = get_rule_signatures(EnzymeRules.inactive, Tuple{Vararg{Any}}, world) + if !rule_sigs_equal(inarules, LastInaWorld[]) + LastInaWorld[] = inarules + invalid = true + end + + if invalid + Base.empty!(cache_or_token) + end + end return EnzymeInterpreter( cache_or_token, @@ -220,32 +364,21 @@ function Core.Compiler.abstract_call_gf_by_type( callinfo = AlwaysInlineCallInfo(callinfo, atype) else method_table = Core.Compiler.method_table(interp) - if is_inactive_from_sig(interp, specTypes, sv) + if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) callinfo = NoInlineCallInfo(callinfo, atype, :inactive) else if interp.forward_rules - if has_frule_from_sig(interp, specTypes, sv) + if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table) callinfo = NoInlineCallInfo(callinfo, atype, :frule) end end if interp.reverse_rules - if has_rrule_from_sig(interp, specTypes, sv) + if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table) callinfo = NoInlineCallInfo(callinfo, atype, :rrule) end end end - - # if interp.forward_rules - # Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.forward)}, interp.world)::Core.MethodInstance) - # Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(EnzymeRules.forward)) - # end - # if interp.reverse_rules - # Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_primal)}, interp.world)::Core.MethodInstance) - # Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(EnzymeRules.augmented_primal)) - # end - # Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.inactive)}, interp.world)::Core.MethodInstance) - # Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(typeof(EnzymeRules.inactive))) end @static if VERSION ≥ v"1.11-" diff --git a/src/compiler/tfunc.jl b/src/compiler/tfunc.jl index a3c4544fe1..57ec0053d1 100644 --- a/src/compiler/tfunc.jl +++ b/src/compiler/tfunc.jl @@ -41,8 +41,9 @@ function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter), end # also need an edge to the method table in case something gets # added that did not intersect with any existing method - fullmatch = Core.Compiler._any(match::Core.MethodMatch -> match.fully_covers, matches) - if !fullmatch + # fullmatch = Core.Compiler._any(match::Core.MethodMatch -> match.fully_covers, matches) + # if !fullmatch + if true if partialsig === nothing Core.Compiler.add_mt_backedge!(sv, mt, sig) else @@ -53,11 +54,11 @@ function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter), if Core.Compiler.isempty(matches) return false else - for i = 1:Core.Compiler.length(matches) - match = Core.Compiler.getindex(matches, i)::Core.MethodMatch - edge = Core.Compiler.specialize_method(match)::Core.MethodInstance - Core.Compiler.add_backedge!(sv, edge) - end + #for i = 1:Core.Compiler.length(matches) + # match = Core.Compiler.getindex(matches, i)::Core.MethodMatch + # edge = Core.Compiler.specialize_method(match)::Core.MethodInstance + # Core.Compiler.add_backedge!(sv, edge) + #end return true end end diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 622e8f601e..8c033c2449 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -537,7 +537,11 @@ end fwd_RT = fwd_RT::Type llvmf = nested_codegen!(mode, mod, fmi, world) - push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) + + @show fmi, fwd_RT, world + println(string(llvmf)) + + # push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) swiftself = has_swiftself(llvmf) if swiftself diff --git a/src/typeutils/inference.jl b/src/typeutils/inference.jl index 066ee80bfb..7898969afe 100644 --- a/src/typeutils/inference.jl +++ b/src/typeutils/inference.jl @@ -26,6 +26,7 @@ function primal_interp_world( GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=# EnzymeCompilerParams, false, + true ) else Enzyme.Compiler.GLOBAL_REV_CACHE @@ -47,6 +48,7 @@ function primal_interp_world( GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=# EnzymeCompilerParams, true, + false ) else Enzyme.Compiler.GLOBAL_FWD_CACHE diff --git a/test/ruleinvalidation.jl b/test/ruleinvalidation.jl index ba21964305..1771e032a3 100644 --- a/test/ruleinvalidation.jl +++ b/test/ruleinvalidation.jl @@ -33,11 +33,7 @@ for m in methods(forward, Tuple{Any,Const{typeof(issue696)},Vararg{Any}}) Base.delete_method(m) end @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 -@static if VERSION < v"1.11-" - @test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 -else - @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 -end +@test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 # now test invalidation for `inactive` inactive(::typeof(issue696), args...) = nothing