Skip to content

Commit

Permalink
flushes the item buffer for empty tables
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Sep 3, 2024
1 parent 346b411 commit e69e7fa
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 117 deletions.
13 changes: 8 additions & 5 deletions dlt/common/data_writers/buffered.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
file_max_items: int = None,
file_max_bytes: int = None,
disable_compression: bool = False,
_caps: DestinationCapabilitiesContext = None
_caps: DestinationCapabilitiesContext = None,
):
self.writer_spec = writer_spec
if self.writer_spec.requires_destination_capabilities and not _caps:
Expand Down Expand Up @@ -102,13 +102,13 @@ def write_data_item(self, item: TDataItems, columns: TTableSchemaColumns) -> int

new_rows_count: int
if isinstance(item, List):
# items coming in single list will be written together, not matter how many are there
self._buffered_items.extend(item)
# update row count, if item supports "num_rows" it will be used to count items
if len(item) > 0 and hasattr(item[0], "num_rows"):
new_rows_count = sum(tbl.num_rows for tbl in item)
else:
new_rows_count = len(item)
# items coming in single list will be written together, not matter how many are there
self._buffered_items.extend(item)
else:
self._buffered_items.append(item)
# update row count, if item supports "num_rows" it will be used to count items
Expand All @@ -117,8 +117,11 @@ def write_data_item(self, item: TDataItems, columns: TTableSchemaColumns) -> int
else:
new_rows_count = 1
self._buffered_items_count += new_rows_count
# flush if max buffer exceeded
if self._buffered_items_count >= self.buffer_max_items:
# flush if max buffer exceeded, the second path of the expression prevents empty data frames to pile up in the buffer
if (
self._buffered_items_count >= self.buffer_max_items
or len(self._buffered_items) >= self.buffer_max_items
):
self._flush_items()
# set last modification date
self._last_modified = time.time()
Expand Down
83 changes: 82 additions & 1 deletion tests/libs/test_parquet_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from dlt.common import pendulum, Decimal, json
from dlt.common.configuration import inject_section
from dlt.common.data_writers.writers import ParquetDataWriter
from dlt.common.data_writers.writers import ArrowToParquetWriter, ParquetDataWriter
from dlt.common.destination import DestinationCapabilitiesContext
from dlt.common.schema.utils import new_column
from dlt.common.configuration.specs.config_section_context import ConfigSectionContext
Expand Down Expand Up @@ -313,3 +313,84 @@ def _assert_pq_column(col: int, prec: str) -> None:
_assert_pq_column(1, "milliseconds")
_assert_pq_column(2, "microseconds")
_assert_pq_column(3, "nanoseconds")


def test_arrow_parquet_row_group_size() -> None:
import pyarrow as pa

c1 = {"col1": new_column("col1", "bigint")}

id_ = -1

def get_id_() -> int:
nonlocal id_
id_ += 1
return id_

single_elem_table = lambda: pa.Table.from_pylist([{"col1": get_id_()}])
single_elem_batch = lambda: pa.RecordBatch.from_pylist([{"col1": get_id_()}])

with get_writer(ArrowToParquetWriter, file_max_bytes=2**8, buffer_max_items=2) as writer:
writer.write_data_item(single_elem_table(), columns=c1)
writer._flush_items()
assert writer._writer.items_count == 1

with pa.parquet.ParquetFile(writer.closed_files[0].file_path) as reader:
assert reader.num_row_groups == 1
assert reader.metadata.row_group(0).num_rows == 1

# should be packages into single group
with get_writer(ArrowToParquetWriter, file_max_bytes=2**8, buffer_max_items=2) as writer:
writer.write_data_item(
[
single_elem_table(),
single_elem_batch(),
single_elem_batch(),
single_elem_table(),
single_elem_batch(),
],
columns=c1,
)
writer._flush_items()
assert writer._writer.items_count == 5

with pa.parquet.ParquetFile(writer.closed_files[0].file_path) as reader:
assert reader.num_row_groups == 1
assert reader.metadata.row_group(0).num_rows == 5

with open(writer.closed_files[0].file_path, "rb") as f:
table = pq.read_table(f)
# all ids are there and in order
assert table["col1"].to_pylist() == list(range(1, 6))

# pass also empty and make it to be written with a separate call to parquet writer (by buffer_max_items)
with get_writer(ArrowToParquetWriter, file_max_bytes=2**8, buffer_max_items=1) as writer:
pq_batch = single_elem_batch()
writer.write_data_item(pq_batch, columns=c1)
# writer._flush_items()
# assert writer._writer.items_count == 5
# this will also create arrow schema
print(pq_batch.schema)
writer.write_data_item(pa.RecordBatch.from_pylist([], schema=pq_batch.schema), columns=c1)

with pa.parquet.ParquetFile(writer.closed_files[0].file_path) as reader:
assert reader.num_row_groups == 2
assert reader.metadata.row_group(0).num_rows == 1
# row group with size 0 for an empty item
assert reader.metadata.row_group(1).num_rows == 0


def test_empty_tables_get_flushed() -> None:
c1 = {"col1": new_column("col1", "bigint")}
single_elem_table = pa.Table.from_pylist([{"col1": 1}])
empty_batch = pa.RecordBatch.from_pylist([], schema=single_elem_table.schema)

with get_writer(ArrowToParquetWriter, file_max_bytes=2**8, buffer_max_items=2) as writer:
writer.write_data_item(empty_batch, columns=c1)
writer.write_data_item(empty_batch, columns=c1)
# written
assert len(writer._buffered_items) == 0
writer.write_data_item(empty_batch, columns=c1)
assert len(writer._buffered_items) == 1
writer.write_data_item(single_elem_table, columns=c1)
assert len(writer._buffered_items) == 0
111 changes: 0 additions & 111 deletions tests/libs/test_pyarrow.py

This file was deleted.

0 comments on commit e69e7fa

Please sign in to comment.