From 3c6de2aac385c76d2697e2151fb9710bc8cf1b62 Mon Sep 17 00:00:00 2001 From: ObserverOfTime Date: Sun, 1 Sep 2024 13:03:59 +0300 Subject: [PATCH] feat(query): implement set_timeout_micros --- docs/classes/tree_sitter.Query.rst | 3 +++ docs/conf.py | 2 +- tree_sitter/__init__.pyi | 13 ++++++++----- tree_sitter/binding/parser.c | 9 ++------- tree_sitter/binding/query.c | 30 +++++++++++++++++++++++++++++- 5 files changed, 43 insertions(+), 14 deletions(-) diff --git a/docs/classes/tree_sitter.Query.rst b/docs/classes/tree_sitter.Query.rst index 900b752..0ec7d03 100644 --- a/docs/classes/tree_sitter.Query.rst +++ b/docs/classes/tree_sitter.Query.rst @@ -95,3 +95,6 @@ Query .. autoattribute:: pattern_count .. versionadded:: 0.23.0 + .. autoattribute:: timeout_micros + + .. versionadded:: 0.23.1 diff --git a/docs/conf.py b/docs/conf.py index 85d74ce..654c1e5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -61,7 +61,7 @@ def process_signature(_app, _what, name, _obj, _options, _signature, return_anno if name == "tree_sitter.Language": return "(ptr)", return_annotation if name == "tree_sitter.Query": - return "(language, source)", return_annotation + return "(language, source, *, timeout_micros=None)", return_annotation if name == "tree_sitter.Parser": return "(language, *, included_ranges=None, timeout_micros=None)", return_annotation if name == "tree_sitter.Range": diff --git a/tree_sitter/__init__.pyi b/tree_sitter/__init__.pyi index e882541..c42dc58 100644 --- a/tree_sitter/__init__.pyi +++ b/tree_sitter/__init__.pyi @@ -258,21 +258,24 @@ class QueryPredicate(Protocol): @final class Query: - def __init__(self, language: Language, source: str) -> None: ... + def __new__(cls, language: Language, source: str) -> Self: ... @property def pattern_count(self) -> int: ... @property def capture_count(self) -> int: ... @property + def timeout_micros(self) -> int: ... + @property def match_limit(self) -> int: ... @property def did_exceed_match_limit(self) -> bool: ... - def set_match_limit(self, match_limit: int | None) -> Self: ... - def set_max_start_depth(self, max_start_depth: int | None) -> Self: ... - def set_byte_range(self, byte_range: tuple[int, int] | None) -> Self: ... + def set_timeout_micros(self, timeout_micros: int) -> Self: ... + def set_match_limit(self, match_limit: int) -> Self: ... + def set_max_start_depth(self, max_start_depth: int) -> Self: ... + def set_byte_range(self, byte_range: tuple[int, int]) -> Self: ... def set_point_range( self, - point_range: tuple[Point | tuple[int, int], Point | tuple[int, int]] | None + point_range: tuple[Point | tuple[int, int], Point | tuple[int, int]] ) -> Self: ... def disable_pattern(self, index: int) -> Self: ... def disable_capture(self, capture: str) -> Self: ... diff --git a/tree_sitter/binding/parser.c b/tree_sitter/binding/parser.c index 8a911ea..5cf4239 100644 --- a/tree_sitter/binding/parser.c +++ b/tree_sitter/binding/parser.c @@ -205,7 +205,7 @@ int parser_set_timeout_micros(Parser *self, PyObject *arg, void *Py_UNUSED(paylo return -1; } - ts_parser_set_timeout_micros(self->parser, PyLong_AsUnsignedLong(arg)); + ts_parser_set_timeout_micros(self->parser, PyLong_AsSize_t(arg)); return 0; } @@ -312,12 +312,7 @@ int parser_set_language(Parser *self, PyObject *arg, void *Py_UNUSED(payload)) { int parser_init(Parser *self, PyObject *args, PyObject *kwargs) { ModuleState *state = GET_MODULE_STATE(self); PyObject *language = NULL, *included_ranges = NULL, *timeout_micros = NULL; - char *keywords[] = { - "language", - "included_ranges", - "timeout_micros", - NULL, - }; + char *keywords[] = {"language", "included_ranges", "timeout_micros", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O!$OO:__init__", keywords, state->language_type, &language, &included_ranges, &timeout_micros)) { diff --git a/tree_sitter/binding/query.c b/tree_sitter/binding/query.c index af84f65..48b92bd 100644 --- a/tree_sitter/binding/query.c +++ b/tree_sitter/binding/query.c @@ -611,6 +611,16 @@ PyObject *query_pattern_assertions(Query *self, PyObject *args) { return item; } +PyObject *query_set_timeout_micros(Query *self, PyObject *args) { + uint32_t timeout_micros; + if (!PyArg_ParseTuple(args, "I:set_timeout_micros", &timeout_micros)) { + return NULL; + } + ts_query_cursor_set_timeout_micros(self->cursor, timeout_micros); + Py_INCREF(self); + return (PyObject *)self; +} + PyObject *query_set_match_limit(Query *self, PyObject *args) { uint32_t match_limit; if (!PyArg_ParseTuple(args, "I:set_match_limit", &match_limit)) { @@ -730,6 +740,10 @@ PyObject *query_get_capture_count(Query *self, void *Py_UNUSED(payload)) { return PyLong_FromSize_t(ts_query_capture_count(self->query)); } +PyObject *query_get_timeout_micros(Query *self, void *Py_UNUSED(payload)) { + return PyLong_FromSize_t(ts_query_cursor_timeout_micros(self->cursor)); +} + PyObject *query_get_match_limit(Query *self, void *Py_UNUSED(payload)) { return PyLong_FromSize_t(ts_query_cursor_match_limit(self->cursor)); } @@ -738,6 +752,9 @@ PyObject *query_get_did_exceed_match_limit(Query *self, void *Py_UNUSED(payload) return PyLong_FromSize_t(ts_query_cursor_did_exceed_match_limit(self->cursor)); } +PyDoc_STRVAR(query_set_timeout_micros_doc, "set_timeout_micros(self, timeout_micros)\n--\n\n" + "Set the maximum duration in microseconds that query " + "execution should be allowed to take before halting."); PyDoc_STRVAR(query_set_match_limit_doc, "set_match_limit(self, match_limit)\n--\n\n" "Set the maximum number of in-progress matches." DOC_RAISES "ValueError\n\n If set to ``0``."); @@ -798,6 +815,12 @@ PyDoc_STRVAR(query_is_pattern_guaranteed_at_step_doc, "Check if a pattern is guaranteed to match once a given byte offset is reached."); static PyMethodDef query_methods[] = { + { + .ml_name = "set_timeout_micros", + .ml_meth = (PyCFunction)query_set_timeout_micros, + .ml_flags = METH_VARARGS, + .ml_doc = query_set_timeout_micros_doc, + }, { .ml_name = "set_match_limit", .ml_meth = (PyCFunction)query_set_match_limit, @@ -902,13 +925,18 @@ static PyGetSetDef query_accessors[] = { PyDoc_STR("The number of patterns in the query."), NULL}, {"capture_count", (getter)query_get_capture_count, NULL, PyDoc_STR("The number of captures in the query."), NULL}, + {"timeout_micros", (getter)query_get_timeout_micros, NULL, + PyDoc_STR("The maximum duration in microseconds that query " + "execution should be allowed to take before halting."), + NULL}, {"match_limit", (getter)query_get_match_limit, NULL, PyDoc_STR("The maximum number of in-progress matches."), NULL}, {"did_exceed_match_limit", (getter)query_get_did_exceed_match_limit, NULL, PyDoc_STR("Check if the query exceeded its maximum number of " "in-progress matches during its last execution."), NULL}, - {NULL}}; + {NULL}, +}; static PyType_Slot query_type_slots[] = { {Py_tp_doc, PyDoc_STR("A set of patterns that match nodes in a syntax tree." DOC_RAISES