Skip to content

Commit

Permalink
TA pass julia type info via param attr
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 20, 2024
1 parent 778deec commit bf2291b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 121 deletions.
61 changes: 45 additions & 16 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2902,7 +2902,7 @@ include("rules/activityrules.jl")
@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: DuplicatedNoNeed = API.DFT_DUP_NONEED
@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: BatchDuplicatedNoNeed = API.DFT_DUP_NONEED

function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wrap, modifiedBetween, returnPrimal, jlrules,expectedTapeType, loweredArgs, boxedArgs)
function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wrap, modifiedBetween, returnPrimal, expectedTapeType, loweredArgs, boxedArgs)
world = job.world
interp = GPUCompiler.get_interpreter(job)
rt = job.config.params.rt
Expand Down Expand Up @@ -3051,11 +3051,6 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr
UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef},
Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef)),
)
for jl in jlrules
rules[jl] = @cfunction(julia_type_rule,
UInt8, (Cint, API.CTypeTreeRef, Ptr{API.CTypeTreeRef},
Ptr{API.IntList}, Csize_t, LLVM.API.LLVMValueRef))
end

logic = Logic()
TA = TypeAnalysis(logic, rules)
Expand Down Expand Up @@ -4491,21 +4486,62 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};

jlargs = classify_arguments(mi.specTypes, function_type(f), sret !== nothing, returnRoots !== nothing, swiftself, parmsRemoved)

ctx = LLVM.context(f)

for arg in jlargs
if arg.cc == GPUCompiler.GHOST || arg.cc == RemovedParam
continue
end
push!(parameter_attributes(f, arg.codegen.i), StringAttribute("enzymejl_parmtype", string(convert(UInt, unsafe_to_pointer(arg.typ)))))
push!(parameter_attributes(f, arg.codegen.i), StringAttribute("enzymejl_parmtype_ref", string(UInt(arg.cc))))

byref = arg.cc

rest = typetree(arg.typ, ctx, dl)

if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF
# adjust first path to size of type since if arg.typ is {[-1]:Int}, that doesn't mean the broader
# object passing this in by ref isnt a {[-1]:Pointer, [-1,-1]:Int}
# aka the next field after this in the bigger object isn't guaranteed to also be the same.
if allocatedinline(arg.typ)
shift!(rest, dl, 0, sizeof(arg.typ), 0)
end
merge!(rest, TypeTree(API.DT_Pointer, ctx))
only!(rest, -1)
else
# canonicalize wrt size
end
push!(parameter_attributes(f, arg.codegen.i), StringAttribute("enzyme_type", string(rest)))
end

if sret !== nothing
idx = 0
if !in(0, parmsRemoved)
rest = typetree(sret, ctx, dl)
push!(parameter_attributes(f, idx+1), StringAttribute("enzyme_type", string(rest)))
idx+=1
end
if returnRoots !== nothing
if !in(1, parmsRemoved)
rest = TypeTree(API.DT_Pointer, -1, ctx)
push!(parameter_attributes(f, idx+1), StringAttribute("enzyme_type", string(rest)))
end
end
end

if llRT !== nothing && LLVM.return_type(LLVM.function_type(f)) != LLVM.VoidType()
@assert !retRemoved
rest = typetree(llRT, ctx, dl)
push!(return_attributes(f), StringAttribute("enzyme_type", string(rest)))
end

push!(function_attributes(f), StringAttribute("enzyme_ta_norecur"))
end


custom = Dict{String, LLVM.API.LLVMLinkage}()
must_wrap = false

foundTys = Dict{String, Tuple{LLVM.FunctionType, Core.MethodInstance}}()

world = job.world
interp = GPUCompiler.get_interpreter(job)
method_table = Core.Compiler.method_table(interp)
Expand Down Expand Up @@ -4566,7 +4602,6 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
end

julia_activity_rule(llvmfn)
foundTys[k_name] = (LLVM.function_type(llvmfn), mi)
if has_custom_rule
handleCustom("enzyme_custom", [StringAttribute("enzyme_preserve_primal", "*")])
continue
Expand Down Expand Up @@ -4814,15 +4849,9 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};

if params.run_enzyme
# Generate the adjoint
jlrules = String["enzyme_custom"]
for (fname, (ftyp, mi)) in foundTys
haskey(functions(mod), fname) || continue
push!(jlrules, fname)
end

memcpy_alloca_to_loadstore(mod)

