From 3edec409c4e43590320df0b02a3463a24638ae0e Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 7 Dec 2024 00:03:35 -0600 Subject: [PATCH] Fix higher order codegen (#2161) * Fix higher order codegen * fix * fix * working * Update validation.jl * handle, again * Update validation.jl --- src/compiler.jl | 26 +++- src/compiler/interpreter.jl | 25 +--- src/compiler/validation.jl | 231 ++++++------------------------------ src/llvm/transforms.jl | 188 +++++++++++++++++++++++++++++ src/rules/parallelrules.jl | 4 +- 5 files changed, 245 insertions(+), 229 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 0155e5da34..36d9c1473d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -5226,12 +5226,12 @@ end # JIT ## -function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType)) +function _link(@nospecialize(job::CompilerJob{<:EnzymeTarget}), mod::LLVM.Module, adjoint_name::String, @nospecialize(primal_name::Union{String, Nothing}), @nospecialize(TapeType), prepost::String) if job.config.params.ABI <: InlineABI return CompileResult( Val((Symbol(mod), Symbol(adjoint_name))), Val((Symbol(mod), Symbol(primal_name))), - TapeType, + TapeType ) end @@ -5269,7 +5269,7 @@ end const DumpPostOpt = Ref(false) # actual compilation -function _thunk(job, postopt::Bool = true) +function _thunk(job, postopt::Bool = true)::Tuple{LLVM.Module, String, Union{String, Nothing}, Type, String} mod, meta = codegen(:llvm, job; optimize = false) adjointf, augmented_primalf = meta.adjointf, meta.augmented_primalf @@ -5287,7 +5287,12 @@ function _thunk(job, postopt::Bool = true) end # Run post optimization pipeline - if postopt + prepost = if postopt + mstr = if job.config.params.ABI <: InlineABI + "" + else + string(mod) + end if job.config.params.ABI <: FFIABI || job.config.params.ABI <: NonGenABI post_optimze!(mod, JIT.get_tm()) if DumpPostOpt[] @@ -5296,12 +5301,17 @@ function _thunk(job, postopt::Bool = true) else propagate_returned!(mod) end + mstr + else + "" end - return (mod, adjoint_name, primal_name, meta.TapeType) + return (mod, adjoint_name, primal_name, meta.TapeType, prepost) end const cache = Dict{UInt,CompileResult}() +const autodiff_cache = Dict{Ptr{Cvoid},Tuple{String, String}}() + const cache_lock = ReentrantLock() @inline function cached_compilation(@nospecialize(job::CompilerJob))::CompileResult key = hash(job) @@ -5313,6 +5323,12 @@ const cache_lock = ReentrantLock() if obj === nothing asm = _thunk(job) obj = _link(job, asm...) + if obj.adjoint isa Ptr{Nothing} + autodiff_cache[obj.adjoint] = (asm[2], asm[5]) + end + if obj.primal isa Ptr{Nothing} && asm[3] isa String + autodiff_cache[obj.primal] = (asm[3], asm[5]) + end cache[key] = obj end obj diff --git a/src/compiler/interpreter.jl b/src/compiler/interpreter.jl index 2d02604eda..2f9d1fbf60 100644 --- a/src/compiler/interpreter.jl +++ b/src/compiler/interpreter.jl @@ -44,7 +44,6 @@ struct EnzymeInterpreter{T} <: AbstractInterpreter forward_rules::Bool reverse_rules::Bool - deferred_lower::Bool broadcast_rewrite::Bool handler::T end @@ -55,7 +54,6 @@ function EnzymeInterpreter( world::UInt, forward_rules::Bool, reverse_rules::Bool, - deferred_lower::Bool = true, broadcast_rewrite::Bool = true, handler = nothing ) @@ -83,7 +81,6 @@ function EnzymeInterpreter( IdDict{Any, Bool}(), forward_rules, reverse_rules, - deferred_lower, broadcast_rewrite, handler ) @@ -94,10 +91,9 @@ EnzymeInterpreter( mt::Union{Nothing,Core.MethodTable}, world::UInt, mode::API.CDerivativeMode, - deferred_lower::Bool = true, broadcast_rewrite::Bool = true, handler = nothing -) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, deferred_lower, broadcast_rewrite, handler) +) = EnzymeInterpreter(cache_or_token, mt, world, mode == API.DEM_ForwardMode, mode == API.DEM_ReverseModeCombined || mode == API.DEM_ReverseModePrimal || mode == API.DEM_ReverseModeGradient, broadcast_rewrite, handler) Core.Compiler.InferenceParams(@nospecialize(interp::EnzymeInterpreter)) = interp.inf_params Core.Compiler.OptimizationParams(@nospecialize(interp::EnzymeInterpreter)) = interp.opt_params @@ -865,25 +861,6 @@ function abstract_call_known( end end - if interp.deferred_lower && f === Enzyme.autodiff && length(argtypes) >= 4 - if widenconst(argtypes[2]) <: Enzyme.Mode && - widenconst(argtypes[3]) <: Enzyme.Annotation && - widenconst(argtypes[4]) <: Type{<:Enzyme.Annotation} - arginfo2 = ArgInfo( - fargs isa Nothing ? nothing : - [:(Enzyme.autodiff_deferred), fargs[2:end]...], - [Core.Const(Enzyme.autodiff_deferred), argtypes[2:end]...], - ) - return Base.@invoke abstract_call_known( - interp::AbstractInterpreter, - Enzyme.autodiff_deferred::Any, - arginfo2::ArgInfo, - si::StmtInfo, - sv::AbsIntState, - max_methods::Int, - ) - end - end if interp.handler != nothing return interp.handler(interp, f, arginfo, si, sv, max_methods) end diff --git a/src/compiler/validation.jl b/src/compiler/validation.jl index e109415d0f..525e4d874c 100644 --- a/src/compiler/validation.jl +++ b/src/compiler/validation.jl @@ -129,9 +129,7 @@ end function memoize!(ptr::Ptr{Cvoid}, fn::String)::String fn = get(ptr_map, ptr, fn) - if !haskey(ptr_map, ptr) - ptr_map[ptr] = fn - else + if haskey(ptr_map, ptr) @assert ptr_map[ptr] == fn end return fn @@ -185,194 +183,6 @@ function check_ir(@nospecialize(job::CompilerJob), mod::LLVM.Module) end end -# Rewrite calls with "jl_roots" to only have the jl_value_t attached and not { { {} addrspace(10)*, [1 x [2 x i64]], i64, i64 }, [2 x i64] } %unbox110183_replacementA -function rewrite_ccalls!(mod::LLVM.Module) - for f in collect(functions(mod)) - replaceAndErase = Tuple{Instruction,Instruction}[] - for bb in blocks(f), inst in instructions(bb) - if isa(inst, LLVM.CallInst) - fn = called_operand(inst) - changed = false - B = IRBuilder() - position!(B, inst) - if isa(fn, LLVM.Function) && LLVM.name(fn) == "llvm.julia.gc_preserve_begin" - uservals = LLVM.Value[] - for lval in collect(arguments(inst)) - llty = value_type(lval) - if isa(llty, LLVM.PointerType) - push!(uservals, lval) - continue - end - vals = get_julia_inner_types(B, nothing, lval) - for v in vals - if isa(v, LLVM.PointerNull) - subchanged = true - continue - end - push!(uservals, v) - end - if length(vals) == 1 && vals[1] == lval - continue - end - changed = true - end - if changed - prevname = LLVM.name(inst) - LLVM.name!(inst, "") - if !isdefined(LLVM, :OperandBundleDef) - newinst = call!( - B, - called_type(inst), - called_operand(inst), - uservals, - collect(operand_bundles(inst)), - prevname, - ) - else - newinst = call!( - B, - called_type(inst), - called_operand(inst), - uservals, - collect(map(LLVM.OperandBundleDef, operand_bundles(inst))), - prevname, - ) - end - for idx in [ - LLVM.API.LLVMAttributeFunctionIndex, - LLVM.API.LLVMAttributeReturnIndex, - [ - LLVM.API.LLVMAttributeIndex(i) for - i = 1:(length(arguments(inst))) - ]..., - ] - idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) - count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) - Attrs = Base.unsafe_convert( - Ptr{LLVM.API.LLVMAttributeRef}, - Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), - ) - LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) - for j = 1:count - LLVM.API.LLVMAddCallSiteAttribute( - newinst, - idx, - unsafe_load(Attrs, j), - ) - end - Libc.free(Attrs) - end - API.EnzymeCopyMetadata(newinst, inst) - callconv!(newinst, callconv(inst)) - push!(replaceAndErase, (inst, newinst)) - end - continue - end - if !isdefined(LLVM, :OperandBundleDef) - newbundles = OperandBundle[] - else - newbundles = OperandBundleDef[] - end - for bunduse in operand_bundles(inst) - if isdefined(LLVM, :OperandBundleDef) - bunduse = LLVM.OperandBundleDef(bunduse) - end - - if !isdefined(LLVM, :OperandBundleDef) - if LLVM.tag(bunduse) != "jl_roots" - push!(newbundles, bunduse) - continue - end - else - if LLVM.tag_name(bunduse) != "jl_roots" - push!(newbundles, bunduse) - continue - end - end - uservals = LLVM.Value[] - subchanged = false - for lval in LLVM.inputs(bunduse) - llty = value_type(lval) - if isa(llty, LLVM.PointerType) - push!(uservals, lval) - continue - end - vals = get_julia_inner_types(B, nothing, lval) - for v in vals - if isa(v, LLVM.PointerNull) - subchanged = true - continue - end - push!(uservals, v) - end - if length(vals) == 1 && vals[1] == lval - continue - end - subchanged = true - end - if !subchanged - push!(newbundles, bunduse) - continue - end - changed = true - if !isdefined(LLVM, :OperandBundleDef) - push!(newbundles, OperandBundle(LLVM.tag(bunduse), uservals)) - else - push!( - newbundles, - OperandBundleDef(LLVM.tag_name(bunduse), uservals), - ) - end - end - changed = false - if changed - prevname = LLVM.name(inst) - LLVM.name!(inst, "") - newinst = call!( - B, - called_type(inst), - called_operand(inst), - collect(arguments(inst)), - newbundles, - prevname, - ) - for idx in [ - LLVM.API.LLVMAttributeFunctionIndex, - LLVM.API.LLVMAttributeReturnIndex, - [ - LLVM.API.LLVMAttributeIndex(i) for - i = 1:(length(arguments(inst))) - ]..., - ] - idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) - count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) - Attrs = Base.unsafe_convert( - Ptr{LLVM.API.LLVMAttributeRef}, - Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), - ) - LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) - for j = 1:count - LLVM.API.LLVMAddCallSiteAttribute( - newinst, - idx, - unsafe_load(Attrs, j), - ) - end - Libc.free(Attrs) - end - API.EnzymeCopyMetadata(newinst, inst) - callconv!(newinst, callconv(inst)) - push!(replaceAndErase, (inst, newinst)) - end - end - end - for (inst, newinst) in replaceAndErase - replace_uses!(inst, newinst) - LLVM.API.LLVMInstructionEraseFromParent(inst) - end - end -end - function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod::LLVM.Module) imported = Set(String[]) if haskey(functions(mod), "malloc") @@ -390,14 +200,14 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod replace_uses!(f, LLVM.Value(LLVM.API.LLVMConstPointerCast(mfn, value_type(f)))) eraseInst(mod, f) end - rewrite_ccalls!(mod) + Compiler.rewrite_ccalls!(mod) del = LLVM.Function[] for f in collect(functions(mod)) if in(f, del) continue end - check_ir!(job, errors, imported, f, del) + check_ir!(job, errors, imported, f, del, mod) end for d in del LLVM.API.LLVMDeleteFunction(d) @@ -408,7 +218,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod if in(f, del) continue end - check_ir!(job, errors, imported, f, del) + check_ir!(job, errors, imported, f, del, mod) end for d in del LLVM.API.LLVMDeleteFunction(d) @@ -417,7 +227,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, mod return errors end -function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, f::LLVM.Function, deletedfns::Vector{LLVM.Function}) +function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, f::LLVM.Function, deletedfns::Vector{LLVM.Function}, mod::LLVM.Module) calls = LLVM.CallInst[] isInline = API.EnzymeGetCLBool(cglobal((:EnzymeInline, API.libEnzyme))) != 0 mod = LLVM.parent(f) @@ -643,7 +453,7 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp while length(calls) > 0 inst = pop!(calls) - check_ir!(job, errors, imported, inst, calls) + check_ir!(job, errors, imported, inst, calls, mod) end return errors end @@ -690,7 +500,7 @@ end import GPUCompiler: DYNAMIC_CALL, DELAYED_BINDING, RUNTIME_FUNCTION, UNKNOWN_FUNCTION, POINTER_FUNCTION import GPUCompiler: backtrace, isintrinsic -function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, inst::LLVM.CallInst, calls::Vector{LLVM.CallInst}) +function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imported::Set{String}, inst::LLVM.CallInst, calls::Vector{LLVM.CallInst}, mod::LLVM.Module) world = job.world interp = GPUCompiler.get_interpreter(job) method_table = Core.Compiler.method_table(interp) @@ -1211,13 +1021,36 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp ptr_val = convert(Int, ptr_arg) ptr = Ptr{Cvoid}(ptr_val) + if haskey(autodiff_cache, ptr) + pname, pmod = autodiff_cache[ptr] + + @assert !haskey(functions(mod), pname) + + pmod = parse(LLVM.Module, pmod) + + @assert haskey(functions(pmod), pname) + + for fn in functions(pmod) + if !isempty(LLVM.blocks(fn)) + linkage!(fn, LLVM.name(fn) != pname ? LLVM.API.LLVMInternalLinkage : LLVM.API.LLVMExternalLinkage) + end + end + + GPUCompiler.link_library!(mod, pmod) + + replaceWith = functions(mod)[pname] + push!(function_attributes(replaceWith), EnumAttribute("alwaysinline")) + linkage!(functions(mod)[pname], LLVM.API.LLVMInternalLinkage) + replace_uses!(ptr_arg, LLVM.const_pointercast(replaceWith, value_type(ptr_arg))) + return errors + end + # look it up in the Julia JIT cache frames = ccall(:jl_lookup_code_address, Any, (Ptr{Cvoid}, Cint), ptr, 0) if length(frames) >= 1 fn, file, line, linfo, fromC, inlined = last(frames) - # Remember pointer in our global map fn = FFI.memoize!(ptr, string(fn)) if length(fn) > 1 && fromC @@ -1229,6 +1062,8 @@ function check_ir!(@nospecialize(job::CompilerJob), errors::Vector{IRError}, imp fn, LLVM.API.LLVMGetCalledFunctionType(inst), ) + # Remember pointer for subsequent restoration + push!(function_attributes(LLVM.Function(lfn)), StringAttribute("enzymejl_needs_restoration", string(reinterpret(UInt, ptr)))) else lfn = LLVM.API.LLVMConstBitCast( lfn, diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl index aebb8bab5c..2f9c61c0b4 100644 --- a/src/llvm/transforms.jl +++ b/src/llvm/transforms.jl @@ -1,4 +1,192 @@ +# Rewrite calls with "jl_roots" to only have the jl_value_t attached and not { { {} addrspace(10)*, [1 x [2 x i64]], i64, i64 }, [2 x i64] } %unbox110183_replacementA +function rewrite_ccalls!(mod::LLVM.Module) + for f in collect(functions(mod)) + replaceAndErase = Tuple{Instruction,Instruction}[] + for bb in blocks(f), inst in instructions(bb) + if isa(inst, LLVM.CallInst) + fn = called_operand(inst) + changed = false + B = IRBuilder() + position!(B, inst) + if isa(fn, LLVM.Function) && LLVM.name(fn) == "llvm.julia.gc_preserve_begin" + uservals = LLVM.Value[] + for lval in collect(arguments(inst)) + llty = value_type(lval) + if isa(llty, LLVM.PointerType) + push!(uservals, lval) + continue + end + vals = get_julia_inner_types(B, nothing, lval) + for v in vals + if isa(v, LLVM.PointerNull) + subchanged = true + continue + end + push!(uservals, v) + end + if length(vals) == 1 && vals[1] == lval + continue + end + changed = true + end + if changed + prevname = LLVM.name(inst) + LLVM.name!(inst, "") + if !isdefined(LLVM, :OperandBundleDef) + newinst = call!( + B, + called_type(inst), + called_operand(inst), + uservals, + collect(operand_bundles(inst)), + prevname, + ) + else + newinst = call!( + B, + called_type(inst), + called_operand(inst), + uservals, + collect(map(LLVM.OperandBundleDef, operand_bundles(inst))), + prevname, + ) + end + for idx in [ + LLVM.API.LLVMAttributeFunctionIndex, + LLVM.API.LLVMAttributeReturnIndex, + [ + LLVM.API.LLVMAttributeIndex(i) for + i = 1:(length(arguments(inst))) + ]..., + ] + idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) + count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) + Attrs = Base.unsafe_convert( + Ptr{LLVM.API.LLVMAttributeRef}, + Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), + ) + LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) + for j = 1:count + LLVM.API.LLVMAddCallSiteAttribute( + newinst, + idx, + unsafe_load(Attrs, j), + ) + end + Libc.free(Attrs) + end + API.EnzymeCopyMetadata(newinst, inst) + callconv!(newinst, callconv(inst)) + push!(replaceAndErase, (inst, newinst)) + end + continue + end + if !isdefined(LLVM, :OperandBundleDef) + newbundles = OperandBundle[] + else + newbundles = OperandBundleDef[] + end + for bunduse in operand_bundles(inst) + if isdefined(LLVM, :OperandBundleDef) + bunduse = LLVM.OperandBundleDef(bunduse) + end + + if !isdefined(LLVM, :OperandBundleDef) + if LLVM.tag(bunduse) != "jl_roots" + push!(newbundles, bunduse) + continue + end + else + if LLVM.tag_name(bunduse) != "jl_roots" + push!(newbundles, bunduse) + continue + end + end + uservals = LLVM.Value[] + subchanged = false + for lval in LLVM.inputs(bunduse) + llty = value_type(lval) + if isa(llty, LLVM.PointerType) + push!(uservals, lval) + continue + end + vals = get_julia_inner_types(B, nothing, lval) + for v in vals + if isa(v, LLVM.PointerNull) + subchanged = true + continue + end + push!(uservals, v) + end + if length(vals) == 1 && vals[1] == lval + continue + end + subchanged = true + end + if !subchanged + push!(newbundles, bunduse) + continue + end + changed = true + if !isdefined(LLVM, :OperandBundleDef) + push!(newbundles, OperandBundle(LLVM.tag(bunduse), uservals)) + else + push!( + newbundles, + OperandBundleDef(LLVM.tag_name(bunduse), uservals), + ) + end + end + changed = false + if changed + prevname = LLVM.name(inst) + LLVM.name!(inst, "") + newinst = call!( + B, + called_type(inst), + called_operand(inst), + collect(arguments(inst)), + newbundles, + prevname, + ) + for idx in [ + LLVM.API.LLVMAttributeFunctionIndex, + LLVM.API.LLVMAttributeReturnIndex, + [ + LLVM.API.LLVMAttributeIndex(i) for + i = 1:(length(arguments(inst))) + ]..., + ] + idx = reinterpret(LLVM.API.LLVMAttributeIndex, idx) + count = LLVM.API.LLVMGetCallSiteAttributeCount(inst, idx) + Attrs = Base.unsafe_convert( + Ptr{LLVM.API.LLVMAttributeRef}, + Libc.malloc(sizeof(LLVM.API.LLVMAttributeRef) * count), + ) + LLVM.API.LLVMGetCallSiteAttributes(inst, idx, Attrs) + for j = 1:count + LLVM.API.LLVMAddCallSiteAttribute( + newinst, + idx, + unsafe_load(Attrs, j), + ) + end + Libc.free(Attrs) + end + API.EnzymeCopyMetadata(newinst, inst) + callconv!(newinst, callconv(inst)) + push!(replaceAndErase, (inst, newinst)) + end + end + end + for (inst, newinst) in replaceAndErase + replace_uses!(inst, newinst) + LLVM.API.LLVMInstructionEraseFromParent(inst) + end + end +end + function force_recompute!(mod::LLVM.Module) for f in functions(mod), bb in blocks(f) iter = LLVM.API.LLVMGetFirstInstruction(bb) diff --git a/src/rules/parallelrules.jl b/src/rules/parallelrules.jl index 78c9cd9ce8..d4356aba61 100644 --- a/src/rules/parallelrules.jl +++ b/src/rules/parallelrules.jl @@ -275,7 +275,7 @@ end world, ) - cmod, fwdmodenm, _, _ = _thunk(ejob, false) #=postopt=# + cmod, fwdmodenm, _, _, _ = _thunk(ejob, false) #=postopt=# LLVM.link!(mod, cmod) @@ -334,7 +334,7 @@ end world, ) - cmod, adjointnm, augfwdnm, TapeType = _thunk(ejob, false) #=postopt=# + cmod, adjointnm, augfwdnm, TapeType, _ = _thunk(ejob, false) #=postopt=# LLVM.link!(mod, cmod)