diff --git a/src/errors.jl b/src/errors.jl index 4ba18219c2..065248a845 100644 --- a/src/errors.jl +++ b/src/errors.jl @@ -605,22 +605,25 @@ end end if isa(cur, LLVM.InsertValueInst) - lhs = make_replacement(operands(cur)[1], prevbb) + B2 = IRBuilder() + position!(B2, LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(cur))) + + lhs = make_replacement(operands(cur)[1], B2) if illegal return ncur end - rhs = make_replacement(operands(cur)[2], prevbb) + rhs = make_replacement(operands(cur)[2], B2) if illegal return ncur end if lhs == operands(cur)[1] && rhs == operands(cur)[2] - return make_batched(ncur, prevbb) + return make_batched(ncur, cur) end inds = LLVM.API.LLVMGetIndices(cur.ref) ninds = LLVM.API.LLVMGetNumIndices(cur.ref) jinds = Cuint[unsafe_load(inds, i) for i = 1:ninds] if width == 1 - nv = API.EnzymeInsertValue(prevbb, lhs, rhs, jinds) + nv = API.EnzymeInsertValue(B2, lhs, rhs, jinds) push!(created, nv) seen[cur] = nv return nv @@ -630,9 +633,9 @@ end jindsv = copy(jinds) pushfirst!(jindsv, idx - 1) shadowres = API.EnzymeInsertValue( - prevbb, + B2, shadowres, - extract_value!(prevbb, rhs, idx - 1), + extract_value!(B2, rhs, idx - 1), jindsv, ) if isa(shadowres, LLVM.Instruction)