Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 15, 2024
1 parent 060626a commit 10c8379
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,7 @@ if VERSION >= v"1.11.0-DEV.1552"
last_ina_rule_world::Tuple
end

@inline EnzymeCacheToken(target_type, always_inline, method_table, param_type, is_forward, is_reverse) =
@inline EnzymeCacheToken(target_type::Type, always_inline::Any, method_table::Core.MethodTable, param_type::Type, world::UInt, is_forward::Bool, is_reverse::Bool) =
EnzymeCacheToken(target_type, always_inline, method_table, param_type,
is_forward ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.forward, Tuple{<:EnzymeCore.EnzymeRules.FwdConfig, <:Annotation, Type{<:Annotation}, Vararg{Annotation}}, world)...,) : nothing,
is_reverse ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.augmented_primal, Tuple{<:EnzymeCore.EnzymeRules.RevConfig, <:Annotation, Type{<:Annotation}, Vararg{Annotation}}, world)...,) : nothing,
Expand All @@ -1218,6 +1218,7 @@ if VERSION >= v"1.11.0-DEV.1552"
job.config.always_inline,
GPUCompiler.method_table(job),
typeof(job.config.params),
job.world,
job.config.params.mode == API.DEM_ForwardMode,
job.config.params.mode != API.DEM_ForwardMode
)
Expand Down
3 changes: 2 additions & 1 deletion src/compiler/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ function get_job(
world=Base.get_world_counter()
end

primal = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, eltype(Core.Typeof(func)), Tuple{map(eltype, types.parameters)...}, world)
primal = my_methodinstance(mode == API.DEM_ForwardMode ? Forward : Reverse, Core.Typeof(func), tt, world)
rt = Compiler.primal_return_type_world(mode == API.DEM_ForwardMode ? Forward : Reverse, world, Core.Typeof(func), tt)

@assert primal !== nothing
rt = A{rt}
target = Compiler.EnzymeTarget()
if modifiedBetween === nothing
Expand Down
2 changes: 2 additions & 0 deletions src/typeutils/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ function primal_interp_world(
false,
GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=#
EnzymeCompilerParams,
world,
false,
true
)
Expand All @@ -47,6 +48,7 @@ function primal_interp_world(
false,
GPUCompiler.GLOBAL_METHOD_TABLE, #=job.config.always_inline=#
EnzymeCompilerParams,
world,
true,
false
)
Expand Down

0 comments on commit 10c8379

Please sign in to comment.