Skip to content

Commit

Permalink
Fix higher order codegen (#2161)
Browse files Browse the repository at this point in the history
* Fix higher order codegen

* fix

* fix

* working

* Update validation.jl

* handle, again

* Update validation.jl
  • Loading branch information
wsmoses authored Dec 7, 2024
1 parent 865cced commit 3edec40
Show file tree
Hide file tree
Showing 5 changed files with 245 additions and 229 deletions.
26 changes: 21 additions & 5 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5226,12 +5226,12 @@ end
# JIT
##

function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType))
function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType), prepost::String)
if job.config.params.ABI <: InlineABI
return CompileResult(
Val((Symbol(mod), Symbol(adjoint_name))),
Val((Symbol(mod), Symbol(primal_name))),
TapeType,
TapeType
)
end

Expand Down Expand Up @@ -5269,7 +5269,7 @@ end
const DumpPostOpt = Ref(false)

# actual compilation
function _thunk(job, postopt::Bool = true)
function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, String, Union{String, Nothing}, Type, String}
mod, meta = codegen(:llvm, job; optimize = false)
adjointf, augmented_primalf = meta.adjointf, meta.augmented_primalf

Expand All @@ -5287,7 +5287,12 @@ function _thunk(job, postopt::Bool = true)
end

# Run post optimization pipeline
if postopt
prepost = if postopt
mstr = if job.config.params.ABI <: InlineABI
""
else
string(mod)
end
if job.config.params.ABI <: FFIABI || job.config.params.ABI <: NonGenABI
post_optimze!(mod, JIT.get_tm())
if DumpPostOpt[]
Expand All @@ -5296,12 +5301,17 @@ function _thunk(job, postopt::Bool = true)
else
propagate_returned!(mod)
end
mstr
else
""
end
return (mod, adjoint_name, primal_name, meta.TapeType)
return (mod, adjoint_name, primal_name, meta.TapeType, prepost)
end

const cache = Dict{UInt,CompileResult}()

const autodiff_cache = Dict{Ptr{Cvoid},Tuple{String, String}}()

const cache_lock = ReentrantLock()
@inline function cached_compilation(@nospecialize(job::CompilerJob))::CompileResult
key = hash(job)
Expand All @@ -5313,6 +5323,12 @@ const cache_lock = ReentrantLock()
if obj === nothing
asm = _thunk(job)
obj = _link(job, asm...)
if obj.adjoint isa Ptr{Nothing}
autodiff_cache[obj.adjoint] = (asm[2], asm[5])
end
if obj.primal isa Ptr{Nothing} && asm[3] isa String
autodiff_cache[obj.primal] = (asm[3], asm[5])
end
cache[key] = obj
end
obj
Expand Down
25 changes: 1 addition & 24 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter

forward_rules::Bool
reverse_rules::Bool
deferred_lower::Bool
broadcast_rewrite::Bool
handler::T
end
Expand All @@ -55,7 +54,6 @@ function EnzymeInterpreter(
world::UInt,
forward_rules::Bool,
reverse_rules::Bool,
deferred_lower::Bool = true,
broadcast_rewrite::Bool = true,
handler = nothing
)
Expand Down Expand Up @@ -83,7 +81,6 @@ function EnzymeInterpreter(
IdDict{Any, Bool}(),
forward_rules,
reverse_rules,
deferred_lower,
broadcast_rewrite,
handler
)
Expand All @@ -94,10 +91,9 @@ EnzymeInterpreter(
mt::Union{Nothing,Core.MethodTable},
world::UInt,
mode::API.CDerivativeMode,
deferred_lower::Bool = true,
broadcast_rewrite::Bool = true,
handler = nothing
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower, broadcast_rewrite, handler)
) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, broadcast_rewrite, handler)

Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params
Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params
Expand Down Expand Up @@ -865,25 +861,6 @@ function abstract_call_known(
end
end

if interp.deferred_lower && f === Enzyme.autodiff && length(argtypes) >= 4
if widenconst(argtypes[2]) <: Enzyme.Mode &&
widenconst(argtypes[3]) <: Enzyme.Annotation &&
widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation}
arginfo2 = ArgInfo(
fargs isa Nothing ? nothing :
[:(Enzyme.autodiff_deferred), fargs[2:end]...],
[Core.Const(Enzyme.autodiff_deferred), argtypes[2:end]...],
)
return Base.@invoke abstract_call_known(
interp::AbstractInterpreter,
Enzyme.autodiff_deferred::Any,
arginfo2::ArgInfo,
si::StmtInfo,
sv::AbsIntState,
max_methods::Int,
)
end
end
if interp.handler != nothing
return interp.handler(interp, f, arginfo, si, sv, max_methods)
end
Expand Down
Loading

0 comments on commit 3edec40

Please sign in to comment.