Skip to content

Commit

Permalink
add typing information
Browse files Browse the repository at this point in the history
  • Loading branch information
Kriechi committed Nov 23, 2024
1 parent 3789435 commit 131b44c
Show file tree
Hide file tree
Showing 20 changed files with 94 additions and 101 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,6 @@ source =

[flake8]
max-complexity = 10
max-line-length = 120
exclude =
hpack/huffman_constants.py
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
author_email='[email protected]',
url='https://github.com/python-hyper/hpack',
packages=find_packages(where="src"),
package_data={'hpack': []},
package_data={'hpack': ['py.typed']},
package_dir={'': 'src'},
python_requires='>=3.9.0',
license='MIT License',
Expand Down
1 change: 0 additions & 1 deletion src/hpack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
hpack
~~~~~
Expand Down
1 change: 0 additions & 1 deletion src/hpack/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
hyper/http20/exceptions
~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
98 changes: 48 additions & 50 deletions src/hpack/hpack.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# -*- coding: utf-8 -*-
"""
hpack/hpack
~~~~~~~~~~~
Implements the HPACK header compression algorithm as detailed by the IETF.
"""
import logging
from typing import Any, Generator, Union

from .table import HeaderTable, table_entry_size
from .exceptions import (
Expand All @@ -16,7 +16,7 @@
REQUEST_CODES, REQUEST_CODES_LENGTH
)
from .huffman_table import decode_huffman
from .struct import HeaderTuple, NeverIndexedHeaderTuple
from .struct import HeaderTuple, NeverIndexedHeaderTuple, Headers

log = logging.getLogger(__name__)

Expand All @@ -29,31 +29,25 @@
# as prefix numbers are not zero indexed.
_PREFIX_BIT_MAX_NUMBERS = [(2 ** i) - 1 for i in range(9)]

try: # pragma: no cover
basestring = basestring
except NameError: # pragma: no cover
basestring = (str, bytes)


# We default the maximum header list we're willing to accept to 64kB. That's a
# lot of headers, but if applications want to raise it they can do.
DEFAULT_MAX_HEADER_LIST_SIZE = 2 ** 16


def _unicode_if_needed(header, raw):
def _unicode_if_needed(header: HeaderTuple, raw: bool) -> HeaderTuple:
"""
Provides a header as a unicode string if raw is False, otherwise returns
it as a bytestring.
"""
name = bytes(header[0])
value = bytes(header[1])
name = bytes(header[0]) # type: ignore
value = bytes(header[1]) # type: ignore
if not raw:
name = name.decode('utf-8')
value = value.decode('utf-8')
return header.__class__(name, value)
return header.__class__(name.decode('utf-8'), value.decode('utf-8'))
else:
return header.__class__(name, value)


def encode_integer(integer, prefix_bits):
def encode_integer(integer: int, prefix_bits: int) -> bytearray:
"""
This encodes an integer according to the wacky integer encoding rules
defined in the HPACK spec.
Expand Down Expand Up @@ -87,7 +81,7 @@ def encode_integer(integer, prefix_bits):
return bytearray(elements)


def decode_integer(data, prefix_bits):
def decode_integer(data: bytes, prefix_bits: int) -> tuple[int, int]:
"""
This decodes an integer according to the wacky integer encoding rules
defined in the HPACK spec. Returns a tuple of the decoded integer and the
Expand Down Expand Up @@ -128,7 +122,8 @@ def decode_integer(data, prefix_bits):
return number, index


def _dict_to_iterable(header_dict):
def _dict_to_iterable(header_dict: Union[dict[bytes, bytes], dict[str, str]]) \
-> Generator[Union[tuple[bytes, bytes], tuple[str, str]], None, None]:
"""
This converts a dictionary to an iterable of two-tuples. This is a
HPACK-specific function because it pulls "special-headers" out first and
Expand All @@ -140,19 +135,19 @@ def _dict_to_iterable(header_dict):
key=lambda k: not _to_bytes(k).startswith(b':')
)
for key in keys:
yield key, header_dict[key]
yield key, header_dict[key] # type: ignore


def _to_bytes(value):
def _to_bytes(value: Union[bytes, str, Any]) -> bytes:
"""
Convert anything to bytes through a UTF-8 encoded string
"""
t = type(value)
if t is bytes:
return value
return value # type: ignore
if t is not str:
value = str(value)
return value.encode("utf-8")
return value.encode("utf-8") # type: ignore


class Encoder:
Expand All @@ -161,27 +156,29 @@ class Encoder:
HTTP/2 header blocks.
"""

