Skip to content

Commit

Permalink
Nodecay (#2230)
Browse files Browse the repository at this point in the history
* Nodecayed fix

* now with phi cache

* fix
  • Loading branch information
wsmoses authored Dec 26, 2024
1 parent 5afff0b commit 9de3274
Showing 1 changed file with 30 additions and 38 deletions.
68 changes: 30 additions & 38 deletions src/llvm/transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ function nodecayed_phis!(mod::LLVM.Module)
end

v0 = v
@inline function getparent(b::LLVM.IRBuilder, @nospecialize(v::LLVM.Value), @nospecialize(offset::LLVM.Value), hasload::Bool)
@inline function getparent(b::LLVM.IRBuilder, @nospecialize(v::LLVM.Value), @nospecialize(offset::LLVM.Value), hasload::Bool, phicache::Dict{LLVM.PHIInst, Tuple{LLVM.PHIInst, LLVM.PHIInst}})
if addr == 11 && addrspace(value_type(v)) == 10
return v, offset, hasload
end
Expand All @@ -612,7 +612,7 @@ function nodecayed_phis!(mod::LLVM.Module)

if addr == 13 && !hasload
if isa(v, LLVM.LoadInst)
v2, o2, hl2 = getparent(b, operands(v)[1], LLVM.ConstantInt(offty, 0), true)
v2, o2, hl2 = getparent(b, operands(v)[1], LLVM.ConstantInt(offty, 0), true, phicache)
@static if VERSION < v"1.11-"
else
@assert offset == LLVM.ConstantInt(offty, 0)
Expand Down Expand Up @@ -641,7 +641,7 @@ function nodecayed_phis!(mod::LLVM.Module)
cf = LLVM.called_operand(v)
if isa(cf, LLVM.Function) && LLVM.name(cf) == "julia.gc_loaded"
ld = operands(v)[2]
ld0, o0, ol0 = getparent(b, ld, LLVM.ConstantInt(offty, 0), hasload)
ld0, o0, ol0 = getparent(b, ld, LLVM.ConstantInt(offty, 0), hasload, phicache)
v2 = ld0
# v2, o2, hl2 = getparent(b, operands(ld)[1], LLVM.ConstantInt(offty, 0), true)

Expand Down Expand Up @@ -716,7 +716,7 @@ function nodecayed_phis!(mod::LLVM.Module)
preop = operands(preop)[1]
end
v2, offset, skipload =
getparent(b, preop, offset, hasload)
getparent(b, preop, offset, hasload, phicache)
v2 = const_bitcast(
v2,
LLVM.PointerType(
Expand All @@ -730,7 +730,7 @@ function nodecayed_phis!(mod::LLVM.Module)

if opcode(v) == LLVM.API.LLVMGetElementPtr
v2, offset, skipload =
getparent(b, operands(v)[1], offset, hasload)
getparent(b, operands(v)[1], offset, hasload, phicache)
offset = const_add(
offset,
API.EnzymeComputeByteOffsetOfGEP(b, v, offty),
Expand Down Expand Up @@ -758,7 +758,7 @@ function nodecayed_phis!(mod::LLVM.Module)
return v2, offset, hasload
end
nv, noffset, nhasload =
getparent(b, operands(v)[1], offset, hasload)
getparent(b, operands(v)[1], offset, hasload, phicache)
if eltype(value_type(nv)) != eltype(value_type(v))
nv = bitcast!(
b,
Expand All @@ -778,7 +778,7 @@ function nodecayed_phis!(mod::LLVM.Module)
preop = operands(preop)[1]
end
v2, offset, skipload =
getparent(b, preop, offset, hasload)
getparent(b, preop, offset, hasload, phicache)
v2 = bitcast!(
b,
v2,
Expand All @@ -796,7 +796,7 @@ function nodecayed_phis!(mod::LLVM.Module)
operands(v)[2:end],
)
v2, offset, skipload =
getparent(b, operands(v)[1], offset, hasload)
getparent(b, operands(v)[1], offset, hasload, phicache)
v2 = bitcast!(
b,
v2,
Expand All @@ -811,7 +811,7 @@ function nodecayed_phis!(mod::LLVM.Module)

if isa(v, LLVM.GetElementPtrInst)
v2, offset, skipload =
getparent(b, operands(v)[1], offset, hasload)
getparent(b, operands(v)[1], offset, hasload, phicache)
offset = nuwadd!(
b,
offset,
Expand Down Expand Up @@ -850,49 +850,40 @@ function nodecayed_phis!(mod::LLVM.Module)
@static if VERSION < v"1.11-"
else
if addr == 13 && isa(v, LLVM.PHIInst)
if haskey(phicache, v)
return (phicache[v]..., hasload)
end
vs = Union{LLVM.Value, Nothing}[]
offs = Union{LLVM.Value, Nothing}[]
blks = LLVM.BasicBlock[]

B = LLVM.IRBuilder()
position!(B, v)
vphi = phi!(B, value_type(v))
ophi = phi!(B, value_type(offset))
phicache[v] = (vphi, ophi)

for (vt, bb) in LLVM.incoming(v)
b2 = IRBuilder()
position!(b2, terminator(bb))
if vt == v
push!(vs, nothing)
push!(offs, nothing)
else
v2, o2, hl2 = getparent(b2, vt, offset, hasload)
push!(vs, v2)
push!(offs, o2)
end
push!(blks, bb)
end
B = LLVM.IRBuilder()
position!(B, v)
offset = if all(x->offs[1] == x, offs)
offs[1]
else
ophi = phi!(B, value_type(offs[1]))
append!(incoming(ophi), collect(zip(map(x->x isa Nothing ? ophi : x, offs), blks)))
ophi
v2, o2, hl2 = getparent(b2, vt, offset, hasload, phicache)
push!(vs, v2)
push!(offs, o2)
end

nv = if all(x->vs[1] == x, vs)
v[1]
else
ophi = phi!(B, value_type(vs[1]))
append!(incoming(ophi), collect(zip(map(x->x isa Nothing ? ophi : x, vs), blks)))
ophi
end
append!(incoming(ophi), collect(zip(offs, blks)))

append!(incoming(vphi), collect(zip(vs, blks)))

return nv, offset, hasload
return vphi, offset, hasload
end
end

if isa(v, LLVM.SelectInst)
lhs_v, lhs_offset, lhs_skipload =
getparent(b, operands(v)[2], offset, hasload)
getparent(b, operands(v)[2], offset, hasload, phicache)
rhs_v, rhs_offset, rhs_skipload =
getparent(b, operands(v)[3], offset, hasload)
getparent(b, operands(v)[3], offset, hasload, phicache)
if value_type(lhs_v) != value_type(rhs_v) ||
value_type(lhs_offset) != value_type(rhs_offset) ||
lhs_skipload != rhs_skipload
Expand Down Expand Up @@ -935,7 +926,8 @@ function nodecayed_phis!(mod::LLVM.Module)
b = IRBuilder()
position!(b, terminator(pb))

v, offset, hadload = getparent(b, v, LLVM.ConstantInt(offty, 0), false)
phicache = Dict{LLVM.PHIInst, Tuple{LLVM.PHIInst, LLVM.PHIInst}}()
v, offset, hadload = getparent(b, v, LLVM.ConstantInt(offty, 0), false, phicache)

if addr == 13
@assert hadload
Expand Down

0 comments on commit 9de3274

Please sign in to comment.