diff --git a/.github/workflows/diff_tests.yml b/.github/workflows/diff_tests.yml index 895339bf..e9555b18 100644 --- a/.github/workflows/diff_tests.yml +++ b/.github/workflows/diff_tests.yml @@ -13,7 +13,7 @@ jobs: image: - { name: 'ubuntu', tag: '20.04', codename: 'focal' } llvm: [ '16' ] - common_base: [ 'https://github.com/lifting-bits/cxx-common/releases/latest/download' ] + common_base: [ 'https://github.com/lifting-bits/cxx-common/releases/download/v0.3.2' ] env: CC: clang-${{ matrix.llvm }} diff --git a/include/rellic/AST/InlineReferences.h b/include/rellic/AST/InlineReferences.h new file mode 100644 index 00000000..14c3ed2b --- /dev/null +++ b/include/rellic/AST/InlineReferences.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2023-present, Trail of Bits, Inc. + * All rights reserved. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +#include "rellic/AST/IRToASTVisitor.h" +#include "rellic/AST/TransformVisitor.h" + +namespace rellic { + +/* + * This pass removes references to variables that can be inlined + * + * int x = 3 + y; + * if(x) { ... } + * becomes + * if(3 + y) { ... } + */ +class InlineReferences : public TransformVisitor { + private: + std::unordered_map refs; + std::unordered_set removable_decls; + + protected: + void RunImpl() override; + + public: + InlineReferences(DecompilationContext& dec_ctx); + + bool VisitCompoundStmt(clang::CompoundStmt* stmt); +}; + +} // namespace rellic diff --git a/include/rellic/AST/Util.h b/include/rellic/AST/Util.h index e42363e9..2f642d3e 100644 --- a/include/rellic/AST/Util.h +++ b/include/rellic/AST/Util.h @@ -52,6 +52,7 @@ clang::Expr *Clone(clang::ASTUnit &unit, clang::Expr *stmt, DecompilationContext::ExprToUseMap &provenance); std::string ClangThingToString(const clang::Stmt *stmt); +std::string ClangThingToString(const clang::Decl *decl); z3::goal ApplyTactic(const z3::tactic &tactic, z3::expr expr); diff --git a/lib/AST/IRToASTVisitor.cpp b/lib/AST/IRToASTVisitor.cpp index 5fb781ff..5645830e 100755 --- a/lib/AST/IRToASTVisitor.cpp +++ b/lib/AST/IRToASTVisitor.cpp @@ -64,7 +64,7 @@ class ExprGen : public llvm::InstVisitor { clang::Expr *IRToASTVisitor::ConvertExpr(z3::expr expr) { if (expr.decl().decl_kind() == Z3_OP_EQ) { // Equalities generated form the reaching conditions of switch instructions - // Always in the for (VAR == CONST) or (CONST == VAR) + // Always in the form (VAR == CONST) or (CONST == VAR) // VAR will uniquely identify a SwitchInst, CONST will represent the index // of the case taken CHECK_EQ(expr.num_args(), 2) << "Equalities must have 2 arguments"; @@ -1161,6 +1161,13 @@ void IRToASTVisitor::VisitFunctionDecl(llvm::Function &func) { for (auto &inst : llvm::instructions(func)) { auto &var{dec_ctx.value_decls[&inst]}; + bool used_as_branch_cond{false}; + for (auto user : inst.users()) { + if (llvm::isa(user)) { + used_as_branch_cond = true; + break; + } + } if (auto alloca = llvm::dyn_cast(&inst)) { auto name{"var" + std::to_string(GetNumDecls(fdecl))}; // TLDR: Here we discard the variable name as present in the bitcode @@ -1217,6 +1224,16 @@ void IRToASTVisitor::VisitFunctionDecl(llvm::Function &func) { } } } + } else if (used_as_branch_cond) { + auto name{"br_cond_" + + std::to_string(GetNumDecls(fdecl))}; + auto type{dec_ctx.GetQualType(inst.getType())}; + if (auto arrayType = clang::dyn_cast(type)) { + type = dec_ctx.ast_ctx.getPointerType(arrayType->getElementType()); + } + + var = ast.CreateVarDecl(fdecl, type, name); + fdecl->addDecl(var); } for (auto &opnd : inst.operands()) { diff --git a/lib/AST/InlineReferences.cpp b/lib/AST/InlineReferences.cpp new file mode 100644 index 00000000..b609195a --- /dev/null +++ b/lib/AST/InlineReferences.cpp @@ -0,0 +1,241 @@ +/* + * Copyright (c) 2022-present, Trail of Bits, Inc. + * All rights reserved. + * + * This source code is licensed in accordance with the terms specified in + * the LICENSE file found in the root directory of this source tree. + */ + +#include "rellic/AST/InlineReferences.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "rellic/AST/ASTBuilder.h" +#include "rellic/AST/DecompilationContext.h" +#include "rellic/AST/TransformVisitor.h" +#include "rellic/AST/Util.h" + +namespace rellic { +InlineReferences::InlineReferences(DecompilationContext& dec_ctx) + : TransformVisitor(dec_ctx) {} + +class ReferenceCounter : public clang::RecursiveASTVisitor { + private: + DecompilationContext& dec_ctx; + + public: + std::unordered_map& referenced_values; + + private: + void GetReferencedValues(z3::expr expr) { + if (expr.decl().decl_kind() == Z3_OP_EQ) { + // Equalities generated form the reaching conditions of switch + // instructions Always in the form (VAR == CONST) or (CONST == VAR) VAR + // will uniquely identify a SwitchInst, CONST will represent the index of + // the case taken + CHECK_EQ(expr.num_args(), 2) << "Equalities must have 2 arguments"; + auto a{expr.arg(0)}; + auto b{expr.arg(1)}; + + llvm::SwitchInst* inst{dec_ctx.z3_sw_vars_inv[a.id()]}; + unsigned case_idx{}; + + // GenerateAST always generates equalities in the form (VAR == CONST), but + // there is a chance that some Z3 simplification inverts the order, so + // handle that here. + if (!inst) { + inst = dec_ctx.z3_sw_vars_inv[b.id()]; + case_idx = a.get_numeral_uint(); + } else { + case_idx = b.get_numeral_uint(); + } + + for (auto sw_case : inst->cases()) { + if (sw_case.getCaseIndex() == case_idx) { + ++referenced_values[inst->getOperandUse(0)]; + return; + } + } + + LOG(FATAL) << "Couldn't find switch case"; + return; + } + + auto hash{expr.id()}; + if (dec_ctx.z3_br_edges_inv.find(hash) != dec_ctx.z3_br_edges_inv.end()) { + auto edge{dec_ctx.z3_br_edges_inv[hash]}; + CHECK(edge.second) << "Inverse map should only be populated for branches " + "taken when condition is true"; + // expr is a variable that represents the condition of a branch + // instruction. + + // FIXME(frabert): Unfortunately there is no public API in BranchInst that + // gives the operand of the condition. From reverse engineering LLVM code, + // this is the way they obtain uses internally, but it's probably not + // stable. + ++referenced_values[*(edge.first->op_end() - 3)]; + return; + } + + switch (expr.decl().decl_kind()) { + case Z3_OP_TRUE: + case Z3_OP_FALSE: + CHECK_EQ(expr.num_args(), 0) << "Literals cannot have arguments"; + return; + case Z3_OP_AND: + case Z3_OP_OR: { + for (auto i{0U}; i < expr.num_args(); ++i) { + GetReferencedValues(expr.arg(i)); + } + return; + } + case Z3_OP_NOT: { + CHECK_EQ(expr.num_args(), 1) << "Not must have one argument"; + GetReferencedValues(expr.arg(0)); + return; + } + default: + LOG(FATAL) << "Invalid z3 op"; + } + } + + public: + ReferenceCounter(DecompilationContext& dec_ctx, + std::unordered_map& refs) + : dec_ctx(dec_ctx), referenced_values(refs) {} + + template + void VisitConditionedStmt(T* stmt) { + if (stmt->getCond() == dec_ctx.marker_expr) { + auto cond{dec_ctx.z3_exprs[dec_ctx.conds[stmt]]}; + GetReferencedValues(cond); + } + } + + bool VisitIfStmt(clang::IfStmt* stmt) { + VisitConditionedStmt(stmt); + return true; + } + + bool VisitWhileStmt(clang::WhileStmt* stmt) { + VisitConditionedStmt(stmt); + return true; + } + + bool VisitDoStmt(clang::DoStmt* stmt) { + VisitConditionedStmt(stmt); + return true; + } +}; + +void InlineReferences::RunImpl() { + LOG(INFO) << "Inlining references"; + TransformVisitor::RunImpl(); + changed = false; + refs.clear(); + removable_decls.clear(); + ReferenceCounter counter{dec_ctx, refs}; + + for (auto decl : dec_ctx.ast_ctx.getTranslationUnitDecl()->decls()) { + if (auto fdecl = clang::dyn_cast(decl)) { + if (Stopped()) { + return; + } + + if (fdecl->hasBody()) { + counter.TraverseStmt(fdecl->getBody()); + } + } + } + + for (auto& [value, count_refs] : refs) { + if (count_refs > 2) { + continue; + } + auto it{dec_ctx.value_decls.find(value)}; + if (it == dec_ctx.value_decls.end()) { + continue; + } + removable_decls.insert(it->second); + dec_ctx.value_decls.erase(it); + } + + TraverseDecl(dec_ctx.ast_ctx.getTranslationUnitDecl()); +} + +bool InlineReferences::VisitCompoundStmt(clang::CompoundStmt* stmt) { + std::vector new_body; + bool should_substitute{false}; + for (auto child : stmt->body()) { + if (Stopped()) { + break; + } + + bool add_to_new_body{true}; + do { + auto bop{clang::dyn_cast(child)}; + if (!bop) { + break; + } + + if (bop->getOpcode() != clang::BO_Assign) { + LOG(INFO) << ClangThingToString(bop) << " not an assignment"; + break; + } + + auto declref{clang::dyn_cast(bop->getLHS())}; + if (!declref) { + LOG(INFO) << ClangThingToString(bop->getLHS()) << " not a declref"; + break; + } + + if (!removable_decls.count(declref->getDecl())) { + LOG(INFO) << ClangThingToString(declref->getDecl()) + << " not a removable decl"; + break; + } + + add_to_new_body = false; + } while (false); + + do { + auto declstmt{clang::dyn_cast(child)}; + if (!declstmt) { + break; + } + + auto vardecl{clang::dyn_cast(declstmt->getSingleDecl())}; + if (!vardecl) { + break; + } + + if (!removable_decls.count(vardecl)) { + break; + } + + add_to_new_body = false; + } while (false); + + if (add_to_new_body) { + new_body.push_back(child); + } else { + should_substitute = true; + } + } + + if (!Stopped() && should_substitute) { + substitutions[stmt] = dec_ctx.ast.CreateCompoundStmt(new_body); + } + return !Stopped(); +} +} // namespace rellic \ No newline at end of file diff --git a/lib/AST/Util.cpp b/lib/AST/Util.cpp index e6c1da99..05ae27ce 100644 --- a/lib/AST/Util.cpp +++ b/lib/AST/Util.cpp @@ -329,6 +329,13 @@ std::string ClangThingToString(const clang::Stmt *stmt) { return s; } +std::string ClangThingToString(const clang::Decl *decl) { + std::string s; + llvm::raw_string_ostream os(s); + decl->print(os); + return s; +} + z3::goal ApplyTactic(const z3::tactic &tactic, z3::expr expr) { z3::goal goal(tactic.ctx()); goal.add(expr.simplify()); diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index b8bf2089..e7a80e3b 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -69,6 +69,7 @@ set(AST_SOURCES AST/CondBasedRefine.cpp AST/ExprCombine.cpp AST/GenerateAST.cpp + AST/InlineReferences.cpp AST/IRToASTVisitor.cpp AST/LocalDeclRenamer.cpp AST/LoopRefine.cpp diff --git a/lib/Decompiler.cpp b/lib/Decompiler.cpp index 7bcbe5dd..8f91fa68 100644 --- a/lib/Decompiler.cpp +++ b/lib/Decompiler.cpp @@ -27,6 +27,7 @@ #include "rellic/AST/ExprCombine.h" #include "rellic/AST/GenerateAST.h" #include "rellic/AST/IRToASTVisitor.h" +#include "rellic/AST/InlineReferences.h" #include "rellic/AST/LocalDeclRenamer.h" #include "rellic/AST/LoopRefine.h" #include "rellic/AST/MaterializeConds.h" @@ -118,6 +119,7 @@ Result Decompile( cbr_passes.push_back(std::make_unique(dec_ctx)); cbr_passes.push_back(std::make_unique(dec_ctx)); + cbr_passes.push_back(std::make_unique(dec_ctx)); while (pass_cbr.Run()) { ; @@ -130,6 +132,7 @@ Result Decompile( loop_passes.push_back(std::make_unique(dec_ctx)); loop_passes.push_back( std::make_unique(dec_ctx)); + loop_passes.push_back(std::make_unique(dec_ctx)); while (pass_loop.Run()) { ; @@ -142,6 +145,7 @@ Result Decompile( scope_passes.push_back( std::make_unique(dec_ctx)); + scope_passes.push_back(std::make_unique(dec_ctx)); while (pass_scope.Run()) { ; @@ -151,6 +155,7 @@ Result Decompile( auto& ec_passes{pass_ec.GetPasses()}; ec_passes.push_back(std::make_unique(dec_ctx)); ec_passes.push_back(std::make_unique(dec_ctx)); + ec_passes.push_back(std::make_unique(dec_ctx)); pass_ec.Run(); diff --git a/scripts/roundtrip.py b/scripts/roundtrip.py index 77ab0887..f5f71a46 100755 --- a/scripts/roundtrip.py +++ b/scripts/roundtrip.py @@ -64,6 +64,11 @@ def decompile(self, rellic, input, output, timeout): return p +def read_file(file): + with open(file, encoding='utf-8') as f: + return f.read() + + def roundtrip(self, rellic, filename, clang, timeout, translate_only, general_flags, binary_compile_flags, bitcode_compile_flags, recompile_flags): with tempfile.TemporaryDirectory() as tempdir: out1 = os.path.join(tempdir, "out1") @@ -93,9 +98,12 @@ def roundtrip(self, rellic, filename, clang, timeout, translate_only, general_fl # capture outputs of binary after roundtrip cp2 = run_cmd([out2], timeout) - self.assertEqual(cp1.stderr, cp2.stderr, "Different stderr") - self.assertEqual(cp1.stdout, cp2.stdout, "Different stdout") - self.assertEqual(cp1.returncode, cp2.returncode, "Different return code") + self.assertEqual(cp1.stderr, cp2.stderr, + "Different stderr\n" + read_file(rt_c)) + self.assertEqual(cp1.stdout, cp2.stdout, + "Different stdout\n" + read_file(rt_c)) + self.assertEqual(cp1.returncode, cp2.returncode, + "Different return code\n" + read_file(rt_c)) class TestRoundtrip(unittest.TestCase): diff --git a/tests/tools/decomp/issue_325_goto.c b/tests/tools/decomp/issue_325_goto.c new file mode 100644 index 00000000..648613c9 --- /dev/null +++ b/tests/tools/decomp/issue_325_goto.c @@ -0,0 +1,17 @@ +#include +int main(int argc, const char** argv) { + int i = argc; + if (i == 1) { // no args + putchar(4 + '0'); + i += 1; + goto n5; + } + if (i > 1) { // args + putchar(5 + '0'); + goto n7; + } +n5: + putchar(6 + '0'); +n7: + putchar(7 + '0'); +} diff --git a/tools/xref/Xref.cpp b/tools/xref/Xref.cpp index 56b6b185..00984726 100644 --- a/tools/xref/Xref.cpp +++ b/tools/xref/Xref.cpp @@ -47,6 +47,7 @@ #include "rellic/AST/ExprCombine.h" #include "rellic/AST/GenerateAST.h" #include "rellic/AST/IRToASTVisitor.h" +#include "rellic/AST/InlineReferences.h" #include "rellic/AST/LocalDeclRenamer.h" #include "rellic/AST/LoopRefine.h" #include "rellic/AST/MaterializeConds.h" @@ -399,6 +400,8 @@ static std::unique_ptr CreatePass( return std::make_unique(*session.DecompContext); } else if (str == "ec") { return std::make_unique(*session.DecompContext); + } else if (str == "ir") { + return std::make_unique(*session.DecompContext); } else if (str == "lr") { return std::make_unique(*session.DecompContext); } else if (str == "mc") { diff --git a/tools/xref/www/main.js b/tools/xref/www/main.js index 074e87b5..63e8cdf8 100644 --- a/tools/xref/www/main.js +++ b/tools/xref/www/main.js @@ -36,6 +36,10 @@ const mc = { id: "mc", label: "Materialize conditions" } +const ir = { + id: "ir", + label: "Inline references" +} Vue.component('list-comp', { props: ["items", "availableCommands", "showDelete", "title"], @@ -106,7 +110,8 @@ const app = new Vue({ rbr, lr, ec, - mc + mc, + ir ], actions: [ { @@ -313,9 +318,9 @@ const app = new Vue({ useDefaultChain() { this.commands = [ dse, - [zcs, ncp, nsc, cbr, rbr], - [lr, nsc], - [zcs, ncp, nsc], + [zcs, ncp, nsc, cbr, rbr, ir], + [lr, nsc, ir], + [zcs, ncp, nsc, ir], mc, ec ] },