From 2af7b1fdae15c0ecb093bebe7520f13fc0d3230a Mon Sep 17 00:00:00 2001 From: ObserverOfTime Date: Sat, 31 Aug 2024 09:32:27 +0300 Subject: [PATCH] refactor(query): replace usage of private function --- tests/test_query.py | 14 ++++++++++++++ tree_sitter/binding/query.c | 30 +++++++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/tests/test_query.py b/tests/test_query.py index 1d43401..e7da57a 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,3 +1,4 @@ +from re import error as RegexError from unittest import TestCase import tree_sitter_python @@ -257,6 +258,19 @@ def test_text_predicates_errors(self): """ ) + with self.assertRaises(QueryError) as ctx: + self.javascript.query( + """ + ((function_declaration + name: (identifier) @function-name) + (#match? @function-name "?")) + """ + ) + self.assertEqual( + str(ctx.exception), "Invalid predicate in pattern at row 1: regular expression error" + ) + self.assertIsInstance(ctx.exception.__cause__, RegexError) + def test_point_range_captures(self): parser = Parser(self.python) source = b"def foo():\n bar()\ndef baz():\n quux()\n" diff --git a/tree_sitter/binding/query.c b/tree_sitter/binding/query.c index 163ca9d..af84f65 100644 --- a/tree_sitter/binding/query.c +++ b/tree_sitter/binding/query.c @@ -274,9 +274,33 @@ PyObject *query_new(PyTypeObject *cls, PyObject *args, PyObject *Py_UNUSED(kwarg PyObject *pattern = PyObject_CallFunction(state->re_compile, "s#", second_arg, length); if (pattern == NULL) { - _PyErr_FormatFromCause( - state->query_error, - "Invalid predicate in pattern at row %u: regular expression error", row); + const char *msg = + "Invalid predicate in pattern at row %u: regular expression error"; +#if PY_MINOR_VERSION < 12 + PyObject *etype, *cause, *exc, *trace; + PyErr_Fetch(&etype, &cause, &trace); + PyErr_NormalizeException(&etype, &cause, &trace); + if (trace != NULL) { + PyException_SetTraceback(cause, trace); + Py_DECREF(trace); + } + Py_DECREF(etype); + PyErr_Format(state->query_error, msg, row); + PyErr_Fetch(&etype, &exc, &trace); + PyErr_NormalizeException(&etype, &exc, &trace); + Py_INCREF(cause); + PyException_SetCause(exc, cause); + PyException_SetContext(exc, cause); + PyErr_Restore(etype, exc, trace); +#else + PyObject *cause = PyErr_GetRaisedException(); + PyErr_Format(state->query_error, msg, row); + PyObject *exc = PyErr_GetRaisedException(); + PyException_SetCause(exc, Py_NewRef(cause)); + PyException_SetContext(exc, Py_NewRef(cause)); + Py_DECREF(cause); + PyErr_SetRaisedException(exc); +#endif goto error; }