Skip to content

Commit

Permalink
fix return type
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 1, 2024
1 parent f7b523f commit c192b72
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions lib/EnzymeCore/src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ end
function has_frule_from_sig(@nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool
ft, tt = _annotate_tt(TT)
TT = Tuple{<:FwdConfig, <:Annotation{ft}, Type{<:Annotation}, tt...}
return isapplicable(forward, TT; world, method_table, caller)
Expand All @@ -180,7 +180,7 @@ end
function has_rrule_from_sig(@nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance,Core.Compiler.MethodLookupResult}=nothing)
caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool
ft, tt = _annotate_tt(TT)
TT = Tuple{<:RevConfig, <:Annotation{ft}, Type{<:Annotation}, tt...}
return isapplicable(augmented_primal, TT; world, method_table, caller)
Expand All @@ -192,7 +192,7 @@ end
function isapplicable(@nospecialize(f), @nospecialize(TT);
world::UInt=Base.get_world_counter(),
method_table::Union{Nothing,Core.Compiler.MethodTableView}=nothing,
caller::Union{Nothing,Core.MethodInstance}=nothing)
caller::Union{Nothing,Core.MethodInstance}=nothing)::Bool
tt = Base.to_tuple_type(TT)
sig = Base.signature_type(f, tt)
mt = ccall(:jl_method_table_for, Any, (Any,), sig)
Expand All @@ -215,11 +215,13 @@ function isapplicable(@nospecialize(f), @nospecialize(TT);
end
if Core.Compiler.isempty(matches)
return false
elseif caller isa Core.MethodInstance
for i = 1:Core.Compiler.length(matches)
match = Core.Compiler.getindex(matches, i)::Core.MethodMatch
edge = Core.Compiler.specialize_method(match)::Core.MethodInstance
add_backedge!(caller, edge, sig)
else
if caller isa Core.MethodInstance
for i = 1:Core.Compiler.length(matches)
match = Core.Compiler.getindex(matches, i)::Core.MethodMatch
edge = Core.Compiler.specialize_method(match)::Core.MethodInstance
add_backedge!(caller, edge, sig)
end
end
return true
end
Expand Down

0 comments on commit c192b72

Please sign in to comment.