Skip to content

Commit

Permalink
fix: GIL deadlock on main python script early abort; improved tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ndrewh committed Jul 9, 2024
1 parent d60b2cd commit 971e1fe
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 111 deletions.
2 changes: 0 additions & 2 deletions lib/pyda/hacks/gdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

# This is a compatibility layer for pwndbg

print("import gdb!")

PARAM_BOOLEAN = 133701
PARAM_ZINTEGER = 133702
PARAM_STRING = 133703
Expand Down
3 changes: 3 additions & 0 deletions pyda_core/pyda_core.c
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ pyda_thread* pyda_mk_thread(pyda_process *proc) {

void pyda_process_destroy(pyda_process *p) {
// We must be holding the GIL lock so we can drop the refs
if (PyGILState_Check()) {
DEBUG_PRINTF("pyda_process_destroy already holds GIL.")
}
PyGILState_STATE gstate = PyGILState_Ensure();

DEBUG_PRINTF("pyda_process_destroy\n");
Expand Down
4 changes: 2 additions & 2 deletions pyda_core/pyda_threads.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ int pyda_cond_init(pthread_cond_t *condvar, const pthread_condattr_t *attr) {
return res;
}
int pyda_cond_timedwait(pthread_cond_t *condvar, pthread_mutex_t *mutex, const struct timespec *abstime) {
// DEBUG_PRINTF("pthread_cond_timedwait %p %p\n", condvar, mutex);
// DEBUG_PRINTF("pthread_cond_timedwait %p %p ids %d\n", condvar, mutex, getpid());
// dr_set_safe_for_sync(false);
int result = pthread_cond_timedwait(condvar, mutex, abstime);
// dr_set_safe_for_sync(true);
return result;
}
int pyda_cond_signal(pthread_cond_t *condvar) {
// DEBUG_PRINTF("pthread_cond_signal %p\n", condvar);
// DEBUG_PRINTF("pthread_cond_signal %p ids %d\n", condvar, getpid());
return pthread_cond_signal(condvar);
}

Expand Down
50 changes: 27 additions & 23 deletions pyda_core/tool.c
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,7 @@ void thread_init_event(void *drcontext) {
__ctype_init();

if (global_proc->main_thread->python_exited) {
pyda_thread_destroy(t); // decrement refcount, immediately exit

// WARN: This must use drcontext passed in.
drmgr_set_tls_field(drcontext, g_pyda_tls_idx, (void*)NULL);
return;
t->errored = 1;
}

// Every thread has its own corresponding python thread
Expand Down Expand Up @@ -143,7 +139,7 @@ void thread_exit_event(void *drcontext) {

DEBUG_PRINTF("thread_exit_event: %p thread id %d\n", t, dr_get_thread_id(drcontext));

if (t->proc->main_thread == t && !t->python_exited) {
if (t->proc->main_thread == t) {
pyda_break_noblock(t);
} else {
// TODO: thread exit hook?
Expand Down Expand Up @@ -294,8 +290,10 @@ void python_main_thread(void *arg) {

python_init();

PyGILState_STATE gstate;
gstate = PyGILState_Ensure();
if (!PyGILState_Check()) {
fprintf(stderr, "[Pyda] Error: GIL expected\n");
dr_abort();
}

DEBUG_PRINTF("Running script...\n");

Expand All @@ -319,26 +317,30 @@ void python_main_thread(void *arg) {
python_exit:
DEBUG_PRINTF("Script exited...\n");
t->python_exited = 1;
t->errored = 1;

// dr_client_thread_set_suspendable(true);
DEBUG_PRINTF("After script exit, GIL status %d\n", PyGILState_Check());
PyEval_SaveThread(); // release GIL

if (t->yield_count == 0) {
dr_fprintf(STDERR, "[PYDA] WARN: Did you forget to call p.run()?\n");
PyGILState_Release(gstate);
dr_fprintf(STDERR, "[Pyda] ERROR: Did you forget to call p.run()?\n");
pyda_yield(t); // unblock (note: blocking)
} else {
// This call will block until the main thread is the last.
DEBUG_PRINTF("python_main_thread destroy\n");
pyda_thread_destroy_last(t);
DEBUG_PRINTF("python_main_thread destroy done\n");

DEBUG_PRINTF("Py_FinalizeEx in thread %d\n", dr_get_thread_id(drcontext));
if (Py_FinalizeEx() < 0) {
DEBUG_PRINTF("WARN: Python finalization failed\n");
}
DEBUG_PRINTF("Py_FinalizeEx done\n");
DEBUG_PRINTF("Implicit pyda_yield finished\n");
}

// This call will block until the main thread is the last.
DEBUG_PRINTF("python_main_thread destroy\n");
pyda_thread_destroy_last(t);
DEBUG_PRINTF("python_main_thread destroy done\n");

DEBUG_PRINTF("Py_FinalizeEx in thread %d\n", dr_get_thread_id(drcontext));
PyGILState_STATE gstate = PyGILState_Ensure();
if (Py_FinalizeEx() < 0) {
DEBUG_PRINTF("WARN: Python finalization failed\n");
}
DEBUG_PRINTF("Py_FinalizeEx done\n");

dr_thread_free(drcontext, tls, sizeof(void*) * 130);
DEBUG_PRINTF("python_main_thread return\n");
}
Expand All @@ -348,13 +350,15 @@ void python_aux_thread(void *arg) {
void *drcontext = dr_get_current_drcontext();
void *tls = python_thread_init(t);

DEBUG_PRINTF("python_aux_thread id %d\n", dr_get_thread_id(drcontext));

PyGILState_STATE gstate;
gstate = PyGILState_Ensure();

DEBUG_PRINTF("python_aux_thread id %d\n", dr_get_thread_id(drcontext));
DEBUG_PRINTF("python_aux_thread id %d locked\n", dr_get_thread_id(drcontext));

// We just call the thread init hook, if one exists
if (t->proc->thread_init_hook) {
if (t->proc->thread_init_hook && !t->errored) {
DEBUG_PRINTF("Calling thread_init_hook\n");
PyObject *result = PyObject_CallFunctionObjArgs(t->proc->thread_init_hook, t->proc->py_obj, NULL);
if (result == NULL) {
Expand Down
166 changes: 82 additions & 84 deletions tests/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from pathlib import Path
from tempfile import TemporaryDirectory

from argparse import ArgumentParser

@dataclass
class ExpectedResult:
retcode: Optional[int] = None

# checker(stdout, stderr) -> bool
checkers: list[Callable[[bytes, bytes], bool]] = list

Res = ExpectedResult

def output_checker(stdout: bytes, stderr: bytes) -> bool:
try:
stdout.decode()
Expand All @@ -22,91 +22,89 @@ def output_checker(stdout: bytes, stderr: bytes) -> bool:

return True

TESTS = [
# tests whether we can handle a large number of threads with concurrent hooks
("threads_concurrent_hooks", "thread_1000.c", "../examples/ltrace_multithreaded.py", ExpectedResult(
retcode=0,
checkers=[
output_checker,
lambda o, e: o.count(b"malloc") == 20000,
lambda o, e: o.count(b"free") == 20000,
lambda o, e: all((o.count(f"[thread {i}]".encode('utf-8')) == 40 for i in range(2, 1002))),
]
)),

# tests whether we can handle a large number of threads that do not get waited on
("threads_nojoin", "thread_nojoin.c", "../examples/ltrace_multithreaded.py", ExpectedResult(
retcode=0,
checkers=[
output_checker,
lambda o, e: o.count(b"malloc") > 15000,
lambda o, e: o.count(b"free") > 15000,
lambda o, e: all((o.count(f"[thread {i}]".encode('utf-8')) == 40 for i in range(2, 100))),
]
)),

# hook throws an exception
("err_hook_throw", "thread_1000.c", "err_hook.py", ExpectedResult(
retcode=0,
checkers=[
output_checker,
lambda o, e: e.count(b"[Pyda] ERROR:") == 1,
]
)),

# thread entry hook throws an exception
("err_thread_entry_throw", "thread_1000.c", "err_thread_entry.py", ExpectedResult(
retcode=0,
checkers=[
output_checker,
lambda o, e: e.count(b"[Pyda] ERROR:") == 1,
]
)),

# tests whether we can handle a simple syscall hook
("syscall_hooks", "simple.c", "test_syscall.py", ExpectedResult(
retcode=0,
checkers=[
output_checker,
lambda o, e: o.count(b"pre syscall") == o.count(b"post syscall") + 1, # (+1 for exit)
lambda o, e: o.index(b"pre syscall") < o.index(b"post syscall"),
]
)),

# user fails to call p.run()
("err_norun", "thread_1000.c", "err_norun.py", ExpectedResult(
retcode=0,
checkers=[
output_checker,
lambda o, e: e.count(b"[Pyda] ERROR:") == 1,
]
))
]

def main():
res = True

# thread_1000.c tests whether we can handle a large number of threads
# with concurrent hooks
res &= run_test(
"thread_1000.c", "../examples/ltrace_multithreaded.py",
ExpectedResult(
retcode=0,
checkers=[
output_checker,
lambda o, e: o.count(b"malloc") == 20000,
lambda o, e: o.count(b"free") == 20000,
lambda o, e: all((o.count(f"[thread {i}]".encode('utf-8')) == 40 for i in range(2, 1002))),
]
)
)

# thread_nojoin.c tests whether we can handle a large number of threads
# that do not get waited on (i.e. they are not joined). Mostly
# we just care about the return code and termination here.
res &= run_test(
"thread_nojoin.c", "../examples/ltrace_multithreaded.py",
ExpectedResult(
retcode=0,
checkers=[
output_checker,
]
)
)

# err_hook.py: hook throws an exception
# NOTE: Hooks intentionally fail 'gracefully' and do not abort
res &= run_test(
"thread_1000.c", "err_hook.py",
ExpectedResult(
retcode=0,
checkers=[
output_checker,
lambda o, e: e.count(b"[Pyda] ERROR:") == 1,
]
)
)

# err_thread_entry.py: thread entry hook throws an exception
# NOTE: Hooks intentionally fail 'gracefully' and do not abort
res &= run_test(
"thread_1000.c", "err_thread_entry.py",
ExpectedResult(
retcode=0,
checkers=[
output_checker,
lambda o, e: e.count(b"[Pyda] ERROR:") == 1,
]
)
)

res &= run_test(
"simple.c", "test_syscall.py",
ExpectedResult(
retcode=0,
checkers=[
output_checker,
lambda o, e: o.count(b"pre syscall") == o.count(b"post syscall") + 1, # (+1 for exit)
lambda o, e: o.index(b"pre syscall") < o.index(b"post syscall"),
]
)
)

# err_norun.py: user fails to call p.run()
res &= run_test(
"thread_1000.c", "err_norun.py",
ExpectedResult(
retcode=0,
checkers=[
output_checker,
lambda o, e: e.count(b"[Pyda] ERROR:") == 1,
]
)
)
ap = ArgumentParser()
ap.add_argument("--test", help="Run a specific test", default=None)
args = ap.parse_args()

if args.test is None:
res = True
for (name, c_file, python_file, expected_result) in TESTS:
res &= run_test(c_file, python_file, expected_result, name)
else:
test = next((t for t in TESTS if t[0] == args.test), None)
if test is None:
print(f"Test {args.test} not found")
exit(1)

name, c_file, python_file, expected_result = test
res = run_test(c_file, python_file, expected_result, name)

if not res:
exit(1)

def run_test(c_file, python_file, expected_result):
def run_test(c_file, python_file, expected_result, test_name):
# Compile to temporary directory
with TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
Expand Down Expand Up @@ -140,11 +138,11 @@ def run_test(c_file, python_file, expected_result):


if len(result_str) > 0:
print(f"[FAIL] {c_file} {python_file}")
print(f"[FAIL] {test_name} ({python_file} {c_file})")
print(result_str)
return False
else:
print(f"[OK] {c_file} {python_file}")
print(f"[OK] {test_name} ({python_file} {c_file})")
return True


Expand Down

0 comments on commit 971e1fe

Please sign in to comment.