Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typing annotations to Python scripts #533

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 39 additions & 27 deletions scripts/perfect_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,16 @@

from __future__ import absolute_import, division, print_function

import sys
import random
import shutil
import string
import subprocess
import shutil
import sys
import tempfile
from collections import defaultdict
from optparse import Values
from os.path import join
from typing import Any, Sequence, TypeVar

if sys.version_info[0] == 2:
from cStringIO import StringIO
Expand All @@ -109,14 +111,14 @@ class Graph(object):
the desired edge value (mod N).
"""

def __init__(self, N):
def __init__(self, N: int):
self.N = N # number of vertices

# maps a vertex number to the list of tuples (vertex, edge value)
# to which it is connected by edges.
self.adjacent = defaultdict(list)
self.adjacent: dict[int, list[tuple[int, int]]] = defaultdict(list)

def connect(self, vertex1, vertex2, edge_value):
def connect(self, vertex1: int, vertex2: int, edge_value: int) -> None:
"""
Connect 'vertex1' and 'vertex2' with an edge, with associated
value 'value'
Expand All @@ -125,7 +127,7 @@ def connect(self, vertex1, vertex2, edge_value):
self.adjacent[vertex1].append((vertex2, edge_value))
self.adjacent[vertex2].append((vertex1, edge_value))

def assign_vertex_values(self):
def assign_vertex_values(self) -> bool:
"""
Try to assign the vertex values, such that, for each edge, you can
add the values for the two vertices involved and get the desired
Expand All @@ -150,7 +152,7 @@ def assign_vertex_values(self):
self.vertex_values[root] = 0 # set arbitrarily to zero

# Stack of vertices to visit, a list of tuples (parent, vertex)
tovisit = [(None, root)]
tovisit: list[tuple[int | None, int]] = [(None, root)]
while tovisit:
parent, vertex = tovisit.pop()
visited[vertex] = True
Expand Down Expand Up @@ -184,7 +186,7 @@ def assign_vertex_values(self):
return True


class StrSaltHash(object):
class StrSaltHash:
"""
Random hash function generator.
Simple byte level hashing: each byte is multiplied to another byte from
Expand All @@ -194,11 +196,11 @@ class StrSaltHash(object):

chars = string.ascii_letters + string.digits

def __init__(self, N):
def __init__(self, N: int):
self.N = N
self.salt = ""

def __call__(self, key):
def __call__(self, key: Sequence[str]) -> int:
# XXX: xkbcommon modification: make the salt length a power of 2
# so that the % operation in the hash is fast.
while len(self.salt) < max(len(key), 32): # add more salt as necessary
Expand All @@ -216,18 +218,18 @@ def perfect_hash(key):
"""


class IntSaltHash(object):
class IntSaltHash:
"""
Random hash function generator.
Simple byte level hashing, each byte is multiplied in sequence to a table
containing random numbers, summed tp, and finally modulo NG is taken.
"""

def __init__(self, N):
self.N = N
self.salt = []
def __init__(self, N: int):
self.N: int = N
self.salt: list[int] = []

def __call__(self, key):
def __call__(self, key: Sequence[str]) -> int:
while len(self.salt) < len(key): # add more salt as necessary
self.salt.append(random.randint(1, self.N - 1))

Expand All @@ -246,7 +248,10 @@ def perfect_hash(key):
"""


def builtin_template(Hash):
H = TypeVar("H", StrSaltHash, IntSaltHash)


