From 1c39f7c699a43094b496c427137d4c3daea531b6 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sun, 1 Dec 2024 17:39:11 +0100 Subject: [PATCH 1/8] Use right caller for isapplicable usage --- lib/EnzymeCore/src/rules.jl | 32 ++++++-------------------------- src/compiler/interpreter.jl | 19 +++++++------------ 2 files changed, 13 insertions(+), 38 deletions(-) diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index 945951b216..433417bd10 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -192,7 +192,7 @@ end function isapplicable(@nospecialize(f), @nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) + caller::Union{Nothing,Core.MethodInstance}=nothing) tt = Base.to_tuple_type(TT) sig = Base.signature_type(f, tt) mt = ccall(:jl_method_table_for, Any, (Any,), sig) @@ -209,35 +209,15 @@ function isapplicable(@nospecialize(f), @nospecialize(TT); end fullmatch = Core.Compiler._any(match::Core.MethodMatch->match.fully_covers, matches) if !fullmatch - if caller isa Core.MethodInstance - add_mt_backedge!(caller, mt, sig) - elseif caller isa Core.Compiler.MethodLookupResult - for j = 1:Core.Compiler.length(caller) - cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch - cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance - add_mt_backedge!(cspec, mt, sig) - end - end + add_mt_backedge!(caller, mt, sig) end if Core.Compiler.isempty(matches) return false else - if caller isa Core.MethodInstance - for i = 1:Core.Compiler.length(matches) - match = Core.Compiler.getindex(matches, i)::Core.MethodMatch - edge = Core.Compiler.specialize_method(match)::Core.MethodInstance - add_backedge!(caller, edge, sig) - end - elseif caller isa Core.Compiler.MethodLookupResult - for j = 1:Core.Compiler.length(caller) - cmatch = Core.Compiler.getindex(caller, j)::Core.MethodMatch - cspec = Core.Compiler.specialize_method(cmatch)::Core.MethodInstance - for i = 1:Core.Compiler.length(matches) - match = Core.Compiler.getindex(matches, i)::Core.MethodMatch - edge = Core.Compiler.specialize_method(match)::Core.MethodInstance - add_backedge!(cspec, edge, sig) - end - end + for i = 1:Core.Compiler.length(matches) + match = Core.Compiler.getindex(matches, i)::Core.MethodMatch + edge = Core.Compiler.specialize_method(match)::Core.MethodInstance + add_backedge!(caller, edge, sig) end return true end diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index ff4c6c991b..dff08d577b 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -214,28 +214,23 @@ function Core.Compiler.abstract_call_gf_by_type( callinfo = ret.info method_table = Core.Compiler.method_table(interp) specTypes = simplify_kw(atype) - caller = if callinfo isa Core.Compiler.MethodMatchInfo && callinfo.results isa Core.Compiler.MethodLookupResult - callinfo.results - else - nothing - end if is_primitive_func(specTypes) callinfo = NoInlineCallInfo(callinfo, atype, :primitive) elseif is_alwaysinline_func(specTypes) callinfo = AlwaysInlineCallInfo(callinfo, atype) - elseif EnzymeRules.is_inactive_from_sig(specTypes; world = interp.world, method_table, caller) - callinfo = NoInlineCallInfo(callinfo, atype, :inactive) + elseif EnzymeRules.is_inactive_from_sig(specTypes; world=interp.world, method_table, caller=sv.linfo) + callinfo = NoInlineCallInfo(callinfo, atype, :inactive) else if interp.forward_rules - if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table, caller) + if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table, caller=sv.linfo) callinfo = NoInlineCallInfo(callinfo, atype, :frule) - end + end end - + if interp.reverse_rules - if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table, caller) - callinfo = NoInlineCallInfo(callinfo, atype, :rrule) + if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table, caller=sv.linfo) + callinfo = NoInlineCallInfo(callinfo, atype, :rrule) end end end From 22760366cc6dae97a1d90342d1626abaa910e208 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sun, 1 Dec 2024 19:18:21 +0100 Subject: [PATCH 2/8] use absint directly for applicable check --- src/compiler/interpreter.jl | 50 +++++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 10 deletions(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index dff08d577b..7b47eedcd3 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -192,6 +192,7 @@ Core.Compiler.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getresult(info.info, idx) +import .EnzymeRules: FwdConfig, RevConfig, Annotation using Core.Compiler: ArgInfo, StmtInfo, AbsIntState function Core.Compiler.abstract_call_gf_by_type( @nospecialize(interp::EnzymeInterpreter), @@ -212,25 +213,54 @@ function Core.Compiler.abstract_call_gf_by_type( max_methods::Int, ) callinfo = ret.info - method_table = Core.Compiler.method_table(interp) specTypes = simplify_kw(atype) if is_primitive_func(specTypes) callinfo = NoInlineCallInfo(callinfo, atype, :primitive) elseif is_alwaysinline_func(specTypes) callinfo = AlwaysInlineCallInfo(callinfo, atype) - elseif EnzymeRules.is_inactive_from_sig(specTypes; world=interp.world, method_table, caller=sv.linfo) - callinfo = NoInlineCallInfo(callinfo, atype, :inactive) else - if interp.forward_rules - if EnzymeRules.has_frule_from_sig(specTypes; world = interp.world, method_table, caller=sv.linfo) - callinfo = NoInlineCallInfo(callinfo, atype, :frule) + (;fargs, argtypes) = arginfo + # 1. Check if function is inactive + inactive_arginfo = ArgInfo(nothing, pushfirst!(copy(argtypes), Core.Const(EnzymeRules.inactive))) + inactive_atype = Tuple{typeof(EnzymeRules.inactive), atype.parameters...} + inactive_meta = @invoke Core.Compiler.abstract_call_gf_by_type( + interp::AbstractInterpreter, + EnzymeRules.inactive::Any, + inactive_arginfo::ArgInfo, + si::StmtInfo, + inactive_atype::Any, + sv::AbsIntState, + max_methods::Int, + ) + if Core.Compiler.nmatches(inactive_meta.info) != 0 + callinfo = NoInlineCallInfo(callinfo, atype, :inactive) + else + # 2. Check if rule is defined + if interp.forward_rules + rulef = EnzymeRules.forward + ft, tt = EnzymeRules._annotate_tt(atype) + rule_atype = Tuple{typeof(EnzymeRules.forward), <:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} + rule_argtypes = Any[Core.Const(EnzymeRules.forward), FwdConfig, Annotation{ft}, Type{<:Annotation}, tt...] + else + rulef = EnzymeRules.reverse + ft, tt = EnzymeRules._annotate_tt(atype) + rule_atype = Tuple{typeof(EnzymeRules.reverse), <:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} + rule_argtypes = Any[Core.Const(EnzymeRules.reverse), RevConfig, Annotation{ft}, Type{<:Annotation}, tt...] end - end - if interp.reverse_rules - if EnzymeRules.has_rrule_from_sig(specTypes; world = interp.world, method_table, caller=sv.linfo) - callinfo = NoInlineCallInfo(callinfo, atype, :rrule) + rule_arginfo = ArgInfo(nothing, rule_argtypes) + rule_meta = @invoke Core.Compiler.abstract_call_gf_by_type( + interp::AbstractInterpreter, + rulef::Any, + rule_arginfo::ArgInfo, + si::StmtInfo, + rule_atype::Any, + sv::AbsIntState, + max_methods::Int, + ) + if Core.Compiler.nmatches(rule_meta.info) != 0 + callinfo = NoInlineCallInfo(callinfo, atype, interp.forward_rules ? :frule : :rrule) end end end From fc1ea08ffd0638ee98843051be283f64fedcd2dc Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 1 Dec 2024 14:46:07 -0500 Subject: [PATCH 3/8] fix nothing error --- lib/EnzymeCore/src/rules.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index 433417bd10..db7f5a6cf2 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -209,11 +209,13 @@ function isapplicable(@nospecialize(f), @nospecialize(TT); end fullmatch = Core.Compiler._any(match::Core.MethodMatch->match.fully_covers, matches) if !fullmatch - add_mt_backedge!(caller, mt, sig) + if caller isa Core.MethodInstance + add_mt_backedge!(caller, mt, sig) + end end if Core.Compiler.isempty(matches) return false - else + elseif caller isa Core.MethodInstance for i = 1:Core.Compiler.length(matches) match = Core.Compiler.getindex(matches, i)::Core.MethodMatch edge = Core.Compiler.specialize_method(match)::Core.MethodInstance From 61278dec5e8c5d3ead046cc657e6ad501893c9a5 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 1 Dec 2024 14:55:08 -0500 Subject: [PATCH 4/8] fix return type --- lib/EnzymeCore/src/rules.jl | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/lib/EnzymeCore/src/rules.jl b/lib/EnzymeCore/src/rules.jl index db7f5a6cf2..dc33c11110 100644 --- a/lib/EnzymeCore/src/rules.jl +++ b/lib/EnzymeCore/src/rules.jl @@ -171,7 +171,7 @@ end function has_frule_from_sig(@nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) + caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool ft, tt = _annotate_tt(TT) TT = Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} return isapplicable(forward, TT; world, method_table, caller) @@ -180,7 +180,7 @@ end function has_rrule_from_sig(@nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing) + caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool ft, tt = _annotate_tt(TT) TT = Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} return isapplicable(augmented_primal, TT; world, method_table, caller) @@ -192,7 +192,7 @@ end function isapplicable(@nospecialize(f), @nospecialize(TT); world::UInt=Base.get_world_counter(), method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing, - caller::Union{Nothing,Core.MethodInstance}=nothing) + caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool tt = Base.to_tuple_type(TT) sig = Base.signature_type(f, tt) mt = ccall(:jl_method_table_for, Any, (Any,), sig) @@ -215,11 +215,13 @@ function isapplicable(@nospecialize(f), @nospecialize(TT); end if Core.Compiler.isempty(matches) return false - elseif caller isa Core.MethodInstance - for i = 1:Core.Compiler.length(matches) - match = Core.Compiler.getindex(matches, i)::Core.MethodMatch - edge = Core.Compiler.specialize_method(match)::Core.MethodInstance - add_backedge!(caller, edge, sig) + else + if caller isa Core.MethodInstance + for i = 1:Core.Compiler.length(matches) + match = Core.Compiler.getindex(matches, i)::Core.MethodMatch + edge = Core.Compiler.specialize_method(match)::Core.MethodInstance + add_backedge!(caller, edge, sig) + end end return true end From 400e81c64b21836cf787dba040a168fc5aa3e1b8 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Sun, 1 Dec 2024 15:31:28 -0500 Subject: [PATCH 5/8] Fix nmatches type mismatch --- src/compiler/interpreter.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 7b47eedcd3..b48fa914a5 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -233,7 +233,7 @@ function Core.Compiler.abstract_call_gf_by_type( sv::AbsIntState, max_methods::Int, ) - if Core.Compiler.nmatches(inactive_meta.info) != 0 + if inactive_meta.info isa Core.Compiler.MethodMatchInfo && Core.Compiler.nmatches(inactive_meta.info) != 0 callinfo = NoInlineCallInfo(callinfo, atype, :inactive) else # 2. Check if rule is defined @@ -259,7 +259,7 @@ function Core.Compiler.abstract_call_gf_by_type( sv::AbsIntState, max_methods::Int, ) - if Core.Compiler.nmatches(rule_meta.info) != 0 + if rule_meta.info isa Core.Compiler.MethodMatchInfo && Core.Compiler.nmatches(rule_meta.info) != 0 callinfo = NoInlineCallInfo(callinfo, atype, interp.forward_rules ? :frule : :rrule) end end From 065668032b9c12e49fbcead612bcc26334131f7d Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 2 Dec 2024 09:11:13 +0100 Subject: [PATCH 6/8] use tfunc for isapplicable directly --- src/compiler/interpreter.jl | 31 ++++++------------------------- 1 file changed, 6 insertions(+), 25 deletions(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index b48fa914a5..0667ecd19f 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -222,44 +222,25 @@ function Core.Compiler.abstract_call_gf_by_type( else (;fargs, argtypes) = arginfo # 1. Check if function is inactive - inactive_arginfo = ArgInfo(nothing, pushfirst!(copy(argtypes), Core.Const(EnzymeRules.inactive))) - inactive_atype = Tuple{typeof(EnzymeRules.inactive), atype.parameters...} - inactive_meta = @invoke Core.Compiler.abstract_call_gf_by_type( - interp::AbstractInterpreter, - EnzymeRules.inactive::Any, - inactive_arginfo::ArgInfo, - si::StmtInfo, - inactive_atype::Any, - sv::AbsIntState, - max_methods::Int, - ) - if inactive_meta.info isa Core.Compiler.MethodMatchInfo && Core.Compiler.nmatches(inactive_meta.info) != 0 + inactive_argtypes = pushfirst!(copy(argtypes), Core.Const(EnzymeRules.inactive)) + inactive_meta = abstract_applicable(interp, inactive_argtypes, sv, max_methods) # Does backedge handling internally + + if inactive_meta.rt !== Core.Const(false) # Ugh it may be Const(true), Const(false), Bool callinfo = NoInlineCallInfo(callinfo, atype, :inactive) else # 2. Check if rule is defined if interp.forward_rules rulef = EnzymeRules.forward ft, tt = EnzymeRules._annotate_tt(atype) - rule_atype = Tuple{typeof(EnzymeRules.forward), <:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} rule_argtypes = Any[Core.Const(EnzymeRules.forward), FwdConfig, Annotation{ft}, Type{<:Annotation}, tt...] else rulef = EnzymeRules.reverse ft, tt = EnzymeRules._annotate_tt(atype) - rule_atype = Tuple{typeof(EnzymeRules.reverse), <:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...} rule_argtypes = Any[Core.Const(EnzymeRules.reverse), RevConfig, Annotation{ft}, Type{<:Annotation}, tt...] end - rule_arginfo = ArgInfo(nothing, rule_argtypes) - rule_meta = @invoke Core.Compiler.abstract_call_gf_by_type( - interp::AbstractInterpreter, - rulef::Any, - rule_arginfo::ArgInfo, - si::StmtInfo, - rule_atype::Any, - sv::AbsIntState, - max_methods::Int, - ) - if rule_meta.info isa Core.Compiler.MethodMatchInfo && Core.Compiler.nmatches(rule_meta.info) != 0 + rule_meta = abstract_applicable(interp, rule_argtypes, sv, max_methods) # Does backedge handling internally + if rule_meta.rt !== Core.Const(false) # Ugh it may be Const(true), Const(false), Bool callinfo = NoInlineCallInfo(callinfo, atype, interp.forward_rules ? :frule : :rrule) end end From 6beac14d40fe03c4da64ad92a743f8938792610e Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 2 Dec 2024 13:30:50 +0100 Subject: [PATCH 7/8] safe WIP --- src/compiler/interpreter.jl | 37 +++++++++++++++++++++++++++---------- 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 0667ecd19f..7afb49c071 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -192,6 +192,21 @@ Core.Compiler.getsplit_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getresult_impl(info::AlwaysInlineCallInfo, idx::Int) = Core.Compiler.getresult(info.info, idx) +function annotate(@nospecialize(T)) + T = widenconst(T) + if Core.Compiler.isvarargtype(T) + VA = T + T = annotate(Core.Compiler.unwrapva(VA)) + if isdefined(VA, :N) + return Vararg{T, VA.N} + else + return Vararg{T} + end + else + return Annotation{T} + end +end + import .EnzymeRules: FwdConfig, RevConfig, Annotation using Core.Compiler: ArgInfo, StmtInfo, AbsIntState function Core.Compiler.abstract_call_gf_by_type( @@ -222,25 +237,27 @@ function Core.Compiler.abstract_call_gf_by_type( else (;fargs, argtypes) = arginfo # 1. Check if function is inactive - inactive_argtypes = pushfirst!(copy(argtypes), Core.Const(EnzymeRules.inactive)) - inactive_meta = abstract_applicable(interp, inactive_argtypes, sv, max_methods) # Does backedge handling internally + inactive_argtypes = Any[Core.Const(Core.applicable), Core.Const(EnzymeRules.inactive)] + append!(inactive_argtypes, argtypes) - if inactive_meta.rt !== Core.Const(false) # Ugh it may be Const(true), Const(false), Bool + inactive_meta = Core.Compiler.abstract_applicable(interp, inactive_argtypes, sv, #=max_methods=# -1) # Does backedge handling internally + if inactive_meta.rt != Core.Const(false) # It may be Const(true), Const(false), Bool callinfo = NoInlineCallInfo(callinfo, atype, :inactive) else # 2. Check if rule is defined + tt = Core.Compiler.anymap(annotate, argtypes) if interp.forward_rules rulef = EnzymeRules.forward - ft, tt = EnzymeRules._annotate_tt(atype) - rule_argtypes = Any[Core.Const(EnzymeRules.forward), FwdConfig, Annotation{ft}, Type{<:Annotation}, tt...] + rule_argtypes = Any[Core.Const(Core.applicable), Core.Const(EnzymeRules.forward), FwdConfig, tt[1], Type{<:Annotation}, tt[2:end]...] else rulef = EnzymeRules.reverse - ft, tt = EnzymeRules._annotate_tt(atype) - rule_argtypes = Any[Core.Const(EnzymeRules.reverse), RevConfig, Annotation{ft}, Type{<:Annotation}, tt...] + rule_argtypes = Any[Core.Const(Core.applicable), Core.Const(EnzymeRules.reverse), RevConfig, tt[1], Type{<:Annotation}, tt[2:end]...] end - - rule_meta = abstract_applicable(interp, rule_argtypes, sv, max_methods) # Does backedge handling internally - if rule_meta.rt !== Core.Const(false) # Ugh it may be Const(true), Const(false), Bool + Base.@show rule_argtypes + Base.@show Core.Compiler.argtypes_to_type(rule_argtypes) + rule_meta = Core.Compiler.abstract_applicable(interp, rule_argtypes, sv, #=max_methods=# -1) # Does backedge handling internally + @show rule_meta.rt + if rule_meta.rt != Core.Const(false) # It may be Const(true), Const(false), Bool callinfo = NoInlineCallInfo(callinfo, atype, interp.forward_rules ? :frule : :rrule) end end From 736ddc86c1c49411b36bed52798c6e9ea88b2bc0 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 4 Dec 2024 14:05:36 +0100 Subject: [PATCH 8/8] invalidation for inactive now works --- test/ruleinvalidation.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/test/ruleinvalidation.jl b/test/ruleinvalidation.jl index 704ada2b6e..37cb21b08f 100644 --- a/test/ruleinvalidation.jl +++ b/test/ruleinvalidation.jl @@ -34,18 +34,14 @@ for m in methods(forward, Tuple{Any,Const{typeof(issue696)},Vararg{Any}}) end @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 @static if VERSION < v"1.11-" -@test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 + @test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 else -@test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 + @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 2.0 end # now test invalidation for `inactive` inactive(::typeof(issue696), args...) = nothing @test autodiff(Forward, issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 -@static if VERSION < v"1.11-" -@test_broken autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 -else @test autodiff(Forward, call_issue696, Duplicated(1.0, 1.0))[1] ≈ 0.0 -end end # module