def __init__(self):
def __init__(self) -> None:
self.header_table = HeaderTable()
self.huffman_coder = HuffmanEncoder(
REQUEST_CODES, REQUEST_CODES_LENGTH
)
self.table_size_changes = []
self.table_size_changes: list[int] = []

@property
def header_table_size(self):
def header_table_size(self) -> int:
"""
Controls the size of the HPACK header table.
"""
return self.header_table.maxsize

@header_table_size.setter
def header_table_size(self, value):
def header_table_size(self, value: int) -> None:
self.header_table.maxsize = value
if self.header_table.resized:
self.table_size_changes.append(value)

def encode(self, headers, huffman=True):
def encode(self,
headers: Headers,
huffman: bool = True) -> bytes:
"""
Takes a set of headers and encodes them into a HPACK-encoded header
block.
Expand Down Expand Up @@ -256,13 +253,13 @@ def encode(self, headers, huffman=True):
header = (_to_bytes(header[0]), _to_bytes(header[1]))
header_block.append(self.add(header, sensitive, huffman))

header_block = b''.join(header_block)
encoded = b''.join(header_block)

log.debug("Encoded header block to %s", header_block)
log.debug("Encoded header block to %s", encoded)

return header_block
return encoded

def add(self, to_add, sensitive, huffman=False):
def add(self, to_add: tuple[bytes, bytes], sensitive: bool, huffman: bool = False) -> bytes:
"""
This function takes a header key-value tuple and serializes it.
"""
Expand Down Expand Up @@ -311,15 +308,15 @@ def add(self, to_add, sensitive, huffman=False):

return encoded

def _encode_indexed(self, index):
def _encode_indexed(self, index: int) -> bytes:
"""
Encodes a header using the indexed representation.
"""
field = encode_integer(index, 7)
field[0] |= 0x80 # we set the top bit
return bytes(field)

def _encode_literal(self, name, value, indexbit, huffman=False):
def _encode_literal(self, name: bytes, value: bytes, indexbit: bytes, huffman: bool = False) -> bytes:
"""
Encodes a header with a literal name and literal value. If ``indexing``
is True, the header will be added to the header table: otherwise it
Expand All @@ -340,7 +337,7 @@ def _encode_literal(self, name, value, indexbit, huffman=False):
[indexbit, bytes(name_len), name, bytes(value_len), value]
)

def _encode_indexed_literal(self, index, value, indexbit, huffman=False):
def _encode_indexed_literal(self, index: int, value: bytes, indexbit: bytes, huffman: bool = False) -> bytes:
"""
Encodes a header with an indexed name and a literal value and performs
incremental indexing.
Expand All @@ -362,16 +359,16 @@ def _encode_indexed_literal(self, index, value, indexbit, huffman=False):

return b''.join([bytes(prefix), bytes(value_len), value])

def _encode_table_size_change(self):
def _encode_table_size_change(self) -> bytes:
"""
Produces the encoded form of all header table size change context
updates.
"""
block = b''
for size_bytes in self.table_size_changes:
size_bytes = encode_integer(size_bytes, 5)
size_bytes[0] |= 0x20
block += bytes(size_bytes)
b = encode_integer(size_bytes, 5)
b[0] |= 0x20
block += bytes(b)
self.table_size_changes = []
return block

Expand All @@ -397,7 +394,7 @@ class Decoder:
Defaults to 64kB.
:type max_header_list_size: ``int``
"""
def __init__(self, max_header_list_size=DEFAULT_MAX_HEADER_LIST_SIZE):
def __init__(self, max_header_list_size: int = DEFAULT_MAX_HEADER_LIST_SIZE) -> None:
self.header_table = HeaderTable()

#: The maximum decompressed size we will allow for any single header
Expand Down Expand Up @@ -426,17 +423,17 @@ def __init__(self, max_header_list_size=DEFAULT_MAX_HEADER_LIST_SIZE):
self.max_allowed_table_size = self.header_table.maxsize

@property
def header_table_size(self):
def header_table_size(self) -> int:
"""
Controls the size of the HPACK header table.
"""
return self.header_table.maxsize

@header_table_size.setter
def header_table_size(self, value):
def header_table_size(self, value: int) -> None:
self.header_table.maxsize = value

