Skip to content

Commit

Permalink
more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 1, 2024
1 parent e147a57 commit 39a99ab
Showing 1 changed file with 21 additions and 40 deletions.
61 changes: 21 additions & 40 deletions src/rules/customrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -523,14 +523,16 @@ end
typeof(EnzymeRules.forward)
end
@safe_debug "Applying custom forward rule" TT = TT, functy = functy
try
fmi = try
fmi = my_methodinstance(functy, TT, world)
fwd_RT = primal_return_type_world(Forward, world, rmi)
fwd_RT = primal_return_type_world(Forward, world, fmi)
fmi
catch e
TT = Tuple{typeof(world),functy,TT.parameters...}
fmi = my_methodinstance(typeof(custom_rule_method_error), TT, world)
pushfirst!(args, LLVM.ConstantInt(world))
fwd_RT = Union{}
fmi
end
llvmf = nested_codegen!(mode, mod, fmi, world)
push!(function_attributes(llvmf), EnumAttribute("alwaysinline", 0))
Expand Down Expand Up @@ -776,10 +778,9 @@ end

mode = get_mode(gutils)

ami = nothing

augprimal_tt = copy(activity)
if isKWCall
functy = if isKWCall
popfirst!(augprimal_tt)
@assert kwtup !== nothing
insert!(augprimal_tt, 1, kwtup)
Expand All @@ -788,50 +789,30 @@ end
insert!(augprimal_tt, 5, Type{RT})

augprimal_TT = Tuple{augprimal_tt...}
kwfunc = Core.kwfunc(EnzymeRules.augmented_primal)
try
ami = my_methodinstance(Core.Typeof(kwfunc), augprimal_TT, world)
@safe_debug "Applying custom augmented_primal rule (kwcall)" TT = augprimal_TT
catch e
augprimal_TT = Tuple{typeof(world),typeof(kwfunc),augprimal_TT.parameters...}
ami = my_methodinstance(
typeof(custom_rule_method_error),
augprimal_TT,
world,
)
if forward
pushfirst!(args, LLVM.ConstantInt(world))
end
end
Core.Typeof(Core.kwfunc(EnzymeRules.augmented_primal))
else
@assert kwtup === nothing
insert!(augprimal_tt, 1, C)
insert!(augprimal_tt, 3, Type{RT})

augprimal_TT = Tuple{augprimal_tt...}
try
ami = my_methodinstance(
Core.Typeof(EnzymeRules.augmented_primal),
augprimal_TT,
world,
)
@safe_debug "Applying custom augmented_primal rule" TT = augprimal_TT
catch e
augprimal_TT = Tuple{
typeof(world),
typeof(EnzymeRules.augmented_primal),
augprimal_TT.parameters...,
}
ami = my_methodinstance(
typeof(custom_rule_method_error),
augprimal_TT,
world,
)
if forward
pushfirst!(args, LLVM.ConstantInt(world))
end
typeof(EnzymeRules.augmented_primal)
end

ami = try
my_methodinstance(functy, augprimal_TT, world)
catch e
augprimal_TT = Tuple{typeof(world),functy,augprimal_TT.parameters...}
ami = my_methodinstance(
typeof(custom_rule_method_error),
augprimal_TT,
world,
)
if forward
pushfirst!(args, LLVM.ConstantInt(world))
end
end
@safe_debug "Applying custom augmented_primal rule" TT = augprimal_TT, functy=functy
return ami,
augprimal_TT,
(
Expand Down

0 comments on commit 39a99ab

Please sign in to comment.