Skip to content

Commit

Permalink
Fix partial store
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 5, 2024
1 parent 6606cd9 commit 48ac1f7
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 0 deletions.
26 changes: 26 additions & 0 deletions src/llvm/transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1539,6 +1539,13 @@ function propagate_returned!(mod::LLVM.Module)
end
else
for u in LLVM.uses(un)
u = LLVM.user(u)
if u isa LLVM.CallInst
op = LLVM.called_operand(u)
if op isa LLVM.Function && LLVM.name(op) == "llvm.enzymefakeread"
continue
end
end
hasAnyUse = true
break
end
Expand Down Expand Up @@ -2038,6 +2045,25 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine)
if isempty(blocks(fn))
continue
end

rt = LLVM.return_type(LLVM.function_type(fn))
if rt isa LLVM.PointerType && addrspace(rt) == 10
for u in LLVM.uses(fn)
u = LLVM.user(u)
if isa(u, LLVM.CallInst)
B = IRBuilder()
nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u))
position!(B, nextInst)
cl = call!(B, funcT, rfunc, LLVM.Value[u])
LLVM.API.LLVMAddCallSiteAttribute(
cl,
LLVM.API.LLVMAttributeIndex(1),
EnumAttribute("nocapture"),
)
end
end
end

# Ensure that interprocedural optimizations do not delete the use of returnRoots (or shadows)
# if inactive sret, this will only occur on 2. If active sret, inactive retRoot, can on 3, and
# active both can occur on 4. If the original sret is removed (at index 1) we no longer need
Expand Down
80 changes: 80 additions & 0 deletions test/passes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
using Enzyme, LLVM, Test


@testset "Partial return preservation" begin
LLVM.Context() do ctx
mod = parse(LLVM.Module, """
source_filename = "start"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128-ni:10:11:12:13"
target triple = "x86_64-linux-gnu"
declare noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj({}**, i64, {} addrspace(10)*) local_unnamed_addr #5
define internal fastcc nonnull {} addrspace(10)* @inner({} addrspace(10)* %v1, {} addrspace(10)* %v2) {
top:
%newstruct = call noalias nonnull dereferenceable(16) {} addrspace(10)* @julia.gc_alloc_obj({}** null, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 129778359735376 to {}*) to {} addrspace(10)*)) #30
%a31 = addrspacecast {} addrspace(10)* %newstruct to {} addrspace(10)* addrspace(11)*
%a32 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %a31, i64 1
store atomic {} addrspace(10)* %v1, {} addrspace(10)* addrspace(11)* %a31 release, align 8
%a33 = addrspacecast {} addrspace(10)* %newstruct to i8 addrspace(11)*
%a34 = getelementptr inbounds i8, i8 addrspace(11)* %a33, i64 8
%a35 = bitcast i8 addrspace(11)* %a34 to {} addrspace(10)* addrspace(11)*
store atomic {} addrspace(10)* %v2, {} addrspace(10)* addrspace(11)* %a35 release, align 8
ret {} addrspace(10)* %newstruct
}
define {} addrspace(10)* @caller({} addrspace(10)* %v1, {} addrspace(10)* %v2) {
top:
%ac = call fastcc nonnull {} addrspace(10)* @inner({} addrspace(10)* %v1, {} addrspace(10)* %v2)
%b = addrspacecast {} addrspace(10)* %ac to {} addrspace(10)* addrspace(11)*
%c = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %b unordered, align 8
ret {} addrspace(10)* %c
}
attributes #5 = { inaccessiblememonly mustprogress nofree nounwind willreturn allockind("alloc,uninitialized") allocsize(1) "enzyme_no_escaping_allocation" "enzymejl_world"="31504" }
""")

Enzyme.Compiler.removeDeadArgs!(mod, Enzyme.Compiler.JIT.get_tm())

callfn = LLVM.functions(mod)["inner"]
@test length(collect(filter(Base.Fix2(isa, LLVM.StoreInst), collect(instructions(first(blocks(callfn))))))) == 2
end
end


@testset "Dead return removal" begin
LLVM.Context() do ctx
mod = parse(LLVM.Module, """
source_filename = "start"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128-ni:10:11:12:13"
target triple = "x86_64-linux-gnu"
declare noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj({}**, i64, {} addrspace(10)*) local_unnamed_addr #5
define internal fastcc nonnull {} addrspace(10)* @julia_MyPrognosticVars_161({} addrspace(10)* %v1, {} addrspace(10)* %v2) {
top:
%newstruct = call noalias nonnull dereferenceable(16) {} addrspace(10)* @julia.gc_alloc_obj({}** null, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 129778359735376 to {}*) to {} addrspace(10)*)) #30
%a31 = addrspacecast {} addrspace(10)* %newstruct to {} addrspace(10)* addrspace(11)*
%a32 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %a31, i64 1
store atomic {} addrspace(10)* %v1, {} addrspace(10)* addrspace(11)* %a31 release, align 8
%a33 = addrspacecast {} addrspace(10)* %newstruct to i8 addrspace(11)*
%a34 = getelementptr inbounds i8, i8 addrspace(11)* %a33, i64 8
%a35 = bitcast i8 addrspace(11)* %a34 to {} addrspace(10)* addrspace(11)*
store atomic {} addrspace(10)* %v2, {} addrspace(10)* addrspace(11)* %a35 release, align 8
ret {} addrspace(10)* %newstruct
}
define void @caller({} addrspace(10)* %v1, {} addrspace(10)* %v2) {
top:
%ac = call fastcc nonnull {} addrspace(10)* @julia_MyPrognosticVars_161({} addrspace(10)* %v1, {} addrspace(10)* %v2)
ret void
}
attributes #5 = { inaccessiblememonly mustprogress nofree nounwind willreturn allockind("alloc,uninitialized") allocsize(1) "enzyme_no_escaping_allocation" "enzymejl_world"="31504" }
""")

Enzyme.Compiler.removeDeadArgs!(mod, Enzyme.Compiler.JIT.get_tm())
callfn = LLVM.functions(mod)["caller"]
@test length(collect(instructions(first(blocks(callfn))))) == 1
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ end

include("abi.jl")
include("typetree.jl")
include("passes.jl")
include("optimize.jl")
include("make_zero.jl")

Expand Down

0 comments on commit 48ac1f7

Please sign in to comment.