Skip to content

Commit

Permalink
refactors writers and buffered code, improves docs
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Sep 8, 2024
1 parent daa1e7d commit f31e686
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 80 deletions.
43 changes: 24 additions & 19 deletions dlt/common/data_writers/buffered.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,32 +99,17 @@ def write_data_item(self, item: TDataItems, columns: TTableSchemaColumns) -> int
# until the first chunk is written we can change the columns schema freely
if columns is not None:
self._current_columns = dict(columns)

new_rows_count: int
if isinstance(item, List):
# update row count, if item supports "num_rows" it will be used to count items
if len(item) > 0 and hasattr(item[0], "num_rows"):
new_rows_count = sum(tbl.num_rows for tbl in item)
else:
new_rows_count = len(item)
# items coming in a single list will be written together, no matter how many there are
self._buffered_items.extend(item)
else:
self._buffered_items.append(item)
# update row count, if item supports "num_rows" it will be used to count items
if hasattr(item, "num_rows"):
new_rows_count = item.num_rows
else:
new_rows_count = 1
# add item to buffer and count new rows
new_rows_count = self._buffer_items_with_row_count(item)
self._buffered_items_count += new_rows_count
# set last modification date
self._last_modified = time.time()
# flush if max buffer exceeded, the second path of the expression prevents empty data frames to pile up in the buffer
if (
self._buffered_items_count >= self.buffer_max_items
or len(self._buffered_items) >= self.buffer_max_items
):
self._flush_items()
# set last modification date
self._last_modified = time.time()
# rotate the file if max_bytes exceeded
if self._file:
# rotate on max file size
Expand Down Expand Up @@ -221,6 +206,26 @@ def __exit__(self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb
if not in_exception:
raise

def _buffer_items_with_row_count(self, item: TDataItems) -> int:
"""Adds `item` to in-memory buffer and counts new rows, depending in item type"""
new_rows_count: int
if isinstance(item, List):
# update row count, if item supports "num_rows" it will be used to count items
if len(item) > 0 and hasattr(item[0], "num_rows"):
new_rows_count = sum(tbl.num_rows for tbl in item)
else:
new_rows_count = len(item)
# items coming in a single list will be written together, no matter how many there are
self._buffered_items.extend(item)
else:
self._buffered_items.append(item)
# update row count, if item supports "num_rows" it will be used to count items
if hasattr(item, "num_rows"):
new_rows_count = item.num_rows
else:
new_rows_count = 1
return new_rows_count

def _rotate_file(self, allow_empty_file: bool = False) -> DataWriterMetrics:
metrics = self._flush_and_close_file(allow_empty_file)
self._file_name = (
Expand Down
99 changes: 42 additions & 57 deletions dlt/common/data_writers/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
from dlt.common.metrics import DataWriterMetrics
from dlt.common.schema.typing import TTableSchemaColumns
from dlt.common.typing import StrAny
from dlt.common.typing import StrAny, TDataItem


if TYPE_CHECKING:
Expand Down Expand Up @@ -72,18 +72,18 @@ def __init__(self, f: IO[Any], caps: DestinationCapabilitiesContext = None) -> N
def write_header(self, columns_schema: TTableSchemaColumns) -> None: # noqa
pass

def write_data(self, rows: Sequence[Any]) -> None:
self.items_count += len(rows)
def write_data(self, items: Sequence[TDataItem]) -> None:
self.items_count += len(items)

def write_footer(self) -> None: # noqa
pass

def close(self) -> None: # noqa
pass

def write_all(self, columns_schema: TTableSchemaColumns, rows: Sequence[Any]) -> None:
def write_all(self, columns_schema: TTableSchemaColumns, items: Sequence[TDataItem]) -> None:
self.write_header(columns_schema)
self.write_data(rows)
self.write_data(items)
self.write_footer()

@classmethod
Expand Down Expand Up @@ -156,9 +156,9 @@ def writer_spec(cls) -> FileWriterSpec:


class JsonlWriter(DataWriter):
def write_data(self, rows: Sequence[Any]) -> None:
super().write_data(rows)
for row in rows:
def write_data(self, items: Sequence[TDataItem]) -> None:
super().write_data(items)
for row in items:
json.dump(row, self._f)
self._f.write(b"\n")

Expand All @@ -175,12 +175,12 @@ def writer_spec(cls) -> FileWriterSpec:


class TypedJsonlListWriter(JsonlWriter):
def write_data(self, rows: Sequence[Any]) -> None:
def write_data(self, items: Sequence[TDataItem]) -> None:
# skip JsonlWriter when calling super
super(JsonlWriter, self).write_data(rows)
super(JsonlWriter, self).write_data(items)
# write all rows as one list which will require to write just one line
# encode types with PUA characters
json.typed_dump(rows, self._f)
json.typed_dump(items, self._f)
self._f.write(b"\n")

@classmethod
Expand Down Expand Up @@ -222,11 +222,11 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None:
if self.writer_type == "default":
self._f.write("VALUES\n")

def write_data(self, rows: Sequence[Any]) -> None:
super().write_data(rows)
def write_data(self, items: Sequence[TDataItem]) -> None:
super().write_data(items)

# do not write empty rows, such things may be produced by Arrow adapters
if len(rows) == 0:
if len(items) == 0:
return

def write_row(row: StrAny, last_row: bool = False) -> None:
Expand All @@ -244,11 +244,11 @@ def write_row(row: StrAny, last_row: bool = False) -> None:
self._f.write(self.sep)

# write rows
for row in rows[:-1]:
for row in items[:-1]:
write_row(row)

# write last row without separator so we can write footer eventually
write_row(rows[-1], last_row=True)
write_row(items[-1], last_row=True)
self._chunks_written += 1

def write_footer(self) -> None:
Expand Down Expand Up @@ -342,19 +342,19 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None:
]
self.writer = self._create_writer(self.schema)

def write_data(self, rows: Sequence[Any]) -> None:
super().write_data(rows)
def write_data(self, items: Sequence[TDataItem]) -> None:
super().write_data(items)
from dlt.common.libs.pyarrow import pyarrow

# replace complex types with json
for key in self.complex_indices:
for row in rows:
for row in items:
if (value := row.get(key)) is not None:
# TODO: make this configurable
if value is not None and not isinstance(value, str):
row[key] = json.dumps(value)

table = pyarrow.Table.from_pylist(rows, schema=self.schema)
table = pyarrow.Table.from_pylist(items, schema=self.schema)
# Write
self.writer.write_table(table, row_group_size=self.parquet_row_group_size)

Expand Down Expand Up @@ -423,10 +423,10 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None:
i for i, field in columns_schema.items() if field["data_type"] == "binary"
]

