diff --git a/src/compiler.jl b/src/compiler.jl index ff127abf23a..30f9c29e984 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3063,7 +3063,8 @@ function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wr logic = Logic() TA = TypeAnalysis(logic, rules) - retT = (!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? Ptr{actualRetType} : actualRetType + retT = (!isa(actualRetType, Union) && GPUCompiler.deserves_retbox(actualRetType)) ? + Ptr{actualRetType} : actualRetType retTT = typetree(retT, ctx, dl, seen) typeInfo = FnTypeInfo(retTT, args_typeInfo, args_known_values) @@ -4105,7 +4106,8 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function if RetActivity <: Const metadata(sretPtr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end - metadata(sretPtr)["enzyme_type"] = to_md(typetree(Ptr{actualRetType}, ctx, dl, seen), ctx) + metadata(sretPtr)["enzyme_type"] = to_md(typetree(Ptr{actualRetType}, ctx, + dl, seen), ctx) push!(wrapper_args, sretPtr) end if returnRoots && !in(1, parmsRemoved) @@ -4133,7 +4135,8 @@ function lower_convention(functy::Type, mod::LLVM.Module, entry_f::LLVM.Function metadata(ptr)["enzyme_inactive"] = MDNode(LLVM.Metadata[]) end ctx = LLVM.context(entry_f) - metadata(ptr)["enzyme_type"] = to_md(typetree(Ptr{arg.typ}, ctx, dl, seen), ctx) + metadata(ptr)["enzyme_type"] = to_md(typetree(Ptr{arg.typ}, ctx, dl, seen), + ctx) if LLVM.addrspace(ty) != 0 ptr = addrspacecast!(builder, ptr, ty) end diff --git a/src/rules/customrules.jl b/src/rules/customrules.jl index f012400c871..7ec09e2c1d5 100644 --- a/src/rules/customrules.jl +++ b/src/rules/customrules.jl @@ -850,7 +850,7 @@ function enzyme_custom_common_rev(forward::Bool, B, orig::LLVM.CallInst, gutils, idx = 0 dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(orig))))) - Tys2 = (eltype(A) for A in activity[2+isKWCall:end] if A <: Active) + Tys2 = (eltype(A) for A in activity[(2 + isKWCall):end] if A <: Active) seen = TypeTreeTable() for (v, Ty) in zip(actives, Tys2) TT = typetree(Ty, ctx, dl, seen)