def decode(self, data, raw=False):
def decode(self, data: bytes, raw: bool = False) -> Headers:
"""
Takes an HPACK-encoded header block and decodes it into a header set.
Expand All @@ -454,7 +451,7 @@ def decode(self, data, raw=False):
log.debug("Decoding %s", data)

data_mem = memoryview(data)
headers = []
headers: list[HeaderTuple] = []
data_len = len(data)
inflated_size = 0
current_index = 0
Expand Down Expand Up @@ -501,7 +498,7 @@ def decode(self, data, raw=False):

if header:
headers.append(header)
inflated_size += table_entry_size(*header)
inflated_size += table_entry_size(header[0], header[1])

if inflated_size > self.max_header_list_size:
raise OversizedHeaderListError(
Expand All @@ -521,7 +518,7 @@ def decode(self, data, raw=False):
except UnicodeDecodeError:
raise HPACKDecodingError("Unable to decode headers as UTF-8.")

def _assert_valid_table_size(self):
def _assert_valid_table_size(self) -> None:
"""
Check that the table size set by the encoder is lower than the maximum
we expect to have.
Expand All @@ -531,7 +528,7 @@ def _assert_valid_table_size(self):
"Encoder did not shrink table size to within the max"
)

def _update_encoding_context(self, data):
def _update_encoding_context(self, data: bytes) -> int:
"""
Handles a byte that updates the encoding context.
"""
Expand All @@ -544,7 +541,7 @@ def _update_encoding_context(self, data):
self.header_table_size = new_size
return consumed

def _decode_indexed(self, data):
def _decode_indexed(self, data: bytes) -> tuple[HeaderTuple, int]:
"""
Decodes a header represented using the indexed representation.
"""
Expand All @@ -553,13 +550,13 @@ def _decode_indexed(self, data):
log.debug("Decoded %s, consumed %d", header, consumed)
return header, consumed

def _decode_literal_no_index(self, data):
def _decode_literal_no_index(self, data: bytes) -> tuple[HeaderTuple, int]:
return self._decode_literal(data, False)

def _decode_literal_index(self, data):
def _decode_literal_index(self, data: bytes) -> tuple[HeaderTuple, int]:
return self._decode_literal(data, True)

def _decode_literal(self, data, should_index):
def _decode_literal(self, data: bytes, should_index: bool) -> tuple[HeaderTuple, int]:
"""
Decodes a header represented with a literal.
"""
Expand All @@ -577,7 +574,7 @@ def _decode_literal(self, data, should_index):
high_byte = data[0]
indexed_name = high_byte & 0x0F
name_len = 4
not_indexable = high_byte & 0x10
not_indexable = bool(high_byte & 0x10)

if indexed_name:
# Indexed header name.
Expand Down Expand Up @@ -616,6 +613,7 @@ def _decode_literal(self, data, should_index):

# If we have been told never to index the header field, encode that in
# the tuple we use.
header: HeaderTuple
if not_indexable:
header = NeverIndexedHeaderTuple(name, value)
else:
Expand Down
17 changes: 8 additions & 9 deletions src/hpack/huffman.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
hpack/huffman_decoder
~~~~~~~~~~~~~~~~~~~~~
Expand All @@ -13,11 +12,11 @@ class HuffmanEncoder:
Encodes a string according to the Huffman encoding table defined in the
HPACK specification.
"""
def __init__(self, huffman_code_list, huffman_code_list_lengths):
def __init__(self, huffman_code_list: list[int], huffman_code_list_lengths: list[int]) -> None:
self.huffman_code_list = huffman_code_list
self.huffman_code_list_lengths = huffman_code_list_lengths

def encode(self, bytes_to_encode):
def encode(self, bytes_to_encode: bytes) -> bytes:
"""
Given a string of bytes, encodes them according to the HPACK Huffman
specification.
Expand Down Expand Up @@ -48,19 +47,19 @@ def encode(self, bytes_to_encode):

# Convert the number to hex and strip off the leading '0x' and the
# trailing 'L', if present.
final_num = hex(final_num)[2:].rstrip('L')
s = hex(final_num)[2:].rstrip('L')

# If this is odd, prepend a zero.
final_num = '0' + final_num if len(final_num) % 2 != 0 else final_num
s = '0' + s if len(s) % 2 != 0 else s

# This number should have twice as many digits as bytes. If not, we're
# missing some leading zeroes. Work out how many bytes we want and how
# many digits we have, then add the missing zero digits to the front.
total_bytes = (final_int_len + bits_to_be_padded) // 8
expected_digits = total_bytes * 2

if len(final_num) != expected_digits:
missing_digits = expected_digits - len(final_num)
final_num = ('0' * missing_digits) + final_num
if len(s) != expected_digits:
missing_digits = expected_digits - len(s)
s = ('0' * missing_digits) + s

return bytes.fromhex(final_num)
return bytes.fromhex(s)
1 change: 0 additions & 1 deletion src/hpack/huffman_constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
hpack/huffman_constants
~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
Loading

0 comments on commit 131b44c

Please sign in to comment.