Skip to content

Commit

Permalink
add support for batch size zero (filepath passthrouh)
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Mar 4, 2024
1 parent dbbbe7c commit db9b488
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 6 deletions.
2 changes: 1 addition & 1 deletion dlt/destinations/impl/sink/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dlt.common.configuration.exceptions import ConfigurationValueError


TSinkCallable = Callable[[TDataItems, TTableSchema], None]
TSinkCallable = Callable[[Union[TDataItems, str], TTableSchema], None]


@configspec
Expand Down
14 changes: 9 additions & 5 deletions dlt/destinations/impl/sink/sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,15 @@ def __init__(
self._state: TLoadJobState = "running"
self._storage_id = f"{self._parsed_file_name.table_name}.{self._parsed_file_name.file_id}"
try:
current_index = destination_state.get(self._storage_id, 0)
for batch in self.run(current_index):
self.call_callable_with_items(batch)
current_index += len(batch)
destination_state[self._storage_id] = current_index
if self._config.batch_size == 0:
# on batch size zero we only call the callable with the filename
self.call_callable_with_items(self._file_path)
else:
current_index = destination_state.get(self._storage_id, 0)
for batch in self.run(current_index):
self.call_callable_with_items(batch)
current_index += len(batch)
destination_state[self._storage_id] = current_index

self._state = "completed"
except Exception as e:
Expand Down
27 changes: 27 additions & 0 deletions tests/load/sink/test_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,30 @@ def direct_sink(items, table):
assert table["columns"]["camelCase"]["name"] == "camelCase"

dlt.pipeline("sink_test", destination=direct_sink, full_refresh=True).run(resource())


def test_file_batch() -> None:
@dlt.resource(table_name="person")
def resource1():
for i in range(100):
yield [{"id": i, "name": f"Name {i}"}]

@dlt.resource(table_name="address")
def resource2():
for i in range(50):
yield [{"id": i, "city": f"City {i}"}]

@dlt.destination(batch_size=0, loader_file_format="parquet")
def direct_sink(file_path, table):
if table["name"].startswith("_dlt"):
return
from dlt.common.libs.pyarrow import pyarrow

assert table["name"] in ["person", "address"]

with pyarrow.parquet.ParquetFile(file_path) as reader:
assert reader.metadata.num_rows == (100 if table["name"] == "person" else 50)

dlt.pipeline("sink_test", destination=direct_sink, full_refresh=True).run(
[resource1(), resource2()]
)

0 comments on commit db9b488

Please sign in to comment.