def builtin_template(Hash: type[H]) -> str:
return (
"""\
# =======================================================================
Expand All @@ -272,7 +277,9 @@ class TooManyInterationsError(Exception):
pass


def generate_hash(keys, Hash=StrSaltHash):
def generate_hash(
keys: list[str], Hash: type[H] = StrSaltHash
) -> tuple[H, H, list[int]]:
"""
Return hash functions f1 and f2, and G for a perfect minimal hash.
Input is an iterable of 'keys', whos indicies are the desired hash values.
Expand Down Expand Up @@ -349,17 +356,17 @@ def generate_hash(keys, Hash=StrSaltHash):


class Format(object):
def __init__(self, width=76, indent=4, delimiter=", "):
def __init__(self, width: int = 76, indent: int = 4, delimiter: str = ", "):
self.width = width
self.indent = indent
self.delimiter = delimiter

def print_format(self):
def print_format(self) -> None:
print("Format options:")
for name in "width", "indent", "delimiter":
print(" %s: %r" % (name, getattr(self, name)))

def __call__(self, data, quote=False):
def __call__(self, data: Any, quote: bool = False) -> str:
if not isinstance(data, (list, tuple)):
return str(data)

Expand All @@ -384,7 +391,12 @@ def __call__(self, data, quote=False):
return "\n".join(l.rstrip() for l in aux.getvalue().split("\n"))


def generate_code(keys, Hash=StrSaltHash, template=None, options=None):
def generate_code(
keys: list[str],
Hash: type = StrSaltHash,
wismill marked this conversation as resolved.
Show resolved Hide resolved
template: str | None = None,
options: Values | None = None,
) -> str:
"""
Takes a list of key value pairs and inserts the generated parameter
lists into the 'template' string. 'Hash' is the random hash function
Expand Down Expand Up @@ -424,7 +436,7 @@ def generate_code(keys, Hash=StrSaltHash, template=None, options=None):
)


def read_table(filename, options):
def read_table(filename: str, options: Values) -> list[str]:
"""
Reads keys and desired hash value pairs from a file. If no column
for the hash value is specified, a sequence of hash values is generated,
Expand Down Expand Up @@ -455,7 +467,7 @@ def read_table(filename, options):
row = [col.strip() for col in line.split(options.splitby)]

try:
key = row[options.keycol - 1]
key: str = row[options.keycol - 1]
except IndexError:
sys.exit(
"%s:%d: Error: Cannot read key, not enough columns." % (filename, n + 1)
Expand All @@ -471,7 +483,7 @@ def read_table(filename, options):
return keys


def read_template(filename):
def read_template(filename: str) -> str:
if verbose:
print("Reading template from file `%s'" % filename)
try:
Expand All @@ -481,7 +493,7 @@ def read_template(filename):
sys.exit("Error: Could not open `%s' for reading." % filename)


def run_code(code):
def run_code(code: str) -> None:
tmpdir = tempfile.mkdtemp()
path = join(tmpdir, "t.py")
with open(path, "w") as fo:
Expand All @@ -494,7 +506,7 @@ def run_code(code):
shutil.rmtree(tmpdir)


def main():
def main() -> None:
from optparse import OptionParser

usage = "usage: %prog [options] KEYS_FILE [TMPL_FILE]"
Expand Down Expand Up @@ -642,7 +654,7 @@ def main():
parser.error("template filename does not contain 'tmpl'")

if options.hft == 1:
Hash = StrSaltHash
Hash: type = StrSaltHash
elif options.hft == 2:
Hash = IntSaltHash
else:
Expand Down
2 changes: 1 addition & 1 deletion scripts/update-headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def generate(
data: dict[str, Any],
root: Path,
file: Path,
):
) -> None:
"""Generate a file from its Jinja2 template"""
template_path = file.with_suffix(f"{file.suffix}.jinja")
template = env.get_template(str(template_path))
Expand Down
4 changes: 2 additions & 2 deletions scripts/update-message-registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Example:
after: str | None

@classmethod
def parse(cls, entry: Any) -> Example:
def parse(cls, entry: dict[str, Any]) -> Example:
name = entry.get("name")
assert name, entry

Expand Down Expand Up @@ -89,7 +89,7 @@ class Entry:
"""

@classmethod
def parse(cls, entry: Any) -> Entry:
def parse(cls, entry: dict[str, Any]) -> Entry:
code = entry.get("code")
assert code is not None and isinstance(code, int) and code > 0, entry

Expand Down
16 changes: 8 additions & 8 deletions scripts/update-unicode.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,11 +621,11 @@ def __iadd__(self, x):
return NotImplemented

@classmethod
def from_singleton(cls, chunk: tuple[T, ...]):
def from_singleton(cls, chunk: tuple[T, ...]) -> Self:
return cls(data=chunk, offsets={chunk: 0})

@classmethod
def from_pair(cls, pair: DeltasPair):
def from_pair(cls, pair: DeltasPair) -> Self:
return cls(
data=pair.d1 + pair.d2[pair.overlap :],
offsets={
Expand All @@ -635,7 +635,7 @@ def from_pair(cls, pair: DeltasPair):
)

@classmethod
def from_iterable(cls, ts: Iterable[tuple[T, ...]]):
def from_iterable(cls, ts: Iterable[tuple[T, ...]]) -> Self:
return reduce(lambda s, t: s.add(t), ts, cls((), {}))

@classmethod
Expand Down Expand Up @@ -1228,7 +1228,7 @@ def stats(self, int_size) -> Stats:
offsets2_int_size=0,
)

@classmethod
@staticmethod
def test(cls):
wismill marked this conversation as resolved.
Show resolved Hide resolved
c1 = (1, 2, 3, 4)
c2 = (2, 3)
Expand All @@ -1238,9 +1238,9 @@ def test(cls):
s += c2
s += c3
s += c4
groups = {c1: [0, 3], c2: [4], c3: [1, 2]}
a = cls.from_overlapped_sequences(s, groups)
assert a == cls(
groups: Groups[int] = {c1: [0, 3], c2: [4], c3: [1, 2]}
a = CompressedArray.from_overlapped_sequences(s, groups)
assert a == CompressedArray(
data=s.data, offsets=(0, 2, 2, 0, 1), chunk_offsets=s.offsets
), a

Expand Down Expand Up @@ -1275,7 +1275,7 @@ def test_compression(cls):
c4 = (3, 4, 5)
c5 = (0, 1, 2)
c6 = (2, 3, 5)
groups = {
groups: Groups[int] = {
c1: [0],
c2: [1],
c3: [2],
Expand Down
Loading