adjointf, augmented_primalf, TapeType = enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, abiwrap, modifiedBetween, returnPrimal, jlrules, expectedTapeType, loweredArgs, boxedArgs)
adjointf, augmented_primalf, TapeType = enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, abiwrap, modifiedBetween, returnPrimal, expectedTapeType, loweredArgs, boxedArgs)
toremove = []
# Inline the wrapper
for f in functions(mod)
Expand Down
105 changes: 0 additions & 105 deletions src/rules/typerules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,108 +116,3 @@ function alloc_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeT
end
return UInt8(false)
end

function julia_type_rule(direction::Cint, ret::API.CTypeTreeRef, args::Ptr{API.CTypeTreeRef}, known_values::Ptr{API.IntList}, numArgs::Csize_t, val::LLVM.API.LLVMValueRef)::UInt8
inst = LLVM.Instruction(val)
ctx = LLVM.context(inst)

mi, RT = enzyme_custom_extract_mi(inst)

ops = collect(operands(inst))[1:end-1]
called = LLVM.called_operand(inst)


llRT, sret, returnRoots = get_return_info(RT)
retRemoved, parmsRemoved = removed_ret_parms(inst)

dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst)))))


expectLen = (sret !== nothing) + (returnRoots !== nothing)
for source_typ in mi.specTypes.parameters
if isghostty(source_typ) || Core.Compiler.isconstType(source_typ)
continue
end
expectLen+=1
end
expectLen -= length(parmsRemoved)

# TODO fix the attributor inlining such that this can assert always true
if expectLen == length(ops)

f = LLVM.called_operand(inst)
swiftself = any(any(map(k->kind(k)==kind(EnumAttribute("swiftself")), collect(parameter_attributes(f, i)))) for i in 1:length(collect(parameters(f))))
jlargs = classify_arguments(mi.specTypes, called_type(inst), sret !== nothing, returnRoots !== nothing, swiftself, parmsRemoved)


for arg in jlargs
if arg.cc == GPUCompiler.GHOST || arg.cc == RemovedParam
continue
end

typ, byref = enzyme_extract_parm_type(f, arg.codegen.i)
@assert typ == arg.typ

op_idx = arg.codegen.i
rest = typetree(arg.typ, ctx, dl)
@assert arg.cc == byref
if byref == GPUCompiler.BITS_REF || byref == GPUCompiler.MUT_REF
# adjust first path to size of type since if arg.typ is {[-1]:Int}, that doesn't mean the broader
# object passing this in by ref isnt a {[-1]:Pointer, [-1,-1]:Int}
# aka the next field after this in the bigger object isn't guaranteed to also be the same.
if allocatedinline(arg.typ)
shift!(rest, dl, 0, sizeof(arg.typ), 0)
end
merge!(rest, TypeTree(API.DT_Pointer, ctx))
only!(rest, -1)
else
# canonicalize wrt size
end
PTT = unsafe_load(args, op_idx)
changed, legal = API.EnzymeCheckedMergeTypeTree(PTT, rest)
if !legal
function c(io)
println(io, "Illegal type analysis update from julia rule of method ", mi)
println(io, "Found type ", arg.typ, " at index ", arg.codegen.i, " of ", string(rest))
t = API.EnzymeTypeTreeToString(PTT)
println(io, "Prior type ", Base.unsafe_string(t))
println(io, inst)
API.EnzymeStringFree(t)
end
msg = sprint(c)

bt = GPUCompiler.backtrace(inst)
ir = sprint(io->show(io, parent_scope(inst)))

sval = ""
# data = API.EnzymeTypeAnalyzerRef(data)
# ip = API.EnzymeTypeAnalyzerToString(data)
# sval = Base.unsafe_string(ip)
# API.EnzymeStringFree(ip)
throw(IllegalTypeAnalysisException(msg, sval, ir, bt))
end
end

if sret !== nothing
idx = 0
if !in(0, parmsRemoved)
API.EnzymeMergeTypeTree(unsafe_load(args, idx+1), typetree(sret, ctx, dl))
idx+=1
end
if returnRoots !== nothing
if !in(1, parmsRemoved)
allpointer = TypeTree(API.DT_Pointer, -1, ctx)
API.EnzymeMergeTypeTree(unsafe_load(args, idx+1), typetree(returnRoots, ctx, dl))
end
end
end

end

if llRT !== nothing && value_type(inst) != LLVM.VoidType()
@assert !retRemoved
API.EnzymeMergeTypeTree(ret, typetree(llRT, ctx, dl))
end

return UInt8(false)
end

0 comments on commit bf2291b

Please sign in to comment.