def write_data(self, rows: Sequence[Any]) -> None:
def write_data(self, items: Sequence[TDataItem]) -> None:
# convert bytes and json
if self.complex_indices or self.bytes_indices:
for row in rows:
for row in items:
for key in self.complex_indices:
if (value := row.get(key)) is not None:
row[key] = json.dumps(value)
Expand All @@ -445,9 +445,9 @@ def write_data(self, rows: Sequence[Any]) -> None:
" type as binary.",
)

self.writer.writerows(rows)
self.writer.writerows(items)
# count rows that got written
self.items_count += sum(len(row) for row in rows)
self.items_count += sum(len(row) for row in items)

def close(self) -> None:
self.writer = None
Expand All @@ -471,35 +471,20 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None:
# Schema will be written as-is from the arrow table
self._column_schema = columns_schema

def write_data(self, rows: Sequence[Any]) -> None:
from dlt.common.libs.pyarrow import pyarrow
def write_data(self, items: Sequence[TDataItem]) -> None:
from dlt.common.libs.pyarrow import concat_batches_and_tables_in_order

if not rows:
if not items:
return
# concat batches and tables into a single one, preserving order
# pyarrow writer starts a row group for each item it writes (even with 0 rows)
# it also converts batches into tables internally. by creating a single table
# we allow the user rudimentary control over row group size via max buffered items
batches = []
tables = []
for row in rows:
self.items_count += row.num_rows
if isinstance(row, pyarrow.RecordBatch):
batches.append(row)
elif isinstance(row, pyarrow.Table):
if batches:
tables.append(pyarrow.Table.from_batches(batches))
batches = []
tables.append(row)
else:
raise ValueError(f"Unsupported type {type(row)}")
if batches:
tables.append(pyarrow.Table.from_batches(batches))

table = pyarrow.concat_tables(tables, promote_options="none")
table = concat_batches_and_tables_in_order(items)
self.items_count += table.num_rows
if not self.writer:
self.writer = self._create_writer(table.schema)
# write concatenated tables, "none" options ensures 0 copy concat
# write concatenated tables
self.writer.write_table(table, row_group_size=self.parquet_row_group_size)

