Skip to content

Commit

Permalink
Restructuring to unify life and replay logic, particularly to make co…
Browse files Browse the repository at this point in the history
…nfig available for both.
  • Loading branch information
jim-carciofini committed Feb 28, 2024
1 parent 10d073d commit 30c470c
Show file tree
Hide file tree
Showing 2 changed files with 214 additions and 190 deletions.
239 changes: 156 additions & 83 deletions pate_binja/pate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import pprint
import re
import shlex
import signal
import sys
import warnings
from json import JSONDecodeError
from subprocess import Popen, PIPE, STDOUT
from subprocess import Popen, PIPE, STDOUT, TimeoutExpired
from typing import IO, Any, Optional

# TODO: Get rid of these globals
Expand All @@ -20,7 +21,7 @@

class PateUserInteraction(abc.ABC):
@abc.abstractmethod
def ask_user(self, prompt: str, choices: list[str]) -> Optional[str]:
def ask_user(self, prompt: str, choices: list[str], replay_choice: Optional[str] = None) -> Optional[str]:
pass

@abc.abstractmethod
Expand All @@ -33,17 +34,103 @@ def show_cfar_graph(self, graph: CFARGraph) -> None:


class PateWrapper:
def __init__(self, user: PateUserInteraction, pate_out: IO, pate_in: IO, trace: IO = None):
user: PateUserInteraction
filename: os.PathLike
pate_proc: Optional[Popen]
trace_file: Optional[IO]

def __init__(self, filename: os.PathLike,
user: PateUserInteraction,
config_callback=None
) -> None:
self.debug_io = False
self.debug_json = False
self.debug_cfar = False

self.filename = filename
self.user = user
self.pate_in = pate_in
self.pate_out = pate_out
self.trace_file = trace # no trace file also indicates replay mode
self.config_callback = config_callback

self.pate_proc = None
self.trace_file = None

def run(self) -> None:
if self.filename.endswith(".run-config.json"):
self._run_live()
elif self.filename.endswith(".replay"):
self._run_replay()
else:
self.user.show_message('ERROR: Unrecognized PATE run file type', self.filename)

def _run_live(self):
cwd = os.path.dirname(self.filename)
self.config = load_run_config(self.filename)
if not self.config:
self.user.show_message('ERROR: Failed to load PATE run config from', self.filename)
return
original = self.config.get('original')
patched = self.config.get('patched')
raw_args = self.config.get('args')
args = shlex.split(' '.join(raw_args))
# We use a helper script to run logic in the user's shell environment.
script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "run-pate.sh")
# Need -l to make sure user's env is fully setup (e.g. access to docker and ghc tools).
with open(os.path.join(cwd, "lastrun.replay"), "w") as trace:
with Popen(['/bin/bash', '-l', script, '-o', original, '-p', patched, '--json-toplevel'] + args,
cwd=cwd,
stdin=PIPE, stdout=PIPE,
stderr=STDOUT,
text=True, encoding='utf-8',
close_fds=True,
# Create a new process group, so we can kill it cleanly
preexec_fn=os.setsid
) as proc:

self.pate_proc = proc
self.trace_file = trace
self.user.replay = False

# Write config to replay file before adding cwd
json.dump(self.config, trace)
trace.write('\n')
self.config['cwd'] = cwd

self.command_loop()

def _run_replay(self):
cwd = os.path.dirname(self.filename)
with Popen(['cat', self.filename],
cwd=cwd, stdin=None, stdout=PIPE, text=True, encoding='utf-8',
close_fds=True,
# Create a new process group, so we can kill it cleanly
preexec_fn=os.setsid
) as proc:

self.pate_proc = proc
self.trace_file = None
self.user.replay = True

# Read config from replay file
self.config = self.next_json()
self.config['cwd'] = cwd

self.command_loop()

def cancel(self) -> None:
if self.pate_proc:
# Closing input should cause PATE process to exit
if self.pate_proc.stdin:
self.pate_proc.stdin.close()
try:
self.pate_proc.wait(3)
except TimeoutExpired:
# Orderly shutdown did not work, kill the process group
print('KILLING PATE Process')
os.killpg(self.pate_proc.pid, signal.SIGKILL)


