diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 0944b03bea..e7bba266df 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -586,10 +586,21 @@ class SupportsStagingDestination(ABC): def should_load_data_to_staging_dataset_on_staging_destination( self, table: TTableSchema ) -> bool: + """If set to True, and staging destination is configured, the data will be loaded to staging dataset on staging destination + instead of a regular dataset on staging destination. Currently it is used by Athena Iceberg which uses staging dataset + on staging destination to copy data to iceberg tables stored on regular dataset on staging destination. + The default is to load data to regular dataset on staging destination from where warehouses like Snowflake (that have their + own storage) will copy data. + """ return False @abstractmethod def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + """If set to True, data in `table` will be truncated on staging destination (regular dataset). This is the default behavior which + can be changed with a config flag. + For Athena + Iceberg this setting is always False - Athena uses regular dataset to store Iceberg tables and we avoid touching it. + For Athena we truncate those tables only on "replace" write disposition. + """ pass diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index b3b2fbcf0f..a5a8ae2562 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -530,7 +530,7 @@ def should_truncate_table_before_load_on_staging_destination(self, table: TTable if table["write_disposition"] == "replace" and not self._is_iceberg_table( self.prepare_load_table(table["name"]) ): - return self.config.truncate_tables_on_staging_destination_before_load + return True return False def should_load_data_to_staging_dataset_on_staging_destination( diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index 42dee5fc8f..3bfa050fd7 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -218,7 +218,18 @@ def test_truncate_staging_dataset(destination_config: DestinationTestConfigurati # check there are two staging files _, staging_client = pipeline._get_destination_clients(pipeline.default_schema) with staging_client: - assert len(staging_client.list_table_files(table_name)) == 2 # type: ignore[attr-defined] + # except Athena + Iceberg which does not store tables in staging dataset + if ( + destination_config.destination == "athena" + and destination_config.table_format == "iceberg" + ): + table_count = 0 + # but keeps them in staging dataset on staging destination - but only the last one + with staging_client.with_staging_dataset(): # type: ignore[attr-defined] + assert len(staging_client.list_table_files(table_name)) == 1 # type: ignore[attr-defined] + else: + table_count = 2 + assert len(staging_client.list_table_files(table_name)) == table_count # type: ignore[attr-defined] # load the data with truncating, so only new file is on the staging pipeline.destination.config_params["truncate_tables_on_staging_destination_before_load"] = True @@ -231,7 +242,15 @@ def test_truncate_staging_dataset(destination_config: DestinationTestConfigurati # check there is only one staging file _, staging_client = pipeline._get_destination_clients(pipeline.default_schema) with staging_client: - assert len(staging_client.list_table_files(table_name)) == 1 # type: ignore[attr-defined] + # except for Athena which does not delete staging destination tables + if destination_config.destination == "athena": + if destination_config.table_format == "iceberg": + table_count = 0 + else: + table_count = 3 + else: + table_count = 1 + assert len(staging_client.list_table_files(table_name)) == table_count # type: ignore[attr-defined] @pytest.mark.parametrize( diff --git a/tests/load/utils.py b/tests/load/utils.py index 15b1e1575e..5427904d52 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -257,6 +257,27 @@ def destinations_configs( # build destination configs destination_configs: List[DestinationTestConfiguration] = [] + # default sql configs that are also default staging configs + default_sql_configs_with_staging = [ + # Athena needs filesystem staging, which will be automatically set; we have to supply a bucket url though. + DestinationTestConfiguration( + destination="athena", + file_format="parquet", + supports_merge=False, + bucket_url=AWS_BUCKET, + ), + DestinationTestConfiguration( + destination="athena", + file_format="parquet", + bucket_url=AWS_BUCKET, + force_iceberg=True, + supports_merge=True, + supports_dbt=False, + table_format="iceberg", + extra_info="iceberg", + ), + ] + # default non staging sql based configs, one per destination if default_sql_configs: destination_configs += [ @@ -268,26 +289,10 @@ def destinations_configs( DestinationTestConfiguration(destination="duckdb", file_format="parquet"), DestinationTestConfiguration(destination="motherduck", file_format="insert_values"), ] - # Athena needs filesystem staging, which will be automatically set; we have to supply a bucket url though. - destination_configs += [ - DestinationTestConfiguration( - destination="athena", - file_format="parquet", - supports_merge=False, - bucket_url=AWS_BUCKET, - ) - ] - destination_configs += [ - DestinationTestConfiguration( - destination="athena", - file_format="parquet", - bucket_url=AWS_BUCKET, - force_iceberg=True, - supports_merge=True, - supports_dbt=False, - extra_info="iceberg", - ) - ] + + # add Athena staging configs + destination_configs += default_sql_configs_with_staging + destination_configs += [ DestinationTestConfiguration( destination="clickhouse", file_format="jsonl", supports_dbt=False @@ -332,6 +337,10 @@ def destinations_configs( DestinationTestConfiguration(destination="qdrant", extra_info="server"), ] + if (default_sql_configs or all_staging_configs) and not default_sql_configs: + # athena default configs not added yet + destination_configs += default_sql_configs_with_staging + if default_staging_configs or all_staging_configs: destination_configs += [ DestinationTestConfiguration(