From a5e5ae40d05c2e99fbcddd31fea3fda770b7fcb2 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 14 Dec 2024 18:28:48 -0500 Subject: [PATCH] fix --- src/compiler.jl | 42 ++++++++---- src/compiler/interpreter.jl | 125 ++++-------------------------------- src/compiler/tfunc.jl | 21 +++--- test/ruleinvalidation.jl | 13 +--- 4 files changed, 57 insertions(+), 144 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 2d9b12630c..97cd22031b 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3556,13 +3556,13 @@ function GPUCompiler.codegen( caller = mi if mode == API.DEM_ForwardMode has_custom_rule = - EnzymeRules.has_frule_from_sig(specTypes; world, method_table) # , caller) + EnzymeRules.has_frule_from_sig(specTypes; world, method_table, caller) if has_custom_rule @safe_debug "Found frule for" mi.specTypes end else has_custom_rule = - EnzymeRules.has_rrule_from_sig(specTypes; world, method_table) # , caller) + EnzymeRules.has_rrule_from_sig(specTypes; world, method_table, caller) if has_custom_rule @safe_debug "Found rrule for" mi.specTypes end @@ -3577,7 +3577,7 @@ function GPUCompiler.codegen( actualRetType = k.ci.rettype end - if EnzymeRules.noalias_from_sig(mi.specTypes; world, method_table) #, caller) + if EnzymeRules.noalias_from_sig(mi.specTypes; world, method_table, caller) push!(return_attributes(llvmfn), EnumAttribute("noalias")) for u in LLVM.uses(llvmfn) c = LLVM.user(u) @@ -3801,7 +3801,7 @@ end end continue end - if EnzymeRules.is_inactive_from_sig(specTypes; world, method_table) && # , caller) && + if EnzymeRules.is_inactive_from_sig(specTypes; world, method_table, caller) && Enzyme.has_method( Tuple{typeof(EnzymeRules.inactive),specTypes.parameters...}, world, @@ -3819,7 +3819,7 @@ end ) continue end - if EnzymeRules.is_inactive_noinl_from_sig(specTypes; world, method_table) && #, caller) && + if EnzymeRules.is_inactive_noinl_from_sig(specTypes; world, method_table, caller) && has_method( Tuple{typeof(EnzymeRules.inactive_noinl),specTypes.parameters...}, world, @@ -5568,18 +5568,34 @@ function thunk_generator(world::UInt, source::LineNumberNode, @nospecialize(FA:: new_ci.min_world = world new_ci.max_world = max_world[] - edges = Core.MethodInstance[mi] + edges = Any[mi] if Mode == API.DEM_ForwardMode - push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.forward)}, world)) - Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(EnzymeRules.forward)) + fwd_sig = Tuple{typeof(EnzymeRules.forward), <:EnzymeRules.FwdConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} + push!(edges, ccall(:jl_method_table_for, Any, (Any,), fwd_sig)::Core.MethodTable) + push!(edges, fwd_sig) else - push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_primal)}, world)) + rev_sig = Tuple{typeof(EnzymeRules.augmented_primal), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} + push!(edges, ccall(:jl_method_table_for, Any, (Any,), rev_sig)::Core.MethodTable) + push!(edges, rev_sig) + + rev_sig = Tuple{typeof(EnzymeRules.reverse), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Union{Type{<:Enzyme.EnzymeCore.Annotation}, Enzyme.EnzymeCore.Active}, Any, Vararg{Enzyme.EnzymeCore.Annotation}} + push!(edges, ccall(:jl_method_table_for, Any, (Any,), rev_sig)::Core.MethodTable) + push!(edges, rev_sig) + end + + ina_sig = Tuple{typeof(EnzymeRules.inactive), Vararg{Any}} + push!(edges, ccall(:jl_method_table_for, Any, (Any,), ina_sig)::Core.MethodTable) + push!(edges, ina_sig) + + for gen_sig in ( + Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{Any}}, + Tuple{typeof(EnzymeRules.noalias), Vararg{Any}}, + Tuple{typeof(EnzymeRules.inactive_type), Type}, + ) + push!(edges, ccall(:jl_method_table_for, Any, (Any,), gen_sig)::Core.MethodTable) + push!(edges, gen_sig) end - - push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.inactive)}, world)) - push!(edges, GPUCompiler.methodinstance(typeof(Compiler.Interpreter.rule_backedge_holder), Tuple{Val{0}}, world)) - Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(Val(0))) new_ci.edges = edges diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 68fbd5ec07..6f4cc99295 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -24,105 +24,6 @@ else import Core.Compiler: get_world_counter, get_world_counter as get_inference_world end -function rule_backedge_holder_generator(world::UInt, source, self, ft::Type) - @nospecialize - sig = Tuple{typeof(Base.identity), Int} - min_world = Ref{UInt}(typemin(UInt)) - max_world = Ref{UInt}(typemax(UInt)) - has_ambig = Ptr{Int32}(C_NULL) - mthds = Base._methods_by_ftype( - sig, - nothing, - -1, #=lim=# - world, - false, #=ambig=# - min_world, - max_world, - has_ambig, - ) - mtypes, msp, m = mthds[1] - mi = ccall( - :jl_specializations_get_linfo, - Ref{Core.MethodInstance}, - (Any, Any, Any), - m, - mtypes, - msp, - ) - ci = Core.Compiler.retrieve_code_info(mi, world)::Core.Compiler.CodeInfo - - # prepare a new code info - new_ci = copy(ci) - empty!(new_ci.code) - @static if isdefined(Core, :DebugInfo) - new_ci.debuginfo = Core.DebugInfo(:none) - else - empty!(new_ci.codelocs) - resize!(new_ci.linetable, 1) # see note below - end - empty!(new_ci.ssaflags) - new_ci.ssavaluetypes = 0 - new_ci.min_world = min_world[] - new_ci.max_world = max_world[] - - ### TODO: backedge from inactive, augmented_primal, forward, reverse - edges = Any[] - - if ft == typeof(EnzymeRules.augmented_primal) - rev_sig = Tuple{typeof(EnzymeRules.augmented_primal), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} - push!(edges, ccall(:jl_method_table_for, Any, (Any,), rev_sig)::Core.MethodTable) - push!(edges, rev_sig) - - rev_sig = Tuple{typeof(EnzymeRules.reverse), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Union{Type{<:Enzyme.EnzymeCore.Annotation}, Enzyme.EnzymeCore.Active}, Any, Vararg{Enzyme.EnzymeCore.Annotation}} - push!(edges, ccall(:jl_method_table_for, Any, (Any,), rev_sig)::Core.MethodTable) - push!(edges, rev_sig) - elseif ft == typeof(EnzymeRules.forward) - fwd_sig = Tuple{typeof(EnzymeRules.forward), <:EnzymeRules.FwdConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} - push!(edges, ccall(:jl_method_table_for, Any, (Any,), fwd_sig)::Core.MethodTable) - push!(edges, fwd_sig) - elseif ft == typeof(EnzymeRules.inactive) - ina_sig = Tuple{typeof(EnzymeRules.inactive), Vararg{Any}} - push!(edges, ccall(:jl_method_table_for, Any, (Any,), ina_sig)::Core.MethodTable) - push!(edges, ina_sig) - else - for gen_sig in ( - Tuple{typeof(EnzymeRules.inactive_noinl), Vararg{Any}}, - Tuple{typeof(EnzymeRules.noalias), Vararg{Any}}, - Tuple{typeof(EnzymeRules.inactive_type), Type}, - ) - push!(edges, ccall(:jl_method_table_for, Any, (Any,), gen_sig)::Core.MethodTable) - push!(edges, gen_sig) - end - end - - new_ci.edges = edges - - # XXX: setting this edge does not give us proper method invalidation, see - # JuliaLang/julia#34962 which demonstrates we also need to "call" the kernel. - # invoking `code_llvm` also does the necessary codegen, as does calling the - # underlying C methods -- which GPUCompiler does, so everything Just Works. - - # prepare the slots - new_ci.slotnames = Symbol[Symbol("#self#"), :ft] - new_ci.slotflags = UInt8[0x00 for i = 1:2] - - # return the codegen world age - push!(new_ci.code, Core.Compiler.ReturnNode(0)) - push!(new_ci.ssaflags, 0x00) # Julia's native compilation pipeline (and its verifier) expects `ssaflags` to be the same length as `code` - @static if isdefined(Core, :DebugInfo) - else - push!(new_ci.codelocs, 1) # see note below - end - new_ci.ssavaluetypes += 1 - - return new_ci -end - -@eval Base.@assume_effects :removable :foldable :nothrow @inline function rule_backedge_holder(ft) - $(Expr(:meta, :generated_only)) - $(Expr(:meta, :generated, rule_backedge_holder_generator)) -end - struct EnzymeInterpreter{T} <: AbstractInterpreter @static if HAS_INTEGRATED_CACHE token::Any @@ -319,32 +220,32 @@ function Core.Compiler.abstract_call_gf_by_type( callinfo = AlwaysInlineCallInfo(callinfo, atype) else method_table = Core.Compiler.method_table(interp) - if EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table) + if is_inactive_from_sig(interp, specTypes, sv) callinfo = NoInlineCallInfo(callinfo, atype, :inactive) else if interp.forward_rules - if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table) + if has_frule_from_sig(interp, specTypes, sv) callinfo = NoInlineCallInfo(callinfo, atype, :frule) end end if interp.reverse_rules - if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table) + if has_rrule_from_sig(interp, specTypes, sv) callinfo = NoInlineCallInfo(callinfo, atype, :rrule) end end end - if interp.forward_rules - Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.forward)}, interp.world)::Core.MethodInstance) - Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(EnzymeRules.forward)) - end - if interp.reverse_rules - Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_primal)}, interp.world)::Core.MethodInstance) - Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(EnzymeRules.augmented_primal)) - end - Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.inactive)}, interp.world)::Core.MethodInstance) - Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(typeof(EnzymeRules.inactive))) + # if interp.forward_rules + # Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.forward)}, interp.world)::Core.MethodInstance) + # Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(EnzymeRules.forward)) + # end + # if interp.reverse_rules + # Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.augmented_primal)}, interp.world)::Core.MethodInstance) + # Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(EnzymeRules.augmented_primal)) + # end + # Core.Compiler.add_backedge!(sv, GPUCompiler.methodinstance(typeof(Enzyme.Compiler.Interpreter.rule_backedge_holder), Tuple{typeof(EnzymeRules.inactive)}, interp.world)::Core.MethodInstance) + # Enzyme.Compiler.Interpreter.rule_backedge_holder(Base.inferencebarrier(typeof(EnzymeRules.inactive))) end @static if VERSION ≥ v"1.11-" diff --git a/src/compiler/tfunc.jl b/src/compiler/tfunc.jl index 701cfa8107..eba2f08d00 100644 --- a/src/compiler/tfunc.jl +++ b/src/compiler/tfunc.jl @@ -2,30 +2,32 @@ import EnzymeCore: Annotation import EnzymeCore.EnzymeRules: FwdConfig, RevConfig, forward, augmented_primal, inactive, _annotate_tt function has_frule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), - @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + @nospecialize(TT::Type), sv::Core.Compiler.AbsIntState, partialedge::Bool=true)::Bool ft, tt = _annotate_tt(TT) TT = Tuple{<:FwdConfig,<:Annotation{ft},Type{<:Annotation},tt...} - return isapplicable(interp, forward, TT, sv) + fwd_sig = Tuple{typeof(EnzymeRules.forward), <:EnzymeRules.FwdConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} + return isapplicable(interp, forward, TT, sv, fwd_sig) end function has_rrule_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), - @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + @nospecialize(TT::Type), sv::Core.Compiler.AbsIntState, partialedge::Bool=true)::Bool ft, tt = _annotate_tt(TT) TT = Tuple{<:RevConfig,<:Annotation{ft},Type{<:Annotation},tt...} - return isapplicable(interp, augmented_primal, TT, sv) + rev_sig = Tuple{typeof(EnzymeRules.augmented_primal), <:EnzymeRules.RevConfig, <:Enzyme.EnzymeCore.Annotation, Type{<:Enzyme.EnzymeCore.Annotation},Vararg{Enzyme.EnzymeCore.Annotation}} + return isapplicable(interp, augmented_primal, TT, sv, rev_sig) end function is_inactive_from_sig(@nospecialize(interp::Core.Compiler.AbstractInterpreter), - @nospecialize(TT), sv::Core.Compiler.AbsIntState) - return isapplicable(interp, inactive, TT, sv) + @nospecialize(TT::Type), sv::Core.Compiler.AbsIntState) + return isapplicable(interp, inactive, TT, sv, TT) end # `hasmethod` is a precise match using `Core.Compiler.findsup`, # but here we want the broader query using `Core.Compiler.findall`. # Also add appropriate backedges to the caller `MethodInstance` if given. function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter), - @nospecialize(f), @nospecialize(TT), sv::Core.Compiler.AbsIntState)::Bool + @nospecialize(f), @nospecialize(TT::Type), sv::Core.Compiler.AbsIntState, @nospecialize(partialsig::Type))::Bool tt = Base.to_tuple_type(TT) sig = Base.signature_type(f, tt) mt = ccall(:jl_method_table_for, Any, (Any,), sig) @@ -41,7 +43,8 @@ function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter), # added that did not intersect with any existing method fullmatch = Core.Compiler._any(match::Core.MethodMatch -> match.fully_covers, matches) if !fullmatch - Core.Compiler.add_mt_backedge!(sv, mt, sig) + pmt = ccall(:jl_method_table_for, Any, (Any,), partialsig)::Core.MethodTable + Core.Compiler.add_mt_backedge!(sv, pmt, partialsig) end if Core.Compiler.isempty(matches) return false @@ -53,4 +56,4 @@ function isapplicable(@nospecialize(interp::Core.Compiler.AbstractInterpreter), end return true end -end \ No newline at end of file +end diff --git a/test/ruleinvalidation.jl b/test/ruleinvalidation.jl index 501b0aac10..62579e2415 100644 --- a/test/ruleinvalidation.jl +++ b/test/ruleinvalidation.jl @@ -33,18 +33,11 @@ for m in methods(forward, Tuple{Any,Const{typeof(issue696)},Vararg{Any}}) Base.delete_method(m) end @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 -@static if VERSION < v"1.11-" - @test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 -else - @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 -end +@test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 # now test invalidation for `inactive` inactive(::typeof(issue696), args...) = nothing @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 -@static if VERSION < v"1.11-" - @test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 -else - @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 -end +@test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 + end # module