diff --git a/cpp/test/roundtrip_test.cc b/cpp/test/roundtrip_test.cc index 04e19330..b5449451 100644 --- a/cpp/test/roundtrip_test.cc +++ b/cpp/test/roundtrip_test.cc @@ -459,6 +459,24 @@ TEST_P(RoundTripTests, SimpleDatasets) { tw->Close(); } +TEST_P(RoundTripTests, SimpleDatasets_Mixed) { + auto tw = CreateValidatingWriter(); + + tw->WriteIntData({}); + tw->EndIntData(); + + tw->WriteOptionalIntData({10, 11, {}, 13, 14, 15, {}, 17, 18, 19}); + tw->EndOptionalIntData(); + + tw->EndRecordWithOptionalVectorData(); + + tw->WriteFixedVector({1, 2, 3}); + tw->WriteFixedVector({4, 5, 6}); + tw->EndFixedVector(); + + tw->Close(); +} + TEST_P(RoundTripTests, SimpleDatasets_Empty) { auto tw = CreateValidatingWriter(); diff --git a/python/tests/test_protocol_roundtrip.py b/python/tests/test_protocol_roundtrip.py index b67c3701..1d27adc6 100644 --- a/python/tests/test_protocol_roundtrip.py +++ b/python/tests/test_protocol_roundtrip.py @@ -584,6 +584,7 @@ def days(): def test_simple_streams(format: Format): c = create_validating_writer_class(format, tm.StreamsWriterBase) + # non-empty streams with c() as w: w.write_int_data(range(10)) w.write_int_data(range(20)) @@ -600,6 +601,13 @@ def test_simple_streams(format: Format): ) w.write_fixed_vector(([1, 2, 3] for _ in range(4))) + # mixed empty and non-empty streams + with c() as w: + w.write_int_data(range(0)) + w.write_optional_int_data([1, 2, None, 4, 5, None, 7, 8, 9, 10]) + w.write_record_with_optional_vector_data([]) + w.write_fixed_vector(([1, 2, 3] for _ in range(4))) + # empty streams with c() as w: w.write_int_data(range(0)) diff --git a/tooling/internal/python/static_files/_binary.py b/tooling/internal/python/static_files/_binary.py index 7d67d4d0..c76ad41e 100644 --- a/tooling/internal/python/static_files/_binary.py +++ b/tooling/internal/python/static_files/_binary.py @@ -77,7 +77,7 @@ def __init__( ) -> None: self._stream = CodedInputStream(stream) magic_bytes = self._stream.read_view(len(MAGIC_BYTES)) - if magic_bytes != MAGIC_BYTES: # pyright: ignore [reportUnnecessaryComparison] + if magic_bytes != MAGIC_BYTES: # pyright: ignore [reportUnnecessaryComparison] raise RuntimeError("Invalid magic bytes") version = read_fixed_int32(self._stream) @@ -955,7 +955,7 @@ def __init__(self, element_serializer: TypeSerializer[T, T_NP]) -> None: def write(self, stream: CodedOutputStream, value: Iterable[T]) -> None: # Note that the final 0 is missing and will be added before the next protocol step # or the protocol is closed. - if isinstance(value, list): + if isinstance(value, list) and len(value) > 0: stream.write_unsigned_varint(len(value)) for element in value: self._element_serializer.write(stream, element) diff --git a/tooling/internal/python/static_files/_ndjson.py b/tooling/internal/python/static_files/_ndjson.py index 020a74e0..7cb67c1d 100644 --- a/tooling/internal/python/static_files/_ndjson.py +++ b/tooling/internal/python/static_files/_ndjson.py @@ -136,6 +136,7 @@ def _read_json_line(self, stepName: str, required: bool) -> object: return value if required: raise ValueError(f"Expected protocol step '{stepName}' not found.") + return MISSING_SENTINEL line = self._stream.readline() if line == "":