Skip to content

Commit

Permalink
final fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 15, 2024
1 parent 36632e4 commit 258c5ed
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 43 deletions.
64 changes: 48 additions & 16 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ function prepare_llvm(mod::LLVM.Module, job, meta)
end
end

const mod_to_edges = Dict{LLVM.Module, Vector{Any}}()

function nested_codegen!(
mode::API.CDerivativeMode,
mod::LLVM.Module,
Expand Down Expand Up @@ -390,6 +392,11 @@ function nested_codegen!(
permit_inlining!(f)
end

edges = get(mod_to_edges, mod, nothing)
@assert edges !== nothing
edges = edges::Vector{Any}
push!(edges, funcspec)

# Apply first stage of optimization's so that this module is at the same stage as `mod`
optimize!(otherMod, JIT.get_tm())
# 4) Link the corresponding module
Expand Down Expand Up @@ -1193,16 +1200,26 @@ if VERSION >= v"1.11.0-DEV.1552"
always_inline::Any
method_table::Core.MethodTable
param_type::Type
is_fwd::Bool
last_fwd_rule_world::Union{Nothing, Tuple}
last_rev_rule_world::Union{Nothing, Tuple}
last_ina_rule_world::Tuple
end

@inline EnzymeCacheToken(target_type, always_inline, method_table, param_type, is_forward, is_reverse) =
EnzymeCacheToken(target_type, always_inline, method_table, param_type,
is_forward ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.forward, Tuple{<:FwdConfig, <:Annotation, Type{<:Annotation}, Vararg{Annotation}}, world)...,) : nothing,
is_reverse ? (Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.augmented_primal, Tuple{<:RevConfig, <:Annotation, Type{<:Annotation}, Vararg{Annotation}}, world)...,) : nothing,
(Enzyme.Compiler.Interpreter.get_rule_signatures(EnzymeRules.inactive, Tuple{Vararg{Any}}, world)...,)
)

GPUCompiler.ci_cache_token(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) =
EnzymeCacheToken(
typeof(job.config.target),
job.config.always_inline,
GPUCompiler.method_table(job),
typeof(job.config.params),
job.config.params.mode == API.DEM_ForwardMode,
job.config.params.mode != API.DEM_ForwardMode
)

GPUCompiler.get_interpreter(job::CompilerJob{<:Any,<:AbstractEnzymeCompilerParams}) =
Expand Down Expand Up @@ -3234,6 +3251,7 @@ function GPUCompiler.codegen(
if params.run_enzyme
# @assert eltype(params.rt) != Union{}
end

expectedTapeType = params.expectedTapeType
mode = params.mode
TT = params.TT
Expand Down Expand Up @@ -3275,6 +3293,8 @@ function GPUCompiler.codegen(

GPUCompiler.prepare_job!(primal_job)
mod, meta = GPUCompiler.emit_llvm(primal_job; libraries=true, toplevel=toplevel, optimize=false, cleanup=false, only_entry=false, validate=false)
edges = Any[]
mod_to_edges[mod] = edges

prepare_llvm(mod, primal_job, meta)
for f in functions(mod)
Expand Down Expand Up @@ -3556,13 +3576,13 @@ function GPUCompiler.codegen(
caller = mi
if mode == API.DEM_ForwardMode
has_custom_rule =
EnzymeRules.has_frule_from_sig(specTypes; world, method_table, caller)
EnzymeRules.has_frule_from_sig(specTypes; world, method_table) #, caller)
if has_custom_rule
@safe_debug "Found frule for" mi.specTypes
end
else
has_custom_rule =
EnzymeRules.has_rrule_from_sig(specTypes; world, method_table, caller)
EnzymeRules.has_rrule_from_sig(specTypes; world, method_table) # , caller)
if has_custom_rule
@safe_debug "Found rrule for" mi.specTypes
end
Expand All @@ -3577,7 +3597,7 @@ function GPUCompiler.codegen(
actualRetType = k.ci.rettype
end

if EnzymeRules.noalias_from_sig(mi.specTypes; world, method_table, caller)
if EnzymeRules.noalias_from_sig(mi.specTypes; world, method_table) #, caller)
push!(return_attributes(llvmfn), EnumAttribute("noalias"))
for u in LLVM.uses(llvmfn)
c = LLVM.user(u)
Expand Down Expand Up @@ -3801,7 +3821,7 @@ end
end
continue
end
if EnzymeRules.is_inactive_from_sig(specTypes; world, method_table, caller) &&
if EnzymeRules.is_inactive_from_sig(specTypes; world, method_table) && #, caller) &&
Enzyme.has_method(
Tuple{typeof(EnzymeRules.inactive),specTypes.parameters...},
world,
Expand All @@ -3819,7 +3839,7 @@ end
)
continue
end
if EnzymeRules.is_inactive_noinl_from_sig(specTypes; world, method_table, caller) &&
if EnzymeRules.is_inactive_noinl_from_sig(specTypes; world, method_table) # , caller) &&
has_method(
Tuple{typeof(EnzymeRules.inactive_noinl),specTypes.parameters...},
world,
Expand Down Expand Up @@ -4588,17 +4608,20 @@ end
isempty(LLVM.blocks(fn)) && continue
linkage!(fn, LLVM.API.LLVMLinkerPrivateLinkage)
end

delete!(mod_to_edges, mod)

use_primal = mode == API.DEM_ReverseModePrimal
entry = use_primal ? augmented_primalf : adjointf
return mod, (; adjointf, augmented_primalf, entry, compiled = meta.compiled, TapeType)
return mod, (; adjointf, augmented_primalf, entry, compiled = meta.compiled, TapeType, edges)
end

# Compiler result
struct CompileResult{AT,PT}
adjoint::AT
primal::PT
TapeType::Type
edges::Vector{Any}
end

@inline (thunk::PrimalErrorThunk{PT,FA,RT,TT,Width,ReturnPrimal})(
Expand Down Expand Up @@ -5224,12 +5247,13 @@ end
# JIT
##

function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType), prepost::String)
function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, edges::Vector{Any}, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType), prepost::String)
if job.config.params.ABI <: InlineABI
return CompileResult(
Val((Symbol(mod), Symbol(adjoint_name))),
Val((Symbol(mod), Symbol(primal_name))),
TapeType
TapeType,
edges
)
end

Expand Down Expand Up @@ -5261,16 +5285,17 @@ function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module
end
end

return CompileResult(adjoint_ptr, primal_ptr, TapeType)
return CompileResult(adjoint_ptr, primal_ptr, TapeType, edges)
end

const DumpPostOpt = Ref(false)

# actual compilation
function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, String, Union{String, Nothing}, Type, String}
function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, Vector{Any}, String, Union{String, Nothing}, Type, String}
mod, meta = codegen(:llvm, job; optimize = false)
adjointf, augmented_primalf = meta.adjointf, meta.augmented_primalf


