From fb8398b96449fbd0cf303663a4bd7582d2d00384 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Wed, 18 Dec 2024 13:37:28 -0600 Subject: [PATCH] Fix insertion point of select replacement --- src/errors.jl | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/errors.jl b/src/errors.jl index 4ba18219c2..84a4cb0652 100644 --- a/src/errors.jl +++ b/src/errors.jl @@ -605,22 +605,25 @@ end end if isa(cur, LLVM.InsertValueInst) - lhs = make_replacement(operands(cur)[1], prevbb) + B2 = IRBuilder() + position!(B2, LLVM.API.LLVMGetNextInstruction(cur)) + + lhs = make_replacement(operands(cur)[1], B2) if illegal return ncur end - rhs = make_replacement(operands(cur)[2], prevbb) + rhs = make_replacement(operands(cur)[2], B2) if illegal return ncur end if lhs == operands(cur)[1] && rhs == operands(cur)[2] - return make_batched(ncur, prevbb) + return make_batched(ncur, cur) end inds = LLVM.API.LLVMGetIndices(cur.ref) ninds = LLVM.API.LLVMGetNumIndices(cur.ref) jinds = Cuint[unsafe_load(inds, i) for i = 1:ninds] if width == 1 - nv = API.EnzymeInsertValue(prevbb, lhs, rhs, jinds) + nv = API.EnzymeInsertValue(B2, lhs, rhs, jinds) push!(created, nv) seen[cur] = nv return nv @@ -630,9 +633,9 @@ end jindsv = copy(jinds) pushfirst!(jindsv, idx - 1) shadowres = API.EnzymeInsertValue( - prevbb, + B2, shadowres, - extract_value!(prevbb, rhs, idx - 1), + extract_value!(B2, rhs, idx - 1), jindsv, ) if isa(shadowres, LLVM.Instruction)