From 5a9c70ce7567ea7c375d5ea5465523de8d7e224b Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 20 Mar 2024 15:26:00 -0700 Subject: [PATCH] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/compiler.jl | 124 +++++++++++++++++++++++++++++++----------------- 1 file changed, 81 insertions(+), 43 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 6f248ded4d..b2a22bf0ea 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -2905,7 +2905,7 @@ include("rules/activityrules.jl") function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wrap, modifiedBetween, returnPrimal, expectedTapeType, loweredArgs, boxedArgs) world = job.world interp = GPUCompiler.get_interpreter(job) - rt = job.config.params.rt + rt = job.config.params.rt shadow_init = job.config.params.shadowInit ctx = context(mod) dl = string(LLVM.datalayout(mod)) @@ -4484,62 +4484,85 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; throw(AssertionError("Wrong number of parameters $(string(f)) expectLen=$expectLen swiftself=$swiftself sret=$sret returnRoots=$returnRoots spec=$(mi.specTypes.parameters) retRem=$retRemoved parmsRem=$parmsRemoved")) end - jlargs = classify_arguments(mi.specTypes, function_type(f), sret !== nothing, returnRoots !== nothing, swiftself, parmsRemoved) + jlargs = classify_arguments( + mi.specTypes, + function_type(f), + sret !== nothing, + returnRoots !== nothing, + swiftself, + parmsRemoved, + ) - ctx = LLVM.context(f) + ctx = LLVM.context(f) for arg in jlargs if arg.cc == GPUCompiler.GHOST || arg.cc == RemovedParam continue end - push!(parameter_attributes(f, arg.codegen.i), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ))))) - push!(parameter_attributes(f, arg.codegen.i), StringAttribute("enzymejl_parmtype_ref", string(UInt(arg.cc)))) + push!( + parameter_attributes(f, arg.codegen.i), + StringAttribute( + "enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ))) + ), + ) + push!( + parameter_attributes(f, arg.codegen.i), + StringAttribute("enzymejl_parmtype_ref", string(UInt(arg.cc))), + ) - byref = arg.cc + byref = arg.cc rest = typetree(arg.typ, ctx, dl) - if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF - # adjust first path to size of type since if arg.typ is {[-1]:Int}, that doesn't mean the broader - # object passing this in by ref isnt a {[-1]:Pointer, [-1,-1]:Int} - # aka the next field after this in the bigger object isn't guaranteed to also be the same. - if allocatedinline(arg.typ) - shift!(rest, dl, 0, sizeof(arg.typ), 0) - end - merge!(rest, TypeTree(API.DT_Pointer, ctx)) - only!(rest, -1) - else - # canonicalize wrt size - end - push!(parameter_attributes(f, arg.codegen.i), StringAttribute("enzyme_type", string(rest))) + if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF + # adjust first path to size of type since if arg.typ is {[-1]:Int}, that doesn't mean the broader + # object passing this in by ref isnt a {[-1]:Pointer, [-1,-1]:Int} + # aka the next field after this in the bigger object isn't guaranteed to also be the same. + if allocatedinline(arg.typ) + shift!(rest, dl, 0, sizeof(arg.typ), 0) + end + merge!(rest, TypeTree(API.DT_Pointer, ctx)) + only!(rest, -1) + else + # canonicalize wrt size + end + push!( + parameter_attributes(f, arg.codegen.i), + StringAttribute("enzyme_type", string(rest)), + ) + end + + if sret !== nothing + idx = 0 + if !in(0, parmsRemoved) + rest = typetree(sret, ctx, dl) + push!( + parameter_attributes(f, idx + 1), + StringAttribute("enzyme_type", string(rest)), + ) + idx += 1 + end + if returnRoots !== nothing + if !in(1, parmsRemoved) + rest = TypeTree(API.DT_Pointer, -1, ctx) + push!( + parameter_attributes(f, idx + 1), + StringAttribute("enzyme_type", string(rest)), + ) + end + end + end + + if llRT !== nothing && LLVM.return_type(LLVM.function_type(f)) != LLVM.VoidType() + @assert !retRemoved + rest = typetree(llRT, ctx, dl) + push!(return_attributes(f), StringAttribute("enzyme_type", string(rest))) end - if sret !== nothing - idx = 0 - if !in(0, parmsRemoved) - rest = typetree(sret, ctx, dl) - push!(parameter_attributes(f, idx+1), StringAttribute("enzyme_type", string(rest))) - idx+=1 - end - if returnRoots !== nothing - if !in(1, parmsRemoved) - rest = TypeTree(API.DT_Pointer, -1, ctx) - push!(parameter_attributes(f, idx+1), StringAttribute("enzyme_type", string(rest))) - end - end - end - - if llRT !== nothing && LLVM.return_type(LLVM.function_type(f)) != LLVM.VoidType() - @assert !retRemoved - rest = typetree(llRT, ctx, dl) - push!(return_attributes(f), StringAttribute("enzyme_type", string(rest))) - end - push!(function_attributes(f), StringAttribute("enzyme_ta_norecur")) end - - custom = Dict{String, LLVM.API.LLVMLinkage}() + custom = Dict{String,LLVM.API.LLVMLinkage}() must_wrap = false world = job.world @@ -4851,7 +4874,22 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget}; # Generate the adjoint memcpy_alloca_to_loadstore(mod) - adjointf, augmented_primalf, TapeType = enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, abiwrap, modifiedBetween, returnPrimal, expectedTapeType, loweredArgs, boxedArgs) + adjointf, augmented_primalf, TapeType = enzyme!( + job, + mod, + primalf, + TT, + mode, + width, + parallel, + actualRetType, + abiwrap, + modifiedBetween, + returnPrimal, + expectedTapeType, + loweredArgs, + boxedArgs, + ) toremove = [] # Inline the wrapper for f in functions(mod)