diff --git a/karapace/protobuf/io.py b/karapace/protobuf/io.py index cbead3913..dbf6f2211 100644 --- a/karapace/protobuf/io.py +++ b/karapace/protobuf/io.py @@ -3,7 +3,7 @@ See LICENSE for details """ from io import BytesIO -from karapace.protobuf.encoding_variants import read_indexes, write_indexes +from karapace.protobuf.varint import read_indexes, write_indexes from karapace.protobuf.exception import IllegalArgumentException, ProtobufSchemaResolutionException, ProtobufTypeException from karapace.protobuf.message_element import MessageElement from karapace.protobuf.protobuf_to_dict import dict_to_protobuf, protobuf_to_dict diff --git a/karapace/protobuf/encoding_variants.py b/karapace/protobuf/varint.py similarity index 75% rename from karapace/protobuf/encoding_variants.py rename to karapace/protobuf/varint.py index 37e1d3cb9..cb9e2ee7f 100644 --- a/karapace/protobuf/encoding_variants.py +++ b/karapace/protobuf/varint.py @@ -5,11 +5,12 @@ # Workaround to encode/decode indexes in protobuf messages # Based on https://developers.google.com/protocol-buffers/docs/encoding#varints +from __future__ import annotations from io import BytesIO from karapace.protobuf.exception import IllegalArgumentException -from typing import List +from typing import List, Final, Sequence -ZERO_BYTE = b"\x00" +ZERO_BYTE: Final = b"\x00" def read_varint(bio: BytesIO) -> int: @@ -35,25 +36,21 @@ def read_varint(bio: BytesIO) -> int: def read_indexes(bio: BytesIO) -> List[int]: try: - size: int = read_varint(bio) + size = read_varint(bio) except EOFError: # TODO: change exception - # pylint: disable=raise-missing-from - raise IllegalArgumentException("problem with reading binary data") - if size == 0: - return [0] + raise IllegalArgumentException("problem with reading binary data") from None return [read_varint(bio) for _ in range(size)] -def write_varint(bio: BytesIO, value: int) -> int: +def write_varint(bio: BytesIO, value: int) -> None: if value < 0: raise ValueError(f"value must not be negative, got {value}") if value == 0: bio.write(ZERO_BYTE) - return 1 + return - written_bytes = 0 while value > 0: to_write = value & 0x7F value = value >> 7 @@ -61,12 +58,10 @@ def write_varint(bio: BytesIO, value: int) -> int: if value > 0: to_write |= 0x80 - bio.write(bytearray(to_write)[0]) - written_bytes += 1 + bio.write(to_write.to_bytes(1, "little")) - return written_bytes - -def write_indexes(bio: BytesIO, indexes: List[int]) -> None: +def write_indexes(bio: BytesIO, indexes: Sequence[int]) -> None: + write_varint(bio, len(indexes)) for i in indexes: write_varint(bio, i) diff --git a/tests/unit/protobuf/test_varint.py b/tests/unit/protobuf/test_varint.py new file mode 100644 index 000000000..49502bccc --- /dev/null +++ b/tests/unit/protobuf/test_varint.py @@ -0,0 +1,35 @@ +from __future__ import annotations +import io +from hypothesis import given, example +from hypothesis.strategies import integers, lists + +from karapace.protobuf.varint import write_varint, read_varint, read_indexes, \ + write_indexes + +varint_values = integers(min_value=0) + + +@given(varint_values) +@example(0) +@example(1) +def test_can_roundtrip_varint(value: int) -> None: + with io.BytesIO() as buffer: + write_varint(buffer, value) + buffer.seek(0) + result = read_varint(buffer) + assert result == value + # Assert buffer is exhausted. + assert buffer.read(1) == b"" + + +@given(lists(elements=varint_values)) +@example([]) +@example([1, 2, 3]) +def test_can_roundtrip_indexes(value: list[int]) -> None: + with io.BytesIO() as buffer: + write_indexes(buffer, value) + buffer.seek(0) + result = read_indexes(buffer) + assert result == value + # Assert buffer is exhausted. + assert buffer.read(1) == b""