def write_footer(self) -> None:
Expand Down Expand Up @@ -544,12 +529,12 @@ def __init__(
def write_header(self, columns_schema: TTableSchemaColumns) -> None:
self._columns_schema = columns_schema

def write_data(self, rows: Sequence[Any]) -> None:
def write_data(self, items: Sequence[TDataItem]) -> None:
from dlt.common.libs.pyarrow import pyarrow
import pyarrow.csv

for row in rows:
if isinstance(row, (pyarrow.Table, pyarrow.RecordBatch)):
for item in items:
if isinstance(item, (pyarrow.Table, pyarrow.RecordBatch)):
if not self.writer:
if self.quoting == "quote_needed":
quoting = "needed"
Expand All @@ -560,14 +545,14 @@ def write_data(self, rows: Sequence[Any]) -> None:
try:
self.writer = pyarrow.csv.CSVWriter(
self._f,
row.schema,
item.schema,
write_options=pyarrow.csv.WriteOptions(
include_header=self.include_header,
delimiter=self._delimiter_b,
quoting_style=quoting,
),
)
self._first_schema = row.schema
self._first_schema = item.schema
except pyarrow.ArrowInvalid as inv_ex:
if "Unsupported Type" in str(inv_ex):
raise InvalidDataItem(
Expand All @@ -579,18 +564,18 @@ def write_data(self, rows: Sequence[Any]) -> None:
)
raise
# make sure that Schema stays the same
if not row.schema.equals(self._first_schema):
if not item.schema.equals(self._first_schema):
raise InvalidDataItem(
"csv",
"arrow",
"Arrow schema changed without rotating the file. This may be internal"
" error or misuse of the writer.\nFirst"
f" schema:\n{self._first_schema}\n\nCurrent schema:\n{row.schema}",
f" schema:\n{self._first_schema}\n\nCurrent schema:\n{item.schema}",
)

# write headers only on the first write
try:
self.writer.write(row)
self.writer.write(item)
except pyarrow.ArrowInvalid as inv_ex:
if "Invalid UTF8 payload" in str(inv_ex):
raise InvalidDataItem(
Expand All @@ -611,9 +596,9 @@ def write_data(self, rows: Sequence[Any]) -> None:
)
raise
else:
raise ValueError(f"Unsupported type {type(row)}")
raise ValueError(f"Unsupported type {type(item)}")
# count rows that got written
self.items_count += row.num_rows
self.items_count += item.num_rows

def write_footer(self) -> None:
if self.writer is None and self.include_header:
Expand Down Expand Up @@ -649,8 +634,8 @@ def writer_spec(cls) -> FileWriterSpec:
class ArrowToObjectAdapter:
"""A mixin that will convert object writer into arrow writer."""

def write_data(self, rows: Sequence[Any]) -> None:
for batch in rows:
def write_data(self, items: Sequence[TDataItem]) -> None:
for batch in items:
# convert to object data item format
super().write_data(batch.to_pylist()) # type: ignore[misc]

Expand Down
24 changes: 24 additions & 0 deletions dlt/common/libs/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,30 @@ def cast_arrow_schema_types(
return schema


def concat_batches_and_tables_in_order(
tables_or_batches: Iterable[Union[pyarrow.Table, pyarrow.RecordBatch]]
) -> pyarrow.Table:
"""Concatenate iterable of tables and batches into a single table, preserving row order. Zero copy is used during
concatenation so schemas must be identical.
"""
batches = []
tables = []
for item in tables_or_batches:
if isinstance(item, pyarrow.RecordBatch):
batches.append(item)
elif isinstance(item, pyarrow.Table):
if batches:
tables.append(pyarrow.Table.from_batches(batches))
batches = []
tables.append(item)
else:
raise ValueError(f"Unsupported type {type(item)}")
if batches:
tables.append(pyarrow.Table.from_batches(batches))
# "none" option ensures 0 copy concat
return pyarrow.concat_tables(tables, promote_options="none")


class NameNormalizationCollision(ValueError):
def __init__(self, reason: str) -> None:
msg = f"Arrow column name collision after input data normalization. {reason}"
Expand Down
9 changes: 5 additions & 4 deletions docs/website/docs/dlt-ecosystem/file-formats/parquet.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,16 @@ To our best knowledge, arrow will convert your timezone aware DateTime(s) to UTC


### Row group size
The `pyarrow` parquet writer writes each item, i.e. table or record batch, in a separate row group.
The `pyarrow` parquet writer writes each item, i.e. table or record batch, in a separate row group.
This may lead to many small row groups which may not be optimal for certain query engines. For example, `duckdb` parallelizes on a row group.
`dlt` allows controlling the size of the row group by
buffering and concatenating tables and batches before they are written. The concatenation is done as a zero-copy to save memory.
You can control the memory needed by setting the count of records to be buffered as follows:
[buffering and concatenating tables](../../reference/performance.md#controlling-in-memory-buffers) and batches before they are written. The concatenation is done as a zero-copy to save memory.
You can control the size of the row group by setting the maximum number of rows kept in the buffer.
```toml
[extract.data_writer]
buffer_max_items=10e6
```
Mind that `dlt` holds the tables in memory. Thus, 1,000,000 rows in the example above may consume a significant amount of RAM.

`row_group_size` has limited utility with `pyarrow` writer. It will split large tables into many groups if set below item buffer size.
`row_group_size` configuration setting has limited utility with `pyarrow` writer. It may be useful when you write single very large pyarrow tables
or when your in memory buffer is really large.

0 comments on commit f31e686

Please sign in to comment.