Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix partial store #2172

Merged
merged 2 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/llvm/transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1539,6 +1539,13 @@ function propagate_returned!(mod::LLVM.Module)
end
else
for u in LLVM.uses(un)
u = LLVM.user(u)
if u isa LLVM.CallInst
op = LLVM.called_operand(u)
if op isa LLVM.Function && LLVM.name(op) == "llvm.enzymefakeread"
continue
end
end
hasAnyUse = true
break
end
Expand Down Expand Up @@ -1611,6 +1618,12 @@ end

function delete_writes_into_removed_args(fn::LLVM.Function, toremove::Vector{Int64}, keepret::Bool)
args = collect(parameters(fn))
if !keepret
for u in LLVM.uses(fn)
u = LLVM.user(u)
replace_uses!(u, LLVM.UndefValue(value_type(u)))
end
end
for tr in toremove
tr = tr + 1
todorep = Tuple{LLVM.Instruction, LLVM.Value}[]
Expand Down Expand Up @@ -2038,6 +2051,25 @@ function removeDeadArgs!(mod::LLVM.Module, tm::LLVM.TargetMachine)
if isempty(blocks(fn))
continue
end

rt = LLVM.return_type(LLVM.function_type(fn))
if rt isa LLVM.PointerType && addrspace(rt) == 10
for u in LLVM.uses(fn)
u = LLVM.user(u)
if isa(u, LLVM.CallInst)
B = IRBuilder()
nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u))
position!(B, nextInst)
cl = call!(B, funcT, rfunc, LLVM.Value[u])
LLVM.API.LLVMAddCallSiteAttribute(
cl,
LLVM.API.LLVMAttributeIndex(1),
EnumAttribute("nocapture"),
)
end
end
end

# Ensure that interprocedural optimizations do not delete the use of returnRoots (or shadows)
# if inactive sret, this will only occur on 2. If active sret, inactive retRoot, can on 3, and
# active both can occur on 4. If the original sret is removed (at index 1) we no longer need
Expand Down
80 changes: 80 additions & 0 deletions test/passes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
using Enzyme, LLVM, Test


@testset "Partial return preservation" begin
LLVM.Context() do ctx
mod = parse(LLVM.Module, """
source_filename = "start"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128-ni:10:11:12:13"
target triple = "x86_64-linux-gnu"

declare noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj({}**, i64, {} addrspace(10)*) local_unnamed_addr #5

define internal fastcc nonnull {} addrspace(10)* @inner({} addrspace(10)* %v1, {} addrspace(10)* %v2) {
top:
%newstruct = call noalias nonnull dereferenceable(16) {} addrspace(10)* @julia.gc_alloc_obj({}** null, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 129778359735376 to {}*) to {} addrspace(10)*)) #30
%a31 = addrspacecast {} addrspace(10)* %newstruct to {} addrspace(10)* addrspace(11)*
%a32 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %a31, i64 1
store atomic {} addrspace(10)* %v1, {} addrspace(10)* addrspace(11)* %a31 release, align 8
%a33 = addrspacecast {} addrspace(10)* %newstruct to i8 addrspace(11)*
%a34 = getelementptr inbounds i8, i8 addrspace(11)* %a33, i64 8
%a35 = bitcast i8 addrspace(11)* %a34 to {} addrspace(10)* addrspace(11)*
store atomic {} addrspace(10)* %v2, {} addrspace(10)* addrspace(11)* %a35 release, align 8
ret {} addrspace(10)* %newstruct
}

define {} addrspace(10)* @caller({} addrspace(10)* %v1, {} addrspace(10)* %v2) {
top:
%ac = call fastcc nonnull {} addrspace(10)* @inner({} addrspace(10)* %v1, {} addrspace(10)* %v2)
%b = addrspacecast {} addrspace(10)* %ac to {} addrspace(10)* addrspace(11)*
%c = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %b unordered, align 8
ret {} addrspace(10)* %c
}

attributes #5 = { inaccessiblememonly mustprogress nofree nounwind willreturn allockind("alloc,uninitialized") allocsize(1) "enzyme_no_escaping_allocation" "enzymejl_world"="31504" }
""")

Enzyme.Compiler.removeDeadArgs!(mod, Enzyme.Compiler.JIT.get_tm())

callfn = LLVM.functions(mod)["inner"]
@test length(collect(filter(Base.Fix2(isa, LLVM.StoreInst), collect(instructions(first(blocks(callfn))))))) == 2
end
end


@testset "Dead return removal" begin
LLVM.Context() do ctx
mod = parse(LLVM.Module, """
source_filename = "start"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128-ni:10:11:12:13"
target triple = "x86_64-linux-gnu"

declare noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj({}**, i64, {} addrspace(10)*) local_unnamed_addr #5

define internal fastcc nonnull {} addrspace(10)* @julia_MyPrognosticVars_161({} addrspace(10)* %v1, {} addrspace(10)* %v2) {
top:
%newstruct = call noalias nonnull dereferenceable(16) {} addrspace(10)* @julia.gc_alloc_obj({}** null, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 129778359735376 to {}*) to {} addrspace(10)*)) #30
%a31 = addrspacecast {} addrspace(10)* %newstruct to {} addrspace(10)* addrspace(11)*
%a32 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %a31, i64 1
store atomic {} addrspace(10)* %v1, {} addrspace(10)* addrspace(11)* %a31 release, align 8
%a33 = addrspacecast {} addrspace(10)* %newstruct to i8 addrspace(11)*
%a34 = getelementptr inbounds i8, i8 addrspace(11)* %a33, i64 8
%a35 = bitcast i8 addrspace(11)* %a34 to {} addrspace(10)* addrspace(11)*
store atomic {} addrspace(10)* %v2, {} addrspace(10)* addrspace(11)* %a35 release, align 8
ret {} addrspace(10)* %newstruct
}

define void @caller({} addrspace(10)* %v1, {} addrspace(10)* %v2) {
top:
%ac = call fastcc nonnull {} addrspace(10)* @julia_MyPrognosticVars_161({} addrspace(10)* %v1, {} addrspace(10)* %v2)
ret void
}

attributes #5 = { inaccessiblememonly mustprogress nofree nounwind willreturn allockind("alloc,uninitialized") allocsize(1) "enzyme_no_escaping_allocation" "enzymejl_world"="31504" }
""")

Enzyme.Compiler.removeDeadArgs!(mod, Enzyme.Compiler.JIT.get_tm())
callfn = LLVM.functions(mod)["caller"]
@test length(collect(instructions(first(blocks(callfn))))) == 1
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ end

include("abi.jl")
include("typetree.jl")
include("passes.jl")
include("optimize.jl")
include("make_zero.jl")

Expand Down
Loading