diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index 8a56481a82..156e7c4fec 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -516,30 +516,23 @@ end world = enzyme_extract_world(fn) @safe_debug "Trying to apply custom forward rule" TT isKWCall llvmf = nothing - if isKWCall - if EnzymeRules.isapplicable(kwfunc, TT; world) - @safe_debug "Applying custom forward rule (kwcall)" TT - llvmf = nested_codegen!(mode, mod, kwfunc, TT, world) - fwd_RT = Compiler.primal_return_type_world(Forward, world, Core.Typeof(kwfunc), TT) - else - TT = Tuple{typeof(world),typeof(kwfunc),TT.parameters...} - llvmf = nested_codegen!(mode, mod, custom_rule_method_error, TT, world) - pushfirst!(args, LLVM.ConstantInt(world)) - fwd_RT = Union{} - end + + functy = if isKWCall + rkwfunc = typeof(Core.kwfunc(EnzymeRules.forward)) else - if EnzymeRules.isapplicable(EnzymeRules.forward, TT; world) - @safe_debug "Applying custom forward rule" TT - llvmf = nested_codegen!(mode, mod, EnzymeRules.forward, TT, world) - fwd_RT = Compiler.primal_return_type_world(Forward, world, typeof(EnzymeRules.forward), TT) - else - TT = Tuple{typeof(world),typeof(EnzymeRules.forward),TT.parameters...} - llvmf = nested_codegen!(mode, mod, custom_rule_method_error, TT, world) - pushfirst!(args, LLVM.ConstantInt(world)) - fwd_RT = Union{} - end + typeof(EnzymeRules.forward) end - + @safe_debug "Applying custom forward rule" TT = TT, functy = functy + try + fmi = my_methodinstance(functy, TT, world) + fwd_RT = primal_return_type_world(Forward, world, rmi) + catch e + TT = Tuple{typeof(world),functy,TT.parameters...} + fmi = my_methodinstance(typeof(custom_rule_method_error), TT, world) + pushfirst!(args, LLVM.ConstantInt(world)) + fwd_RT = Union{} + end + llvmf = nested_codegen!(mode, mod, fmi, world) push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0)) swiftself = has_swiftself(llvmf) @@ -998,50 +991,24 @@ function enzyme_custom_common_rev( end rev_TT = Tuple{tt...} - if isKWCall - rkwfunc = Core.kwfunc(EnzymeRules.reverse) - if EnzymeRules.isapplicable(rkwfunc, rev_TT; world) - @safe_debug "Applying custom reverse rule (kwcall)" TT = rev_TT - try - llvmf = nested_codegen!(mode, mod, rkwfunc, rev_TT, world) - rev_RT = Compiler.primal_return_type_world(Reverse, world, Core.Typeof(rkwfunc), rev_TT) - catch e - rev_TT = Tuple{typeof(world),typeof(rkwfunc),rev_TT.parameters...} - llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) - pushfirst!(args, LLVM.ConstantInt(world)) - rev_RT = Union{} - applicablefn = false - end - else - rev_TT = Tuple{typeof(world),typeof(rkwfunc),rev_TT.parameters...} - llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) - pushfirst!(args, LLVM.ConstantInt(world)) - rev_RT = Union{} - applicablefn = false - end + functy = if isKWCall + rkwfunc = typeof(Core.kwfunc(EnzymeRules.reverse)) else - if EnzymeRules.isapplicable(EnzymeRules.reverse, rev_TT; world) - @safe_debug "Applying custom reverse rule" TT = rev_TT - try - llvmf = nested_codegen!(mode, mod, EnzymeRules.reverse, rev_TT, world) - rev_RT = Compiler.primal_return_type_world(Reverse, world, typeof(EnzymeRules.reverse), rev_TT) - catch e - rev_TT = - Tuple{typeof(world),typeof(EnzymeRules.reverse),rev_TT.parameters...} - llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) - pushfirst!(args, LLVM.ConstantInt(world)) - rev_RT = Union{} - applicablefn = false - end - else - rev_TT = - Tuple{typeof(world),typeof(EnzymeRules.reverse),rev_TT.parameters...} - llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world) - pushfirst!(args, LLVM.ConstantInt(world)) - rev_RT = Union{} - applicablefn = false - end + typeof(EnzymeRules.reverse) + end + + @safe_debug "Applying custom reverse rule" TT = rev_TT, functy=functy + try + rmi = my_methodinstance(functy, rev_TT, world) + rev_RT = return_type(interp, rmi) + catch e + rev_TT = Tuple{typeof(world),functy,rev_TT.parameters...} + rmi = my_methodinstance(typeof(custom_rule_method_error), rev_TT, world) + pushfirst!(args, LLVM.ConstantInt(world)) + rev_RT = Union{} + applicablefn = false end + llvmf = nested_codegen!(mode, mod, rmi, world) end push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))