Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type hints to pretty.py and lark.py #4225

Merged
merged 3 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
RELEASE_TYPE: patch

Add internal type hints to our pretty printer.
54 changes: 33 additions & 21 deletions hypothesis-python/src/hypothesis/extra/lark.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@
from typing import Optional

import lark
from lark.grammar import NonTerminal, Terminal
from lark.grammar import NonTerminal, Rule, Symbol, Terminal
from lark.lark import Lark
from lark.lexer import TerminalDef

from hypothesis import strategies as st
from hypothesis.errors import InvalidArgument
from hypothesis.internal.conjecture.data import ConjectureData
from hypothesis.internal.conjecture.utils import calc_label_from_name
from hypothesis.internal.validation import check_type
from hypothesis.strategies._internal.regex import IncompatibleWithAlphabet
Expand All @@ -40,7 +43,9 @@
__all__ = ["from_lark"]


def get_terminal_names(terminals, rules, ignore_names):
def get_terminal_names(
terminals: list[TerminalDef], rules: list[Rule], ignore_names: list[str]
) -> set[str]:
"""Get names of all terminals in the grammar.

The arguments are the results of calling ``Lark.grammar.compile()``,
Expand All @@ -60,13 +65,15 @@ class LarkStrategy(st.SearchStrategy):
See ``from_lark`` for details.
"""

def __init__(self, grammar, start, explicit, alphabet):
def __init__(
self,
grammar: Lark,
start: Optional[str],
explicit: dict[str, st.SearchStrategy[str]],
alphabet: st.SearchStrategy[str],
) -> None:
assert isinstance(grammar, lark.lark.Lark)
if start is None:
start = grammar.options.start
if not isinstance(start, list):
start = [start]
self.grammar = grammar
start: list[str] = grammar.options.start if start is None else [start]

# This is a total hack, but working around the changes is a nicer user
# experience than breaking for anyone who doesn't instantly update their
Expand All @@ -76,19 +83,18 @@ def __init__(self, grammar, start, explicit, alphabet):
terminals, rules, ignore_names = grammar.grammar.compile(start, ())
elif "start" in compile_args: # pragma: no cover
# Support lark <= 0.10.0, without the terminals_to_keep argument.
terminals, rules, ignore_names = grammar.grammar.compile(start)
terminals, rules, ignore_names = grammar.grammar.compile(start) # type: ignore
else: # pragma: no cover
# This branch is to support lark <= 0.7.1, without the start argument.
terminals, rules, ignore_names = grammar.grammar.compile()
terminals, rules, ignore_names = grammar.grammar.compile() # type: ignore

self.names_to_symbols = {}
self.names_to_symbols: dict[str, Symbol] = {}

for r in rules:
t = r.origin
self.names_to_symbols[t.name] = t
self.names_to_symbols[r.origin.name] = r.origin

disallowed = set()
self.terminal_strategies = {}
self.terminal_strategies: dict[str, st.SearchStrategy[str]] = {}
for t in terminals:
self.names_to_symbols[t.name] = Terminal(t.name)
s = st.from_regex(t.pattern.to_regexp(), fullmatch=True, alphabet=alphabet)
Expand Down Expand Up @@ -119,7 +125,8 @@ def __init__(self, grammar, start, explicit, alphabet):
)
self.terminal_strategies.update(explicit)

nonterminals = {}
# can in fact contain any symbol, despite its name.
nonterminals: dict[str, list[tuple[Symbol, ...]]] = {}

for rule in rules:
if disallowed.isdisjoint(r.name for r in rule.expansion):
Expand Down Expand Up @@ -149,23 +156,28 @@ def __init__(self, grammar, start, explicit, alphabet):
k: st.sampled_from(sorted(v, key=len)) for k, v in nonterminals.items()
}

self.__rule_labels = {}
self.__rule_labels: dict[str, int] = {}

def do_draw(self, data):
state = []
def do_draw(self, data: ConjectureData) -> str:
state: list[str] = []
start = data.draw(self.start)
self.draw_symbol(data, start, state)
return "".join(state)

def rule_label(self, name):
def rule_label(self, name: str) -> int:
try:
return self.__rule_labels[name]
except KeyError:
return self.__rule_labels.setdefault(
name, calc_label_from_name(f"LARK:{name}")
)

def draw_symbol(self, data, symbol, draw_state):
def draw_symbol(
self,
data: ConjectureData,
symbol: Symbol,
draw_state: list[str],
) -> None:
if isinstance(symbol, Terminal):
strategy = self.terminal_strategies[symbol.name]
draw_state.append(data.draw(strategy))
Expand All @@ -178,7 +190,7 @@ def draw_symbol(self, data, symbol, draw_state):
self.gen_ignore(data, draw_state)
data.stop_example()

def gen_ignore(self, data, draw_state):
def gen_ignore(self, data: ConjectureData, draw_state: list[str]) -> None:
if self.ignored_symbols and data.draw_boolean(1 / 4):
emit = data.draw(st.sampled_from(self.ignored_symbols))
self.draw_symbol(data, emit, draw_state)
Expand Down
Loading
Loading