def next_line(self) -> str:
line = self.pate_out.readline()
line = self.pate_proc.stdout.readline()
if not line:
raise EOFError
if self.trace_file:
Expand All @@ -64,38 +151,38 @@ def next_json(self, gotoPromptAfterNonJson=False):
# Output line and continue looking for a JSON record
self.user.show_message(line.rstrip('\n'))
if gotoPromptAfterNonJson:
self.command('goto_prompt')
self._command('goto_prompt')
# Skip lines till we get json
gotoPromptAfterNonJson = False
else:
return rec

def skip_lines_till(self, s: str) -> None:
def _skip_lines_till(self, s: str) -> None:
"""Skip lines till EOF or line completely matching s (without newline)."""
while True:
line = self.next_line().rstrip('\n')
self.user.show_message(line)
if line == s:
break

def command(self, cmd: str) -> None:
def _command(self, cmd: str) -> None:
if self.debug_io:
print('Command to Pate: ', cmd)
if self.pate_in:
print(cmd, file=self.pate_in, flush=True)
if self.pate_proc.stdin:
print(cmd, file=self.pate_proc.stdin, flush=True)
if self.trace_file:
# Write line to trace file for replay
self.trace_file.write('Command: ' + cmd + '\n')
self.trace_file.flush()
else:
# Replay mode
# TODO: make cmd available to show in replay
cmd = self.pate_out.readline()
cmd = self.pate_proc.stdout.readline()
# TODO: Check that cmd is same as arg?

def extract_graph(self) -> CFARGraph:
cfar_graph = CFARGraph()

self.command('top')
self._command('top')
top = self.next_json()

self.extract_graph_rec(top, [], None, cfar_graph, None, None, None)
Expand Down Expand Up @@ -281,33 +368,45 @@ def extract_graph_rec(self,
# Look for counter-example trace
or (len(path) == 3 and rec.get('trace_node_kind') == 'blocktarget'
and child.get('trace_node_kind') == 'node')):
self.command(str(i))
self._command(str(i))
childrec = self.next_json()
# update with values from child. TODO: ask Dan about this?
childrec.update(child)
self.extract_graph_rec(childrec, path + [i], context, cfar_graph, cfar_parent, cfar_exit, tnc)
self.command('up')
self._command('up')
# Consume result of up, but do not need it
ignore = self.next_json()
# else:
# if self.debug_cfar:
# print('CFAR skip child:')
# pp.pprint(child)

def _ask_user(self, prompt_rec: dict) -> str:
def _ask_user_rec(self, prompt_rec: dict) -> str:
# Read entry point choices
prompt = prompt_rec['this']
choices = list(map(get_choice_id, prompt_rec.get('trace_node_contents', [])))
while True:
choice = self.user.ask_user(prompt, choices).strip()
if choice:
return choice
self.user.show_message("error: empty choice")
return self._ask_user(prompt, choices).strip()
# Write line to trace file for replay

def _ask_user(self, prompt: str, choices: list[str]) -> Optional[str]:
replay_choice = None
if self.trace_file is None:
replay_line = self.pate_proc.stdout.readline()
if replay_line.startswith('User choice: '):
replay_choice = replay_line[len('User choice: '):].strip()

choice = self.user.ask_user(prompt, choices, replay_choice).strip()

if self.trace_file:
self.trace_file.write('User choice: ' + choice + '\n')
self.trace_file.flush()
return choice

def command_loop(self):
if self.config_callback:
self.config_callback(self.config)
rec = self.next_json()
self.command('goto_prompt')
self._command('goto_prompt')
while self.command_step():
pass
self.user.show_message("Pate finished")
Expand Down Expand Up @@ -358,16 +457,16 @@ def process_json(self, rec):
if cfar_graph:
self.user.show_cfar_graph(cfar_graph)
# Go back to prompt
self.command('goto_prompt')
self._command('goto_prompt')
rec = self.next_json()
choice = self._ask_user(rec)
self.command(choice)
choice = self._ask_user_rec(rec)
self._command(choice)

elif isinstance(rec, list) and rec[len(rec) - 1]['content'] == {'node_kind': 'final_result'}:
# TODO: Hack to detect finish. Talk to Dan about providing a better mechanism.
choices = list(map(get_choice_id, rec))
choice = self.user.ask_user('Final Prompt:', choices)
self.command(choice)
choice = self._ask_user('Final Prompt:', choices)
self._command(choice)