adjoint_name = name(adjointf)

if augmented_primalf !== nothing
Expand Down Expand Up @@ -5303,7 +5328,7 @@ function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, String, Union{Str
else
""
end
return (mod, adjoint_name, primal_name, meta.TapeType, prepost)
return (mod, meta.edges, adjoint_name, primal_name, meta.TapeType, prepost)
end

const cache = Dict{UInt,CompileResult}()
Expand All @@ -5322,10 +5347,10 @@ const cache_lock = ReentrantLock()
asm = _thunk(job)
obj = _link(job, asm...)
if obj.adjoint isa Ptr{Nothing}
autodiff_cache[obj.adjoint] = (asm[2], asm[5])
autodiff_cache[obj.adjoint] = (asm[3], asm[6])
end
if obj.primal isa Ptr{Nothing} && asm[3] isa String
autodiff_cache[obj.primal] = (asm[3], asm[5])
if obj.primal isa Ptr{Nothing} && asm[4] isa String
autodiff_cache[obj.primal] = (asm[4], asm[6])
end
cache[key] = obj
end
Expand All @@ -5349,7 +5374,8 @@ end
@nospecialize(ABI::Type),
ErrIfFuncWritten::Bool,
RuntimeActivity::Bool,
)
edges::Union{Nothing, Vector{Any}}
)
target = Compiler.EnzymeTarget()
params = Compiler.EnzymeCompilerParams(
Tuple{FA,TT.parameters...},
Expand Down Expand Up @@ -5430,6 +5456,11 @@ end


compile_result = cached_compilation(job)
if edges !== nothing
for e in compile_result.edges
push!(edges, e)
end
end
if !run_enzyme
ErrT = PrimalErrorThunk{typeof(compile_result.adjoint),FA,rt2,TT,width,ReturnPrimal}
if Mode == API.DEM_ReverseModePrimal || Mode == API.DEM_ReverseModeGradient
Expand Down Expand Up @@ -5622,6 +5653,7 @@ function thunk_generator(world::UInt, source::LineNumberNode, @nospecialize(FA::
ABI,
ErrIfFuncWritten,
RuntimeActivity,
edges
)
finally
deactivate(ctx)
Expand Down
Loading

0 comments on commit 258c5ed

Please sign in to comment.