From c77e6bba783d3ae69cc9bed7aa3f9ddf1c3cea58 Mon Sep 17 00:00:00 2001 From: Nuno Lopes Date: Sun, 24 Sep 2023 18:26:46 +0100 Subject: [PATCH] Add a function ASM mode, which enables assembly semantics for that function Right now it enables int2ptr and int->ptr load type punning --- ir/attrs.cpp | 2 ++ ir/attrs.h | 2 +- ir/memory.cpp | 31 ++++++++++++++++++++++--------- ir/state_value.cpp | 8 ++++++++ ir/state_value.h | 3 +++ llvm_util/cmd_args_def.h | 2 +- llvm_util/cmd_args_list.h | 6 +++--- tools/transform.cpp | 3 +++ util/config.cpp | 2 +- util/config.h | 2 +- 10 files changed, 45 insertions(+), 16 deletions(-) diff --git a/ir/attrs.cpp b/ir/attrs.cpp index 2105e021a..72dbb1ee8 100644 --- a/ir/attrs.cpp +++ b/ir/attrs.cpp @@ -130,6 +130,8 @@ ostream& operator<<(ostream &os, const FnAttrs &attr) { attr.fp_denormal.print(os); if (attr.fp_denormal32) attr.fp_denormal32->print(os, true); + if (attr.has(FnAttrs::Asm)) + os << " asm"; return os << attr.mem; } diff --git a/ir/attrs.h b/ir/attrs.h index e2e5b80c2..0ec586b6c 100644 --- a/ir/attrs.h +++ b/ir/attrs.h @@ -123,7 +123,7 @@ class FnAttrs final { DereferenceableOrNull = 1 << 10, NullPointerIsValid = 1 << 11, AllocSize = 1 << 12, ZeroExt = 1<<13, - SignExt = 1<<14, NoFPClass = 1<<15, }; + SignExt = 1<<14, NoFPClass = 1<<15, Asm = 1<<16 }; FnAttrs(unsigned bits = None) : bits(bits) {} diff --git a/ir/memory.cpp b/ir/memory.cpp index ce7a3b246..f2a20c19d 100644 --- a/ir/memory.cpp +++ b/ir/memory.cpp @@ -2,6 +2,7 @@ // Distributed under the MIT license that can be found in the LICENSE file. #include "ir/memory.h" +#include "ir/function.h" #include "ir/globals.h" #include "ir/state.h" #include "ir/value.h" @@ -514,16 +515,16 @@ static StateValue bytesToValue(const Memory &m, const vector &bytes, // if bits of loaded ptr are a subset of the non-ptr value, // we know they must be zero otherwise the value is poison. // Therefore we obtain a null pointer for free. - expr _, value; + expr _; unsigned low, high, low2, high2; if (loaded_ptr.isExtract(_, high, low) && bytes[0].nonptrValue().isExtract(_, high2, low2) && high2 >= high && low2 <= low) { - value = std::move(loaded_ptr); + // do nothing } else { - value = expr::mkIf(is_ptr, loaded_ptr, Pointer::mkNullPointer(m)()); + loaded_ptr = expr::mkIf(is_ptr, loaded_ptr, Pointer::mkNullPointer(m)()); } - return { std::move(value), std::move(non_poison) }; + return { std::move(loaded_ptr), std::move(non_poison) }; } else { assert(!toType.isAggregateType() || isNonPtrVector(toType)); @@ -541,7 +542,18 @@ static StateValue bytesToValue(const Memory &m, const vector &bytes, val = first ? std::move(v) : v.concat(val); first = false; } - return toType.fromInt(val.trunc(bitsize, toType.np_bits(true))); + val = toType.fromInt(val.trunc(bitsize, toType.np_bits(true))); + + // allow ptr->int type punning in Assembly mode + if (bitsize == bits_ptr_address && + m.getState().getFn().getFnAttrs().has(FnAttrs::Asm)) { + StateValue ptr_val = bytesToValue(m, bytes, PtrType(0)); + ptr_val.value = Pointer(m, ptr_val.value).getAddress(); + expr is_ptr = bytes[0].isPtr(); + val + = StateValue::mkIf(is_ptr, ptr_val.subst(is_ptr, true).simplify(), val); + } + return val; } } @@ -1232,8 +1244,10 @@ void Memory::mkAxioms(const Memory &tgt) const { auto sz = p1.blockSize().zextOrTrunc(bits_ptr_address); auto align = p1.blockAlignment(); - if (!has_null_block || bid != 0) + if (!has_null_block || bid != 0) { state->addAxiom(addr != 0); + state->addAxiom(p1.blockSize() != 0); + } // address must be properly aligned state->addAxiom( @@ -2036,8 +2050,7 @@ expr Memory::ptr2int(const expr &ptr) const { expr Memory::int2ptr(const expr &val) const { assert(!memory_unused() && observesAddresses()); - // TODO: missing pointer escaping - if (config::enable_approx_int2ptr) { + if (state->getFn().getFnAttrs().has(FnAttrs::Asm)) { DisjointExpr ret(expr{}); expr valx = val.zextOrTrunc(bits_program_pointer); @@ -2046,7 +2059,7 @@ expr Memory::int2ptr(const expr &val) const { Pointer p(*this, i, local); Pointer p_end = p + p.blockSize(); ret.add((p + (valx - p.getAddress())).release(), - valx.uge(p.getAddress()) && valx.ule(p_end.getAddress())); + valx.uge(p.getAddress()) && valx.ult(p_end.getAddress())); } }; add(numLocals(), true); diff --git a/ir/state_value.cpp b/ir/state_value.cpp index 3b1f8e211..20a38851c 100644 --- a/ir/state_value.cpp +++ b/ir/state_value.cpp @@ -61,6 +61,10 @@ set StateValue::vars() const { return expr::vars({ &value, &non_poison }); } +StateValue StateValue::subst(const expr &from, const expr &to) const { + return { value.subst(from, to), non_poison.subst(from, to) }; +} + StateValue StateValue::subst(const vector> &repls) const { if (!value.isValid() || !non_poison.isValid()) return { value.subst(repls), non_poison.subst(repls) }; @@ -81,6 +85,10 @@ StateValue StateValue::subst(const vector> &repls) const { return { then.eq(v1) ? std::move(els) : std::move(then), std::move(np) }; } +StateValue StateValue::simplify() const { + return { value.simplify(), non_poison.simplify() }; +} + ostream& operator<<(ostream &os, const StateValue &val) { return os << val.value << " / " << val.non_poison; } diff --git a/ir/state_value.h b/ir/state_value.h index 05e2bbeae..a667e3320 100644 --- a/ir/state_value.h +++ b/ir/state_value.h @@ -34,9 +34,12 @@ struct StateValue { bool eq(const StateValue &other) const; std::set vars() const; + StateValue subst(const smt::expr &from, const smt::expr &to) const; StateValue subst(const std::vector> &repls) const; + StateValue simplify() const; + auto operator<=>(const StateValue &rhs) const = default; friend std::ostream& operator<<(std::ostream &os, const StateValue &val); diff --git a/llvm_util/cmd_args_def.h b/llvm_util/cmd_args_def.h index 45a585229..689cef515 100644 --- a/llvm_util/cmd_args_def.h +++ b/llvm_util/cmd_args_def.h @@ -10,7 +10,7 @@ config::src_unroll_cnt = opt_unrolling_factor; #endif config::disable_undef_input = opt_disable_undef; config::disable_poison_input = opt_disable_poison; -config::enable_approx_int2ptr = opt_enable_approx_int2ptr; +config::tgt_is_asm = opt_tgt_is_asm; config::check_if_src_is_ub = opt_check_if_src_is_ub; config::symexec_print_each_value = opt_se_verbose; smt::set_query_timeout(to_string(opt_smt_to)); diff --git a/llvm_util/cmd_args_list.h b/llvm_util/cmd_args_list.h index 99dd4557f..d63fc3d45 100644 --- a/llvm_util/cmd_args_list.h +++ b/llvm_util/cmd_args_list.h @@ -45,9 +45,9 @@ llvm::cl::opt opt_check_if_src_is_ub( llvm::cl::desc("Check if source function is always UB (default=false)"), llvm::cl::init(false), llvm::cl::cat(alive_cmdargs)); -llvm::cl::opt opt_enable_approx_int2ptr( - LLVM_ARGS_PREFIX "enable-approx-int2ptr", - llvm::cl::desc("Enable unsound approximation of int2ptr (default=false)"), +llvm::cl::opt opt_tgt_is_asm( + LLVM_ARGS_PREFIX "tgt-is-asm", + llvm::cl::desc("Target uses assembly semantics (default=false)"), llvm::cl::init(false), llvm::cl::cat(alive_cmdargs)); llvm::cl::opt opt_error_fatal(LLVM_ARGS_PREFIX "exit-on-error", diff --git a/tools/transform.cpp b/tools/transform.cpp index d9b2d222b..4eed80012 100644 --- a/tools/transform.cpp +++ b/tools/transform.cpp @@ -1589,6 +1589,9 @@ static void optimize_ptrcmp(Function &f) { } void Transform::preprocess() { + if (config::tgt_is_asm) + tgt.getFnAttrs().set(FnAttrs::Asm); + remove_unreachable_bbs(src); remove_unreachable_bbs(tgt); diff --git a/util/config.cpp b/util/config.cpp index ad6399c8d..e5b40c027 100644 --- a/util/config.cpp +++ b/util/config.cpp @@ -15,7 +15,7 @@ bool skip_smt = false; string smt_benchmark_dir; bool disable_poison_input = false; bool disable_undef_input = false; -bool enable_approx_int2ptr = false; +bool tgt_is_asm = false; bool check_if_src_is_ub = false; bool disallow_ub_exploitation = false; bool debug = false; diff --git a/util/config.h b/util/config.h index e21dc46ee..d4b694315 100644 --- a/util/config.h +++ b/util/config.h @@ -19,7 +19,7 @@ extern bool disable_poison_input; extern bool disable_undef_input; -extern bool enable_approx_int2ptr; +extern bool tgt_is_asm; extern bool check_if_src_is_ub;