From 39a99ab3b864cf1381651fb5f18a87972cec5e07 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sat, 30 Nov 2024 20:24:52 -0500 Subject: [PATCH] more cleanup --- src/rules/customrules.jl | 61 ++++++++++++++-------------------------- 1 file changed, 21 insertions(+), 40 deletions(-) diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 156e7c4fec..2aaa560201 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -523,14 +523,16 @@ end typeof(EnzymeRules.forward) end @safe_debug "Applying custom forward rule" TT = TT, functy = functy - try + fmi = try fmi = my_methodinstance(functy, TT, world) - fwd_RT = primal_return_type_world(Forward, world, rmi) + fwd_RT = primal_return_type_world(Forward, world, fmi) + fmi catch e TT = Tuple{typeof(world),functy,TT.parameters...} fmi = my_methodinstance(typeof(custom_rule_method_error), TT, world) pushfirst!(args, LLVM.ConstantInt(world)) fwd_RT = Union{} + fmi end llvmf = nested_codegen!(mode, mod, fmi, world) push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) @@ -776,10 +778,9 @@ end mode = get_mode(gutils) - ami = nothing augprimal_tt = copy(activity) - if isKWCall + functy = if isKWCall popfirst!(augprimal_tt) @assert kwtup !== nothing insert!(augprimal_tt, 1, kwtup) @@ -788,50 +789,30 @@ end insert!(augprimal_tt, 5, Type{RT}) augprimal_TT = Tuple{augprimal_tt...} - kwfunc = Core.kwfunc(EnzymeRules.augmented_primal) - try - ami = my_methodinstance(Core.Typeof(kwfunc), augprimal_TT, world) - @safe_debug "Applying custom augmented_primal rule (kwcall)" TT = augprimal_TT - catch e - augprimal_TT = Tuple{typeof(world),typeof(kwfunc),augprimal_TT.parameters...} - ami = my_methodinstance( - typeof(custom_rule_method_error), - augprimal_TT, - world, - ) - if forward - pushfirst!(args, LLVM.ConstantInt(world)) - end - end + Core.Typeof(Core.kwfunc(EnzymeRules.augmented_primal)) else @assert kwtup === nothing insert!(augprimal_tt, 1, C) insert!(augprimal_tt, 3, Type{RT}) augprimal_TT = Tuple{augprimal_tt...} - try - ami = my_methodinstance( - Core.Typeof(EnzymeRules.augmented_primal), - augprimal_TT, - world, - ) - @safe_debug "Applying custom augmented_primal rule" TT = augprimal_TT - catch e - augprimal_TT = Tuple{ - typeof(world), - typeof(EnzymeRules.augmented_primal), - augprimal_TT.parameters..., - } - ami = my_methodinstance( - typeof(custom_rule_method_error), - augprimal_TT, - world, - ) - if forward - pushfirst!(args, LLVM.ConstantInt(world)) - end + typeof(EnzymeRules.augmented_primal) + end + + ami = try + my_methodinstance(functy, augprimal_TT, world) + catch e + augprimal_TT = Tuple{typeof(world),functy,augprimal_TT.parameters...} + ami = my_methodinstance( + typeof(custom_rule_method_error), + augprimal_TT, + world, + ) + if forward + pushfirst!(args, LLVM.ConstantInt(world)) end end + @safe_debug "Applying custom augmented_primal rule" TT = augprimal_TT, functy=functy return ami, augprimal_TT, (