From 74bd584a3fb92c57bacab9b9798c3d86bd639ec0 Mon Sep 17 00:00:00 2001 From: Nuno Lopes Date: Mon, 29 Apr 2024 09:34:53 +0100 Subject: [PATCH] memset doesn't extend the deref check to a multiple of alignment According to LangRef, although this is incompatible with load/store semantics --- ir/instr.cpp | 2 +- ir/memory.cpp | 2 +- ir/pointer.cpp | 13 ++++++++----- ir/pointer.h | 6 ++++-- tests/alive-tv/memory/memset-align2.srctgt.ll | 15 +++++++++++++++ 5 files changed, 29 insertions(+), 9 deletions(-) create mode 100644 tests/alive-tv/memory/memset-align2.srctgt.ll diff --git a/ir/instr.cpp b/ir/instr.cpp index 19a722b5a..0357f8210 100644 --- a/ir/instr.cpp +++ b/ir/instr.cpp @@ -3992,7 +3992,7 @@ DEFINE_AS_RETZEROALIGN(Memset, getMaxAllocSize); DEFINE_AS_RETZERO(Memset, getMaxGEPOffset); uint64_t Memset::getMaxAccessSize() const { - return round_up(getIntOr(*bytes, UINT64_MAX), align); + return getIntOr(*bytes, UINT64_MAX); } MemInstr::ByteAccessInfo Memset::getByteAccessInfo() const { diff --git a/ir/memory.cpp b/ir/memory.cpp index a6fe9be1d..de2ae92d6 100644 --- a/ir/memory.cpp +++ b/ir/memory.cpp @@ -2022,7 +2022,7 @@ void Memory::memset(const expr &p, const StateValue &val, const expr &bytesize, unsigned bytesz = bits_byte / 8; Pointer ptr(*this, p); if (deref_check) - state->addUB(ptr.isDereferenceable(bytesize, align, true)); + state->addUB(ptr.isDereferenceable(bytesize, align, true, false, false)); auto wval = val; for (unsigned i = 1; i < bytesz; ++i) { diff --git a/ir/pointer.cpp b/ir/pointer.cpp index 9b029790c..dadd66d1b 100644 --- a/ir/pointer.cpp +++ b/ir/pointer.cpp @@ -427,9 +427,11 @@ static pair is_dereferenceable(Pointer &p, // When bytes is 0, pointer is always dereferenceable pair Pointer::isDereferenceable(const expr &bytes0, uint64_t align, - bool iswrite, bool ignore_accessability) { - expr bytes = bytes0.zextOrTrunc(bits_size_t) - .round_up(expr::mkUInt(align, bits_size_t)); + bool iswrite, bool ignore_accessability, + bool round_size_to_align) { + expr bytes = bytes0.zextOrTrunc(bits_size_t); + if (round_size_to_align) + bytes = bytes.round_up(expr::mkUInt(align, bits_size_t)); expr bytes_off = bytes.zextOrTrunc(bits_for_offset); DisjointExpr UB(expr(false)), is_aligned(expr(false)), all_ptrs; @@ -466,9 +468,10 @@ Pointer::isDereferenceable(const expr &bytes0, uint64_t align, pair Pointer::isDereferenceable(uint64_t bytes, uint64_t align, - bool iswrite, bool ignore_accessability) { + bool iswrite, bool ignore_accessability, + bool round_size_to_align) { return isDereferenceable(expr::mkUInt(bytes, bits_size_t), align, iswrite, - ignore_accessability); + ignore_accessability, round_size_to_align); } // This function assumes that both begin + len don't overflow diff --git a/ir/pointer.h b/ir/pointer.h index 8fbbc9abc..73c00982d 100644 --- a/ir/pointer.h +++ b/ir/pointer.h @@ -102,10 +102,12 @@ class Pointer { smt::expr isAligned(const smt::expr &align); std::pair isDereferenceable(uint64_t bytes, uint64_t align, bool iswrite = false, - bool ignore_accessability = false); + bool ignore_accessability = false, + bool round_size_to_align = true); std::pair isDereferenceable(const smt::expr &bytes, uint64_t align, bool iswrite, - bool ignore_accessability = false); + bool ignore_accessability = false, + bool round_size_to_align = true); void isDisjointOrEqual(const smt::expr &len1, const Pointer &ptr2, const smt::expr &len2) const; diff --git a/tests/alive-tv/memory/memset-align2.srctgt.ll b/tests/alive-tv/memory/memset-align2.srctgt.ll new file mode 100644 index 000000000..db9fb6d7e --- /dev/null +++ b/tests/alive-tv/memory/memset-align2.srctgt.ll @@ -0,0 +1,15 @@ +define void @src(ptr %P) { + %arrayidx = getelementptr inbounds i32, ptr %P, i64 1 + store i32 0, ptr %arrayidx, align 4 + %add.ptr = getelementptr inbounds i32, ptr %P, i64 2 + tail call void @llvm.memset.p0.i64(ptr %add.ptr, i8 0, i64 11, i1 false) + ret void +} + +define void @tgt(ptr %P) { + %arrayidx = getelementptr inbounds i8, ptr %P, i64 4 + tail call void @llvm.memset.p0.i64(ptr align(4) %arrayidx, i8 0, i64 15, i1 false) + ret void +} + +declare void @llvm.memset.p0.i64(ptr, i8, i64, i1)