diff --git a/dlt/common/libs/deltalake.py b/dlt/common/libs/deltalake.py index e6cd91bd0a..d98795d07c 100644 --- a/dlt/common/libs/deltalake.py +++ b/dlt/common/libs/deltalake.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict, Union +from typing import Optional, Dict, Union, List from pathlib import Path from dlt import version, Pipeline @@ -71,9 +71,13 @@ def write_delta_table( table_or_uri: Union[str, Path, DeltaTable], data: Union[pa.Table, pa.RecordBatchReader], write_disposition: TWriteDisposition, + partition_by: Optional[Union[List[str], str]] = None, storage_options: Optional[Dict[str, str]] = None, ) -> None: - """Writes in-memory Arrow table to on-disk Delta table.""" + """Writes in-memory Arrow table to on-disk Delta table. + + Thin wrapper around `deltalake.write_deltalake`. + """ # throws warning for `s3` protocol: https://github.com/delta-io/delta-rs/issues/2460 # TODO: upgrade `deltalake` lib after https://github.com/delta-io/delta-rs/pull/2500 @@ -81,6 +85,7 @@ def write_delta_table( write_deltalake( # type: ignore[call-overload] table_or_uri=table_or_uri, data=ensure_delta_compatible_arrow_data(data), + partition_by=partition_by, mode=get_delta_write_mode(write_disposition), schema_mode="merge", # enable schema evolution (adding new columns) storage_options=storage_options, diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index 7009ad95ac..55bfdc5dd0 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -114,6 +114,9 @@ def run(self) -> None: storage_options = _deltalake_storage_options(self._job_client.config) dt = try_get_deltatable(dt_path, storage_options=storage_options) + # get partition columns + part_cols = get_columns_names_with_prop(self._load_table, "partition") + # explicitly check if there is data # (https://github.com/delta-io/delta-rs/issues/2686) if arrow_ds.head(1).num_rows == 0: @@ -123,6 +126,7 @@ def run(self) -> None: table_uri=dt_path, schema=ensure_delta_compatible_arrow_schema(arrow_ds.schema), mode="overwrite", + partition_by=part_cols, storage_options=storage_options, ) return @@ -158,6 +162,7 @@ def run(self) -> None: table_or_uri=dt_path if dt is None else dt, data=arrow_rbr, write_disposition=self._load_table["write_disposition"], + partition_by=part_cols, storage_options=storage_options, ) diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index 71620e889d..ec9f6fe4e9 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -437,6 +437,91 @@ def complex_table(): assert len(rows_dict["complex_table__child__grandchild"]) == 5 +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + table_format_filesystem_configs=True, + table_format="delta", + bucket_subset=(FILE_BUCKET), + ), + ids=lambda x: x.name, +) +def test_delta_table_partitioning( + destination_config: DestinationTestConfiguration, +) -> None: + """Tests partitioning for `delta` table format.""" + + from dlt.common.libs.deltalake import get_delta_tables + from tests.pipeline.utils import users_materialize_table_schema + + pipeline = destination_config.setup_pipeline("fs_pipe", dev_mode=True) + + # zero partition columns + @dlt.resource(table_format="delta") + def zero_part(): + yield {"foo": 1, "bar": 1} + + info = pipeline.run(zero_part()) + assert_load_info(info) + dt = get_delta_tables(pipeline, "zero_part")["zero_part"] + assert dt.metadata().partition_columns == [] + assert load_table_counts(pipeline, "zero_part")["zero_part"] == 1 + + # one partition column + @dlt.resource(table_format="delta", columns={"c1": {"partition": True}}) + def one_part(): + yield [ + {"c1": "foo", "c2": 1}, + {"c1": "foo", "c2": 2}, + {"c1": "bar", "c2": 3}, + {"c1": "baz", "c2": 4}, + ] + + info = pipeline.run(one_part()) + assert_load_info(info) + dt = get_delta_tables(pipeline, "one_part")["one_part"] + assert dt.metadata().partition_columns == ["c1"] + assert load_table_counts(pipeline, "one_part")["one_part"] == 4 + + # two partition columns + @dlt.resource( + table_format="delta", columns={"c1": {"partition": True}, "c2": {"partition": True}} + ) + def two_part(): + yield [ + {"c1": "foo", "c2": 1, "c3": True}, + {"c1": "foo", "c2": 2, "c3": True}, + {"c1": "bar", "c2": 1, "c3": True}, + {"c1": "baz", "c2": 1, "c3": True}, + ] + + info = pipeline.run(two_part()) + assert_load_info(info) + dt = get_delta_tables(pipeline, "two_part")["two_part"] + assert dt.metadata().partition_columns == ["c1", "c2"] + assert load_table_counts(pipeline, "two_part")["two_part"] == 4 + + # test partitioning with empty source + users_materialize_table_schema.apply_hints( + table_format="delta", + columns={"id": {"partition": True}}, + ) + info = pipeline.run(users_materialize_table_schema()) + assert_load_info(info) + dt = get_delta_tables(pipeline, "users")["users"] + assert dt.metadata().partition_columns == ["id"] + assert load_table_counts(pipeline, "users")["users"] == 0 + + # changing partitioning after initial table creation is not supported + zero_part.apply_hints(columns={"foo": {"partition": True}}) + with pytest.raises(PipelineStepFailed) as pip_ex: + pipeline.run(zero_part()) + assert isinstance(pip_ex.value.__context__, LoadClientJobRetry) + assert "partitioning" in pip_ex.value.__context__.retry_message + dt = get_delta_tables(pipeline, "zero_part")["zero_part"] + assert dt.metadata().partition_columns == [] + + @pytest.mark.parametrize( "destination_config", destinations_configs(