From db9b48889989cf9cf04a103f5dcb3b997c40f236 Mon Sep 17 00:00:00 2001 From: Dave Date: Mon, 4 Mar 2024 18:02:23 +0100 Subject: [PATCH] add support for batch size zero (filepath passthrouh) --- dlt/destinations/impl/sink/configuration.py | 2 +- dlt/destinations/impl/sink/sink.py | 14 +++++++---- tests/load/sink/test_sink.py | 27 +++++++++++++++++++++ 3 files changed, 37 insertions(+), 6 deletions(-) diff --git a/dlt/destinations/impl/sink/configuration.py b/dlt/destinations/impl/sink/configuration.py index 9a96aea98d..8d9289ff8b 100644 --- a/dlt/destinations/impl/sink/configuration.py +++ b/dlt/destinations/impl/sink/configuration.py @@ -12,7 +12,7 @@ from dlt.common.configuration.exceptions import ConfigurationValueError -TSinkCallable = Callable[[TDataItems, TTableSchema], None] +TSinkCallable = Callable[[Union[TDataItems, str], TTableSchema], None] @configspec diff --git a/dlt/destinations/impl/sink/sink.py b/dlt/destinations/impl/sink/sink.py index 2ebfefe516..816eece079 100644 --- a/dlt/destinations/impl/sink/sink.py +++ b/dlt/destinations/impl/sink/sink.py @@ -42,11 +42,15 @@ def __init__( self._state: TLoadJobState = "running" self._storage_id = f"{self._parsed_file_name.table_name}.{self._parsed_file_name.file_id}" try: - current_index = destination_state.get(self._storage_id, 0) - for batch in self.run(current_index): - self.call_callable_with_items(batch) - current_index += len(batch) - destination_state[self._storage_id] = current_index + if self._config.batch_size == 0: + # on batch size zero we only call the callable with the filename + self.call_callable_with_items(self._file_path) + else: + current_index = destination_state.get(self._storage_id, 0) + for batch in self.run(current_index): + self.call_callable_with_items(batch) + current_index += len(batch) + destination_state[self._storage_id] = current_index self._state = "completed" except Exception as e: diff --git a/tests/load/sink/test_sink.py b/tests/load/sink/test_sink.py index 72dcfd5b1e..f5cf318ee8 100644 --- a/tests/load/sink/test_sink.py +++ b/tests/load/sink/test_sink.py @@ -318,3 +318,30 @@ def direct_sink(items, table): assert table["columns"]["camelCase"]["name"] == "camelCase" dlt.pipeline("sink_test", destination=direct_sink, full_refresh=True).run(resource()) + + +def test_file_batch() -> None: + @dlt.resource(table_name="person") + def resource1(): + for i in range(100): + yield [{"id": i, "name": f"Name {i}"}] + + @dlt.resource(table_name="address") + def resource2(): + for i in range(50): + yield [{"id": i, "city": f"City {i}"}] + + @dlt.destination(batch_size=0, loader_file_format="parquet") + def direct_sink(file_path, table): + if table["name"].startswith("_dlt"): + return + from dlt.common.libs.pyarrow import pyarrow + + assert table["name"] in ["person", "address"] + + with pyarrow.parquet.ParquetFile(file_path) as reader: + assert reader.metadata.num_rows == (100 if table["name"] == "person" else 50) + + dlt.pipeline("sink_test", destination=direct_sink, full_refresh=True).run( + [resource1(), resource2()] + )