From bf2291b34920fb163b0bde0e2e102564fdd0517c Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Tue, 20 Feb 2024 16:25:07 -0500 Subject: [PATCH] TA pass julia type info via param attr --- src/compiler.jl | 61 +++++++++++++++++------- src/rules/typerules.jl | 105 ----------------------------------------- 2 files changed, 45 insertions(+), 121 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 8405ccbdec..6f248ded4d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2902,7 +2902,7 @@ include("rules/activityrules.jl") @inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: DuplicatedNoNeed = API.DFT_DUP_NONEED @inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: BatchDuplicatedNoNeed = API.DFT_DUP_NONEED -function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wrap, modifiedBetween, returnPrimal, jlrules,expectedTapeType, loweredArgs, boxedArgs) +function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wrap, modifiedBetween, returnPrimal, expectedTapeType, loweredArgs, boxedArgs) world = job.world interp = GPUCompiler.get_interpreter(job) rt = job.config.params.rt @@ -3051,11 +3051,6 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)), ) - for jl in jlrules - rules[jl] = @cfunction(julia_type_rule, - UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)) - end logic = Logic() TA = TypeAnalysis(logic, rules) @@ -4491,21 +4486,62 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; jlargs = classify_arguments(mi.specTypes, function_type(f), sret !== nothing, returnRoots !== nothing, swiftself, parmsRemoved) + ctx = LLVM.context(f) + for arg in jlargs if arg.cc == GPUCompiler.GHOST || arg.cc == RemovedParam continue end push!(parameter_attributes(f, arg.codegen.i), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ))))) push!(parameter_attributes(f, arg.codegen.i), StringAttribute("enzymejl_parmtype_ref", string(UInt(arg.cc)))) + + byref = arg.cc + + rest = typetree(arg.typ, ctx, dl) + + if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF + # adjust first path to size of type since if arg.typ is {[-1]:Int}, that doesn't mean the broader + # object passing this in by ref isnt a {[-1]:Pointer, [-1,-1]:Int} + # aka the next field after this in the bigger object isn't guaranteed to also be the same. + if allocatedinline(arg.typ) + shift!(rest, dl, 0, sizeof(arg.typ), 0) + end + merge!(rest, TypeTree(API.DT_Pointer, ctx)) + only!(rest, -1) + else + # canonicalize wrt size + end + push!(parameter_attributes(f, arg.codegen.i), StringAttribute("enzyme_type", string(rest))) end + + if sret !== nothing + idx = 0 + if !in(0, parmsRemoved) + rest = typetree(sret, ctx, dl) + push!(parameter_attributes(f, idx+1), StringAttribute("enzyme_type", string(rest))) + idx+=1 + end + if returnRoots !== nothing + if !in(1, parmsRemoved) + rest = TypeTree(API.DT_Pointer, -1, ctx) + push!(parameter_attributes(f, idx+1), StringAttribute("enzyme_type", string(rest))) + end + end + end + + if llRT !== nothing && LLVM.return_type(LLVM.function_type(f)) != LLVM.VoidType() + @assert !retRemoved + rest = typetree(llRT, ctx, dl) + push!(return_attributes(f), StringAttribute("enzyme_type", string(rest))) + end + + push!(function_attributes(f), StringAttribute("enzyme_ta_norecur")) end custom = Dict{String, LLVM.API.LLVMLinkage}() must_wrap = false - foundTys = Dict{String, Tuple{LLVM.FunctionType, Core.MethodInstance}}() - world = job.world interp = GPUCompiler.get_interpreter(job) method_table = Core.Compiler.method_table(interp) @@ -4566,7 +4602,6 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; end julia_activity_rule(llvmfn) - foundTys[k_name] = (LLVM.function_type(llvmfn), mi) if has_custom_rule handleCustom("enzyme_custom", [StringAttribute("enzyme_preserve_primal", "*")]) continue @@ -4814,15 +4849,9 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; if params.run_enzyme # Generate the adjoint - jlrules = String["enzyme_custom"] - for (fname, (ftyp, mi)) in foundTys - haskey(functions(mod), fname) || continue - push!(jlrules, fname) - end - memcpy_alloca_to_loadstore(mod) - adjointf, augmented_primalf, TapeType = enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, abiwrap, modifiedBetween, returnPrimal, jlrules, expectedTapeType, loweredArgs, boxedArgs) + adjointf, augmented_primalf, TapeType = enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, abiwrap, modifiedBetween, returnPrimal, expectedTapeType, loweredArgs, boxedArgs) toremove = [] # Inline the wrapper for f in functions(mod) diff --git a/src/rules/typerules.jl b/src/rules/typerules.jl index 2c3e51e2ff..cc5cea622e 100644 --- a/src/rules/typerules.jl +++ b/src/rules/typerules.jl @@ -116,108 +116,3 @@ function alloc_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeT end return UInt8(false) end - -function julia_type_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8 - inst = LLVM.Instruction(val) - ctx = LLVM.context(inst) - - mi, RT = enzyme_custom_extract_mi(inst) - - ops = collect(operands(inst))[1:end-1] - called = LLVM.called_operand(inst) - - - llRT, sret, returnRoots = get_return_info(RT) - retRemoved, parmsRemoved = removed_ret_parms(inst) - - dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - - - expectLen = (sret !== nothing) + (returnRoots !== nothing) - for source_typ in mi.specTypes.parameters - if isghostty(source_typ) || Core.Compiler.isconstType(source_typ) - continue - end - expectLen+=1 - end - expectLen -= length(parmsRemoved) - - # TODO fix the attributor inlining such that this can assert always true - if expectLen == length(ops) - - f = LLVM.called_operand(inst) - swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(f, i)))) for i in 1:length(collect(parameters(f)))) - jlargs = classify_arguments(mi.specTypes, called_type(inst), sret !== nothing, returnRoots !== nothing, swiftself, parmsRemoved) - - - for arg in jlargs - if arg.cc == GPUCompiler.GHOST || arg.cc == RemovedParam - continue - end - - typ, byref = enzyme_extract_parm_type(f, arg.codegen.i) - @assert typ == arg.typ - - op_idx = arg.codegen.i - rest = typetree(arg.typ, ctx, dl) - @assert arg.cc == byref - if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF - # adjust first path to size of type since if arg.typ is {[-1]:Int}, that doesn't mean the broader - # object passing this in by ref isnt a {[-1]:Pointer, [-1,-1]:Int} - # aka the next field after this in the bigger object isn't guaranteed to also be the same. - if allocatedinline(arg.typ) - shift!(rest, dl, 0, sizeof(arg.typ), 0) - end - merge!(rest, TypeTree(API.DT_Pointer, ctx)) - only!(rest, -1) - else - # canonicalize wrt size - end - PTT = unsafe_load(args, op_idx) - changed, legal = API.EnzymeCheckedMergeTypeTree(PTT, rest) - if !legal - function c(io) - println(io, "Illegal type analysis update from julia rule of method ", mi) - println(io, "Found type ", arg.typ, " at index ", arg.codegen.i, " of ", string(rest)) - t = API.EnzymeTypeTreeToString(PTT) - println(io, "Prior type ", Base.unsafe_string(t)) - println(io, inst) - API.EnzymeStringFree(t) - end - msg = sprint(c) - - bt = GPUCompiler.backtrace(inst) - ir = sprint(io->show(io, parent_scope(inst))) - - sval = "" - # data = API.EnzymeTypeAnalyzerRef(data) - # ip = API.EnzymeTypeAnalyzerToString(data) - # sval = Base.unsafe_string(ip) - # API.EnzymeStringFree(ip) - throw(IllegalTypeAnalysisException(msg, sval, ir, bt)) - end - end - - if sret !== nothing - idx = 0 - if !in(0, parmsRemoved) - API.EnzymeMergeTypeTree(unsafe_load(args, idx+1), typetree(sret, ctx, dl)) - idx+=1 - end - if returnRoots !== nothing - if !in(1, parmsRemoved) - allpointer = TypeTree(API.DT_Pointer, -1, ctx) - API.EnzymeMergeTypeTree(unsafe_load(args, idx+1), typetree(returnRoots, ctx, dl)) - end - end - end - - end - - if llRT !== nothing && value_type(inst) != LLVM.VoidType() - @assert !retRemoved - API.EnzymeMergeTypeTree(ret, typetree(llRT, ctx, dl)) - end - - return UInt8(false) -end