Skip to content

Commit

Permalink
fewer calls in custom rules
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 1, 2024
1 parent 348303f commit e147a57
Showing 1 changed file with 31 additions and 64 deletions.
95 changes: 31 additions & 64 deletions src/rules/customrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -516,30 +516,23 @@ end
world = enzyme_extract_world(fn)
@safe_debug "Trying to apply custom forward rule" TT isKWCall
llvmf = nothing
if isKWCall
if EnzymeRules.isapplicable(kwfunc, TT; world)
@safe_debug "Applying custom forward rule (kwcall)" TT
llvmf = nested_codegen!(mode, mod, kwfunc, TT, world)
fwd_RT = Compiler.primal_return_type_world(Forward, world, Core.Typeof(kwfunc), TT)
else
TT = Tuple{typeof(world),typeof(kwfunc),TT.parameters...}
llvmf = nested_codegen!(mode, mod, custom_rule_method_error, TT, world)
pushfirst!(args, LLVM.ConstantInt(world))
fwd_RT = Union{}
end

functy = if isKWCall
rkwfunc = typeof(Core.kwfunc(EnzymeRules.forward))
else
if EnzymeRules.isapplicable(EnzymeRules.forward, TT; world)
@safe_debug "Applying custom forward rule" TT
llvmf = nested_codegen!(mode, mod, EnzymeRules.forward, TT, world)
fwd_RT = Compiler.primal_return_type_world(Forward, world, typeof(EnzymeRules.forward), TT)
else
TT = Tuple{typeof(world),typeof(EnzymeRules.forward),TT.parameters...}
llvmf = nested_codegen!(mode, mod, custom_rule_method_error, TT, world)
pushfirst!(args, LLVM.ConstantInt(world))
fwd_RT = Union{}
end
typeof(EnzymeRules.forward)
end

@safe_debug "Applying custom forward rule" TT = TT, functy = functy
try
fmi = my_methodinstance(functy, TT, world)
fwd_RT = primal_return_type_world(Forward, world, rmi)
catch e
TT = Tuple{typeof(world),functy,TT.parameters...}
fmi = my_methodinstance(typeof(custom_rule_method_error), TT, world)
pushfirst!(args, LLVM.ConstantInt(world))
fwd_RT = Union{}
end
llvmf = nested_codegen!(mode, mod, fmi, world)
push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))

swiftself = has_swiftself(llvmf)
Expand Down Expand Up @@ -998,50 +991,24 @@ function enzyme_custom_common_rev(
end
rev_TT = Tuple{tt...}

if isKWCall
rkwfunc = Core.kwfunc(EnzymeRules.reverse)
if EnzymeRules.isapplicable(rkwfunc, rev_TT; world)
@safe_debug "Applying custom reverse rule (kwcall)" TT = rev_TT
try
llvmf = nested_codegen!(mode, mod, rkwfunc, rev_TT, world)
rev_RT = Compiler.primal_return_type_world(Reverse, world, Core.Typeof(rkwfunc), rev_TT)
catch e
rev_TT = Tuple{typeof(world),typeof(rkwfunc),rev_TT.parameters...}
llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world)
pushfirst!(args, LLVM.ConstantInt(world))
rev_RT = Union{}
applicablefn = false
end
else
rev_TT = Tuple{typeof(world),typeof(rkwfunc),rev_TT.parameters...}
llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world)
pushfirst!(args, LLVM.ConstantInt(world))
rev_RT = Union{}
applicablefn = false
end
functy = if isKWCall
rkwfunc = typeof(Core.kwfunc(EnzymeRules.reverse))
else
if EnzymeRules.isapplicable(EnzymeRules.reverse, rev_TT; world)
@safe_debug "Applying custom reverse rule" TT = rev_TT
try
llvmf = nested_codegen!(mode, mod, EnzymeRules.reverse, rev_TT, world)
rev_RT = Compiler.primal_return_type_world(Reverse, world, typeof(EnzymeRules.reverse), rev_TT)
catch e
rev_TT =
Tuple{typeof(world),typeof(EnzymeRules.reverse),rev_TT.parameters...}
llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world)
pushfirst!(args, LLVM.ConstantInt(world))
rev_RT = Union{}
applicablefn = false
end
else
rev_TT =
Tuple{typeof(world),typeof(EnzymeRules.reverse),rev_TT.parameters...}
llvmf = nested_codegen!(mode, mod, custom_rule_method_error, rev_TT, world)
pushfirst!(args, LLVM.ConstantInt(world))
rev_RT = Union{}
applicablefn = false
end
typeof(EnzymeRules.reverse)
end

@safe_debug "Applying custom reverse rule" TT = rev_TT, functy=functy
try
rmi = my_methodinstance(functy, rev_TT, world)
rev_RT = return_type(interp, rmi)
catch e
rev_TT = Tuple{typeof(world),functy,rev_TT.parameters...}
rmi = my_methodinstance(typeof(custom_rule_method_error), rev_TT, world)
pushfirst!(args, LLVM.ConstantInt(world))
rev_RT = Union{}
applicablefn = false
end
llvmf = nested_codegen!(mode, mod, rmi, world)
end

push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))
Expand Down

0 comments on commit e147a57

Please sign in to comment.