From 60b21a49b8d8c1e0f49dad2e7a634a8327c9f958 Mon Sep 17 00:00:00 2001 From: Yurii Kostyukov Date: Tue, 12 Sep 2023 12:34:41 +0300 Subject: [PATCH] [feat] Z3 Tree incremental solver --- include/klee/ADT/Incremental.h | 358 +++++++++ include/klee/Expr/Constraints.h | 1 + include/klee/Solver/IncompleteSolver.h | 1 + include/klee/Solver/Solver.h | 9 + include/klee/Solver/SolverCmdLine.h | 3 + include/klee/Solver/SolverImpl.h | 2 + include/klee/Solver/SolverUtil.h | 35 +- lib/Core/ExecutionState.cpp | 8 +- lib/Core/ExecutionState.h | 7 +- lib/Core/Executor.cpp | 11 +- lib/Core/ImpliedValue.cpp | 7 +- lib/Core/TimingSolver.cpp | 58 +- lib/Core/TimingSolver.h | 4 + lib/Expr/Constraints.cpp | 3 + lib/Solver/AssignmentValidatingSolver.cpp | 15 +- lib/Solver/CachingSolver.cpp | 5 + lib/Solver/CexCachingSolver.cpp | 5 + lib/Solver/ConcretizingSolver.cpp | 20 +- lib/Solver/CoreSolver.cpp | 11 +- lib/Solver/DummySolver.cpp | 3 + lib/Solver/IncompleteSolver.cpp | 11 +- lib/Solver/IndependentSolver.cpp | 12 +- lib/Solver/MetaSMTSolver.cpp | 1 + lib/Solver/QueryLoggingSolver.cpp | 10 +- lib/Solver/QueryLoggingSolver.h | 1 + lib/Solver/STPSolver.cpp | 1 + lib/Solver/Solver.cpp | 19 +- lib/Solver/SolverCmdLine.cpp | 18 +- lib/Solver/SolverImpl.cpp | 7 +- lib/Solver/ValidatingSolver.cpp | 12 +- lib/Solver/Z3Builder.cpp | 4 +- lib/Solver/Z3Builder.h | 14 +- lib/Solver/Z3Solver.cpp | 854 +++++++++++++++++----- lib/Solver/Z3Solver.h | 13 +- test/Solver/CrosscheckZ3AndZ3TreeInc.c | 11 + tools/kleaver/main.cpp | 35 +- unittests/Solver/SolverTest.cpp | 6 +- unittests/Solver/Z3SolverTest.cpp | 2 +- 38 files changed, 1261 insertions(+), 336 deletions(-) create mode 100644 include/klee/ADT/Incremental.h create mode 100644 test/Solver/CrosscheckZ3AndZ3TreeInc.c diff --git a/include/klee/ADT/Incremental.h b/include/klee/ADT/Incremental.h new file mode 100644 index 00000000000..b2e84173505 --- /dev/null +++ b/include/klee/ADT/Incremental.h @@ -0,0 +1,358 @@ +//===---- Incremental.h -----------------------------------------*- C++ -*-===// +// +// The KLEE Symbolic Virtual Machine +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef KLEE_INCREMENTAL_H +#define KLEE_INCREMENTAL_H + +#include +#include +#include +#include +#include +#include + +#include "klee/Expr/ExprUtil.h" + +namespace klee { + +template +void extend(std::vector<_Tp, _Alloc> &ths, + const std::vector<_Tp, _Alloc> &other) { + ths.reserve(ths.size() + other.size()); + ths.insert(ths.end(), other.begin(), other.end()); +} + +template > +class inc_vector { +public: + using vec = std::vector<_Tp, _Alloc>; + using frame_size_it = std::vector::const_iterator; + using frame_it = typename vec::const_iterator; + + /// It is public, so that all vector operations are supported + /// Everything pushed to v is pushed to the last frame + vec v; + + std::vector frame_sizes; + +private: + // v.size() == sum(frame_sizes) + size of the fresh frame + + size_t freshFrameSize() const { + return v.size() - + std::accumulate(frame_sizes.begin(), frame_sizes.end(), 0); + } + + void take(size_t n, size_t &frames_count, size_t &frame_index) const { + size_t i = 0; + size_t c = n; + for (; i < frame_sizes.size(); i++) { + if (frame_sizes[i] > c) + break; + c -= frame_sizes[i]; + } + frames_count = c; + frame_index = i; + } + +public: + inc_vector() {} + inc_vector(const std::vector<_Tp> &constraints) : v(constraints) {} + + void clear() { + v.clear(); + frame_sizes.clear(); + } + + frame_size_it begin() const { return frame_sizes.cbegin(); } + frame_size_it end() const { return frame_sizes.cend(); } + size_t framesSize() const { return frame_sizes.size() + 1; } + + frame_it begin(int frame_index) const { + assert(-(long long)framesSize() <= (long long)frame_index && + (long long)frame_index <= (long long)framesSize()); + if (frame_index < 0) + frame_index += framesSize(); + if ((long long)frame_index == (long long)framesSize()) + return v.end(); + auto fend = frame_sizes.begin() + frame_index; + auto shift = std::accumulate(frame_sizes.begin(), fend, 0); + return v.begin() + shift; + } + frame_it end(int frame_index) const { return begin(frame_index + 1); } + size_t size(size_t frame_index) const { + assert(frame_index < framesSize()); + if (frame_index == framesSize() - 1) // last frame + return freshFrameSize(); + return frame_sizes[frame_index]; + } + + void pop(size_t popFrames) { + assert(freshFrameSize() == 0); + if (popFrames == 0) + return; + size_t toPop = + std::accumulate(frame_sizes.end() - popFrames, frame_sizes.end(), 0); + v.resize(v.size() - toPop); + frame_sizes.resize(frame_sizes.size() - popFrames); + } + + void push() { + auto freshSize = freshFrameSize(); + frame_sizes.push_back(freshSize); + assert(freshFrameSize() == 0); + } + + /// ensures that last frame is empty + void extend(const std::vector<_Tp, _Alloc> &other) { + assert(freshFrameSize() == 0); + // push(); + klee::extend(v, other); + push(); + } + + /// ensures that last frame is empty + void extend(const inc_vector<_Tp, _Alloc> &other) { + assert(freshFrameSize() == 0); + for (size_t i = 0, e = other.framesSize(); i < e; i++) { + v.reserve(v.size() + other.size(i)); + v.insert(v.end(), other.begin(i), other.end(i)); + push(); + } + } + + void takeAfter(size_t n, inc_vector<_Tp, _Alloc> &result) const { + size_t frames_count, frame_index; + take(n, frames_count, frame_index); + result = *this; + std::vector<_Tp, _Alloc>(result.v.begin() + n, result.v.end()) + .swap(result.v); + std::vector(result.frame_sizes.begin() + frame_index, + result.frame_sizes.end()) + .swap(result.frame_sizes); + if (frames_count) + result.frame_sizes[0] -= frames_count; + } + + void butLast(inc_vector<_Tp, _Alloc> &result) const { + assert(!v.empty() && "butLast of empty vector"); + assert(freshFrameSize() && "butLast of empty fresh frame"); + result = *this; + result.v.pop_back(); + } + + void takeBefore(size_t n, size_t &toPop, size_t &takeFromOther) const { + take(n, takeFromOther, toPop); + toPop = frame_sizes.size() - toPop; + } +}; + +using FrameId = size_t; +using FrameIds = std::unordered_set; + +template , + typename _Pred = std::equal_to<_Value>, + typename _Alloc = std::allocator<_Value>> +class inc_uset { +private: + class MinFrameIds { + FrameIds ids; + FrameId min = std::numeric_limits::max(); + + public: + bool empty() const { return ids.empty(); } + + bool hasMin(FrameId other) const { return min == other && !ids.empty(); } + + void insert(FrameId i) { + ids.insert(i); + if (i < min) + min = i; + } + + MinFrameIds bound(FrameId upperBound) { + MinFrameIds result; + std::copy_if(ids.begin(), ids.end(), + std::inserter(result.ids, result.ids.begin()), + [upperBound](FrameId i) { return i <= upperBound; }); + auto min_it = std::min_element(result.ids.begin(), result.ids.end()); + if (min_it == result.ids.end()) + result.min = std::numeric_limits::max(); + else + result.min = *min_it; + return result; + } + }; + + using idMap = std::unordered_map<_Value, MinFrameIds, _Hash, _Pred, _Alloc>; + using citerator = typename idMap::const_iterator; + idMap ids; + FrameId current_frame = 0; + +public: + size_t framesSize() const { return current_frame + 1; } + + void clear() { + ids.clear(); + current_frame = 0; + } + + class frame_it + : public std::iterator, int> { + citerator set_it; + const citerator set_ite; + const FrameId frame_index = 0; + + void gotoNext() { + while (set_it != set_ite && !set_it->second.hasMin(frame_index)) + set_it++; + } + + public: + using value_type = _Value; + + explicit frame_it(const idMap &ids) + : set_it(ids.end()), set_ite(ids.end()) {} + explicit frame_it(const idMap &ids, FrameId frame_index) + : set_it(ids.begin()), set_ite(ids.end()), frame_index(frame_index) { + gotoNext(); + } + + bool operator!=(const frame_it &other) const { + return set_it != other.set_it; + } + + const _Value &operator*() const { return set_it->first; } + + frame_it &operator++() { + if (set_it != set_ite) { + set_it++; + gotoNext(); + } + return *this; + } + }; + + class all_it + : public std::iterator, int> { + citerator set_it; + + public: + using value_type = _Value; + + explicit all_it(citerator set_it) : set_it(set_it) {} + + bool operator!=(const all_it &other) const { + return set_it != other.set_it; + } + + const _Value &operator*() const { return set_it->first; } + + all_it &operator++() { + set_it++; + return *this; + } + }; + + all_it begin() const { return all_it(ids.begin()); } + all_it end() const { return all_it(ids.end()); } + + frame_it begin(int frame_index) const { + assert(-(long long)framesSize() <= (long long)frame_index && + (long long)frame_index <= (long long)framesSize()); + if (frame_index < 0) + frame_index += framesSize(); + return frame_it(ids, frame_index); + } + frame_it end(int frame_index) const { return frame_it(ids); } + + void insert(const _Value &v) { ids[v].insert(current_frame); } + + template + void insert(_InputIterator __first, _InputIterator __last) { + for (; __first != __last; __first++) + ids[*__first].insert(current_frame); + } + + void pop(size_t popFrames) { + current_frame -= popFrames; + idMap newIdMap; + for (auto &keyAndIds : ids) { + MinFrameIds newIds = keyAndIds.second.bound(current_frame); + if (!newIds.empty()) + newIdMap.insert(std::make_pair(keyAndIds.first, newIds)); + } + ids = newIdMap; + } + + void push() { current_frame++; } +}; + +template , + typename _Pred = std::equal_to<_Key>, + typename _Alloc = std::allocator>> +class inc_umap { +private: + std::unordered_map<_Key, _Tp, _Hash, _Pred, _Alloc> map; + using idMap = std::unordered_map<_Key, FrameIds, _Hash, _Pred, _Alloc>; + idMap ids; + FrameId current_frame = 0; + +public: + void clear() { + map.clear(); + ids.clear(); + current_frame = 0; + } + + void insert(const std::pair<_Key, _Tp> &pair) { + map.insert(pair); + ids[pair.first].insert(current_frame); + } + + _Tp &operator[](const _Key &key) { + ids[key].insert(current_frame); + return map[key]; + } + + size_t count(const _Key &key) const { return map.count(key); } + + const _Tp &at(_Key &key) const { return map.at(key); } + + void pop(size_t popFrames) { + current_frame -= popFrames; + idMap newIdMap; + for (auto &keyAndIds : ids) { + FrameIds newIds; + for (auto id : keyAndIds.second) + if (id <= current_frame) + newIds.insert(id); + if (newIds.empty()) + map.erase(keyAndIds.first); + else + newIdMap.insert(std::make_pair(keyAndIds.first, newIds)); + } + ids = newIdMap; + } + + void push() { current_frame++; } + + void dump() const { + for (auto kv : map) { + kv.first.dump(); + llvm::errs() << "----->\n"; + kv.second.dump(); + llvm::errs() << "\n;;;;;;;;;\n"; + } + } +}; + +} // namespace klee + +#endif /* KLEE_INCREMENTAL_H */ diff --git a/include/klee/Expr/Constraints.h b/include/klee/Expr/Constraints.h index a14b9e191e7..93acd1e17bf 100644 --- a/include/klee/Expr/Constraints.h +++ b/include/klee/Expr/Constraints.h @@ -35,6 +35,7 @@ class ConstraintSet { public: ConstraintSet(constraints_ty cs, symcretes_ty symcretes, Assignment concretization); + explicit ConstraintSet(constraints_ty cs); ConstraintSet(); void addConstraint(ref e, const Assignment &delta); diff --git a/include/klee/Solver/IncompleteSolver.h b/include/klee/Solver/IncompleteSolver.h index 4f47c1cdd61..65dac30c4be 100644 --- a/include/klee/Solver/IncompleteSolver.h +++ b/include/klee/Solver/IncompleteSolver.h @@ -80,6 +80,7 @@ class StagedSolverImpl : public SolverImpl { SolverRunStatus getOperationStatusCode(); char *getConstraintLog(const Query &); void setCoreSolverTimeout(time::Span timeout); + void notifyStateTermination(std::uint32_t id); }; } // namespace klee diff --git a/include/klee/Solver/Solver.h b/include/klee/Solver/Solver.h index 30b344efecb..c39d8a2ae39 100644 --- a/include/klee/Solver/Solver.h +++ b/include/klee/Solver/Solver.h @@ -191,6 +191,10 @@ class Solver { virtual char *getConstraintLog(const Query &query); virtual void setCoreSolverTimeout(time::Span timeout); + + /// @brief Notify the solver that the state with specified id has been + /// terminated + void notifyStateTermination(std::uint32_t id); }; /* *** */ @@ -264,6 +268,11 @@ std::unique_ptr createCoreSolver(CoreSolverType cst); std::unique_ptr createConcretizingSolver(std::unique_ptr s, AddressGenerator *addressGenerator); + +/// Return a list of all unique symbolic objects referenced by the +/// given Query. +void findSymbolicObjects(const Query &query, + std::vector &results); } // namespace klee #endif /* KLEE_SOLVER_H */ diff --git a/include/klee/Solver/SolverCmdLine.h b/include/klee/Solver/SolverCmdLine.h index 7fbfac03940..6bd0c285bbd 100644 --- a/include/klee/Solver/SolverCmdLine.h +++ b/include/klee/Solver/SolverCmdLine.h @@ -50,6 +50,8 @@ extern llvm::cl::opt CoreSolverOptimizeDivides; extern llvm::cl::opt UseAssignmentValidatingSolver; +extern llvm::cl::opt MaxSolversApproxTreeInc; + /// The different query logging solvers that can be switched on/off enum QueryLoggingSolverType { ALL_KQUERY, ///< Log all queries in .kquery (KQuery) format @@ -65,6 +67,7 @@ enum CoreSolverType { METASMT_SOLVER, DUMMY_SOLVER, Z3_SOLVER, + Z3_TREE_SOLVER, NO_SOLVER }; diff --git a/include/klee/Solver/SolverImpl.h b/include/klee/Solver/SolverImpl.h index 3e24a16ed7f..da8e510f220 100644 --- a/include/klee/Solver/SolverImpl.h +++ b/include/klee/Solver/SolverImpl.h @@ -119,6 +119,8 @@ class SolverImpl { } virtual void setCoreSolverTimeout(time::Span timeout){}; + + virtual void notifyStateTermination(std::uint32_t id) = 0; }; } // namespace klee diff --git a/include/klee/Solver/SolverUtil.h b/include/klee/Solver/SolverUtil.h index c1fdb2832c6..3daedc5bdb6 100644 --- a/include/klee/Solver/SolverUtil.h +++ b/include/klee/Solver/SolverUtil.h @@ -41,6 +41,9 @@ enum class Validity { True = 1, False = -1, Unknown = 0 }; struct SolverQueryMetaData { /// @brief Costs for all queries issued for this state time::Span queryCost; + + /// @brief Caller state id + std::uint32_t id = 0; }; struct Query { @@ -48,25 +51,36 @@ struct Query { const ConstraintSet constraints; ref expr; - Query(const ConstraintSet &_constraints, ref _expr) - : constraints(_constraints), expr(_expr) {} + /// @brief id of the state initiated this query + const std::uint32_t id; + + Query(const ConstraintSet &_constraints, ref _expr, std::uint32_t _id) + : constraints(_constraints), expr(_expr), id(_id) {} + + /// This constructor should be used *only* if + /// this query is created *not* from some known ExecutionState + /// Otherwise consider using the above constructor + Query(const constraints_ty &cs, ref e) + : Query(ConstraintSet(cs), e, 0) {} Query(const Query &query) - : constraints(query.constraints), expr(query.expr) {} + : constraints(query.constraints), expr(query.expr), id(query.id) {} /// withExpr - Return a copy of the query with the given expression. - Query withExpr(ref _expr) const { return Query(constraints, _expr); } + Query withExpr(ref _expr) const { + return Query(constraints, _expr, id); + } /// withFalse - Return a copy of the query with a false expression. Query withFalse() const { - return Query(constraints, ConstantExpr::alloc(0, Expr::Bool)); + return Query(constraints, ConstantExpr::alloc(0, Expr::Bool), id); } /// negateExpr - Return a copy of the query with the expression negated. Query negateExpr() const { return withExpr(Expr::createIsZero(expr)); } Query withConstraints(const ConstraintSet &_constraints) const { - return Query(_constraints, expr); + return Query(_constraints, expr, id); } /// Get all arrays that figure in the query std::vector gatherArrays() const; @@ -97,11 +111,8 @@ struct ValidityCore { ValidityCore(const constraints_typ &_constraints, ref _expr) : constraints(_constraints), expr(_expr) {} - ValidityCore(const ExprHashSet &_constraints, ref _expr) : expr(_expr) { - for (auto e : _constraints) { - constraints.insert(e); - } - } + ValidityCore(const ExprHashSet &_constraints, ref _expr) + : constraints(_constraints.begin(), _constraints.end()), expr(_expr) {} /// withExpr - Return a copy of the validity core with the given expression. ValidityCore withExpr(ref _expr) const { @@ -117,6 +128,8 @@ struct ValidityCore { /// negated. ValidityCore negateExpr() const { return withExpr(Expr::createIsZero(expr)); } + Query toQuery() const; + /// Dump validity core void dump() const; diff --git a/lib/Core/ExecutionState.cpp b/lib/Core/ExecutionState.cpp index ddec6ec63b6..0690a40bd41 100644 --- a/lib/Core/ExecutionState.cpp +++ b/lib/Core/ExecutionState.cpp @@ -182,7 +182,9 @@ ExecutionState::ExecutionState(const ExecutionState &state) returnValue(state.returnValue), gepExprBases(state.gepExprBases), prevTargets_(state.prevTargets_), targets_(state.targets_), prevHistory_(state.prevHistory_), history_(state.history_), - isTargeted_(state.isTargeted_) {} + isTargeted_(state.isTargeted_) { + queryMetaData.id = state.id; +} ExecutionState *ExecutionState::branch() { depth++; @@ -496,3 +498,7 @@ bool ExecutionState::reachedTarget(ref target) const { return pc == target->getBlock()->getFirstInstruction(); } } + +Query ExecutionState::toQuery() const { + return Query(constraints.cs(), Expr::createFalse(), id); +} diff --git a/lib/Core/ExecutionState.h b/lib/Core/ExecutionState.h index c0bfeb7301d..7fc32904ff6 100644 --- a/lib/Core/ExecutionState.h +++ b/lib/Core/ExecutionState.h @@ -436,12 +436,17 @@ class ExecutionState { void addConstraint(ref e, const Assignment &c); void addCexPreference(const ref &cond); + Query toQuery() const; + void dumpStack(llvm::raw_ostream &out) const; bool visited(KBlock *block) const; std::uint32_t getID() const { return id; }; - void setID() { id = nextID++; }; + void setID() { + id = nextID++; + queryMetaData.id = id; + }; llvm::BasicBlock *getInitPCBlock() const; llvm::BasicBlock *getPrevPCBlock() const; llvm::BasicBlock *getPCBlock() const; diff --git a/lib/Core/Executor.cpp b/lib/Core/Executor.cpp index 064935305dd..01bd38fd81b 100644 --- a/lib/Core/Executor.cpp +++ b/lib/Core/Executor.cpp @@ -4507,6 +4507,7 @@ void Executor::terminateState(ExecutionState &state, interpreterHandler->incPathsExplored(); state.pc = state.prevPC; targetCalculator->update(state); + solver->notifyStateTermination(state.id); removedStates.push_back(&state); } @@ -5420,8 +5421,8 @@ MemoryObject *Executor::allocate(ExecutionState &state, ref size, ZExtExpr::create(size, pointerWidthInBits)}; constraints_ty required; - IndependentElementSet eltsClosure = getIndependentConstraints( - Query(state.constraints.cs(), ZExtExpr::create(size, pointerWidthInBits)), + auto eltsClosure = getIndependentConstraints( + state.toQuery().withExpr(ZExtExpr::create(size, pointerWidthInBits)), required); /* Collect dependent size symcretes. */ for (ref symcrete : eltsClosure.symcretes) { @@ -6826,7 +6827,7 @@ void Executor::getConstraintLog(const ExecutionState &state, std::string &res, switch (logFormat) { case STP: { - Query query(state.constraints.cs(), ConstantExpr::alloc(0, Expr::Bool)); + auto query = state.toQuery(); char *log = solver->getConstraintLog(query); res = std::string(log); free(log); @@ -6844,7 +6845,7 @@ void Executor::getConstraintLog(const ExecutionState &state, std::string &res, llvm::raw_string_ostream info(Str); ExprSMTLIBPrinter printer; printer.setOutput(info); - Query query(state.constraints.cs(), ConstantExpr::alloc(0, Expr::Bool)); + auto query = state.toQuery(); printer.setQuery(query); printer.generateOutput(); res = info.str(); @@ -6998,7 +6999,7 @@ Assignment Executor::computeConcretization(const ConstraintSet &constraints, ref condition, SolverQueryMetaData &queryMetaData) { Assignment concretization; - if (Query(constraints, condition).containsSymcretes()) { + if (Query(constraints, condition, queryMetaData.id).containsSymcretes()) { ref response; solver->setTimeout(coreSolverTimeout); bool success = solver->getResponse( diff --git a/lib/Core/ImpliedValue.cpp b/lib/Core/ImpliedValue.cpp index 8841ffab3c7..cf0e8c003e6 100644 --- a/lib/Core/ImpliedValue.cpp +++ b/lib/Core/ImpliedValue.cpp @@ -203,8 +203,8 @@ void ImpliedValue::checkForImpliedValues(Solver *S, ref e, std::set> readsSet(reads.begin(), reads.end()); reads = std::vector>(readsSet.begin(), readsSet.end()); - ConstraintSet assumption; - assumption.addConstraint(EqExpr::create(e, value), {}); + constraints_ty assumption; + assumption.insert(EqExpr::create(e, value)); // obscure... we need to make sure that all the read indices are // bounds checked. if we don't do this we can end up constructing @@ -215,8 +215,7 @@ void ImpliedValue::checkForImpliedValues(Solver *S, ref e, for (std::vector>::iterator i = reads.begin(), ie = reads.end(); i != ie; ++i) { ReadExpr *re = i->get(); - assumption.addConstraint(UltExpr::create(re->index, re->updates.root->size), - {}); + assumption.insert(UltExpr::create(re->index, re->updates.root->size)); } for (const auto &var : reads) { diff --git a/lib/Core/TimingSolver.cpp b/lib/Core/TimingSolver.cpp index edae9a3a5b9..91d1d57e9cd 100644 --- a/lib/Core/TimingSolver.cpp +++ b/lib/Core/TimingSolver.cpp @@ -44,11 +44,11 @@ bool TimingSolver::evaluate(const ConstraintSet &constraints, ref expr, ref queryResult; ref negatedQueryResult; + Query query(constraints, expr, metaData.id); bool success = produceValidityCore - ? solver->evaluate(Query(constraints, expr), queryResult, - negatedQueryResult) - : solver->evaluate(Query(constraints, expr), result); + ? solver->evaluate(query, queryResult, negatedQueryResult) + : solver->evaluate(query, result); if (success && produceValidityCore) { if (isa(queryResult) && @@ -91,12 +91,12 @@ bool TimingSolver::tryGetUnique(const ConstraintSet &constraints, ref e, e = optimizer.optimizeExpr(e, true); TimerStatIncrementer timer(stats::solverTime); - if (!solver->getValue(Query(constraints, e), value)) { + if (!solver->getValue(Query(constraints, e, metaData.id), value)) { return false; } ref cond = EqExpr::create(e, value); cond = optimizer.optimizeExpr(cond, false); - if (!solver->mustBeTrue(Query(constraints, cond), isTrue)) { + if (!solver->mustBeTrue(Query(constraints, cond, metaData.id), isTrue)) { return false; } if (isTrue) { @@ -125,11 +125,11 @@ bool TimingSolver::mustBeTrue(const ConstraintSet &constraints, ref expr, expr = Simplificator::simplifyExpr(constraints, expr).simplified; ValidityCore validityCore; + Query query(constraints, expr, metaData.id); bool success = produceValidityCore - ? solver->getValidityCore(Query(constraints, expr), - validityCore, result) - : solver->mustBeTrue(Query(constraints, expr), result); + ? solver->getValidityCore(query, validityCore, result) + : solver->mustBeTrue(query, result); metaData.queryCost += timer.delta(); @@ -178,7 +178,8 @@ bool TimingSolver::getValue(const ConstraintSet &constraints, ref expr, if (simplifyExprs) expr = Simplificator::simplifyExpr(constraints, expr).simplified; - bool success = solver->getValue(Query(constraints, expr), result); + bool success = + solver->getValue(Query(constraints, expr, metaData.id), result); metaData.queryCost += timer.delta(); @@ -201,8 +202,8 @@ bool TimingSolver::getMinimalUnsignedValue(const ConstraintSet &constraints, if (simplifyExprs) expr = Simplificator::simplifyExpr(constraints, expr).simplified; - bool success = - solver->getMinimalUnsignedValue(Query(constraints, expr), result); + bool success = solver->getMinimalUnsignedValue( + Query(constraints, expr, metaData.id), result); metaData.queryCost += timer.delta(); @@ -220,15 +221,11 @@ bool TimingSolver::getInitialValues( TimerStatIncrementer timer(stats::solverTime); ref queryResult; + Query query(constraints, Expr::createFalse(), metaData.id); - bool success = - produceValidityCore - ? solver->check( - Query(constraints, ConstantExpr::alloc(0, Expr::Bool)), - queryResult) - : solver->getInitialValues( - Query(constraints, ConstantExpr::alloc(0, Expr::Bool)), objects, - result); + bool success = produceValidityCore + ? solver->check(query, queryResult) + : solver->getInitialValues(query, objects, result); if (success && produceValidityCore && isa(queryResult)) { success = queryResult->tryGetInitialValuesFor(objects, result); @@ -251,22 +248,22 @@ bool TimingSolver::evaluate(const ConstraintSet &constraints, ref expr, auto simplified = simplification.simplified; auto dependency = simplification.dependency; if (auto CE = dyn_cast(simplified)) { + Query query(constraints, simplified, metaData.id); if (CE->isTrue()) { queryResult = new ValidResponse(ValidityCore(dependency, expr)); - return solver->check(Query(constraints, simplified).negateExpr(), - negatedQueryResult); + return solver->check(query.negateExpr(), negatedQueryResult); } else { negatedQueryResult = new ValidResponse( ValidityCore(dependency, Expr::createIsZero(expr))); - return solver->check(Query(constraints, simplified), queryResult); + return solver->check(query, queryResult); } } else { expr = simplified; } } - bool success = solver->evaluate(Query(constraints, expr), queryResult, - negatedQueryResult); + bool success = solver->evaluate(Query(constraints, expr, metaData.id), + queryResult, negatedQueryResult); metaData.queryCost += timer.delta(); @@ -300,8 +297,8 @@ bool TimingSolver::getValidityCore(const ConstraintSet &constraints, } } - bool success = - solver->getValidityCore(Query(constraints, expr), validityCore, result); + bool success = solver->getValidityCore(Query(constraints, expr, metaData.id), + validityCore, result); metaData.queryCost += timer.delta(); @@ -334,7 +331,8 @@ bool TimingSolver::getResponse(const ConstraintSet &constraints, ref expr, } } - bool success = solver->check(Query(constraints, expr), queryResult); + bool success = + solver->check(Query(constraints, expr, metaData.id), queryResult); metaData.queryCost += timer.delta(); @@ -346,8 +344,12 @@ TimingSolver::getRange(const ConstraintSet &constraints, ref expr, SolverQueryMetaData &metaData, time::Span timeout) { ++stats::queries; TimerStatIncrementer timer(stats::solverTime); - auto query = Query(constraints, expr); + Query query(constraints, expr, metaData.id); auto result = solver->getRange(query, timeout); metaData.queryCost += timer.delta(); return result; } + +void TimingSolver::notifyStateTermination(std::uint32_t id) { + solver->notifyStateTermination(id); +} diff --git a/lib/Core/TimingSolver.h b/lib/Core/TimingSolver.h index fbf65c4a56c..7692d503778 100644 --- a/lib/Core/TimingSolver.h +++ b/lib/Core/TimingSolver.h @@ -49,6 +49,10 @@ class TimingSolver { return solver->getConstraintLog(query); } + /// @brief Notify the solver that the state with specified id has been + /// terminated + void notifyStateTermination(std::uint32_t id); + bool evaluate(const ConstraintSet &, ref, PartialValidity &result, SolverQueryMetaData &metaData, bool produceValidityCore = false); diff --git a/lib/Expr/Constraints.cpp b/lib/Expr/Constraints.cpp index 802a75bea4d..8e3b04afedb 100644 --- a/lib/Expr/Constraints.cpp +++ b/lib/Expr/Constraints.cpp @@ -184,6 +184,9 @@ ConstraintSet::ConstraintSet(constraints_ty cs, symcretes_ty symcretes, : _constraints(cs), _symcretes(symcretes), _concretization(concretization) { } +ConstraintSet::ConstraintSet(constraints_ty cs) + : _constraints(cs), _symcretes({}), _concretization(true) {} + ConstraintSet::ConstraintSet() : _concretization(Assignment(true)) {} void ConstraintSet::addConstraint(ref e, const Assignment &delta) { diff --git a/lib/Solver/AssignmentValidatingSolver.cpp b/lib/Solver/AssignmentValidatingSolver.cpp index dde0f1a23d5..cdff8fec1b8 100644 --- a/lib/Solver/AssignmentValidatingSolver.cpp +++ b/lib/Solver/AssignmentValidatingSolver.cpp @@ -44,6 +44,7 @@ class AssignmentValidatingSolver : public SolverImpl { SolverRunStatus getOperationStatusCode(); char *getConstraintLog(const Query &); void setCoreSolverTimeout(time::Span timeout); + void notifyStateTermination(std::uint32_t id); }; // TODO: use computeInitialValues for all queries for more stress testing @@ -142,14 +143,8 @@ bool AssignmentValidatingSolver::check(const Query &query, return true; } - ExprHashSet expressions; - assert(!query.containsSymcretes()); - expressions.insert(query.constraints.cs().begin(), - query.constraints.cs().end()); - expressions.insert(query.expr); - std::vector objects; - findSymbolicObjects(expressions.begin(), expressions.end(), objects); + findSymbolicObjects(query, objects); std::vector> values; assert(isa(result)); @@ -177,7 +172,7 @@ void AssignmentValidatingSolver::dumpAssignmentQuery( for (const auto &constraint : query.constraints.cs()) constraints.addConstraint(constraint, {}); - Query augmentedQuery(constraints, query.expr); + Query augmentedQuery = query.withConstraints(constraints); // Ask the solver for the log for this query. char *logText = solver->getConstraintLog(augmentedQuery); @@ -198,6 +193,10 @@ void AssignmentValidatingSolver::setCoreSolverTimeout(time::Span timeout) { return solver->impl->setCoreSolverTimeout(timeout); } +void AssignmentValidatingSolver::notifyStateTermination(std::uint32_t id) { + solver->impl->notifyStateTermination(id); +} + std::unique_ptr createAssignmentValidatingSolver(std::unique_ptr s) { return std::make_unique( diff --git a/lib/Solver/CachingSolver.cpp b/lib/Solver/CachingSolver.cpp index de43695d5e1..76b27baaa3a 100644 --- a/lib/Solver/CachingSolver.cpp +++ b/lib/Solver/CachingSolver.cpp @@ -90,6 +90,7 @@ class CachingSolver : public SolverImpl { SolverRunStatus getOperationStatusCode(); char *getConstraintLog(const Query &); void setCoreSolverTimeout(time::Span timeout); + void notifyStateTermination(std::uint32_t id); }; /** @returns the canonical version of the given query. The reference @@ -388,6 +389,10 @@ void CachingSolver::setCoreSolverTimeout(time::Span timeout) { solver->impl->setCoreSolverTimeout(timeout); } +void CachingSolver::notifyStateTermination(std::uint32_t id) { + solver->impl->notifyStateTermination(id); +} + /// std::unique_ptr diff --git a/lib/Solver/CexCachingSolver.cpp b/lib/Solver/CexCachingSolver.cpp index bdbc73a4dbc..0633a554a8f 100644 --- a/lib/Solver/CexCachingSolver.cpp +++ b/lib/Solver/CexCachingSolver.cpp @@ -112,6 +112,7 @@ class CexCachingSolver : public SolverImpl { SolverRunStatus getOperationStatusCode(); char *getConstraintLog(const Query &query); void setCoreSolverTimeout(time::Span timeout); + void notifyStateTermination(std::uint32_t id); }; /// @@ -420,6 +421,10 @@ void CexCachingSolver::setCoreSolverTimeout(time::Span timeout) { solver->impl->setCoreSolverTimeout(timeout); } +void CexCachingSolver::notifyStateTermination(std::uint32_t id) { + solver->impl->notifyStateTermination(id); +} + /// std::unique_ptr diff --git a/lib/Solver/ConcretizingSolver.cpp b/lib/Solver/ConcretizingSolver.cpp index 36f5a81191c..cd3298badf6 100644 --- a/lib/Solver/ConcretizingSolver.cpp +++ b/lib/Solver/ConcretizingSolver.cpp @@ -52,6 +52,7 @@ class ConcretizingSolver : public SolverImpl { SolverRunStatus getOperationStatusCode(); char *getConstraintLog(const Query &); void setCoreSolverTimeout(time::Span timeout); + void notifyStateTermination(std::uint32_t id); private: bool assertConcretization(const Query &query, const Assignment &assign) const; @@ -66,7 +67,7 @@ Query ConcretizingSolver::constructConcretizedQuery(const Query &query, for (auto e : query.constraints.cs()) { constraints.addConstraint(e, {}); } - return Query(constraints, query.expr); + return query.withConstraints(constraints); } bool ConcretizingSolver::assertConcretization(const Query &query, @@ -200,9 +201,8 @@ bool ConcretizingSolver::relaxSymcreteConstraints(const Query &query, for (const ref &symcrete : currentlyBrokenSymcretes) { constraints_ty required; IndependentElementSet eltsClosure = getIndependentConstraints( - Query(query.constraints, - AndExpr::create(query.expr, - Expr::createIsZero(symcrete->symcretized))), + query.withExpr(AndExpr::create( + query.expr, Expr::createIsZero(symcrete->symcretized))), required); for (ref symcrete : eltsClosure.symcretes) { currentlyBrokenSymcretes.insert(symcrete); @@ -255,7 +255,8 @@ bool ConcretizingSolver::relaxSymcreteConstraints(const Query &query, UgtExpr::create( symbolicSizesSum, ConstantExpr::create(SymbolicAllocationThreshold, - symbolicSizesSum->getWidth()))), + symbolicSizesSum->getWidth())), + query.id), response)) { return false; } @@ -264,14 +265,15 @@ bool ConcretizingSolver::relaxSymcreteConstraints(const Query &query, ref minimalValueOfSum; /* Receive model with a smallest sum as possible. */ if (!solver->impl->computeMinimalUnsignedValue( - Query(queryConstraints, symbolicSizesSum), minimalValueOfSum)) { + Query(queryConstraints, symbolicSizesSum, query.id), + minimalValueOfSum)) { return false; } bool hasSolution = false; if (!solver->impl->computeInitialValues( Query(queryConstraints, - EqExpr::create(symbolicSizesSum, minimalValueOfSum)) + EqExpr::create(symbolicSizesSum, minimalValueOfSum), query.id) .negateExpr(), objects, brokenSymcretizedValues, hasSolution)) { return false; @@ -596,6 +598,10 @@ void ConcretizingSolver::setCoreSolverTimeout(time::Span timeout) { solver->setCoreSolverTimeout(timeout); } +void ConcretizingSolver::notifyStateTermination(std::uint32_t id) { + solver->impl->notifyStateTermination(id); +} + std::unique_ptr createConcretizingSolver(std::unique_ptr s, AddressGenerator *addressGenerator) { diff --git a/lib/Solver/CoreSolver.cpp b/lib/Solver/CoreSolver.cpp index b6de1024b86..e44e8c37a1b 100644 --- a/lib/Solver/CoreSolver.cpp +++ b/lib/Solver/CoreSolver.cpp @@ -28,6 +28,7 @@ DISABLE_WARNING_POP namespace klee { std::unique_ptr createCoreSolver(CoreSolverType cst) { + bool isTreeSolver = false; switch (cst) { case STP_SOLVER: #ifdef ENABLE_STP @@ -54,16 +55,22 @@ std::unique_ptr createCoreSolver(CoreSolverType cst) { #endif case DUMMY_SOLVER: return createDummySolver(); + case Z3_TREE_SOLVER: + isTreeSolver = true; case Z3_SOLVER: #ifdef ENABLE_Z3 klee_message("Using Z3 solver backend"); + Z3BuilderType type; #ifdef ENABLE_FP klee_message("Using Z3 bitvector builder"); - return std::make_unique(KLEE_BITVECTOR); + type = KLEE_BITVECTOR; #else klee_message("Using Z3 core builder"); - return std::make_unique(KLEE_CORE); + type = KLEE_CORE; #endif + if (isTreeSolver) + return std::make_unique(type, MaxSolversApproxTreeInc); + return std::make_unique(type); #else klee_message("Not compiled with Z3 support"); return NULL; diff --git a/lib/Solver/DummySolver.cpp b/lib/Solver/DummySolver.cpp index b039581ab15..89f44298968 100644 --- a/lib/Solver/DummySolver.cpp +++ b/lib/Solver/DummySolver.cpp @@ -30,6 +30,7 @@ class DummySolverImpl : public SolverImpl { bool computeValidityCore(const Query &query, ValidityCore &validityCore, bool &isValid); SolverRunStatus getOperationStatusCode(); + void notifyStateTermination(std::uint32_t id); }; DummySolverImpl::DummySolverImpl() {} @@ -79,6 +80,8 @@ SolverImpl::SolverRunStatus DummySolverImpl::getOperationStatusCode() { return SOLVER_RUN_STATUS_FAILURE; } +void DummySolverImpl::notifyStateTermination(std::uint32_t id) {} + std::unique_ptr createDummySolver() { return std::make_unique(std::make_unique()); } diff --git a/lib/Solver/IncompleteSolver.cpp b/lib/Solver/IncompleteSolver.cpp index 5fd7c74b3b2..85ad5a8d6d2 100644 --- a/lib/Solver/IncompleteSolver.cpp +++ b/lib/Solver/IncompleteSolver.cpp @@ -119,13 +119,8 @@ bool StagedSolverImpl::computeInitialValues( } bool StagedSolverImpl::check(const Query &query, ref &result) { - ExprHashSet expressions; - expressions.insert(query.constraints.cs().begin(), - query.constraints.cs().end()); - expressions.insert(query.expr); - std::vector objects; - findSymbolicObjects(expressions.begin(), expressions.end(), objects); + findSymbolicObjects(query, objects); std::vector> values; bool hasSolution; @@ -157,3 +152,7 @@ char *StagedSolverImpl::getConstraintLog(const Query &query) { void StagedSolverImpl::setCoreSolverTimeout(time::Span timeout) { secondary->impl->setCoreSolverTimeout(timeout); } + +void StagedSolverImpl::notifyStateTermination(std::uint32_t id) { + secondary->impl->notifyStateTermination(id); +} diff --git a/lib/Solver/IndependentSolver.cpp b/lib/Solver/IndependentSolver.cpp index b7fe116123c..901f6f798ce 100644 --- a/lib/Solver/IndependentSolver.cpp +++ b/lib/Solver/IndependentSolver.cpp @@ -60,6 +60,7 @@ class IndependentSolver : public SolverImpl { SolverRunStatus getOperationStatusCode(); char *getConstraintLog(const Query &); void setCoreSolverTimeout(time::Span timeout); + void notifyStateTermination(std::uint32_t id); }; bool IndependentSolver::computeValidity(const Query &query, @@ -179,8 +180,9 @@ bool IndependentSolver::computeInitialValues( query.constraints.concretization().part(it->symcretes)); ref factorExpr = ConstantExpr::alloc(0, Expr::Bool); std::vector> tempValues; - if (!solver->impl->computeInitialValues( - Query(tmp, factorExpr), arraysInFactor, tempValues, hasSolution)) { + if (!solver->impl->computeInitialValues(Query(tmp, factorExpr, query.id), + arraysInFactor, tempValues, + hasSolution)) { values.clear(); delete factors; return false; @@ -273,7 +275,7 @@ bool IndependentSolver::check(const Query &query, ref &result) { ref tempResult; std::vector> tempValues; - if (!solver->impl->check(Query(tmp, factorExpr), tempResult)) { + if (!solver->impl->check(Query(tmp, factorExpr, query.id), tempResult)) { delete factors; return false; } else if (isa(tempResult)) { @@ -342,6 +344,10 @@ void IndependentSolver::setCoreSolverTimeout(time::Span timeout) { solver->impl->setCoreSolverTimeout(timeout); } +void IndependentSolver::notifyStateTermination(std::uint32_t id) { + solver->impl->notifyStateTermination(id); +} + std::unique_ptr klee::createIndependentSolver(std::unique_ptr s) { return std::make_unique( diff --git a/lib/Solver/MetaSMTSolver.cpp b/lib/Solver/MetaSMTSolver.cpp index 17e0972904b..f162013864c 100644 --- a/lib/Solver/MetaSMTSolver.cpp +++ b/lib/Solver/MetaSMTSolver.cpp @@ -95,6 +95,7 @@ template class MetaSMTSolverImpl : public SolverImpl { char *getConstraintLog(const Query &); void setCoreSolverTimeout(time::Span timeout) { _timeout = timeout; } + void notifyStateTermination(std::uint32_t id) {} bool computeTruth(const Query &, bool &isValid); bool computeValue(const Query &, ref &result); diff --git a/lib/Solver/QueryLoggingSolver.cpp b/lib/Solver/QueryLoggingSolver.cpp index e35ca0daaa3..030cbcf8c4b 100644 --- a/lib/Solver/QueryLoggingSolver.cpp +++ b/lib/Solver/QueryLoggingSolver.cpp @@ -271,8 +271,7 @@ bool QueryLoggingSolver::check(const Query &query, result->tryGetValidityCore(validityCore); logBuffer << queryCommentSign << " ValidityCore:\n"; - printQuery(Query(ConstraintSet(validityCore.constraints, {}, {true}), - validityCore.expr)); + printQuery(validityCore.toQuery()); } } logBuffer << "\n"; @@ -300,8 +299,7 @@ bool QueryLoggingSolver::computeValidityCore(const Query &query, if (isValid) { logBuffer << queryCommentSign << " ValidityCore:\n"; - printQuery(Query(ConstraintSet(validityCore.constraints, {}, {true}), - validityCore.expr)); + printQuery(validityCore.toQuery()); } logBuffer << "\n"; @@ -322,3 +320,7 @@ char *QueryLoggingSolver::getConstraintLog(const Query &query) { void QueryLoggingSolver::setCoreSolverTimeout(time::Span timeout) { solver->impl->setCoreSolverTimeout(timeout); } + +void QueryLoggingSolver::notifyStateTermination(std::uint32_t id) { + solver->impl->notifyStateTermination(id); +} diff --git a/lib/Solver/QueryLoggingSolver.h b/lib/Solver/QueryLoggingSolver.h index 16e50f55462..3dd2f1a3574 100644 --- a/lib/Solver/QueryLoggingSolver.h +++ b/lib/Solver/QueryLoggingSolver.h @@ -81,6 +81,7 @@ class QueryLoggingSolver : public SolverImpl { SolverRunStatus getOperationStatusCode(); char *getConstraintLog(const Query &); void setCoreSolverTimeout(time::Span timeout); + void notifyStateTermination(std::uint32_t id); }; #endif /* KLEE_QUERYLOGGINGSOLVER_H */ diff --git a/lib/Solver/STPSolver.cpp b/lib/Solver/STPSolver.cpp index 48931159267..5927a90038f 100644 --- a/lib/Solver/STPSolver.cpp +++ b/lib/Solver/STPSolver.cpp @@ -104,6 +104,7 @@ class STPSolverImpl : public SolverImpl { void setCoreSolverTimeout(time::Span timeout) override { this->timeout = timeout; } + void notifyStateTermination(std::uint32_t id) override {} bool computeTruth(const Query &, bool &isValid) override; bool computeValue(const Query &, ref &result) override; diff --git a/lib/Solver/Solver.cpp b/lib/Solver/Solver.cpp index ec30aa9b19f..f6b06522354 100644 --- a/lib/Solver/Solver.cpp +++ b/lib/Solver/Solver.cpp @@ -162,6 +162,10 @@ bool Solver::check(const Query &query, ref &queryResult) { return impl->check(query, queryResult); } +void Solver::notifyStateTermination(std::uint32_t id) { + impl->notifyStateTermination(id); +} + static std::pair, ref> getDefaultRange() { return std::make_pair(ConstantExpr::create(0, 64), ConstantExpr::create(0, 64)); @@ -327,6 +331,15 @@ bool Query::containsSizeSymcretes() const { return false; } +void klee::findSymbolicObjects(const Query &query, + std::vector &results) { + ExprHashSet expressions; + expressions.insert(query.constraints.cs().begin(), + query.constraints.cs().end()); + expressions.insert(query.expr); + findSymbolicObjects(expressions.begin(), expressions.end(), results); +} + void Query::dump() const { constraints.dump(); llvm::errs() << "Query [\n"; @@ -334,6 +347,6 @@ void Query::dump() const { llvm::errs() << "]\n"; } -void ValidityCore::dump() const { - Query(ConstraintSet(constraints, {}, {true}), expr).dump(); -} +Query ValidityCore::toQuery() const { return Query(constraints, expr); } + +void ValidityCore::dump() const { toQuery().dump(); } diff --git a/lib/Solver/SolverCmdLine.cpp b/lib/Solver/SolverCmdLine.cpp index 0c96fb12634..0f51525d535 100644 --- a/lib/Solver/SolverCmdLine.cpp +++ b/lib/Solver/SolverCmdLine.cpp @@ -123,6 +123,13 @@ cl::opt UseAssignmentValidatingSolver( cl::desc("Debug the correctness of generated assignments (default=false)"), cl::cat(SolvingCat)); +cl::opt + MaxSolversApproxTreeInc("max-solvers-approx-tree-inc", + cl::desc("Maximum size of the Z3 solver pool for " + "approximating tree incrementality." + " Set to 0 to disable (default=0)"), + cl::init(0), cl::cat(SolvingCat)); + void KCommandLine::HideOptions(llvm::cl::OptionCategory &Category) { StringMap &map = cl::getRegisteredOptions(); @@ -196,11 +203,12 @@ cl::opt MetaSMTBackend( cl::opt CoreSolverToUse( "solver-backend", cl::desc("Specifiy the core solver backend to use"), - cl::values(clEnumValN(STP_SOLVER, "stp", "STP" STP_IS_DEFAULT_STR), - clEnumValN(METASMT_SOLVER, "metasmt", - "metaSMT" METASMT_IS_DEFAULT_STR), - clEnumValN(DUMMY_SOLVER, "dummy", "Dummy solver"), - clEnumValN(Z3_SOLVER, "z3", "Z3" Z3_IS_DEFAULT_STR)), + cl::values( + clEnumValN(STP_SOLVER, "stp", "STP" STP_IS_DEFAULT_STR), + clEnumValN(METASMT_SOLVER, "metasmt", "metaSMT" METASMT_IS_DEFAULT_STR), + clEnumValN(DUMMY_SOLVER, "dummy", "Dummy solver"), + clEnumValN(Z3_SOLVER, "z3", "Z3" Z3_IS_DEFAULT_STR), + clEnumValN(Z3_TREE_SOLVER, "z3-tree", "Z3 tree-incremental solver")), cl::init(DEFAULT_CORE_SOLVER), cl::cat(SolvingCat)); cl::opt DebugCrossCheckCoreSolverWith( diff --git a/lib/Solver/SolverImpl.cpp b/lib/Solver/SolverImpl.cpp index 033f1d6d192..2768ce3ce9a 100644 --- a/lib/Solver/SolverImpl.cpp +++ b/lib/Solver/SolverImpl.cpp @@ -42,13 +42,8 @@ bool SolverImpl::computeValidity(const Query &query, } bool SolverImpl::check(const Query &query, ref &result) { - ExprHashSet expressions; - expressions.insert(query.constraints.cs().begin(), - query.constraints.cs().end()); - expressions.insert(query.expr); - std::vector objects; - findSymbolicObjects(expressions.begin(), expressions.end(), objects); + findSymbolicObjects(query, objects); std::vector> values; bool hasSolution; diff --git a/lib/Solver/ValidatingSolver.cpp b/lib/Solver/ValidatingSolver.cpp index ddae26f2e3f..bc8ef025a5e 100644 --- a/lib/Solver/ValidatingSolver.cpp +++ b/lib/Solver/ValidatingSolver.cpp @@ -43,6 +43,7 @@ class ValidatingSolver : public SolverImpl { SolverRunStatus getOperationStatusCode(); char *getConstraintLog(const Query &); void setCoreSolverTimeout(time::Span timeout); + void notifyStateTermination(std::uint32_t id); }; bool ValidatingSolver::computeTruth(const Query &query, bool &isValid) { @@ -128,7 +129,8 @@ bool ValidatingSolver::computeInitialValues( for (auto const &constraint : query.constraints.cs()) constraints = AndExpr::create(constraints, constraint); - if (!oracle->impl->computeTruth(Query(bindings, constraints), answer)) + if (!oracle->impl->computeTruth(Query(bindings, constraints, query.id), + answer)) return false; if (!answer) assert(0 && "invalid solver result (computeInitialValues)"); @@ -184,7 +186,8 @@ bool ValidatingSolver::check(const Query &query, ref &result) { for (auto const &constraint : query.constraints.cs()) constraints = AndExpr::create(constraints, constraint); - if (!oracle->impl->computeTruth(Query(bindings, constraints), banswer)) + if (!oracle->impl->computeTruth(Query(bindings, constraints, query.id), + banswer)) return false; if (!banswer) assert(0 && "invalid solver result (computeInitialValues)"); @@ -229,6 +232,11 @@ void ValidatingSolver::setCoreSolverTimeout(time::Span timeout) { solver->impl->setCoreSolverTimeout(timeout); } +void ValidatingSolver::notifyStateTermination(std::uint32_t id) { + solver->impl->notifyStateTermination(id); + oracle->impl->notifyStateTermination(id); +} + std::unique_ptr createValidatingSolver(std::unique_ptr s, Solver *oracle, bool ownsOracle) { diff --git a/lib/Solver/Z3Builder.cpp b/lib/Solver/Z3Builder.cpp index f74bf27a542..e80feae88c3 100644 --- a/lib/Solver/Z3Builder.cpp +++ b/lib/Solver/Z3Builder.cpp @@ -30,14 +30,14 @@ using namespace klee; namespace klee { // Declared here rather than `Z3Builder.h` so they can be called in gdb. -template <> void Z3NodeHandle::dump() { +template <> void Z3NodeHandle::dump() const { llvm::errs() << "Z3SortHandle:\n" << ::Z3_sort_to_string(context, node) << "\n"; } template <> unsigned Z3NodeHandle::hash() { return Z3_get_ast_hash(context, as_ast()); } -template <> void Z3NodeHandle::dump() { +template <> void Z3NodeHandle::dump() const { llvm::errs() << "Z3ASTHandle:\n" << ::Z3_ast_to_string(context, as_ast()) << "\n"; } diff --git a/lib/Solver/Z3Builder.h b/lib/Solver/Z3Builder.h index 0ad50c151c0..7b407c6e3da 100644 --- a/lib/Solver/Z3Builder.h +++ b/lib/Solver/Z3Builder.h @@ -31,7 +31,7 @@ template class Z3NodeHandle { private: // To be specialised - inline ::Z3_ast as_ast(); + inline ::Z3_ast as_ast() const; public: Z3NodeHandle() : node(NULL), context(NULL) {} @@ -73,7 +73,7 @@ template class Z3NodeHandle { return *this; } // To be specialised - void dump(); + void dump() const; operator T() const { return node; } @@ -82,19 +82,21 @@ template class Z3NodeHandle { }; // Specialise for Z3_sort -template <> inline ::Z3_ast Z3NodeHandle::as_ast() { +template <> inline ::Z3_ast Z3NodeHandle::as_ast() const { // In Z3 internally this call is just a cast. We could just do that // instead to simplify our implementation but this seems cleaner. return ::Z3_sort_to_ast(context, node); } typedef Z3NodeHandle Z3SortHandle; -template <> void Z3NodeHandle::dump() __attribute__((used)); +template <> void Z3NodeHandle::dump() const __attribute__((used)); template <> unsigned Z3NodeHandle::hash() __attribute__((used)); // Specialise for Z3_ast -template <> inline ::Z3_ast Z3NodeHandle::as_ast() { return node; } +template <> inline ::Z3_ast Z3NodeHandle::as_ast() const { + return node; +} typedef Z3NodeHandle Z3ASTHandle; -template <> void Z3NodeHandle::dump() __attribute__((used)); +template <> void Z3NodeHandle::dump() const __attribute__((used)); template <> unsigned Z3NodeHandle::hash() __attribute__((used)); struct Z3ASTHandleHash { diff --git a/lib/Solver/Z3Solver.cpp b/lib/Solver/Z3Solver.cpp index 3a92ae7fe1a..f9df3b8ef18 100644 --- a/lib/Solver/Z3Solver.cpp +++ b/lib/Solver/Z3Solver.cpp @@ -8,11 +8,6 @@ //===----------------------------------------------------------------------===// #include "klee/Config/config.h" -#include "klee/Support/ErrorHandling.h" -#include "klee/Support/FileHandling.h" -#include "klee/Support/OptionCategories.h" - -#include #ifdef ENABLE_Z3 @@ -21,17 +16,20 @@ #include "Z3CoreBuilder.h" #include "Z3Solver.h" +#include "klee/ADT/Incremental.h" #include "klee/ADT/SparseStorage.h" #include "klee/Expr/Assignment.h" #include "klee/Expr/Constraints.h" #include "klee/Expr/ExprUtil.h" #include "klee/Solver/Solver.h" #include "klee/Solver/SolverImpl.h" +#include "klee/Support/ErrorHandling.h" +#include "klee/Support/FileHandling.h" +#include "klee/Support/OptionCategories.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/raw_ostream.h" -#include -#include +#include namespace { // NOTE: Very useful for debugging Z3 behaviour. These files can be given to @@ -67,30 +65,207 @@ DISABLE_WARNING_POP namespace klee { -class Z3SolverImpl : public SolverImpl { +using ConstraintFrames = inc_vector>; +using ExprIncMap = + inc_umap, Z3ASTHandleHash, Z3ASTHandleCmp>; +using Z3ASTIncMap = + inc_umap; +using ExprIncSet = + inc_uset, klee::util::ExprHash, klee::util::ExprCmp>; +using Z3ASTIncSet = inc_uset; + +void dump(const ConstraintFrames &frames) { + llvm::errs() << "frame sizes:"; + for (auto size : frames.frame_sizes) { + llvm::errs() << " " << size; + } + llvm::errs() << "\n"; + llvm::errs() << "frames:\n"; + for (auto &x : frames.v) { + llvm::errs() << x->toString() << "\n"; + } +} + +class ConstraintQuery { private: + // this should be used when only query is needed, se comment below + ref expr; + +public: + // KLEE Queries are validity queries i.e. + // ∀ X Constraints(X) → query(X) + // but Z3 works in terms of satisfiability so instead we ask the + // negation of the equivalent i.e. + // ∃ X Constraints(X) ∧ ¬ query(X) + // so this `constraints` field contains: Constraints(X) ∧ ¬ query(X) + ConstraintFrames constraints; + + explicit ConstraintQuery() {} + + ConstraintQuery(ConstraintFrames &frames) : constraints(frames) { + if (frames.v.size() == 0) + expr = Expr::createFalse(); + else + expr = Expr::createIsZero(frames.v.back()); + } + + explicit ConstraintQuery(const Query &q, bool incremental) : expr(q.expr) { + if (incremental) { + for (auto &constraint : q.constraints.cs()) { + constraints.v.push_back(constraint); + constraints.push(); + } + } else { + const auto &other = q.constraints.cs(); + constraints.v.reserve(other.size()); + constraints.v.insert(constraints.v.end(), other.begin(), other.end()); + } + if (!q.expr->isFalse()) + constraints.v.push_back(Expr::createIsZero(q.expr)); + } + + size_t size() const { return constraints.v.size(); } + + ref getOriginalQueryExpr() const { return expr; } + + ConstraintQuery withFalse() const { + if (constraints.v.empty()) + return *this; + ConstraintFrames newFrames; + constraints.butLast(newFrames); + return ConstraintQuery(newFrames); + } + + std::vector gatherArrays() const { + std::vector arrays; + findObjects(constraints.v.begin(), constraints.v.end(), arrays); + return arrays; + } +}; + +enum class ObjectAssignment { + NotNeeded, + NeededForObjectsFromEnv, + NeededForObjectsFromQuery +}; + +struct Z3SolverEnv { + using arr_vec = std::vector; + inc_vector objects; + arr_vec objectsForGetModel; + inc_vector z3_ast_expr_constraints; + ExprIncMap z3_ast_expr_to_klee_expr; + Z3ASTIncMap expr_to_track; + inc_umap usedArrayBytes; + ExprIncSet symbolicObjects; + + explicit Z3SolverEnv(){}; + explicit Z3SolverEnv(const arr_vec &objects); + + void pop(size_t popSize); + void push(); + void clear(); + + const arr_vec *getObjectsForGetModel(ObjectAssignment oa) const; +}; + +Z3SolverEnv::Z3SolverEnv(const std::vector &objects) + : objects(objects) {} + +void Z3SolverEnv::pop(size_t popSize) { + if (popSize == 0) + return; + objects.pop(popSize); + objectsForGetModel.clear(); + z3_ast_expr_constraints.pop(popSize); + z3_ast_expr_to_klee_expr.pop(popSize); + expr_to_track.pop(popSize); + usedArrayBytes.pop(popSize); + symbolicObjects.pop(popSize); +} + +void Z3SolverEnv::push() { + objects.push(); + z3_ast_expr_constraints.push(); + z3_ast_expr_to_klee_expr.push(); + expr_to_track.push(); + usedArrayBytes.push(); + symbolicObjects.push(); +} + +void Z3SolverEnv::clear() { + objects.clear(); + objectsForGetModel.clear(); + z3_ast_expr_constraints.clear(); + z3_ast_expr_to_klee_expr.clear(); + expr_to_track.clear(); + usedArrayBytes.clear(); + symbolicObjects.clear(); +} + +const Z3SolverEnv::arr_vec * +Z3SolverEnv::getObjectsForGetModel(ObjectAssignment oa) const { + switch (oa) { + case ObjectAssignment::NotNeeded: + return nullptr; + case ObjectAssignment::NeededForObjectsFromEnv: + return &objectsForGetModel; + case ObjectAssignment::NeededForObjectsFromQuery: + return &objects.v; + } +} + +class Z3SolverImpl : public SolverImpl { +protected: std::unique_ptr builder; + ::Z3_params solverParameters; + +private: Z3BuilderType builderType; time::Span timeout; - SolverRunStatus runStatusCode; + SolverImpl::SolverRunStatus runStatusCode; std::unique_ptr dumpedQueriesFile; - ::Z3_params solverParameters; // Parameter symbols ::Z3_symbol timeoutParamStrSymbol; ::Z3_symbol unsatCoreParamStrSymbol; - bool internalRunSolver(const Query &, - const std::vector *objects, + bool internalRunSolver(const ConstraintQuery &query, Z3SolverEnv &env, + ObjectAssignment needObjects, std::vector> *values, ValidityCore *validityCore, bool &hasSolution); + bool validateZ3Model(::Z3_solver &theSolver, ::Z3_model &theModel); -public: + SolverImpl::SolverRunStatus + handleSolverResponse(::Z3_solver theSolver, ::Z3_lbool satisfiable, + const Z3SolverEnv &env, ObjectAssignment needObjects, + std::vector> *values, + bool &hasSolution); + +protected: Z3SolverImpl(Z3BuilderType type); ~Z3SolverImpl(); - char *getConstraintLog(const Query &); - void setCoreSolverTimeout(time::Span _timeout) { + virtual Z3_solver initNativeZ3(const ConstraintQuery &query, + Z3ASTIncSet &assertions) = 0; + virtual void deinitNativeZ3(Z3_solver theSolver) = 0; + virtual void push(Z3_context c, Z3_solver s) = 0; + + bool computeTruth(const ConstraintQuery &, Z3SolverEnv &env, bool &isValid); + bool computeValue(const ConstraintQuery &, Z3SolverEnv &env, + ref &result); + bool computeInitialValues(const ConstraintQuery &, Z3SolverEnv &env, + std::vector> &values, + bool &hasSolution); + bool check(const ConstraintQuery &query, Z3SolverEnv &env, + ref &result); + bool computeValidityCore(const ConstraintQuery &query, Z3SolverEnv &env, + ValidityCore &validityCore, bool &isValid); + +public: + char *getConstraintLog(const Query &) final; + SolverImpl::SolverRunStatus getOperationStatusCode() final; + void setCoreSolverTimeout(time::Span _timeout) final { timeout = _timeout; auto timeoutInMilliSeconds = @@ -109,26 +284,20 @@ class Z3SolverImpl : public SolverImpl { Z3_FALSE); } - bool computeTruth(const Query &, bool &isValid); - bool computeValue(const Query &, ref &result); - bool computeInitialValues(const Query &, - const std::vector &objects, - std::vector> &values, - bool &hasSolution); - bool check(const Query &query, ref &result); - bool computeValidityCore(const Query &query, ValidityCore &validityCore, - bool &isValid); - SolverRunStatus handleSolverResponse( - ::Z3_solver theSolver, ::Z3_lbool satisfiable, - const std::vector *objects, - std::vector> *values, - const std::unordered_map &usedArrayBytes, - bool &hasSolution); - SolverRunStatus getOperationStatusCode(); + // pass virtual functions to children + using SolverImpl::check; + using SolverImpl::computeInitialValues; + using SolverImpl::computeTruth; + using SolverImpl::computeValidityCore; + using SolverImpl::computeValue; }; +void deleteNativeZ3(Z3_context ctx, Z3_solver theSolver) { + Z3_solver_dec_ref(ctx, theSolver); +} + Z3SolverImpl::Z3SolverImpl(Z3BuilderType type) - : builderType(type), runStatusCode(SOLVER_RUN_STATUS_FAILURE) { + : builderType(type), runStatusCode(SolverImpl::SOLVER_RUN_STATUS_FAILURE) { switch (type) { case KLEE_CORE: builder = std::unique_ptr(new Z3CoreBuilder( @@ -185,17 +354,6 @@ Z3SolverImpl::~Z3SolverImpl() { Z3_params_dec_ref(builder->ctx, solverParameters); } -Z3Solver::Z3Solver(Z3BuilderType type) - : Solver(std::make_unique(type)) {} - -char *Z3Solver::getConstraintLog(const Query &query) { - return impl->getConstraintLog(query); -} - -void Z3Solver::setCoreSolverTimeout(time::Span timeout) { - impl->setCoreSolverTimeout(timeout); -} - char *Z3SolverImpl::getConstraintLog(const Query &query) { std::vector assumptions; // We use a different builder here because we don't want to interfere @@ -273,82 +431,77 @@ char *Z3SolverImpl::getConstraintLog(const Query &query) { return strdup(result); } -bool Z3SolverImpl::computeTruth(const Query &query, bool &isValid) { +bool Z3SolverImpl::computeTruth(const ConstraintQuery &query, Z3SolverEnv &env, + bool &isValid) { bool hasSolution = false; // to remove compiler warning - bool status = internalRunSolver(query, /*objects=*/NULL, /*values=*/NULL, + bool status = internalRunSolver(query, /*env=*/env, + ObjectAssignment::NotNeeded, /*values=*/NULL, /*validityCore=*/NULL, hasSolution); isValid = !hasSolution; return status; } -bool Z3SolverImpl::computeValue(const Query &query, ref &result) { - std::vector objects; +bool Z3SolverImpl::computeValue(const ConstraintQuery &query, Z3SolverEnv &env, + ref &result) { std::vector> values; bool hasSolution; // Find the object used in the expression, and compute an assignment // for them. - findSymbolicObjects(query.expr, objects); - if (!computeInitialValues(query.withFalse(), objects, values, hasSolution)) + findSymbolicObjects(query.getOriginalQueryExpr(), env.objectsForGetModel); + if (!computeInitialValues(query.withFalse(), env, values, hasSolution)) return false; assert(hasSolution && "state has invalid constraint set"); // Evaluate the expression with the computed assignment. - Assignment a(objects, values); - result = a.evaluate(query.expr); + Assignment a(env.objectsForGetModel, values); + result = a.evaluate(query.getOriginalQueryExpr()); return true; } bool Z3SolverImpl::computeInitialValues( - const Query &query, const std::vector &objects, + const ConstraintQuery &query, Z3SolverEnv &env, std::vector> &values, bool &hasSolution) { - return internalRunSolver(query, &objects, &values, /*validityCore=*/NULL, - hasSolution); + return internalRunSolver(query, env, + ObjectAssignment::NeededForObjectsFromEnv, &values, + /*validityCore=*/NULL, hasSolution); } -bool Z3SolverImpl::check(const Query &query, ref &result) { - ExprHashSet expressions; - assert(!query.containsSymcretes()); - expressions.insert(query.constraints.cs().begin(), - query.constraints.cs().end()); - expressions.insert(query.expr); - - std::vector objects; - findSymbolicObjects(expressions.begin(), expressions.end(), objects); +bool Z3SolverImpl::check(const ConstraintQuery &query, Z3SolverEnv &env, + ref &result) { std::vector> values; - ValidityCore validityCore; - bool hasSolution = false; - bool status = - internalRunSolver(query, &objects, &values, &validityCore, hasSolution); + internalRunSolver(query, env, ObjectAssignment::NeededForObjectsFromQuery, + &values, &validityCore, hasSolution); if (status) { result = hasSolution - ? (SolverResponse *)new InvalidResponse(objects, values) + ? (SolverResponse *)new InvalidResponse(env.objects.v, values) : (SolverResponse *)new ValidResponse(validityCore); } return status; } -bool Z3SolverImpl::computeValidityCore(const Query &query, +bool Z3SolverImpl::computeValidityCore(const ConstraintQuery &query, + Z3SolverEnv &env, ValidityCore &validityCore, bool &isValid) { bool hasSolution = false; // to remove compiler warning - bool status = internalRunSolver(query, /*objects=*/NULL, /*values=*/NULL, - &validityCore, hasSolution); + bool status = + internalRunSolver(query, /*env=*/env, ObjectAssignment::NotNeeded, + /*values=*/NULL, &validityCore, hasSolution); isValid = !hasSolution; return status; } bool Z3SolverImpl::internalRunSolver( - const Query &query, const std::vector *objects, + const ConstraintQuery &query, Z3SolverEnv &env, + ObjectAssignment needObjects, std::vector> *values, ValidityCore *validityCore, bool &hasSolution) { - assert(!query.containsSymcretes()); - if (ProduceUnsatCore && validityCore) { enableUnsatCore(); } else { @@ -363,107 +516,86 @@ bool Z3SolverImpl::internalRunSolver( // TODO: Investigate using a custom tactic as described in // https://github.com/klee/klee/issues/653 - Z3_goal goal = Z3_mk_goal(builder->ctx, false, false, false); - Z3_goal_inc_ref(builder->ctx, goal); - - // TODO: make a RAII - Z3_probe probe = Z3_mk_probe(builder->ctx, "is-qfaufbv"); - Z3_probe_inc_ref(builder->ctx, probe); - - runStatusCode = SOLVER_RUN_STATUS_FAILURE; + runStatusCode = SolverImpl::SOLVER_RUN_STATUS_FAILURE; + + std::unordered_set all_constant_arrays_in_query; + Z3ASTIncSet exprs; + + for (size_t i = 0; i < query.constraints.framesSize(); + i++, env.push(), exprs.push()) { + ConstantArrayFinder constant_arrays_in_query; + env.symbolicObjects.insert(query.constraints.begin(i), + query.constraints.end(i)); + // FIXME: findSymbolicObjects template does not support inc_uset::iterator + // findSymbolicObjects(env.symbolicObjects.begin(-1), + // env.symbolicObjects.end(-1), env.objects.v); + std::vector> tmp(env.symbolicObjects.begin(-1), + env.symbolicObjects.end(-1)); + findSymbolicObjects(tmp.begin(), tmp.end(), env.objects.v); + for (auto cs_it = query.constraints.begin(i), + cs_ite = query.constraints.end(i); + cs_it != cs_ite; cs_it++) { + const auto &constraint = *cs_it; + Z3ASTHandle z3Constraint = builder->construct(constraint); + if (ProduceUnsatCore && validityCore) { + Z3ASTHandle p = builder->buildFreshBoolConst(); + env.z3_ast_expr_to_klee_expr.insert({p, constraint}); + env.z3_ast_expr_constraints.v.push_back(p); + env.expr_to_track[z3Constraint] = p; + } - ConstantArrayFinder constant_arrays_in_query; - std::vector z3_ast_expr_constraints; - std::unordered_map, Z3ASTHandleHash, Z3ASTHandleCmp> - z3_ast_expr_to_klee_expr; + exprs.insert(z3Constraint); - std::unordered_map - expr_to_track; - std::unordered_set exprs; + constant_arrays_in_query.visit(constraint); - for (auto const &constraint : query.constraints.cs()) { - Z3ASTHandle z3Constraint = builder->construct(constraint); - if (ProduceUnsatCore && validityCore) { - Z3ASTHandle p = builder->buildFreshBoolConst(); - z3_ast_expr_to_klee_expr.insert({p, constraint}); - z3_ast_expr_constraints.push_back(p); - expr_to_track[z3Constraint] = p; + std::vector> reads; + findReads(constraint, true, reads); + for (const auto &readExpr : reads) { + auto readFromArray = readExpr->updates.root; + assert(readFromArray); + env.usedArrayBytes[readFromArray].insert(readExpr->index); + } } - Z3_goal_assert(builder->ctx, goal, z3Constraint); - exprs.insert(z3Constraint); + for (auto constant_array : constant_arrays_in_query.results) { + assert(builder->constant_array_assertions.count(constant_array) == 1 && + "Constant array found in query, but not handled by Z3Builder"); + if (all_constant_arrays_in_query.count(constant_array)) + continue; + all_constant_arrays_in_query.insert(constant_array); + const auto &cas = builder->constant_array_assertions[constant_array]; + exprs.insert(cas.begin(), cas.end()); + } - constant_arrays_in_query.visit(constraint); + // Assert an generated side constraints we have to this last so that all + // other constraints have been traversed so we have all the side constraints + // needed. + exprs.insert(builder->sideConstraints.begin(), + builder->sideConstraints.end()); } + exprs.pop(1); // drop last empty frame + ++stats::solverQueries; - if (objects) + if (!env.objects.v.empty()) ++stats::queryCounterexamples; - Z3ASTHandle z3QueryExpr = - Z3ASTHandle(builder->construct(query.expr), builder->ctx); - constant_arrays_in_query.visit(query.expr); - - for (auto const &constant_array : constant_arrays_in_query.results) { - assert(builder->constant_array_assertions.count(constant_array) == 1 && - "Constant array found in query, but not handled by Z3Builder"); - for (auto const &arrayIndexValueExpr : - builder->constant_array_assertions[constant_array]) { - Z3_goal_assert(builder->ctx, goal, arrayIndexValueExpr); - exprs.insert(arrayIndexValueExpr); - } - } - - // KLEE Queries are validity queries i.e. - // ∀ X Constraints(X) → query(X) - // but Z3 works in terms of satisfiability so instead we ask the - // negation of the equivalent i.e. - // ∃ X Constraints(X) ∧ ¬ query(X) - Z3ASTHandle z3NotQueryExpr = - Z3ASTHandle(Z3_mk_not(builder->ctx, z3QueryExpr), builder->ctx); - Z3_goal_assert(builder->ctx, goal, z3NotQueryExpr); - - // Assert an generated side constraints we have to this last so that all other - // constraints have been traversed so we have all the side constraints needed. - for (std::vector::iterator it = builder->sideConstraints.begin(), - ie = builder->sideConstraints.end(); - it != ie; ++it) { - Z3ASTHandle sideConstraint = *it; - Z3_goal_assert(builder->ctx, goal, sideConstraint); - exprs.insert(sideConstraint); - } - - std::vector arrays = query.gatherArrays(); - bool forceTactic = true; - for (const Array *array : arrays) { - if (isa(array->source)) { - forceTactic = false; - break; - } - } - - Z3_solver theSolver; - if (forceTactic && Z3_probe_apply(builder->ctx, probe, goal)) { - theSolver = Z3_mk_solver_for_logic( - builder->ctx, Z3_mk_string_symbol(builder->ctx, "QF_AUFBV")); - } else { - theSolver = Z3_mk_solver(builder->ctx); - } - Z3_solver_inc_ref(builder->ctx, theSolver); - Z3_solver_set_params(builder->ctx, theSolver, solverParameters); - - for (std::unordered_set::iterator it = exprs.begin(), - ie = exprs.end(); - it != ie; ++it) { - Z3ASTHandle expr = *it; - if (expr_to_track.count(expr)) { - Z3_solver_assert_and_track(builder->ctx, theSolver, expr, - expr_to_track[expr]); - } else { - Z3_solver_assert(builder->ctx, theSolver, expr); + Z3_solver theSolver = initNativeZ3(query, exprs); + + for (size_t i = 0; i < exprs.framesSize(); i++) { + push(builder->ctx, theSolver); + for (auto it = exprs.begin(i), ie = exprs.end(i); it != ie; ++it) { + Z3ASTHandle expr = *it; + if (env.expr_to_track.count(expr)) { + Z3_solver_assert_and_track(builder->ctx, theSolver, expr, + env.expr_to_track[expr]); + } else { + Z3_solver_assert(builder->ctx, theSolver, expr); + } } } - Z3_solver_assert(builder->ctx, theSolver, z3NotQueryExpr); + assert(!Z3_solver_get_num_scopes(builder->ctx, theSolver) || + Z3_solver_get_num_scopes(builder->ctx, theSolver) + 1 == + env.objects.framesSize()); if (dumpedQueriesFile) { *dumpedQueriesFile << "; start Z3 query\n"; @@ -476,22 +608,9 @@ bool Z3SolverImpl::internalRunSolver( dumpedQueriesFile->flush(); } - constraints_ty allConstraints = query.constraints.cs(); - allConstraints.insert(query.expr); - std::unordered_map usedArrayBytes; - for (auto constraint : allConstraints) { - std::vector> reads; - findReads(constraint, true, reads); - for (auto readExpr : reads) { - const Array *readFromArray = readExpr->updates.root; - assert(readFromArray); - usedArrayBytes[readFromArray].insert(readExpr->index); - } - } - ::Z3_lbool satisfiable = Z3_solver_check(builder->ctx, theSolver); - runStatusCode = handleSolverResponse(theSolver, satisfiable, objects, values, - usedArrayBytes, hasSolution); + runStatusCode = handleSolverResponse(theSolver, satisfiable, env, needObjects, + values, hasSolution); if (ProduceUnsatCore && validityCore && satisfiable == Z3_L_FALSE) { constraints_ty unsatCore; Z3_ast_vector z3_unsat_core = @@ -508,15 +627,15 @@ bool Z3SolverImpl::internalRunSolver( z3_ast_expr_unsat_core.insert(constraint); } - for (auto &z3_constraint : z3_ast_expr_constraints) { + for (const auto &z3_constraint : env.z3_ast_expr_constraints.v) { if (z3_ast_expr_unsat_core.find(z3_constraint) != z3_ast_expr_unsat_core.end()) { - ref constraint = z3_ast_expr_to_klee_expr[z3_constraint]; + ref constraint = env.z3_ast_expr_to_klee_expr[z3_constraint]; unsatCore.insert(constraint); } } assert(validityCore && "validityCore cannot be nullptr"); - *validityCore = ValidityCore(unsatCore, query.expr); + *validityCore = ValidityCore(unsatCore, query.getOriginalQueryExpr()); Z3_ast_vector assertions = Z3_solver_get_assertions(builder->ctx, theSolver); @@ -531,9 +650,7 @@ bool Z3SolverImpl::internalRunSolver( Z3_ast_vector_dec_ref(builder->ctx, assertions); } - Z3_goal_dec_ref(builder->ctx, goal); - Z3_probe_dec_ref(builder->ctx, probe); - Z3_solver_dec_ref(builder->ctx, theSolver); + deinitNativeZ3(theSolver); // Clear the builder's cache to prevent memory usage exploding. // By using ``autoClearConstructCache=false`` and clearning now @@ -558,17 +675,16 @@ bool Z3SolverImpl::internalRunSolver( } SolverImpl::SolverRunStatus Z3SolverImpl::handleSolverResponse( - ::Z3_solver theSolver, ::Z3_lbool satisfiable, - const std::vector *objects, - std::vector> *values, - const std::unordered_map &usedArrayBytes, - bool &hasSolution) { + ::Z3_solver theSolver, ::Z3_lbool satisfiable, const Z3SolverEnv &env, + ObjectAssignment needObjects, + std::vector> *values, bool &hasSolution) { switch (satisfiable) { case Z3_L_TRUE: { hasSolution = true; + auto objects = env.getObjectsForGetModel(needObjects); if (!objects) { // No assignment is needed - assert(values == NULL); + assert(!values); return SolverImpl::SOLVER_RUN_STATUS_SUCCESS_SOLVABLE; } assert(values && "values cannot be nullptr"); @@ -576,10 +692,7 @@ SolverImpl::SolverRunStatus Z3SolverImpl::handleSolverResponse( assert(theModel && "Failed to retrieve model"); Z3_model_inc_ref(builder->ctx, theModel); values->reserve(objects->size()); - for (std::vector::const_iterator it = objects->begin(), - ie = objects->end(); - it != ie; ++it) { - const Array *array = *it; + for (auto array : *objects) { SparseStorage data; ::Z3_ast arraySizeExpr; @@ -596,9 +709,9 @@ SolverImpl::SolverRunStatus Z3SolverImpl::handleSolverResponse( assert(success && "Failed to get size"); data.resize(arraySize); - if (usedArrayBytes.count(array)) { + if (env.usedArrayBytes.count(array)) { std::unordered_set offsetValues; - for (ref offsetExpr : usedArrayBytes.at(array)) { + for (const ref &offsetExpr : env.usedArrayBytes.at(array)) { ::Z3_ast arrayElementOffsetExpr; Z3_model_eval(builder->ctx, theModel, builder->construct(offsetExpr), Z3_TRUE, &arrayElementOffsetExpr); @@ -734,5 +847,354 @@ bool Z3SolverImpl::validateZ3Model(::Z3_solver &theSolver, SolverImpl::SolverRunStatus Z3SolverImpl::getOperationStatusCode() { return runStatusCode; } + +class Z3NonIncSolverImpl final : public Z3SolverImpl { +private: +public: + Z3NonIncSolverImpl(Z3BuilderType type) : Z3SolverImpl(type) {} + + /// implementation of Z3SolverImpl interface + Z3_solver initNativeZ3(const ConstraintQuery &query, + Z3ASTIncSet &assertions) override { + Z3_solver theSolver = nullptr; + auto arrays = query.gatherArrays(); + bool forceTactic = true; + for (auto array : arrays) { + if (isa(array->source)) { + forceTactic = false; + break; + } + } + + auto ctx = builder->ctx; + if (forceTactic) { + Z3_probe probe = Z3_mk_probe(ctx, "is-qfaufbv"); + Z3_probe_inc_ref(ctx, probe); + Z3_goal goal = Z3_mk_goal(ctx, false, false, false); + Z3_goal_inc_ref(ctx, goal); + + for (auto constraint : assertions) + Z3_goal_assert(ctx, goal, constraint); + if (Z3_probe_apply(ctx, probe, goal)) + theSolver = + Z3_mk_solver_for_logic(ctx, Z3_mk_string_symbol(ctx, "QF_AUFBV")); + Z3_goal_dec_ref(ctx, goal); + Z3_probe_dec_ref(ctx, probe); + } + if (!theSolver) + theSolver = Z3_mk_solver(ctx); + Z3_solver_inc_ref(ctx, theSolver); + Z3_solver_set_params(ctx, theSolver, solverParameters); + return theSolver; + } + void deinitNativeZ3(Z3_solver theSolver) override { + deleteNativeZ3(builder->ctx, theSolver); + } + void push(Z3_context c, Z3_solver s) override {} + + /// implementation of the SolverImpl interface + bool computeTruth(const Query &query, bool &isValid) override { + Z3SolverEnv env; + return Z3SolverImpl::computeTruth(ConstraintQuery(query, false), env, + isValid); + } + bool computeValue(const Query &query, ref &result) override { + Z3SolverEnv env; + return Z3SolverImpl::computeValue(ConstraintQuery(query, false), env, + result); + } + bool computeInitialValues(const Query &query, + const std::vector &objects, + std::vector> &values, + bool &hasSolution) override { + Z3SolverEnv env(objects); + return Z3SolverImpl::computeInitialValues(ConstraintQuery(query, false), + env, values, hasSolution); + } + bool check(const Query &query, ref &result) override { + Z3SolverEnv env; + return Z3SolverImpl::check(ConstraintQuery(query, false), env, result); + } + bool computeValidityCore(const Query &query, ValidityCore &validityCore, + bool &isValid) override { + Z3SolverEnv env; + return Z3SolverImpl::computeValidityCore(ConstraintQuery(query, false), env, + validityCore, isValid); + } + void notifyStateTermination(std::uint32_t id) override {} +}; + +Z3Solver::Z3Solver(Z3BuilderType type) + : Solver(std::make_unique(type)) {} + +struct ConstraintDistance { + size_t toPopSize = 0; + ConstraintQuery toPush; + + explicit ConstraintDistance() {} + ConstraintDistance(const ConstraintQuery &q) : toPush(q) {} + explicit ConstraintDistance(size_t toPopSize, const ConstraintQuery &q) + : toPopSize(toPopSize), toPush(q) {} + + size_t getDistance() const { return toPopSize + toPush.size(); } + + bool isOnlyPush() const { return toPopSize == 0; } + + void dump() const { + llvm::errs() << "ConstraintDistance: pop: " << toPopSize << "; push:\n"; + klee::dump(toPush.constraints); + } +}; + +class Z3IncNativeSolver { +private: + Z3_solver nativeSolver = nullptr; + Z3_context ctx; + Z3_params solverParameters; + /// underlying solver frames + /// saved only for calculating distances from next queries + ConstraintFrames frames; + + void pop(size_t popSize); + void push(); + +public: + Z3SolverEnv env; + std::uint32_t stateID = 0; + bool isRecycled = false; + + Z3IncNativeSolver(Z3_context ctx, Z3_params solverParameters) + : ctx(ctx), solverParameters(solverParameters) {} + ~Z3IncNativeSolver(); + + void clear(); + + void distance(const ConstraintQuery &query, ConstraintDistance &delta) const; + + void popPush(ConstraintDistance &delta); + + Z3_solver getOrInit(); + + bool isConsistent() const { + auto num_scopes = + nativeSolver ? Z3_solver_get_num_scopes(ctx, nativeSolver) : 0; + bool consistentWithZ3 = num_scopes + 1 == frames.framesSize(); + assert(consistentWithZ3); + bool constistentItself = frames.framesSize() == env.objects.framesSize(); + assert(constistentItself); + return consistentWithZ3 && constistentItself; + } + + void dump() const { ::klee::dump(frames); } +}; + +void Z3IncNativeSolver::pop(size_t popSize) { + if (!nativeSolver || !popSize) + return; + Z3_solver_pop(ctx, nativeSolver, popSize); +} + +void Z3IncNativeSolver::popPush(ConstraintDistance &delta) { + env.pop(delta.toPopSize); + pop(delta.toPopSize); + frames.pop(delta.toPopSize); + frames.extend(delta.toPush.constraints); +} + +Z3_solver Z3IncNativeSolver::getOrInit() { + if (nativeSolver == nullptr) { + nativeSolver = Z3_mk_solver(ctx); + Z3_solver_inc_ref(ctx, nativeSolver); + Z3_solver_set_params(ctx, nativeSolver, solverParameters); + } + return nativeSolver; +} + +Z3IncNativeSolver::~Z3IncNativeSolver() { + if (nativeSolver != nullptr) + deleteNativeZ3(ctx, nativeSolver); +} + +void Z3IncNativeSolver::clear() { + if (!nativeSolver) + return; + env.clear(); + frames.clear(); + Z3_solver_reset(ctx, nativeSolver); + isRecycled = false; +} + +void Z3IncNativeSolver::distance(const ConstraintQuery &query, + ConstraintDistance &delta) const { + auto sit = frames.v.begin(); + auto site = frames.v.end(); + auto qit = query.constraints.v.begin(); + auto qite = query.constraints.v.end(); + auto it = frames.begin(); + auto ite = frames.end(); + size_t intersect = 0; + for (; it != ite && sit != site && qit != qite && *sit == *qit; it++) { + size_t frame_size = *it; + for (size_t i = 0; + i < frame_size && sit != site && qit != qite && *sit == *qit; + i++, sit++, qit++, intersect++) { + } + } + for (; sit != site && qit != qite && *sit == *qit; + sit++, qit++, intersect++) { + } + size_t toPop, extraTakeFromOther; + ConstraintFrames d; + if (sit == site) { // solver frames ended + toPop = 0; + extraTakeFromOther = 0; + } else { + frames.takeBefore(intersect, toPop, extraTakeFromOther); + } + query.constraints.takeAfter(intersect - extraTakeFromOther, d); + delta = ConstraintDistance(toPop, d); +} + +class Z3TreeSolverImpl final : public Z3SolverImpl { +private: + using solvers_ty = std::vector>; + using solvers_it = solvers_ty::iterator; + + const size_t maxSolvers; + std::unique_ptr currentSolver = nullptr; + solvers_ty solvers; + + void findSuitableSolver(const ConstraintQuery &query, + ConstraintDistance &delta); + void setSolver(solvers_it &it, bool recycle = false); + ConstraintQuery prepare(const Query &q); + +public: + Z3TreeSolverImpl(Z3BuilderType type, size_t maxSolvers) + : Z3SolverImpl(type), maxSolvers(maxSolvers){}; + + /// implementation of Z3SolverImpl interface + Z3_solver initNativeZ3(const ConstraintQuery &query, + Z3ASTIncSet &assertions) override { + return currentSolver->getOrInit(); + } + void deinitNativeZ3(Z3_solver theSolver) override { + assert(currentSolver->isConsistent()); + solvers.push_back(std::move(currentSolver)); + } + void push(Z3_context c, Z3_solver s) override { Z3_solver_push(c, s); } + + /// implementation of the SolverImpl interface + bool computeTruth(const Query &query, bool &isValid) override; + bool computeValue(const Query &query, ref &result) override; + bool computeInitialValues(const Query &query, + const std::vector &objects, + std::vector> &values, + bool &hasSolution) override; + bool check(const Query &query, ref &result) override; + bool computeValidityCore(const Query &query, ValidityCore &validityCore, + bool &isValid) override; + + void notifyStateTermination(std::uint32_t id) override; +}; + +void Z3TreeSolverImpl::setSolver(solvers_it &it, bool recycle) { + assert(it != solvers.end()); + currentSolver = std::move(*it); + solvers.erase(it); + currentSolver->isRecycled = false; + if (recycle) + currentSolver->clear(); +} + +void Z3TreeSolverImpl::findSuitableSolver(const ConstraintQuery &query, + ConstraintDistance &delta) { + ConstraintDistance min_delta; + auto min_distance = std::numeric_limits::max(); + auto min_it = solvers.end(); + auto free_it = solvers.end(); + for (auto it = solvers.begin(), ite = min_it; it != ite; it++) { + if ((*it)->isRecycled) + free_it = it; + (*it)->distance(query, delta); + if (delta.isOnlyPush()) { + setSolver(it); + return; + } + auto distance = delta.getDistance(); + if (distance < min_distance) { + min_delta = delta; + min_distance = distance; + min_it = it; + } + } + if (solvers.size() < maxSolvers) { + delta = ConstraintDistance(query); + if (delta.getDistance() < min_distance) { + // it is cheaper to create new solver + if (free_it == solvers.end()) + currentSolver = + std::make_unique(builder->ctx, solverParameters); + else + setSolver(free_it, /*recycle=*/true); + return; + } + } + assert(min_it != solvers.end()); + delta = min_delta; + setSolver(min_it); +} + +ConstraintQuery Z3TreeSolverImpl::prepare(const Query &q) { + ConstraintDistance delta; + ConstraintQuery query(q, true); + findSuitableSolver(query, delta); + assert(currentSolver->isConsistent()); + currentSolver->stateID = q.id; + currentSolver->popPush(delta); + return delta.toPush; +} + +bool Z3TreeSolverImpl::computeTruth(const Query &query, bool &isValid) { + auto q = prepare(query); + return Z3SolverImpl::computeTruth(q, currentSolver->env, isValid); +} + +bool Z3TreeSolverImpl::computeValue(const Query &query, ref &result) { + auto q = prepare(query); + return Z3SolverImpl::computeValue(q, currentSolver->env, result); +} + +bool Z3TreeSolverImpl::computeInitialValues( + const Query &query, const std::vector &objects, + std::vector> &values, bool &hasSolution) { + auto q = prepare(query); + currentSolver->env.objectsForGetModel = objects; + return Z3SolverImpl::computeInitialValues(q, currentSolver->env, values, + hasSolution); +} + +bool Z3TreeSolverImpl::check(const Query &query, ref &result) { + auto q = prepare(query); + return Z3SolverImpl::check(q, currentSolver->env, result); +} + +bool Z3TreeSolverImpl::computeValidityCore(const Query &query, + ValidityCore &validityCore, + bool &isValid) { + auto q = prepare(query); + return Z3SolverImpl::computeValidityCore(q, currentSolver->env, validityCore, + isValid); +} + +void Z3TreeSolverImpl::notifyStateTermination(std::uint32_t id) { + for (auto &s : solvers) + if (s->stateID == id) + s->isRecycled = true; +} + +Z3TreeSolver::Z3TreeSolver(Z3BuilderType type, unsigned maxSolvers) + : Solver(std::make_unique(type, maxSolvers)) {} + } // namespace klee #endif // ENABLE_Z3 diff --git a/lib/Solver/Z3Solver.h b/lib/Solver/Z3Solver.h index 0189dec08f1..6b4aca126be 100644 --- a/lib/Solver/Z3Solver.h +++ b/lib/Solver/Z3Solver.h @@ -23,16 +23,13 @@ class Z3Solver : public Solver { public: /// Z3Solver - Construct a new Z3Solver. Z3Solver(Z3BuilderType type); +}; - /// Get the query in SMT-LIBv2 format. - /// \return A C-style string. The caller is responsible for freeing this. - virtual char *getConstraintLog(const Query &); - - /// setCoreSolverTimeout - Set constraint solver timeout delay to the given - /// value; 0 - /// is off. - virtual void setCoreSolverTimeout(time::Span timeout); +class Z3TreeSolver : public Solver { +public: + Z3TreeSolver(Z3BuilderType type, unsigned maxSolvers); }; + } // namespace klee #endif /* KLEE_Z3SOLVER_H */ diff --git a/test/Solver/CrosscheckZ3AndZ3TreeInc.c b/test/Solver/CrosscheckZ3AndZ3TreeInc.c new file mode 100644 index 00000000000..9c4c499a263 --- /dev/null +++ b/test/Solver/CrosscheckZ3AndZ3TreeInc.c @@ -0,0 +1,11 @@ +// REQUIRES: z3 +// RUN: %clang %s -emit-llvm %O0opt -c -o %t1.bc +// RUN: rm -rf %t.klee-out +// RUN: %klee --output-dir=%t.klee-out --search=bfs --solver-backend=z3-tree --max-solvers-approx-tree-inc=4 --debug-crosscheck-core-solver=z3 --debug-z3-validate-models --debug-assignment-validating-solver --use-cex-cache=false %t1.bc 2>&1 | FileCheck %s +// RUN: rm -rf %t.klee-out +// RUN: %klee --output-dir=%t.klee-out --search=dfs --solver-backend=z3-tree --max-solvers-approx-tree-inc=64 --debug-crosscheck-core-solver=z3 --debug-z3-validate-models --debug-assignment-validating-solver --use-cex-cache=false %t1.bc 2>&1 | FileCheck %s + +#include "ExerciseSolver.c.inc" + +// CHECK: KLEE: done: completed paths = 10 +// CHECK: KLEE: done: partially completed paths = 4 diff --git a/tools/kleaver/main.cpp b/tools/kleaver/main.cpp index 441e7ffacf3..f99b9e88c26 100644 --- a/tools/kleaver/main.cpp +++ b/tools/kleaver/main.cpp @@ -224,13 +224,9 @@ static bool EvaluateInputAST(const char *Filename, const llvm::MemoryBuffer *MB, assert("FIXME: Support counterexample query commands!"); if (QC->Values.empty() && QC->Objects.empty()) { bool result; - constraints_ty constraints; - for (auto i : QC->Constraints) { - constraints.insert(i); - } - if (S->mustBeTrue( - Query(ConstraintSet(constraints, {}, {true}), QC->Query), - result)) { + constraints_ty constraints(QC->Constraints.begin(), + QC->Constraints.end()); + if (S->mustBeTrue(Query(constraints, QC->Query), result)) { llvm::outs() << (result ? "VALID" : "INVALID"); } else { llvm::outs() << "FAIL (reason: " @@ -250,9 +246,7 @@ static bool EvaluateInputAST(const char *Filename, const llvm::MemoryBuffer *MB, for (auto i : QC->Constraints) { constraints.insert(i); } - if (S->getValue( - Query(ConstraintSet(constraints, {}, {true}), QC->Values[0]), - result)) { + if (S->getValue(Query(constraints, QC->Values[0]), result)) { llvm::outs() << "INVALID\n"; llvm::outs() << "\tExpr 0:\t" << result; } else { @@ -264,14 +258,11 @@ static bool EvaluateInputAST(const char *Filename, const llvm::MemoryBuffer *MB, } else { std::vector> result; - constraints_ty constraints; - for (auto i : QC->Constraints) { - constraints.insert(i); - } + constraints_ty constraints(QC->Constraints.begin(), + QC->Constraints.end()); - if (S->getInitialValues( - Query(ConstraintSet(constraints, {}, {true}), QC->Query), - QC->Objects, result)) { + if (S->getInitialValues(Query(constraints, QC->Query), QC->Objects, + result)) { llvm::outs() << "INVALID\n"; Assignment solutionAssugnment(QC->Objects, result); for (unsigned i = 0, e = result.size(); i != e; ++i) { @@ -378,13 +369,9 @@ static bool printInputAsSMTLIBv2(const char *Filename, * constraint in the constraint set is set to NULL and * will later cause a NULL pointer dereference. */ - constraints_ty constraints; - for (auto i : QC->Constraints) { - constraints.insert(i); - } - - ConstraintSet constraintM(constraints, {}, {true}); - Query query(constraintM, QC->Query); + constraints_ty constraints(QC->Constraints.begin(), + QC->Constraints.end()); + Query query(constraints, QC->Query); printer.setQuery(query); if (!QC->Objects.empty()) diff --git a/unittests/Solver/SolverTest.cpp b/unittests/Solver/SolverTest.cpp index 3206bd1a18a..a208376c7db 100644 --- a/unittests/Solver/SolverTest.cpp +++ b/unittests/Solver/SolverTest.cpp @@ -82,9 +82,9 @@ void testOperation(Solver &solver, int value, Expr::Width operandWidth, ref queryExpr = EqExpr::create(fullySymbolicExpr, partiallyConstantExpr); - ConstraintSet constraints; - constraints.addConstraint( - Simplificator::simplifyExpr(ConstraintSet(), expr).simplified, {}); + constraints_ty constraints; + constraints.insert( + Simplificator::simplifyExpr(ConstraintSet(), expr).simplified); bool res; bool success = solver.mustBeTrue(Query(constraints, queryExpr), res); EXPECT_EQ(true, success) << "Constraint solving failed"; diff --git a/unittests/Solver/Z3SolverTest.cpp b/unittests/Solver/Z3SolverTest.cpp index 8c8d84e9be8..d10a7b853c7 100644 --- a/unittests/Solver/Z3SolverTest.cpp +++ b/unittests/Solver/Z3SolverTest.cpp @@ -37,7 +37,7 @@ class Z3SolverTest : public ::testing::Test { }; TEST_F(Z3SolverTest, GetConstraintLog) { - ConstraintSet Constraints; + constraints_ty Constraints; const std::vector ConstantValues{1, 2, 3, 4}; std::vector> ConstantExpressions;