Skip to content

Commit

Permalink
improve gep inbounds
Browse files Browse the repository at this point in the history
  • Loading branch information
nunoplopes committed Jan 9, 2024
1 parent 90141d0 commit 96c44da
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 47 deletions.
12 changes: 8 additions & 4 deletions ir/instr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3668,12 +3668,14 @@ StateValue GEP::toSMT(State &s) const {
auto scalar = [&](const StateValue &ptrval,
vector<pair<uint64_t, StateValue>> &offsets) -> StateValue {
Pointer ptr(s.getMemory(), ptrval.value);
Pointer ptr_log = inbounds ? ptr.toLogical().first : ptr;
AndExpr non_poison(ptrval.non_poison);
AndExpr inbounds_np;
AndExpr idx_all_zeros;

// FIXME: this is only partially implemented for physical pointers
if (inbounds)
inbounds_np.add(ptr.inbounds());
inbounds_np.add(ptr_log.inbounds());

for (auto &[sz, idx] : offsets) {
auto &[v, np] = idx;
Expand All @@ -3687,7 +3689,7 @@ StateValue GEP::toSMT(State &s) const {
non_poison.add(val.sextOrTrunc(v.bits()) == v);
}
non_poison.add(multiplier.mul_no_soverflow(val));
non_poison.add(ptr.addNoOverflow(inc));
non_poison.add(ptr_log.addNoOverflow(inc));
}

#ifndef NDEBUG
Expand All @@ -3699,8 +3701,10 @@ StateValue GEP::toSMT(State &s) const {
ptr += inc;
non_poison.add(np);

if (inbounds)
inbounds_np.add(ptr.inbounds());
if (inbounds) {
ptr_log += inc;
inbounds_np.add(ptr_log.inbounds());
}
}

if (inbounds) {
Expand Down
19 changes: 0 additions & 19 deletions ir/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2166,25 +2166,6 @@ expr Memory::ptr2int(const expr &ptr) const {
return p.getAddress();
}

Pointer Memory::searchPointer(const expr &val0) const {
DisjointExpr<Pointer> ret;
expr val = val0.zextOrTrunc(bits_ptr_address);

auto add = [&](unsigned limit, bool local) {
for (unsigned i = 0; i != limit; ++i) {
Pointer p(*this, i, local);
Pointer p_end = p + p.blockSize();
ret.add(p + (val - p.getAddress()),
!local && i == 0 && has_null_block
? val == 0
: val.uge(p.getAddress()) && val.ult(p_end.getAddress()));
}
};
add(numLocals(), true);
add(numNonlocals(), false);
return *std::move(ret)();
}

expr Memory::int2ptr(const expr &val) const {
assert(!memory_unused() && observesAddresses());
return
Expand Down
1 change: 0 additions & 1 deletion ir/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,6 @@ class Memory {

smt::expr ptr2int(const smt::expr &ptr) const;
smt::expr int2ptr(const smt::expr &val) const;
Pointer searchPointer(const smt::expr &val) const;

std::tuple<smt::expr, Pointer, std::set<smt::expr>>
refined(const Memory &other, bool fncall,
Expand Down
63 changes: 42 additions & 21 deletions ir/pointer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ expr Pointer::isLocal(bool simplify) const {
if (m.numNonlocals() == 0)
return true;

return toLogical().isLogLocal(simplify);
return toLogical().first.isLogLocal(simplify);
}

expr Pointer::isConstGlobal() const {
Expand Down Expand Up @@ -264,15 +264,15 @@ expr Pointer::getLogOffset() const {
}

expr Pointer::getBid() const {
return toLogical().getLogBid();
return toLogical().first.getLogBid();
}

expr Pointer::getShortBid() const {
return toLogical().getLogShortBid();
return toLogical().first.getLogShortBid();
}

expr Pointer::getOffset() const {
return toLogical().getLogOffset();
return toLogical().first.getLogOffset();
}

expr Pointer::getOffsetSizet() const {
Expand All @@ -281,8 +281,8 @@ expr Pointer::getOffsetSizet() const {
}

expr Pointer::getShortOffset() const {
return toLogical().p.extract(bits_for_offset + bits_for_ptrattrs - 1,
bits_for_ptrattrs + zeroBitsShortOffset());
return toLogical().first.p.extract(bits_for_offset + bits_for_ptrattrs - 1,
bits_for_ptrattrs + zeroBitsShortOffset());
}

expr Pointer::getAttrs() const {
Expand Down Expand Up @@ -608,12 +608,14 @@ expr Pointer::isHeapAllocated() const {

expr Pointer::refined(const Pointer &other) const {
bool is_asm = other.m.isAsmMode();
auto [p1l, d1] = toLogical();
auto [p2l, d2] = other.toLogical();

// This refers to a block that was malloc'ed within the function
expr local = other.isLocal();
local &= getAllocType() == other.getAllocType();
local &= blockSize() == other.blockSize();
local &= getOffset() == other.getOffset();
expr local = p2l.isLocal();
local &= p1l.getAllocType() == p2l.getAllocType();
local &= p1l.blockSize() == p2l.blockSize();
local &= p1l.getOffset() == p2l.getOffset();
// Attributes are ignored at refinement.

// TODO: this induces an infinite loop
Expand All @@ -628,9 +630,10 @@ expr Pointer::refined(const Pointer &other) const {
getAddress() == other.getAddress());

return expr::mkIf(isNull(), other.isNull(),
expr::mkIf(isLocal(), std::move(local), nonlocal) &&
expr::mkIf(p1l.isLocal(), std::move(local), nonlocal) &&
(is_asm ? expr(true)
: isBlockAlive().implies(other.isBlockAlive())));
: (d1 && p1l.isBlockAlive())
.implies(p2l.isBlockAlive())));
}

expr Pointer::fninputRefined(const Pointer &other, set<expr> &undef,
Expand Down Expand Up @@ -750,13 +753,31 @@ expr Pointer::isNull() const {
return *this == mkNullPointer(m);
}

Pointer Pointer::toLogical() const {
pair<Pointer, expr> Pointer::findLogicalPointer(const expr &addr) const {
DisjointExpr<Pointer> ret;
expr val = addr.zextOrTrunc(bits_ptr_address);

auto add = [&](unsigned limit, bool local) {
for (unsigned i = 0; i != limit; ++i) {
Pointer p(m, i, local);
Pointer p_end = p + p.blockSize();
ret.add(p + (val - p.getAddress()),
!local && i == 0 && has_null_block
? val == 0
: val.uge(p.getAddress()) && val.ult(p_end.getAddress()));
}
};
add(m.numLocals(), true);
add(m.numNonlocals(), false);
return { *std::move(ret)(), ret.domain() };
}

pair<Pointer, expr> Pointer::toLogical() const {
if (isLogical().isTrue())
return *this;
return { *this, true };

DisjointExpr<Pointer> ret;
DisjointExpr<expr> leftover;
OrExpr leftover_domain;

// Try to optimize the conversion
for (auto [e, cond] : DisjointExpr<expr>(p, 5)) {
Expand Down Expand Up @@ -793,16 +814,16 @@ Pointer Pointer::toLogical() const {
.sextOrTrunc(bits_for_offset);
ret.add(Pointer(m, bid, offset), std::move(cond));
} else {
leftover.add(std::move(e), cond);
leftover_domain.add(std::move(cond));
leftover.add(std::move(e), std::move(cond));
}
}

if (!leftover_domain.empty())
ret.add(m.searchPointer(*std::move(leftover)()),
std::move(leftover_domain)());
if (!leftover.empty()) {
auto [ptr, domain] = findLogicalPointer(*std::move(leftover)());
ret.add(std::move(ptr), leftover.domain() && domain);
}

return mkIf(isLogical(), *this, *std::move(ret)());
return { mkIf(isLogical(), *this, *std::move(ret)()), ret.domain() };
}

Pointer
Expand Down
6 changes: 4 additions & 2 deletions ir/pointer.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class Pointer {
const smt::FunctionExpr &nonlocal_fn,
const smt::expr &ret_type, bool src_name = false) const;

Pointer toLogical() const;

public:
Pointer(const Memory &m, const smt::expr &bid, const smt::expr &offset,
const smt::expr &attr);
Expand All @@ -57,6 +55,10 @@ class Pointer {
Pointer(Pointer &&other) noexcept = default;
void operator=(Pointer &&rhs) noexcept { p = std::move(rhs.p); }

// returns (log-ptr, domain of inboundness)
std::pair<Pointer, smt::expr> findLogicalPointer(const smt::expr &addr) const;
std::pair<Pointer, smt::expr> toLogical() const;

static smt::expr mkLongBid(const smt::expr &short_bid, bool local);
static smt::expr mkUndef(State &s);

Expand Down
5 changes: 5 additions & 0 deletions smt/exprs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ ostream &operator<<(ostream &os, const AndExpr &e) {
}


void OrExpr::add(const expr &e) {
if (!e.isFalse())
exprs.emplace(e);
}

void OrExpr::add(expr &&e) {
if (!e.isFalse())
exprs.insert(std::move(e));
Expand Down
10 changes: 10 additions & 0 deletions smt/exprs.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class OrExpr {
std::set<expr> exprs;

public:
void add(const expr &e);
void add(expr &&e);
void add(const OrExpr &other);
expr operator()() const;
Expand Down Expand Up @@ -117,6 +118,14 @@ class DisjointExpr {
return std::move(default_val);
}

expr domain() const {
OrExpr ret;
for (auto &[val, domain] : vals) {
ret.add(domain);
}
return std::move(ret)();
}

std::optional<T> lookup(const expr &domain) const {
for (auto &[v, d] : vals) {
if (d.eq(domain))
Expand All @@ -128,6 +137,7 @@ class DisjointExpr {
auto begin() const { return vals.begin(); }
auto end() const { return vals.end(); }
auto size() const { return vals.size(); }
bool empty() const { return vals.empty() && !default_val; }
};


Expand Down

0 comments on commit 96c44da

Please sign in to comment.