diff --git a/pate_binja/pate.py b/pate_binja/pate.py index aa003f85..90156ed0 100644 --- a/pate_binja/pate.py +++ b/pate_binja/pate.py @@ -1413,42 +1413,56 @@ def pprint_node_event_trace_original(trace, pre: str = '', out: IO = sys.stdout) out.write(f'{pre}No trace\n') return - pprint_event_trace(trace.get('traces', {}).get('original'), pre, out) + pprint_event_trace(trace.get('traces', {}).get('original'), + other_et=trace.get('traces', {}).get('patched'), + pre=pre, + out=out + ) def pprint_node_event_trace_patched(trace, pre: str = '', out: IO = sys.stdout): if not trace: out.write(f'{pre}No trace\n') return - pprint_event_trace(trace.get('traces', {}).get('patched'), pre, out) + pprint_event_trace(trace.get('traces', {}).get('patched'), + other_et=trace.get('traces', {}).get('original'), + pre=pre, + out=out + ) -def pprint_event_trace(et: dict, pre: str = '', out: IO = sys.stdout): +def pprint_event_trace(et: dict, pre: str = '', out: IO = sys.stdout, other_et: dict = None): if not et: out.write(f'{pre}No trace\n') return - pprint_event_trace_initial_reg(et['initial_regs'], pre, out) + + no_prune = [] + if other_et: + for reg in other_et['initial_regs']['reg_op']['map']: + ppval = get_value_id(reg['val']) + key: dict = reg['key'] + if not ppval.startswith('0x0:'): + no_prune.append(key) + + pprint_event_trace_initial_reg(et['initial_regs'], pre, out, no_prune) pprint_event_trace_instructions(et['events'], pre, out) -def pprint_event_trace_initial_reg(initial_regs: dict, pre: str = '', out: IO = sys.stdout): +def pprint_event_trace_initial_reg(initial_regs: dict, pre: str = '', out: IO = sys.stdout, no_prune: list = []): """Pretty print an event trace's initial registers.""" out.write(f'{pre}Initial Register Values (non-zero):\n') - pprint_reg_ops(initial_regs['reg_op'], pre + ' ', out, True) + pprint_reg_ops(initial_regs['reg_op'], pre + ' ', out, True, no_prune) -def pprint_reg_ops(reg_op: dict, pre: str = '', out: IO = sys.stdout, prune_zero: bool = False): +def pprint_reg_ops(reg_op: dict, pre: str = '', out: IO = sys.stdout, prune_zero: bool = False, no_prune: list = []): for reg in reg_op['map']: - pprint_reg_op(reg, pre, out, prune_zero) + pprint_reg_op(reg, pre, out, prune_zero, no_prune) -def pprint_reg_op(reg: dict, pre: str = '', out: IO = sys.stdout, prune_zero: bool = False): - val: dict = reg['val'] - ppval = get_value_id(val) +def pprint_reg_op(reg: dict, pre: str = '', out: IO = sys.stdout, prune_zero: bool = False, no_prune: list = []): + ppval = get_value_id(reg['val']) key: dict = reg['key'] - if (not isinstance(val, dict) - or not prune_zero - or not ppval.startswith('0x0:')): + if not (prune_zero and ppval.startswith('0x0:') and key not in no_prune): match key: case {'arch_reg': name}: if name == '_PC' and ppval.startswith('0x0:'):