Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
vchuravy and github-actions[bot] authored Mar 19, 2024
1 parent 3d82861 commit 2f92152
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
9 changes: 6 additions & 3 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3063,7 +3063,8 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr
logic = Logic()
TA = TypeAnalysis(logic, rules)

retT = (!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? Ptr{actualRetType} : actualRetType
retT = (!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ?
Ptr{actualRetType} : actualRetType
retTT = typetree(retT, ctx, dl, seen)

typeInfo = FnTypeInfo(retTT, args_typeInfo, args_known_values)
Expand Down Expand Up @@ -4105,7 +4106,8 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
if RetActivity <: Const
metadata(sretPtr)["enzyme_inactive"] = MDNode(LLVM.Metadata[])
end
metadata(sretPtr)["enzyme_type"] = to_md(typetree(Ptr{actualRetType}, ctx, dl, seen), ctx)
metadata(sretPtr)["enzyme_type"] = to_md(typetree(Ptr{actualRetType}, ctx,
dl, seen), ctx)
push!(wrapper_args, sretPtr)
end
if returnRoots && !in(1, parmsRemoved)
Expand Down Expand Up @@ -4133,7 +4135,8 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function
metadata(ptr)["enzyme_inactive"] = MDNode(LLVM.Metadata[])
end
ctx = LLVM.context(entry_f)
metadata(ptr)["enzyme_type"] = to_md(typetree(Ptr{arg.typ}, ctx, dl, seen), ctx)
metadata(ptr)["enzyme_type"] = to_md(typetree(Ptr{arg.typ}, ctx, dl, seen),
ctx)
if LLVM.addrspace(ty) != 0
ptr = addrspacecast!(builder, ptr, ty)
end
Expand Down
2 changes: 1 addition & 1 deletion src/rules/customrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils,

idx = 0
dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(orig)))))
Tys2 = (eltype(A) for A in activity[2+isKWCall:end] if A <: Active)
Tys2 = (eltype(A) for A in activity[(2 + isKWCall):end] if A <: Active)
seen = TypeTreeTable()
for (v, Ty) in zip(actives, Tys2)
TT = typetree(Ty, ctx, dl, seen)
Expand Down

0 comments on commit 2f92152

Please sign in to comment.