diff --git a/pate_binja/pate.py b/pate_binja/pate.py index 416a9d66..a3c6fc04 100644 --- a/pate_binja/pate.py +++ b/pate_binja/pate.py @@ -483,8 +483,18 @@ def pprint(self, pre: str = ''): #print('data:') #pp.pprint(self.data) - def pprint_node_contents(self, pre: str = '', out: IO = sys.stdout, - show_ce_trace: bool = False): + def pprint_node_contents(self, pre: str = '', out: IO = sys.stdout, show_ce_trace: bool = False): + self.pprint_node_domain(pre, out, show_ce_trace) + if show_ce_trace: + for n in self.exits: + out.write(f'{pre}Exit: {n.id}\n') + if self.exit_meta_data.get(n,{}).get('ce_event_trace'): + self.pprint_node_event_trace(self.exit_meta_data[n]['ce_event_trace'], 'Counter-Example', pre + ' ', out) + elif self.exit_meta_data.get(n, {}).get('event_trace'): + self.pprint_node_event_trace(self.exit_meta_data[n]['event_trace'], '', pre + ' ', out) + + def pprint_node_domain(self, pre: str = '', out: IO = sys.stdout, + show_ce_trace: bool = False): if self.predomain: out.write(f'{pre}Predomain:\n') pprint_domain(self.predomain, pre + ' ', out) @@ -495,27 +505,28 @@ def pprint_node_contents(self, pre: str = '', out: IO = sys.stdout, if self.external_postdomain: out.write(f'{pre}Postdomain:\n') pprint_domain(self.external_postdomain, pre + ' ', out) - if show_ce_trace: - for n in self.exits: - out.write(f'{pre}Exit: {n.id}\n') - if self.exit_meta_data.get(n,{}).get('ce_event_trace'): - self.pprint_node_event_trace(self.exit_meta_data[n]['ce_event_trace'], 'Counter-Example', pre + ' ', out) - # elif self.exit_meta_data.get(n, {}).get('event_trace'): - # self.pprint_node_event_trace(self.exit_meta_data[n]['event_trace'], '', pre + ' ', out) def pprint_node_event_trace(self, trace, label: str, pre: str = '', out: IO = sys.stdout): + self.pprint_node_event_trace_domain(trace, label, pre, out) + self.pprint_node_event_trace_original(trace, label, pre, out) + self.pprint_node_event_trace_patched(trace, label, pre, out) + + def pprint_node_event_trace_domain(self, trace, label: str, pre: str = '', out: IO = sys.stdout): if trace.get('precondition'): out.write(f'{pre}Trace Precondition:\n') pprint_eq_domain(trace['precondition'], pre + ' ', out) if trace.get('postcondition'): out.write(f'{pre}Trace Postcondition:\n') pprint_eq_domain(trace['postcondition'], pre + ' ', out) + + def pprint_node_event_trace_original(self, trace, label: str, pre: str = '', out: IO = sys.stdout): if trace.get('traces', {}).get('original'): pprint_event_trace(f'{label} Original', trace['traces']['original'], pre, out) + + def pprint_node_event_trace_patched(self, trace, label: str, pre: str = '', out: IO = sys.stdout): if trace.get('traces',{}).get('patched'): pprint_event_trace(f'{label} Patched', trace['traces']['patched'], pre, out) - class CFARGraph: nodes: dict[str, CFARNode] diff --git a/pate_binja/view.py b/pate_binja/view.py index 01201c7f..09d1f8d5 100644 --- a/pate_binja/view.py +++ b/pate_binja/view.py @@ -12,7 +12,7 @@ from binaryninja import show_graph_report, execute_on_main_thread_and_wait, BinaryView, OpenFileNameField, interaction, \ MultilineTextField from binaryninja.enums import BranchType, HighlightStandardColor -from binaryninja.flowgraph import FlowGraph, FlowGraphNode +from binaryninja.flowgraph import FlowGraph, FlowGraphNode, FlowGraphEdge from binaryninja.plugin import BackgroundTaskThread from binaryninjaui import GlobalAreaWidget, GlobalArea, UIAction, UIActionHandler, Menu, UIActionContext, \ FlowGraphWidget @@ -25,16 +25,15 @@ from . import pate + class PateWidget(QWidget): def __init__(self, parent: QWidget, filename: str) -> None: - global instance_id super().__init__(parent) self.filename = filename - self.pate_thread: PateThread= None + self.pate_thread: PateThread = None self.flow_graph_widget = MyFlowGraphWidget(self) - self.flow_graph_widget.setWindowTitle('FNORT BLORT') self.output_field = QPlainTextEdit() self.output_field.setReadOnly(True) @@ -122,8 +121,7 @@ def show_message(self, msg) -> None: execute_on_main_thread_and_wait(lambda: self.pate_widget.output_field.appendPlainText(msg)) def show_cfar_graph(self, graph: pate.CFARGraph) -> None: - flow_graph = build_pate_flow_graph(graph, self.show_ce_trace) - execute_on_main_thread_and_wait(lambda: self.pate_widget.flow_graph_widget.setGraph(flow_graph)) + execute_on_main_thread_and_wait(lambda: self.pate_widget.flow_graph_widget.build_pate_flow_graph(graph, self.show_ce_trace)) class PateThread(Thread): @@ -192,82 +190,136 @@ def _command_loop(self, proc: Popen, show_ce_trace: bool = False, trace_io=None) # pt.start() -def build_pate_flow_graph(cfar_graph: pate.CFARGraph, - show_ce_trace: bool = False): - flow_graph = FlowGraph() - - # First create all nodes - flow_nodes = {} - cfar_node: pate.CFARNode - for cfar_node in cfar_graph.nodes.values(): - flow_node = FlowGraphNode(flow_graph) +class PateCfarExitDialog(QDialog): + def __init__(self, parent=None): + super().__init__(parent) - out = io.StringIO() + self.setWindowTitle("CFAR Exit Info") - out.write(cfar_node.id.replace(' <- ','\n <- ')) - out.write('\n') + self.commonField = QPlainTextEdit() + self.commonField.setReadOnly(True) + self.commonField.setMaximumBlockCount(1000) - cfar_node.pprint_node_contents('', out, show_ce_trace) + self.originalField = QPlainTextEdit() + self.originalField.setReadOnly(True) + self.originalField.setMaximumBlockCount(1000) - flow_node.lines = out.getvalue().split('\n') - #flow_node.lines = [lines[0]] + self.patchedField = QPlainTextEdit() + self.patchedField.setReadOnly(True) + self.patchedField.setMaximumBlockCount(1000) - if cfar_node.id.find(' vs ') >= 0: - # Per discussion wit dan, it does not make sense to highlight these. - # flow_node.highlight = HighlightStandardColor.BlueHighlightColor - pass - elif cfar_node.id.find('(original') >= 0: - flow_node.highlight = HighlightStandardColor.GreenHighlightColor - elif cfar_node.id.find('(patched)') >= 0: - flow_node.highlight = HighlightStandardColor.MagentaHighlightColor + hsplitter = QSplitter() + hsplitter.setOrientation(Qt.Orientation.Horizontal) + hsplitter.addWidget(self.originalField) + hsplitter.addWidget(self.patchedField) - flow_graph.append(flow_node) - flow_nodes[cfar_node.id] = flow_node + vsplitter = QSplitter() + vsplitter.setOrientation(Qt.Orientation.Vertical) + vsplitter.addWidget(self.commonField) + vsplitter.addWidget(hsplitter) - # Add edges - cfar_node: pate.CFARNode - for cfar_node in cfar_graph.nodes.values(): - flow_node = flow_nodes[cfar_node.id] - cfar_exit: pate.CFARNode - for cfar_exit in cfar_node.exits: - flow_exit = flow_nodes[cfar_exit.id] - flow_node.add_outgoing_edge(BranchType.UnconditionalBranch, flow_exit) + QBtn = QDialogButtonBox.Ok + self.buttonBox = QDialogButtonBox(QBtn) + self.buttonBox.accepted.connect(self.accept) - return flow_graph + main_layout = QHBoxLayout() + main_layout.addWidget(vsplitter) + #main_layout.addWidget(self.buttonBox) + self.setLayout(main_layout) class MyFlowGraphWidget(FlowGraphWidget): def __init__(self, parent: QWidget, view: BinaryView=None, graph: FlowGraph=None): super().__init__(parent, view, graph) + self.flowToCfarNode: dict[FlowGraphNode, pate.CFARNode] = {} + + def build_pate_flow_graph(self, + cfar_graph: pate.CFARGraph, + show_ce_trace: bool = False): + flow_graph = FlowGraph() + + show_ce_trace = False # disable for now + + # First create all nodes + cfarToFlowNode = {} + cfar_node: pate.CFARNode + for cfar_node in cfar_graph.nodes.values(): + flow_node = FlowGraphNode(flow_graph) + + self.flowToCfarNode[flow_node] = cfar_node + + out = io.StringIO() + + out.write(cfar_node.id.replace(' <- ', '\n <- ')) + out.write('\n') + + cfar_node.pprint_node_contents('', out, show_ce_trace) + + flow_node.lines = out.getvalue().split('\n') + # flow_node.lines = [lines[0]] + + if cfar_node.id.find(' vs ') >= 0: + # Per discussion with Dan, it does not make sense to highlight these. + # flow_node.highlight = HighlightStandardColor.BlueHighlightColor + pass + elif cfar_node.id.find('(original') >= 0: + flow_node.highlight = HighlightStandardColor.GreenHighlightColor + elif cfar_node.id.find('(patched)') >= 0: + flow_node.highlight = HighlightStandardColor.MagentaHighlightColor + + flow_graph.append(flow_node) + cfarToFlowNode[cfar_node.id] = flow_node + + # Add edges + cfar_node: pate.CFARNode + for cfar_node in cfar_graph.nodes.values(): + flow_node = cfarToFlowNode[cfar_node.id] + cfar_exit: pate.CFARNode + for cfar_exit in cfar_node.exits: + flow_exit = cfarToFlowNode[cfar_exit.id] + flow_node.add_outgoing_edge(BranchType.UnconditionalBranch, flow_exit) + + self.setGraph(flow_graph) def mousePressEvent(self, event: QMouseEvent): - edge = self.getEdgeForMouseEvent(event) node = self.getNodeForMouseEvent(event) - print("Edge: ", edge) + edgeTuple = self.getEdgeForMouseEvent(event) print("Node: ", node) + print("Edge: ", edgeTuple) + if edgeTuple: + self.showExitInfo(edgeTuple) -class PateWindow(QDialog): - def __init__(self, context: UIActionContext, parent=None): - super().__init__(parent) - self.context = context # TODO: What is this for? - - g = FlowGraph() - n1 = FlowGraphNode(g) - n1.lines = ["foo"] - g.append(n1) - n2 = FlowGraphNode(g) - n2.lines = ["bar"] - g.append(n2) - n1.add_outgoing_edge(BranchType.UnconditionalBranch, n2) - - self.flow_graph_widget = MyFlowGraphWidget(self, None, g) - self.flow_graph_widget.setMinimumWidth(400) - self.flow_graph_widget.setMinimumHeight(400) - - layout = QVBoxLayout() - layout.addWidget(self.flow_graph_widget) - self.setLayout(layout) + def showExitInfo(self, edgeTuple: tuple[FlowGraphEdge, bool]) -> None: + edge = edgeTuple[0] + incoming = edgeTuple[1] + sourceCfarNode = self.flowToCfarNode[edge.source] + exitCfarNode = self.flowToCfarNode[edge.target] + + exitMetaData = sourceCfarNode.exit_meta_data.get(exitCfarNode, {}) + + ceTrace = exitMetaData.get('ce_event_trace') + trace = exitMetaData.get('event_trace') + if ceTrace: + self.showExitTraceInfo(sourceCfarNode, ceTrace, 'Counter-Example Trace') + elif trace: + self.showExitTraceInfo(sourceCfarNode, trace, 'Trace') + else: + # TODO: dialog? + print("No exit info") + + def showExitTraceInfo(self, sourceCfarNode: pate.CFARNode, trace: dict, label: str): + d = PateCfarExitDialog(parent=self) + with io.StringIO() as out: + sourceCfarNode.pprint_node_event_trace_domain(trace, label, out=out) + d.commonField.setPlainText(out.getvalue()) + with io.StringIO() as out: + sourceCfarNode.pprint_node_event_trace_original(trace, label, out=out) + d.originalField.setPlainText(out.getvalue()) + with io.StringIO() as out: + sourceCfarNode.pprint_node_event_trace_patched(trace, label, out=out) + d.patchedField.setPlainText(out.getvalue()) + d.exec() def launch_pate(context: UIActionContext):