elif isinstance(rec, dict) and rec.get('trace_node_kind') == 'equivalence_result':
# Done if we got an equivalence result
Expand All @@ -376,7 +475,7 @@ def process_json(self, rec):

elif isinstance(rec, dict) and rec.get('error'):
self.show_message('error: ' + rec['error'])
self.command('goto_prompt')
self._command('goto_prompt')

else:
# Message(s)
Expand Down Expand Up @@ -1173,38 +1272,36 @@ def pprint_val_domain(v, pre: str = '', out: IO = sys.stdout):


class TtyUserInteraction(PateUserInteraction):
def __init__(self, replay: bool = False):
self.replay = replay
ask_show_cfar_graph: bool

def ask_user(self, prompt: str, choices: list[str]) -> str:
def __init__(self, ask_show_cfar_graph: bool = False):
self.ask_show_cfar_graph = ask_show_cfar_graph

def ask_user(self, prompt: str, choices: list[str], replay_choice: Optional[str] = None) -> str:
print()
print(prompt)
for i, e in enumerate(choices):
print(' {}'.format(e))

# # Hack to auto respond for nov23 target 3. Need more cases.
# if prompt == 'Control flow desynchronization found at: GraphNode segment1+0x1ad0 [ via: "RR_ReadTlmInput" (segment1+0x18e4) ]':
# print('Pate command: 3\n')
# return 3

if self.replay:
if replay_choice:
# In replay mode, response is ignored, just return anything for fast replay
print('Pate command: auto replay\n')
choice = '42'
choice = replay_choice
print(f'Pate command (replay): {choice}\n')
else:
choice = input("Pate command: ")

return choice

def show_message(self, msg: str) -> None:
print(msg)

def show_cfar_graph(self, graph: CFARGraph) -> None:
print()
if self.replay:
# In replay mode, just return true for fast replay
choice = 'y' # For fast replay
else:
if self.ask_show_cfar_graph:
print()
choice = input("Show CFAR Graph (y or n)? ")
else:
choice = 'y'

if choice == 'y':
print('\nPate CFAR Graph:\n')
graph.pprint()
Expand All @@ -1214,40 +1311,20 @@ def show_cfar_graph(self, graph: CFARGraph) -> None:
print('Prompt Node:', promptNode.id)



def test(pate_out, pate_in, trace):
user = TtyUserInteraction(trace is None)
pate = PateWrapper(user, pate_out, pate_in, trace)

#pate.debug_io = True
#pate.debug_cfar = True

pate.command_loop()


def test_live(run_fn):
with open("trace.txt", "w") as trace:
with run_fn(False) as proc:
test(proc.stdout, proc.stdin, trace)


def test_replay(run_fn):
with run_fn(True) as proc:
test(proc.stdout, proc.stdin, None)


def run_replay(file: str) -> Popen:
return Popen(
['cat', file],
stdin=None, stdout=PIPE, text=True, encoding='utf-8'
)


def get_run_config(file: os.PathLike) -> dict:
with open(file, 'r') as f:
config = json.load(f)
config['cwd'] = os.path.dirname(file)
return config
def load_run_config(file: os.PathLike) -> Optional[dict]:
try:
with open(file, 'r') as f:
config = json.load(f)
return config
except OSError:
return None


def run_config(config: dict):
Expand Down Expand Up @@ -1275,14 +1352,6 @@ def run_pate(cwd: str, original: str, patched: str, args: list[str]) -> Popen:
)


def run_pate_config_or_replay_file(f: str) -> Popen:
if f.endswith(".run-config.json"):
config = get_run_config(f)
test_live(lambda ignore: run_config(config))
elif f.endswith(".replay"):
test_replay(lambda ignore: run_replay(f))


def get_demo_files():
files = []
demos_dir = os.getenv('PATE_BINJA_DEMOS')
Expand All @@ -1295,12 +1364,16 @@ def get_demo_files():
files.append(f)
return files


def run_pate_demo():
files = get_demo_files()
print("Select PATE run configuration or replay file:")
for i, f in enumerate(files):
print(' {}: {}'.format(i, f))

choice = input("Choice: ")
file = files[int(choice)]
run_pate_config_or_replay_file(file)

replay = f.endswith('.replay')
user = TtyUserInteraction(not replay)
pate = PateWrapper(file, user)
pate.run()
Loading

0 comments on commit 30c470c

Please sign in to comment.