Skip to content

Commit

Permalink
Fix #142 and #143 (#144)
Browse files Browse the repository at this point in the history
Fix bugs that manifest when reading and writing empty protocol streams in Python.
  • Loading branch information
naegelejd authored Apr 2, 2024
1 parent 7a0ab26 commit d2a3ee2
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 2 deletions.
18 changes: 18 additions & 0 deletions cpp/test/roundtrip_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,24 @@ TEST_P(RoundTripTests, SimpleDatasets) {
tw->Close();
}

TEST_P(RoundTripTests, SimpleDatasets_Mixed) {
auto tw = CreateValidatingWriter<StreamsWriterBase>();

tw->WriteIntData({});
tw->EndIntData();

tw->WriteOptionalIntData({10, 11, {}, 13, 14, 15, {}, 17, 18, 19});
tw->EndOptionalIntData();

tw->EndRecordWithOptionalVectorData();

tw->WriteFixedVector({1, 2, 3});
tw->WriteFixedVector({4, 5, 6});
tw->EndFixedVector();

tw->Close();
}

TEST_P(RoundTripTests, SimpleDatasets_Empty) {
auto tw = CreateValidatingWriter<StreamsWriterBase>();

Expand Down
8 changes: 8 additions & 0 deletions python/tests/test_protocol_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,7 @@ def days():
def test_simple_streams(format: Format):
c = create_validating_writer_class(format, tm.StreamsWriterBase)

# non-empty streams
with c() as w:
w.write_int_data(range(10))
w.write_int_data(range(20))
Expand All @@ -600,6 +601,13 @@ def test_simple_streams(format: Format):
)
w.write_fixed_vector(([1, 2, 3] for _ in range(4)))

# mixed empty and non-empty streams
with c() as w:
w.write_int_data(range(0))
w.write_optional_int_data([1, 2, None, 4, 5, None, 7, 8, 9, 10])
w.write_record_with_optional_vector_data([])
w.write_fixed_vector(([1, 2, 3] for _ in range(4)))

# empty streams
with c() as w:
w.write_int_data(range(0))
Expand Down
4 changes: 2 additions & 2 deletions tooling/internal/python/static_files/_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
) -> None:
self._stream = CodedInputStream(stream)
magic_bytes = self._stream.read_view(len(MAGIC_BYTES))
if magic_bytes != MAGIC_BYTES: # pyright: ignore [reportUnnecessaryComparison]
if magic_bytes != MAGIC_BYTES: # pyright: ignore [reportUnnecessaryComparison]
raise RuntimeError("Invalid magic bytes")

version = read_fixed_int32(self._stream)
Expand Down Expand Up @@ -955,7 +955,7 @@ def __init__(self, element_serializer: TypeSerializer[T, T_NP]) -> None:
def write(self, stream: CodedOutputStream, value: Iterable[T]) -> None:
# Note that the final 0 is missing and will be added before the next protocol step
# or the protocol is closed.
if isinstance(value, list):
if isinstance(value, list) and len(value) > 0:
stream.write_unsigned_varint(len(value))
for element in value:
self._element_serializer.write(stream, element)
Expand Down
1 change: 1 addition & 0 deletions tooling/internal/python/static_files/_ndjson.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def _read_json_line(self, stepName: str, required: bool) -> object:
return value
if required:
raise ValueError(f"Expected protocol step '{stepName}' not found.")
return MISSING_SENTINEL

line = self._stream.readline()
if line == "":
Expand Down

0 comments on commit d2a3ee2

Please sign in to comment.