diff --git a/src/llvm/transforms.jl b/src/llvm/transforms.jl index 935aea85b5..6411f8337f 100644 --- a/src/llvm/transforms.jl +++ b/src/llvm/transforms.jl @@ -602,7 +602,7 @@ function nodecayed_phis!(mod::LLVM.Module) end v0 = v - @inline function getparent(b::LLVM.IRBuilder, @nospecialize(v::LLVM.Value), @nospecialize(offset::LLVM.Value), hasload::Bool) + @inline function getparent(b::LLVM.IRBuilder, @nospecialize(v::LLVM.Value), @nospecialize(offset::LLVM.Value), hasload::Bool, phicache::Dict{LLVM.PHIInst, Tuple{LLVM.PHIInst, LLVM.PHIInst}}) if addr == 11 && addrspace(value_type(v)) == 10 return v, offset, hasload end @@ -612,7 +612,7 @@ function nodecayed_phis!(mod::LLVM.Module) if addr == 13 && !hasload if isa(v, LLVM.LoadInst) - v2, o2, hl2 = getparent(b, operands(v)[1], LLVM.ConstantInt(offty, 0), true) + v2, o2, hl2 = getparent(b, operands(v)[1], LLVM.ConstantInt(offty, 0), true, phicache) @static if VERSION < v"1.11-" else @assert offset == LLVM.ConstantInt(offty, 0) @@ -641,7 +641,7 @@ function nodecayed_phis!(mod::LLVM.Module) cf = LLVM.called_operand(v) if isa(cf, LLVM.Function) && LLVM.name(cf) == "julia.gc_loaded" ld = operands(v)[2] - ld0, o0, ol0 = getparent(b, ld, LLVM.ConstantInt(offty, 0), hasload) + ld0, o0, ol0 = getparent(b, ld, LLVM.ConstantInt(offty, 0), hasload, phicache) v2 = ld0 # v2, o2, hl2 = getparent(b, operands(ld)[1], LLVM.ConstantInt(offty, 0), true) @@ -716,7 +716,7 @@ function nodecayed_phis!(mod::LLVM.Module) preop = operands(preop)[1] end v2, offset, skipload = - getparent(b, preop, offset, hasload) + getparent(b, preop, offset, hasload, phicache) v2 = const_bitcast( v2, LLVM.PointerType( @@ -730,7 +730,7 @@ function nodecayed_phis!(mod::LLVM.Module) if opcode(v) == LLVM.API.LLVMGetElementPtr v2, offset, skipload = - getparent(b, operands(v)[1], offset, hasload) + getparent(b, operands(v)[1], offset, hasload, phicache) offset = const_add( offset, API.EnzymeComputeByteOffsetOfGEP(b, v, offty), @@ -758,7 +758,7 @@ function nodecayed_phis!(mod::LLVM.Module) return v2, offset, hasload end nv, noffset, nhasload = - getparent(b, operands(v)[1], offset, hasload) + getparent(b, operands(v)[1], offset, hasload, phicache) if eltype(value_type(nv)) != eltype(value_type(v)) nv = bitcast!( b, @@ -778,7 +778,7 @@ function nodecayed_phis!(mod::LLVM.Module) preop = operands(preop)[1] end v2, offset, skipload = - getparent(b, preop, offset, hasload) + getparent(b, preop, offset, hasload, phicache) v2 = bitcast!( b, v2, @@ -796,7 +796,7 @@ function nodecayed_phis!(mod::LLVM.Module) operands(v)[2:end], ) v2, offset, skipload = - getparent(b, operands(v)[1], offset, hasload) + getparent(b, operands(v)[1], offset, hasload, phicache) v2 = bitcast!( b, v2, @@ -811,7 +811,7 @@ function nodecayed_phis!(mod::LLVM.Module) if isa(v, LLVM.GetElementPtrInst) v2, offset, skipload = - getparent(b, operands(v)[1], offset, hasload) + getparent(b, operands(v)[1], offset, hasload, phicache) offset = nuwadd!( b, offset, @@ -850,49 +850,40 @@ function nodecayed_phis!(mod::LLVM.Module) @static if VERSION < v"1.11-" else if addr == 13 && isa(v, LLVM.PHIInst) + if haskey(phicache, v) + return (phicache[v]..., hasload) + end vs = Union{LLVM.Value, Nothing}[] offs = Union{LLVM.Value, Nothing}[] blks = LLVM.BasicBlock[] + + B = LLVM.IRBuilder() + position!(B, v) + vphi = phi!(B, value_type(v)) + ophi = phi!(B, value_type(offset)) + phicache[v] = (vphi, ophi) + for (vt, bb) in LLVM.incoming(v) b2 = IRBuilder() position!(b2, terminator(bb)) - if vt == v - push!(vs, nothing) - push!(offs, nothing) - else - v2, o2, hl2 = getparent(b2, vt, offset, hasload) - push!(vs, v2) - push!(offs, o2) - end - push!(blks, bb) - end - B = LLVM.IRBuilder() - position!(B, v) - offset = if all(x->offs[1] == x, offs) - offs[1] - else - ophi = phi!(B, value_type(offs[1])) - append!(incoming(ophi), collect(zip(map(x->x isa Nothing ? ophi : x, offs), blks))) - ophi + v2, o2, hl2 = getparent(b2, vt, offset, hasload, phicache) + push!(vs, v2) + push!(offs, o2) end - nv = if all(x->vs[1] == x, vs) - v[1] - else - ophi = phi!(B, value_type(vs[1])) - append!(incoming(ophi), collect(zip(map(x->x isa Nothing ? ophi : x, vs), blks))) - ophi - end + append!(incoming(ophi), collect(zip(offs, blks))) + + append!(incoming(vphi), collect(zip(vs, blks))) - return nv, offset, hasload + return vphi, offset, hasload end end if isa(v, LLVM.SelectInst) lhs_v, lhs_offset, lhs_skipload = - getparent(b, operands(v)[2], offset, hasload) + getparent(b, operands(v)[2], offset, hasload, phicache) rhs_v, rhs_offset, rhs_skipload = - getparent(b, operands(v)[3], offset, hasload) + getparent(b, operands(v)[3], offset, hasload, phicache) if value_type(lhs_v) != value_type(rhs_v) || value_type(lhs_offset) != value_type(rhs_offset) || lhs_skipload != rhs_skipload @@ -935,7 +926,8 @@ function nodecayed_phis!(mod::LLVM.Module) b = IRBuilder() position!(b, terminator(pb)) - v, offset, hadload = getparent(b, v, LLVM.ConstantInt(offty, 0), false) + phicache = Dict{LLVM.PHIInst, Tuple{LLVM.PHIInst, LLVM.PHIInst}}() + v, offset, hadload = getparent(b, v, LLVM.ConstantInt(offty, 0), false, phicache) if addr == 13 @assert hadload