From 2788235572de105ff01aaf5c1ebcbe4ea40b249b Mon Sep 17 00:00:00 2001 From: Akela Drissner-Schmid <32450038+akelad@users.noreply.github.com> Date: Mon, 26 Aug 2024 16:32:22 +0200 Subject: [PATCH 1/7] Update snowflake.md --- docs/website/docs/dlt-ecosystem/destinations/snowflake.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md index 181d024a2f..d08578c5a2 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md +++ b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md @@ -136,7 +136,12 @@ If you set the [`replace` strategy](../../general-usage/full-loading.md) to `sta recreated with a [clone command](https://docs.snowflake.com/en/sql-reference/sql/create-clone) from the staging tables. ## Data loading -The data is loaded using an internal Snowflake stage. We use the `PUT` command and per-table built-in stages by default. Stage files are immediately removed (if not specified otherwise). +The data is loaded using an internal Snowflake stage. We use the `PUT` command and per-table built-in stages by default. Stage files are kept by default, unless specified otherwise via the `keep_staged_files` parameter: + +```toml +[destination.snowflake] +keep_staged_files = false +``` ## Supported file formats * [insert-values](../file-formats/insert-format.md) is used by default From e337cca079ab21742339e097eb381635eafc5de5 Mon Sep 17 00:00:00 2001 From: rudolfix Date: Tue, 27 Aug 2024 18:32:07 +0200 Subject: [PATCH 2/7] provides detail exception messages when cursor stored value cannot be coerced to data in incremental (#1748) --- .../impl/filesystem/filesystem.py | 1 + dlt/extract/incremental/exceptions.py | 26 ++++++++ dlt/extract/incremental/transform.py | 63 ++++++++++++++++--- tests/extract/test_incremental.py | 21 ++++++- 4 files changed, 101 insertions(+), 10 deletions(-) diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index 05261ccb1b..62263a10b9 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -89,6 +89,7 @@ def make_remote_path(self) -> str: ) def make_remote_uri(self) -> str: + """Returns path on a remote filesystem as a full uri including scheme.""" return self._job_client.make_remote_uri(self.make_remote_path()) def metrics(self) -> Optional[LoadJobMetrics]: diff --git a/dlt/extract/incremental/exceptions.py b/dlt/extract/incremental/exceptions.py index e318a028dc..a5f94c2974 100644 --- a/dlt/extract/incremental/exceptions.py +++ b/dlt/extract/incremental/exceptions.py @@ -1,3 +1,5 @@ +from typing import Any + from dlt.extract.exceptions import PipeException from dlt.common.typing import TDataItem @@ -13,6 +15,30 @@ def __init__(self, pipe_name: str, json_path: str, item: TDataItem, msg: str = N super().__init__(pipe_name, msg) +class IncrementalCursorInvalidCoercion(PipeException): + def __init__( + self, + pipe_name: str, + cursor_path: str, + cursor_value: TDataItem, + cursor_value_type: str, + item: TDataItem, + item_type: Any, + details: str, + ) -> None: + self.cursor_path = cursor_path + self.cursor_value = cursor_value + self.cursor_value_type = cursor_value_type + self.item = item + msg = ( + f"Could not coerce {cursor_value_type} with value {cursor_value} and type" + f" {type(cursor_value)} to actual data item {item} at path {cursor_path} with type" + f" {item_type}: {details}. You need to use different data type for" + f" {cursor_value_type} or cast your data ie. by using `add_map` on this resource." + ) + super().__init__(pipe_name, msg) + + class IncrementalPrimaryKeyMissing(PipeException): def __init__(self, pipe_name: str, primary_key_column: str, item: TDataItem) -> None: self.primary_key_column = primary_key_column diff --git a/dlt/extract/incremental/transform.py b/dlt/extract/incremental/transform.py index 947e21f7b8..0ac9fdf520 100644 --- a/dlt/extract/incremental/transform.py +++ b/dlt/extract/incremental/transform.py @@ -8,6 +8,7 @@ from dlt.common.typing import TDataItem from dlt.common.jsonpath import find_values, JSONPathFields, compile_path from dlt.extract.incremental.exceptions import ( + IncrementalCursorInvalidCoercion, IncrementalCursorPathMissing, IncrementalPrimaryKeyMissing, ) @@ -158,14 +159,36 @@ def __call__( # Check whether end_value has been reached # Filter end value ranges exclusively, so in case of "max" function we remove values >= end_value - if self.end_value is not None and ( - last_value_func((row_value, self.end_value)) != self.end_value - or last_value_func((row_value,)) == self.end_value - ): - return None, False, True - + if self.end_value is not None: + try: + if ( + last_value_func((row_value, self.end_value)) != self.end_value + or last_value_func((row_value,)) == self.end_value + ): + return None, False, True + except Exception as ex: + raise IncrementalCursorInvalidCoercion( + self.resource_name, + self.cursor_path, + self.end_value, + "end_value", + row_value, + type(row_value).__name__, + str(ex), + ) from ex check_values = (row_value,) + ((last_value,) if last_value is not None else ()) - new_value = last_value_func(check_values) + try: + new_value = last_value_func(check_values) + except Exception as ex: + raise IncrementalCursorInvalidCoercion( + self.resource_name, + self.cursor_path, + last_value, + "start_value/initial_value", + row_value, + type(row_value).__name__, + str(ex), + ) from ex # new_value is "less" or equal to last_value (the actual max) if last_value == new_value: # use func to compute row_value into last_value compatible @@ -294,14 +317,36 @@ def __call__( # If end_value is provided, filter to include table rows that are "less" than end_value if self.end_value is not None: - end_value_scalar = to_arrow_scalar(self.end_value, cursor_data_type) + try: + end_value_scalar = to_arrow_scalar(self.end_value, cursor_data_type) + except Exception as ex: + raise IncrementalCursorInvalidCoercion( + self.resource_name, + cursor_path, + self.end_value, + "end_value", + "", + cursor_data_type, + str(ex), + ) from ex tbl = tbl.filter(end_compare(tbl[cursor_path], end_value_scalar)) # Is max row value higher than end value? # NOTE: pyarrow bool *always* evaluates to python True. `as_py()` is necessary end_out_of_range = not end_compare(row_value_scalar, end_value_scalar).as_py() if self.start_value is not None: - start_value_scalar = to_arrow_scalar(self.start_value, cursor_data_type) + try: + start_value_scalar = to_arrow_scalar(self.start_value, cursor_data_type) + except Exception as ex: + raise IncrementalCursorInvalidCoercion( + self.resource_name, + cursor_path, + self.start_value, + "start_value/initial_value", + "", + cursor_data_type, + str(ex), + ) from ex # Remove rows lower or equal than the last start value keep_filter = last_value_compare(tbl[cursor_path], start_value_scalar) start_out_of_range = bool(pa.compute.any(pa.compute.invert(keep_filter)).as_py()) diff --git a/tests/extract/test_incremental.py b/tests/extract/test_incremental.py index f4082a7d86..c401552fb2 100644 --- a/tests/extract/test_incremental.py +++ b/tests/extract/test_incremental.py @@ -30,6 +30,7 @@ from dlt.sources.helpers.transform import take_first from dlt.extract.incremental import IncrementalResourceWrapper, Incremental from dlt.extract.incremental.exceptions import ( + IncrementalCursorInvalidCoercion, IncrementalCursorPathMissing, IncrementalPrimaryKeyMissing, ) @@ -1303,7 +1304,7 @@ def some_data( ) # will cause invalid comparison if item_type == "object": - with pytest.raises(InvalidStepFunctionArguments): + with pytest.raises(IncrementalCursorInvalidCoercion): list(resource) else: data = data_item_to_list(item_type, list(resource)) @@ -2065,3 +2066,21 @@ def test_source(): incremental_steps = test_source_incremental().table_name._pipe._steps assert isinstance(incremental_steps[-2], ValidateItem) assert isinstance(incremental_steps[-1], IncrementalResourceWrapper) + + +@pytest.mark.parametrize("item_type", ALL_TEST_DATA_ITEM_FORMATS) +def test_cursor_date_coercion(item_type: TestDataItemFormat) -> None: + today = datetime.today().date() + + @dlt.resource() + def updated_is_int(updated_at=dlt.sources.incremental("updated_at", initial_value=today)): + data = [{"updated_at": d} for d in [1, 2, 3]] + yield data_to_item_format(item_type, data) + + pip_1_name = "test_pydantic_columns_validator_" + uniq_id() + pipeline = dlt.pipeline(pipeline_name=pip_1_name, destination="duckdb") + + with pytest.raises(PipelineStepFailed) as pip_ex: + pipeline.run(updated_is_int()) + assert isinstance(pip_ex.value.__cause__, IncrementalCursorInvalidCoercion) + assert pip_ex.value.__cause__.cursor_path == "updated_at" From 98ca505fd06b8146a4355c6355174abe8b45ef66 Mon Sep 17 00:00:00 2001 From: VioletM Date: Wed, 28 Aug 2024 06:28:50 -0400 Subject: [PATCH 3/7] Expose staging tables truncation to config (#1717) * Expose staging tables truncation to config * Fix comments, add tests * Fix tests * Move implementation from mixing, add tests * Fix docs grammar --- dlt/common/destination/reference.py | 8 ++- dlt/destinations/impl/athena/athena.py | 2 +- dlt/destinations/impl/bigquery/bigquery.py | 3 + .../impl/clickhouse/clickhouse.py | 3 + .../impl/databricks/databricks.py | 3 + dlt/destinations/impl/dremio/dremio.py | 3 + dlt/destinations/impl/dummy/configuration.py | 2 + dlt/destinations/impl/dummy/dummy.py | 3 + dlt/destinations/impl/redshift/redshift.py | 3 + dlt/destinations/impl/snowflake/snowflake.py | 3 + dlt/destinations/impl/synapse/synapse.py | 3 + dlt/load/utils.py | 7 +- docs/website/docs/dlt-ecosystem/staging.md | 72 ++++++++++++------- tests/load/pipeline/test_stage_loading.py | 57 ++++++++++++++- tests/load/test_dummy_client.py | 17 +++++ 15 files changed, 152 insertions(+), 37 deletions(-) diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 744cbbd1f5..0944b03bea 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -269,6 +269,8 @@ class DestinationClientDwhWithStagingConfiguration(DestinationClientDwhConfigura staging_config: Optional[DestinationClientStagingConfiguration] = None """configuration of the staging, if present, injected at runtime""" + truncate_tables_on_staging_destination_before_load: bool = True + """If dlt should truncate the tables on staging destination before loading data.""" TLoadJobState = Literal["ready", "running", "failed", "retry", "completed"] @@ -578,7 +580,7 @@ def with_staging_dataset(self) -> ContextManager["JobClientBase"]: return self # type: ignore -class SupportsStagingDestination: +class SupportsStagingDestination(ABC): """Adds capability to support a staging destination for the load""" def should_load_data_to_staging_dataset_on_staging_destination( @@ -586,9 +588,9 @@ def should_load_data_to_staging_dataset_on_staging_destination( ) -> bool: return False + @abstractmethod def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: - # the default is to truncate the tables on the staging destination... - return True + pass # TODO: type Destination properly diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index 0c90d171a3..b28309b930 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -531,7 +531,7 @@ def should_truncate_table_before_load_on_staging_destination(self, table: TTable if table["write_disposition"] == "replace" and not self._is_iceberg_table( self.prepare_load_table(table["name"]) ): - return True + return self.config.truncate_tables_on_staging_destination_before_load return False def should_load_data_to_staging_dataset_on_staging_destination( diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index 8291415434..11326cf3ed 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -503,6 +503,9 @@ def _should_autodetect_schema(self, table_name: str) -> bool: self.schema._schema_tables, table_name, AUTODETECT_SCHEMA_HINT, allow_none=True ) or (self.config.autodetect_schema and table_name not in self.schema.dlt_table_names()) + def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + return self.config.truncate_tables_on_staging_destination_before_load + def _streaming_load( items: List[Dict[Any, Any]], table: Dict[str, Any], job_client: BigQueryClient diff --git a/dlt/destinations/impl/clickhouse/clickhouse.py b/dlt/destinations/impl/clickhouse/clickhouse.py index 5f17a5a18c..282fbaf338 100644 --- a/dlt/destinations/impl/clickhouse/clickhouse.py +++ b/dlt/destinations/impl/clickhouse/clickhouse.py @@ -372,3 +372,6 @@ def _from_db_type( self, ch_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: return self.type_mapper.from_db_type(ch_t, precision, scale) + + def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + return self.config.truncate_tables_on_staging_destination_before_load diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 2f23e88ea0..38412b2608 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -325,3 +325,6 @@ def _get_storage_table_query_columns(self) -> List[str]: "full_data_type" ) return fields + + def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + return self.config.truncate_tables_on_staging_destination_before_load diff --git a/dlt/destinations/impl/dremio/dremio.py b/dlt/destinations/impl/dremio/dremio.py index 68a3fedc31..149d106dcd 100644 --- a/dlt/destinations/impl/dremio/dremio.py +++ b/dlt/destinations/impl/dremio/dremio.py @@ -210,3 +210,6 @@ def _make_add_column_sql( self, new_columns: Sequence[TColumnSchema], table_format: TTableFormat = None ) -> List[str]: return ["ADD COLUMNS (" + ", ".join(self._get_column_def_sql(c) for c in new_columns) + ")"] + + def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + return self.config.truncate_tables_on_staging_destination_before_load diff --git a/dlt/destinations/impl/dummy/configuration.py b/dlt/destinations/impl/dummy/configuration.py index 023b88e51a..a066479294 100644 --- a/dlt/destinations/impl/dummy/configuration.py +++ b/dlt/destinations/impl/dummy/configuration.py @@ -34,6 +34,8 @@ class DummyClientConfiguration(DestinationClientConfiguration): """raise terminal exception in job init""" fail_transiently_in_init: bool = False """raise transient exception in job init""" + truncate_tables_on_staging_destination_before_load: bool = True + """truncate tables on staging destination""" # new jobs workflows create_followup_jobs: bool = False diff --git a/dlt/destinations/impl/dummy/dummy.py b/dlt/destinations/impl/dummy/dummy.py index 49b55ec65d..feb09369dc 100644 --- a/dlt/destinations/impl/dummy/dummy.py +++ b/dlt/destinations/impl/dummy/dummy.py @@ -202,6 +202,9 @@ def complete_load(self, load_id: str) -> None: def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: return super().should_load_data_to_staging_dataset(table) + def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + return self.config.truncate_tables_on_staging_destination_before_load + @contextmanager def with_staging_dataset(self) -> Iterator[JobClientBase]: try: diff --git a/dlt/destinations/impl/redshift/redshift.py b/dlt/destinations/impl/redshift/redshift.py index 93827c8163..0e201dc4e0 100644 --- a/dlt/destinations/impl/redshift/redshift.py +++ b/dlt/destinations/impl/redshift/redshift.py @@ -274,3 +274,6 @@ def _from_db_type( self, pq_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: return self.type_mapper.from_db_type(pq_t, precision, scale) + + def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + return self.config.truncate_tables_on_staging_destination_before_load diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index 8b4eabc961..6688b5bc17 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -325,3 +325,6 @@ def _get_column_def_sql(self, c: TColumnSchema, table_format: TTableFormat = Non return ( f"{name} {self.type_mapper.to_db_type(c)} {self._gen_not_null(c.get('nullable', True))}" ) + + def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + return self.config.truncate_tables_on_staging_destination_before_load diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index e43e2a6dfa..750a4895f0 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -173,6 +173,9 @@ def create_load_job( ) return job + def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + return self.config.truncate_tables_on_staging_destination_before_load + class SynapseCopyFileLoadJob(CopyRemoteFileLoadJob): def __init__( diff --git a/dlt/load/utils.py b/dlt/load/utils.py index 741c01f249..e3a2ebcd79 100644 --- a/dlt/load/utils.py +++ b/dlt/load/utils.py @@ -179,9 +179,10 @@ def _init_dataset_and_update_schema( applied_update = job_client.update_stored_schema( only_tables=update_tables, expected_update=expected_update ) - logger.info( - f"Client for {job_client.config.destination_type} will truncate tables {staging_text}" - ) + if truncate_tables: + logger.info( + f"Client for {job_client.config.destination_type} will truncate tables {staging_text}" + ) job_client.initialize_storage(truncate_tables=truncate_tables) return applied_update diff --git a/docs/website/docs/dlt-ecosystem/staging.md b/docs/website/docs/dlt-ecosystem/staging.md index 05e31a574b..789189b7dd 100644 --- a/docs/website/docs/dlt-ecosystem/staging.md +++ b/docs/website/docs/dlt-ecosystem/staging.md @@ -1,36 +1,33 @@ --- title: Staging -description: Configure an s3 or gcs bucket for staging before copying into the destination +description: Configure an S3 or GCS bucket for staging before copying into the destination keywords: [staging, destination] --- # Staging -The goal of staging is to bring the data closer to the database engine so the modification of the destination (final) dataset happens faster and without errors. `dlt`, when asked, creates two -staging areas: +The goal of staging is to bring the data closer to the database engine so that the modification of the destination (final) dataset happens faster and without errors. `dlt`, when asked, creates two staging areas: 1. A **staging dataset** used by the [merge and replace loads](../general-usage/incremental-loading.md#merge-incremental_loading) to deduplicate and merge data with the destination. -2. A **staging storage** which is typically a s3/gcp bucket where [loader files](file-formats/) are copied before they are loaded by the destination. +2. A **staging storage** which is typically an S3/GCP bucket where [loader files](file-formats/) are copied before they are loaded by the destination. ## Staging dataset -`dlt` creates a staging dataset when write disposition of any of the loaded resources requires it. It creates and migrates required tables exactly like for the -main dataset. Data in staging tables is truncated when load step begins and only for tables that will participate in it. -Such staging dataset has the same name as the dataset passed to `dlt.pipeline` but with `_staging` suffix in the name. Alternatively, you can provide your own staging dataset pattern or use a fixed name, identical for all the -configured datasets. +`dlt` creates a staging dataset when the write disposition of any of the loaded resources requires it. It creates and migrates required tables exactly like for the main dataset. Data in staging tables is truncated when the load step begins and only for tables that will participate in it. +Such a staging dataset has the same name as the dataset passed to `dlt.pipeline` but with a `_staging` suffix in the name. Alternatively, you can provide your own staging dataset pattern or use a fixed name, identical for all the configured datasets. ```toml [destination.postgres] staging_dataset_name_layout="staging_%s" ``` -Entry above switches the pattern to `staging_` prefix and for example for dataset with name **github_data** `dlt` will create **staging_github_data**. +The entry above switches the pattern to `staging_` prefix and for example, for a dataset with the name **github_data**, `dlt` will create **staging_github_data**. -To configure static staging dataset name, you can do the following (we use destination factory) +To configure a static staging dataset name, you can do the following (we use the destination factory) ```py import dlt dest_ = dlt.destinations.postgres(staging_dataset_name_layout="_dlt_staging") ``` -All pipelines using `dest_` as destination will use **staging_dataset** to store staging tables. Make sure that your pipelines are not overwriting each other's tables. +All pipelines using `dest_` as the destination will use the **staging_dataset** to store staging tables. Make sure that your pipelines are not overwriting each other's tables. -### Cleanup up staging dataset automatically -`dlt` does not truncate tables in staging dataset at the end of the load. Data that is left after contains all the extracted data and may be useful for debugging. +### Cleanup staging dataset automatically +`dlt` does not truncate tables in the staging dataset at the end of the load. Data that is left after contains all the extracted data and may be useful for debugging. If you prefer to truncate it, put the following line in `config.toml`: ```toml @@ -39,19 +36,23 @@ truncate_staging_dataset=true ``` ## Staging storage -`dlt` allows to chain destinations where the first one (`staging`) is responsible for uploading the files from local filesystem to the remote storage. It then generates followup jobs for the second destination that (typically) copy the files from remote storage into destination. +`dlt` allows chaining destinations where the first one (`staging`) is responsible for uploading the files from the local filesystem to the remote storage. It then generates follow-up jobs for the second destination that (typically) copy the files from remote storage into the destination. -Currently, only one destination the [filesystem](destinations/filesystem.md) can be used as a staging. Following destinations can copy remote files: -1. [Redshift.](destinations/redshift.md#staging-support) -2. [Bigquery.](destinations/bigquery.md#staging-support) -3. [Snowflake.](destinations/snowflake.md#staging-support) +Currently, only one destination, the [filesystem](destinations/filesystem.md), can be used as staging. The following destinations can copy remote files: + +1. [Azure Synapse](destinations/synapse#staging-support) +1. [Athena](destinations/athena#staging-support) +1. [Bigquery](destinations/bigquery.md#staging-support) +1. [Dremio](destinations/dremio#staging-support) +1. [Redshift](destinations/redshift.md#staging-support) +1. [Snowflake](destinations/snowflake.md#staging-support) ### How to use -In essence, you need to set up two destinations and then pass them to `dlt.pipeline`. Below we'll use `filesystem` staging with `parquet` files to load into `Redshift` destination. +In essence, you need to set up two destinations and then pass them to `dlt.pipeline`. Below we'll use `filesystem` staging with `parquet` files to load into the `Redshift` destination. -1. **Set up the s3 bucket and filesystem staging.** +1. **Set up the S3 bucket and filesystem staging.** - Please follow our guide in [filesystem destination documentation](destinations/filesystem.md). Test the staging as standalone destination to make sure that files go where you want them. In your `secrets.toml` you should now have a working `filesystem` configuration: + Please follow our guide in the [filesystem destination documentation](destinations/filesystem.md). Test the staging as a standalone destination to make sure that files go where you want them. In your `secrets.toml`, you should now have a working `filesystem` configuration: ```toml [destination.filesystem] bucket_url = "s3://[your_bucket_name]" # replace with your bucket name, @@ -63,15 +64,15 @@ In essence, you need to set up two destinations and then pass them to `dlt.pipel 2. **Set up the Redshift destination.** - Please follow our guide in [redshift destination documentation](destinations/redshift.md). In your `secrets.toml` you added: + Please follow our guide in the [redshift destination documentation](destinations/redshift.md). In your `secrets.toml`, you added: ```toml # keep it at the top of your toml file! before any section starts destination.redshift.credentials="redshift://loader:@localhost/dlt_data?connect_timeout=15" ``` -3. **Authorize Redshift cluster to access the staging bucket.** +3. **Authorize the Redshift cluster to access the staging bucket.** - By default `dlt` will forward the credentials configured for `filesystem` to the `Redshift` COPY command. If you are fine with this, move to the next step. + By default, `dlt` will forward the credentials configured for `filesystem` to the `Redshift` COPY command. If you are fine with this, move to the next step. 4. **Chain staging to destination and request `parquet` file format.** @@ -79,7 +80,7 @@ In essence, you need to set up two destinations and then pass them to `dlt.pipel ```py # Create a dlt pipeline that will load # chess player data to the redshift destination - # via staging on s3 + # via staging on S3 pipeline = dlt.pipeline( pipeline_name='chess_pipeline', destination='redshift', @@ -87,7 +88,7 @@ In essence, you need to set up two destinations and then pass them to `dlt.pipel dataset_name='player_data' ) ``` - `dlt` will automatically select an appropriate loader file format for the staging files. Below we explicitly specify `parquet` file format (just to demonstrate how to do it): + `dlt` will automatically select an appropriate loader file format for the staging files. Below we explicitly specify the `parquet` file format (just to demonstrate how to do it): ```py info = pipeline.run(chess(), loader_file_format="parquet") ``` @@ -96,4 +97,21 @@ In essence, you need to set up two destinations and then pass them to `dlt.pipel Run the pipeline script as usual. -> 💡 Please note that `dlt` does not delete loaded files from the staging storage after the load is complete. +:::tip +Please note that `dlt` does not delete loaded files from the staging storage after the load is complete, but it truncates previously loaded files. +::: + +### How to prevent staging files truncation + +Before `dlt` loads data to the staging storage, it truncates previously loaded files. To prevent it and keep the whole history +of loaded files, you can use the following parameter: + +```toml +[destination.redshift] +truncate_table_before_load_on_staging_destination=false +``` + +:::caution +The [Athena](destinations/athena#staging-support) destination only truncates not iceberg tables with `replace` merge_disposition. +Therefore, the parameter `truncate_table_before_load_on_staging_destination` only controls the truncation of corresponding files for these tables. +::: diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index a760c86526..f216fa3c05 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -1,12 +1,12 @@ import pytest -from typing import Dict, Any, List +from typing import List import dlt, os -from dlt.common import json, sleep -from copy import deepcopy +from dlt.common import json from dlt.common.storages.configuration import FilesystemConfiguration from dlt.common.utils import uniq_id from dlt.common.schema.typing import TDataType +from dlt.destinations.impl.filesystem.filesystem import FilesystemClient from tests.load.pipeline.test_merge_disposition import github from tests.pipeline.utils import load_table_counts, assert_load_info @@ -40,6 +40,13 @@ def load_modified_issues(): yield from issues +@dlt.resource(table_name="events", write_disposition="append", primary_key="timestamp") +def event_many_load_2(): + with open("tests/normalize/cases/event.event.many_load_2.json", "r", encoding="utf-8") as f: + events = json.load(f) + yield from events + + @pytest.mark.parametrize( "destination_config", destinations_configs(all_staging_configs=True), ids=lambda x: x.name ) @@ -183,6 +190,50 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: assert replace_counts == initial_counts +@pytest.mark.parametrize( + "destination_config", destinations_configs(all_staging_configs=True), ids=lambda x: x.name +) +def test_truncate_staging_dataset(destination_config: DestinationTestConfiguration) -> None: + """This test checks if tables truncation on staging destination done according to the configuration. + + Test loads data to the destination three times: + * with truncation + * without truncation (after this 2 staging files should be left) + * with truncation (after this 1 staging file should be left) + """ + pipeline = destination_config.setup_pipeline( + pipeline_name="test_stage_loading", dataset_name="test_staging_load" + uniq_id() + ) + resource = event_many_load_2() + table_name: str = resource.table_name # type: ignore[assignment] + + # load the data, files stay on the stage after the load + info = pipeline.run(resource) + assert_load_info(info) + + # load the data without truncating of the staging, should see two files on staging + pipeline.destination.config_params["truncate_tables_on_staging_destination_before_load"] = False + info = pipeline.run(resource) + assert_load_info(info) + # check there are two staging files + _, staging_client = pipeline._get_destination_clients(pipeline.default_schema) + with staging_client: + assert len(staging_client.list_table_files(table_name)) == 2 # type: ignore[attr-defined] + + # load the data with truncating, so only new file is on the staging + pipeline.destination.config_params["truncate_tables_on_staging_destination_before_load"] = True + info = pipeline.run(resource) + assert_load_info(info) + # check that table exists in the destination + with pipeline.sql_client() as sql_client: + qual_name = sql_client.make_qualified_table_name + assert len(sql_client.execute_sql(f"SELECT * from {qual_name(table_name)}")) > 4 + # check there is only one staging file + _, staging_client = pipeline._get_destination_clients(pipeline.default_schema) + with staging_client: + assert len(staging_client.list_table_files(table_name)) == 1 # type: ignore[attr-defined] + + @pytest.mark.parametrize( "destination_config", destinations_configs(all_staging_configs=True), ids=lambda x: x.name ) diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index 9f0bca6ac5..59b7acac15 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -548,6 +548,23 @@ def test_completed_loop_with_delete_completed() -> None: assert_complete_job(load, should_delete_completed=True) +@pytest.mark.parametrize("to_truncate", [True, False]) +def test_truncate_table_before_load_on_stanging(to_truncate) -> None: + load = setup_loader( + client_config=DummyClientConfiguration( + truncate_tables_on_staging_destination_before_load=to_truncate + ) + ) + load_id, schema = prepare_load_package(load.load_storage, NORMALIZED_FILES) + destination_client = load.get_destination_client(schema) + assert ( + destination_client.should_truncate_table_before_load_on_staging_destination( # type: ignore + schema.tables["_dlt_version"] + ) + == to_truncate + ) + + def test_retry_on_new_loop() -> None: # test job that retries sitting in new jobs load = setup_loader(client_config=DummyClientConfiguration(retry_prob=1.0)) From 4e1c6077c7ed4bbaf127e34a2cbc7d87fe48d924 Mon Sep 17 00:00:00 2001 From: rudolfix Date: Wed, 28 Aug 2024 13:17:11 +0200 Subject: [PATCH 4/7] enables external location and named credential in databricks (#1755) * allows to configure external location and named credential for databricks * fixes #1703 * normalizes 'value' when wrapping simple objects in relational, fixes #1754 * simplifies fsspec globbing and allows various url formats that are preserved when reconstituting full url, allows abfss databricks format * adds info on partially loaded packages to docs * renames remote_uri to remote_url in traces * fixes delta for abfss * adds nested tables dlt columns collision test --- .github/workflows/test_destinations.yml | 1 + .../configuration/specs/azure_credentials.py | 2 + dlt/common/libs/deltalake.py | 3 +- dlt/common/metrics.py | 2 +- dlt/common/normalizers/json/__init__.py | 4 +- dlt/common/normalizers/json/relational.py | 4 +- dlt/common/storages/configuration.py | 119 +++++++++++++----- dlt/common/storages/fsspec_filesystem.py | 58 +++++---- dlt/destinations/impl/athena/athena.py | 1 - dlt/destinations/impl/bigquery/bigquery.py | 2 +- .../impl/databricks/configuration.py | 4 + .../impl/databricks/databricks.py | 108 ++++++++++------ dlt/destinations/impl/databricks/factory.py | 6 + dlt/destinations/impl/dummy/dummy.py | 4 +- .../impl/filesystem/filesystem.py | 32 ++--- .../dlt-ecosystem/destinations/databricks.md | 33 ++++- .../dlt-ecosystem/destinations/snowflake.md | 2 +- .../docs/running-in-production/running.md | 16 ++- tests/.dlt/config.toml | 3 +- tests/common/cases/normalizers/sql_upper.py | 2 - .../common/storages/test_local_filesystem.py | 10 +- .../test_destination_name_and_config.py | 4 +- .../test_databricks_configuration.py | 50 +++++++- .../load/filesystem/test_filesystem_common.py | 54 +++++--- .../load/pipeline/test_databricks_pipeline.py | 85 +++++++++++++ .../load/pipeline/test_filesystem_pipeline.py | 18 +-- tests/load/pipeline/test_stage_loading.py | 10 +- tests/load/test_dummy_client.py | 10 +- tests/load/utils.py | 12 +- .../cases/contracts/trace.schema.yaml | 2 +- tests/pipeline/test_pipeline.py | 14 +++ tests/pipeline/test_pipeline_trace.py | 2 +- 32 files changed, 510 insertions(+), 167 deletions(-) create mode 100644 tests/load/pipeline/test_databricks_pipeline.py diff --git a/.github/workflows/test_destinations.yml b/.github/workflows/test_destinations.yml index a034ac7eb0..7fae69ff9e 100644 --- a/.github/workflows/test_destinations.yml +++ b/.github/workflows/test_destinations.yml @@ -29,6 +29,7 @@ env: # Test redshift and filesystem with all buckets # postgres runs again here so we can test on mac/windows ACTIVE_DESTINATIONS: "[\"redshift\", \"postgres\", \"duckdb\", \"filesystem\", \"dummy\"]" + # note that all buckets are enabled for testing jobs: get_docs_changes: diff --git a/dlt/common/configuration/specs/azure_credentials.py b/dlt/common/configuration/specs/azure_credentials.py index 7fa34fa00f..6794b581ce 100644 --- a/dlt/common/configuration/specs/azure_credentials.py +++ b/dlt/common/configuration/specs/azure_credentials.py @@ -32,6 +32,8 @@ def to_object_store_rs_credentials(self) -> Dict[str, str]: creds = self.to_adlfs_credentials() if creds["sas_token"] is None: creds.pop("sas_token") + if creds["account_key"] is None: + creds.pop("account_key") return creds def create_sas_token(self) -> None: diff --git a/dlt/common/libs/deltalake.py b/dlt/common/libs/deltalake.py index d4cb46c600..38b23ea27a 100644 --- a/dlt/common/libs/deltalake.py +++ b/dlt/common/libs/deltalake.py @@ -176,7 +176,8 @@ def _deltalake_storage_options(config: FilesystemConfiguration) -> Dict[str, str """Returns dict that can be passed as `storage_options` in `deltalake` library.""" creds = {} extra_options = {} - if config.protocol in ("az", "gs", "s3"): + # TODO: create a mixin with to_object_store_rs_credentials for a proper discovery + if hasattr(config.credentials, "to_object_store_rs_credentials"): creds = config.credentials.to_object_store_rs_credentials() if config.deltalake_storage_options is not None: extra_options = config.deltalake_storage_options diff --git a/dlt/common/metrics.py b/dlt/common/metrics.py index 5cccee4045..d6acf19d0d 100644 --- a/dlt/common/metrics.py +++ b/dlt/common/metrics.py @@ -64,7 +64,7 @@ class LoadJobMetrics(NamedTuple): started_at: datetime.datetime finished_at: datetime.datetime state: Optional[str] - remote_uri: Optional[str] + remote_url: Optional[str] class LoadMetrics(StepMetrics): diff --git a/dlt/common/normalizers/json/__init__.py b/dlt/common/normalizers/json/__init__.py index a13bab15f4..725f6a8355 100644 --- a/dlt/common/normalizers/json/__init__.py +++ b/dlt/common/normalizers/json/__init__.py @@ -54,9 +54,9 @@ class SupportsDataItemNormalizer(Protocol): """A class with a name DataItemNormalizer deriving from normalizers.json.DataItemNormalizer""" -def wrap_in_dict(item: Any) -> DictStrAny: +def wrap_in_dict(label: str, item: Any) -> DictStrAny: """Wraps `item` that is not a dictionary into dictionary that can be json normalized""" - return {"value": item} + return {label: item} __all__ = [ diff --git a/dlt/common/normalizers/json/relational.py b/dlt/common/normalizers/json/relational.py index 1dbcec4bff..33184640f0 100644 --- a/dlt/common/normalizers/json/relational.py +++ b/dlt/common/normalizers/json/relational.py @@ -281,7 +281,7 @@ def _normalize_list( else: # list of simple types child_row_hash = DataItemNormalizer._get_child_row_hash(parent_row_id, table, idx) - wrap_v = wrap_in_dict(v) + wrap_v = wrap_in_dict(self.c_value, v) wrap_v[self.c_dlt_id] = child_row_hash e = self._link_row(wrap_v, parent_row_id, idx) DataItemNormalizer._extend_row(extend, e) @@ -387,7 +387,7 @@ def normalize_data_item( ) -> TNormalizedRowIterator: # wrap items that are not dictionaries in dictionary, otherwise they cannot be processed by the JSON normalizer if not isinstance(item, dict): - item = wrap_in_dict(item) + item = wrap_in_dict(self.c_value, item) # we will extend event with all the fields necessary to load it as root row row = cast(DictStrAny, item) # identify load id if loaded data must be processed after loading incrementally diff --git a/dlt/common/storages/configuration.py b/dlt/common/storages/configuration.py index b2bdb3a7b6..04780528c4 100644 --- a/dlt/common/storages/configuration.py +++ b/dlt/common/storages/configuration.py @@ -1,7 +1,7 @@ import os import pathlib from typing import Any, Literal, Optional, Type, get_args, ClassVar, Dict, Union -from urllib.parse import urlparse, unquote +from urllib.parse import urlparse, unquote, urlunparse from dlt.common.configuration import configspec, resolve_type from dlt.common.configuration.exceptions import ConfigurationValueError @@ -52,6 +52,53 @@ class LoadStorageConfiguration(BaseConfiguration): ] +def _make_az_url(scheme: str, fs_path: str, bucket_url: str) -> str: + parsed_bucket_url = urlparse(bucket_url) + if parsed_bucket_url.username: + # az://@.dfs.core.windows.net/ + # fs_path always starts with container + split_path = fs_path.split("/", maxsplit=1) + if len(split_path) == 1: + split_path.append("") + container, path = split_path + netloc = f"{container}@{parsed_bucket_url.hostname}" + return urlunparse(parsed_bucket_url._replace(path=path, scheme=scheme, netloc=netloc)) + return f"{scheme}://{fs_path}" + + +def _make_file_url(scheme: str, fs_path: str, bucket_url: str) -> str: + """Creates a normalized file:// url from a local path + + netloc is never set. UNC paths are represented as file://host/path + """ + p_ = pathlib.Path(fs_path) + p_ = p_.expanduser().resolve() + return p_.as_uri() + + +MAKE_URI_DISPATCH = {"az": _make_az_url, "file": _make_file_url} + +MAKE_URI_DISPATCH["adl"] = MAKE_URI_DISPATCH["az"] +MAKE_URI_DISPATCH["abfs"] = MAKE_URI_DISPATCH["az"] +MAKE_URI_DISPATCH["azure"] = MAKE_URI_DISPATCH["az"] +MAKE_URI_DISPATCH["abfss"] = MAKE_URI_DISPATCH["az"] +MAKE_URI_DISPATCH["local"] = MAKE_URI_DISPATCH["file"] + + +def make_fsspec_url(scheme: str, fs_path: str, bucket_url: str) -> str: + """Creates url from `fs_path` and `scheme` using bucket_url as an `url` template + + Args: + scheme (str): scheme of the resulting url + fs_path (str): kind of absolute path that fsspec uses to locate resources for particular filesystem. + bucket_url (str): an url template. the structure of url will be preserved if possible + """ + _maker = MAKE_URI_DISPATCH.get(scheme) + if _maker: + return _maker(scheme, fs_path, bucket_url) + return f"{scheme}://{fs_path}" + + @configspec class FilesystemConfiguration(BaseConfiguration): """A configuration defining filesystem location and access credentials. @@ -59,7 +106,7 @@ class FilesystemConfiguration(BaseConfiguration): When configuration is resolved, `bucket_url` is used to extract a protocol and request corresponding credentials class. * s3 * gs, gcs - * az, abfs, adl + * az, abfs, adl, abfss, azure * file, memory * gdrive """ @@ -72,6 +119,8 @@ class FilesystemConfiguration(BaseConfiguration): "az": AnyAzureCredentials, "abfs": AnyAzureCredentials, "adl": AnyAzureCredentials, + "abfss": AnyAzureCredentials, + "azure": AnyAzureCredentials, } bucket_url: str = None @@ -93,17 +142,21 @@ def protocol(self) -> str: else: return urlparse(self.bucket_url).scheme + @property + def is_local_filesystem(self) -> bool: + return self.protocol == "file" + def on_resolved(self) -> None: - uri = urlparse(self.bucket_url) - if not uri.path and not uri.netloc: + url = urlparse(self.bucket_url) + if not url.path and not url.netloc: raise ConfigurationValueError( "File path and netloc are missing. Field bucket_url of" - " FilesystemClientConfiguration must contain valid uri with a path or host:password" + " FilesystemClientConfiguration must contain valid url with a path or host:password" " component." ) # this is just a path in a local file system if self.is_local_path(self.bucket_url): - self.bucket_url = self.make_file_uri(self.bucket_url) + self.bucket_url = self.make_file_url(self.bucket_url) @resolve_type("credentials") def resolve_credentials_type(self) -> Type[CredentialsConfiguration]: @@ -122,44 +175,50 @@ def fingerprint(self) -> str: if self.is_local_path(self.bucket_url): return digest128("") - uri = urlparse(self.bucket_url) - return digest128(self.bucket_url.replace(uri.path, "")) + url = urlparse(self.bucket_url) + return digest128(self.bucket_url.replace(url.path, "")) + + def make_url(self, fs_path: str) -> str: + """Makes a full url (with scheme) form fs_path which is kind-of absolute path used by fsspec to identify resources. + This method will use `bucket_url` to infer the original form of the url. + """ + return make_fsspec_url(self.protocol, fs_path, self.bucket_url) def __str__(self) -> str: """Return displayable destination location""" - uri = urlparse(self.bucket_url) + url = urlparse(self.bucket_url) # do not show passwords - if uri.password: - new_netloc = f"{uri.username}:****@{uri.hostname}" - if uri.port: - new_netloc += f":{uri.port}" - return uri._replace(netloc=new_netloc).geturl() + if url.password: + new_netloc = f"{url.username}:****@{url.hostname}" + if url.port: + new_netloc += f":{url.port}" + return url._replace(netloc=new_netloc).geturl() return self.bucket_url @staticmethod - def is_local_path(uri: str) -> bool: - """Checks if `uri` is a local path, without a schema""" - uri_parsed = urlparse(uri) + def is_local_path(url: str) -> bool: + """Checks if `url` is a local path, without a schema""" + url_parsed = urlparse(url) # this prevents windows absolute paths to be recognized as schemas - return not uri_parsed.scheme or os.path.isabs(uri) + return not url_parsed.scheme or os.path.isabs(url) @staticmethod - def make_local_path(file_uri: str) -> str: + def make_local_path(file_url: str) -> str: """Gets a valid local filesystem path from file:// scheme. Supports POSIX/Windows/UNC paths Returns: str: local filesystem path """ - uri = urlparse(file_uri) - if uri.scheme != "file": - raise ValueError(f"Must be file scheme but is {uri.scheme}") - if not uri.path and not uri.netloc: + url = urlparse(file_url) + if url.scheme != "file": + raise ValueError(f"Must be file scheme but is {url.scheme}") + if not url.path and not url.netloc: raise ConfigurationValueError("File path and netloc are missing.") - local_path = unquote(uri.path) - if uri.netloc: + local_path = unquote(url.path) + if url.netloc: # or UNC file://localhost/path - local_path = "//" + unquote(uri.netloc) + local_path + local_path = "//" + unquote(url.netloc) + local_path else: # if we are on windows, strip the POSIX root from path which is always absolute if os.path.sep != local_path[0]: @@ -172,11 +231,9 @@ def make_local_path(file_uri: str) -> str: return str(pathlib.Path(local_path)) @staticmethod - def make_file_uri(local_path: str) -> str: - """Creates a normalized file:// uri from a local path + def make_file_url(local_path: str) -> str: + """Creates a normalized file:// url from a local path netloc is never set. UNC paths are represented as file://host/path """ - p_ = pathlib.Path(local_path) - p_ = p_.expanduser().resolve() - return p_.as_uri() + return make_fsspec_url("file", local_path, None) diff --git a/dlt/common/storages/fsspec_filesystem.py b/dlt/common/storages/fsspec_filesystem.py index be9ae2bbb1..7da5ebabef 100644 --- a/dlt/common/storages/fsspec_filesystem.py +++ b/dlt/common/storages/fsspec_filesystem.py @@ -21,7 +21,7 @@ ) from urllib.parse import urlparse -from fsspec import AbstractFileSystem, register_implementation +from fsspec import AbstractFileSystem, register_implementation, get_filesystem_class from fsspec.core import url_to_fs from dlt import version @@ -32,7 +32,11 @@ AzureCredentials, ) from dlt.common.exceptions import MissingDependencyException -from dlt.common.storages.configuration import FileSystemCredentials, FilesystemConfiguration +from dlt.common.storages.configuration import ( + FileSystemCredentials, + FilesystemConfiguration, + make_fsspec_url, +) from dlt.common.time import ensure_pendulum_datetime from dlt.common.typing import DictStrAny @@ -65,18 +69,20 @@ class FileItem(TypedDict, total=False): MTIME_DISPATCH["gs"] = MTIME_DISPATCH["gcs"] MTIME_DISPATCH["s3a"] = MTIME_DISPATCH["s3"] MTIME_DISPATCH["abfs"] = MTIME_DISPATCH["az"] +MTIME_DISPATCH["abfss"] = MTIME_DISPATCH["az"] # Map of protocol to a filesystem type CREDENTIALS_DISPATCH: Dict[str, Callable[[FilesystemConfiguration], DictStrAny]] = { "s3": lambda config: cast(AwsCredentials, config.credentials).to_s3fs_credentials(), - "adl": lambda config: cast(AzureCredentials, config.credentials).to_adlfs_credentials(), "az": lambda config: cast(AzureCredentials, config.credentials).to_adlfs_credentials(), - "gcs": lambda config: cast(GcpCredentials, config.credentials).to_gcs_credentials(), "gs": lambda config: cast(GcpCredentials, config.credentials).to_gcs_credentials(), "gdrive": lambda config: {"credentials": cast(GcpCredentials, config.credentials)}, - "abfs": lambda config: cast(AzureCredentials, config.credentials).to_adlfs_credentials(), - "azure": lambda config: cast(AzureCredentials, config.credentials).to_adlfs_credentials(), } +CREDENTIALS_DISPATCH["adl"] = CREDENTIALS_DISPATCH["az"] +CREDENTIALS_DISPATCH["abfs"] = CREDENTIALS_DISPATCH["az"] +CREDENTIALS_DISPATCH["azure"] = CREDENTIALS_DISPATCH["az"] +CREDENTIALS_DISPATCH["abfss"] = CREDENTIALS_DISPATCH["az"] +CREDENTIALS_DISPATCH["gcs"] = CREDENTIALS_DISPATCH["gs"] def fsspec_filesystem( @@ -90,7 +96,7 @@ def fsspec_filesystem( Please supply credentials instance corresponding to the protocol. The `protocol` is just the code name of the filesystem i.e.: * s3 - * az, abfs + * az, abfs, abfss, adl, azure * gcs, gs also see filesystem_from_config @@ -136,7 +142,7 @@ def fsspec_from_config(config: FilesystemConfiguration) -> Tuple[AbstractFileSys Authenticates following filesystems: * s3 - * az, abfs + * az, abfs, abfss, adl, azure * gcs, gs All other filesystems are not authenticated @@ -146,8 +152,14 @@ def fsspec_from_config(config: FilesystemConfiguration) -> Tuple[AbstractFileSys fs_kwargs = prepare_fsspec_args(config) try: + # first get the class to check the protocol + fs_cls = get_filesystem_class(config.protocol) + if fs_cls.protocol == "abfs": + # if storage account is present in bucket_url and in credentials, az fsspec will fail + if urlparse(config.bucket_url).username: + fs_kwargs.pop("account_name") return url_to_fs(config.bucket_url, **fs_kwargs) # type: ignore - except ModuleNotFoundError as e: + except ImportError as e: raise MissingDependencyException( "filesystem", [f"{version.DLT_PKG_NAME}[{config.protocol}]"] ) from e @@ -291,10 +303,8 @@ def glob_files( """ is_local_fs = "file" in fs_client.protocol if is_local_fs and FilesystemConfiguration.is_local_path(bucket_url): - bucket_url = FilesystemConfiguration.make_file_uri(bucket_url) - bucket_url_parsed = urlparse(bucket_url) - else: - bucket_url_parsed = urlparse(bucket_url) + bucket_url = FilesystemConfiguration.make_file_url(bucket_url) + bucket_url_parsed = urlparse(bucket_url) if is_local_fs: root_dir = FilesystemConfiguration.make_local_path(bucket_url) @@ -302,7 +312,8 @@ def glob_files( files = glob.glob(str(pathlib.Path(root_dir).joinpath(file_glob)), recursive=True) glob_result = {file: fs_client.info(file) for file in files} else: - root_dir = bucket_url_parsed._replace(scheme="", query="").geturl().lstrip("/") + # convert to fs_path + root_dir = fs_client._strip_protocol(bucket_url) filter_url = posixpath.join(root_dir, file_glob) glob_result = fs_client.glob(filter_url, detail=True) if isinstance(glob_result, list): @@ -314,20 +325,23 @@ def glob_files( for file, md in glob_result.items(): if md["type"] != "file": continue + scheme = bucket_url_parsed.scheme + # relative paths are always POSIX if is_local_fs: - rel_path = pathlib.Path(file).relative_to(root_dir).as_posix() - file_url = FilesystemConfiguration.make_file_uri(file) + # use OS pathlib for local paths + loc_path = pathlib.Path(file) + file_name = loc_path.name + rel_path = loc_path.relative_to(root_dir).as_posix() + file_url = FilesystemConfiguration.make_file_url(file) else: - rel_path = posixpath.relpath(file.lstrip("/"), root_dir) - file_url = bucket_url_parsed._replace( - path=posixpath.join(bucket_url_parsed.path, rel_path) - ).geturl() + file_name = posixpath.basename(file) + rel_path = posixpath.relpath(file, root_dir) + file_url = make_fsspec_url(scheme, file, bucket_url) - scheme = bucket_url_parsed.scheme mime_type, encoding = guess_mime_type(rel_path) yield FileItem( - file_name=posixpath.basename(rel_path), + file_name=file_name, relative_path=rel_path, file_url=file_url, mime_type=mime_type, diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index b28309b930..b3b2fbcf0f 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -34,7 +34,6 @@ from dlt.common import logger from dlt.common.exceptions import TerminalValueError -from dlt.common.storages.fsspec_filesystem import fsspec_from_config from dlt.common.utils import uniq_id, without_none from dlt.common.schema import TColumnSchema, Schema, TTableSchema from dlt.common.schema.typing import ( diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index 11326cf3ed..1dd4c727be 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -432,7 +432,7 @@ def _create_load_job(self, table: TTableSchema, file_path: str) -> bigquery.Load # append to table for merge loads (append to stage) and regular appends. table_name = table["name"] - # determine whether we load from local or uri + # determine whether we load from local or url bucket_path = None ext: str = os.path.splitext(file_path)[1][1:] if ReferenceFollowupJobRequest.is_reference_job(file_path): diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index 3bd2d12a5a..789dbedae9 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -43,6 +43,10 @@ def to_connector_params(self) -> Dict[str, Any]: class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration): destination_type: Final[str] = dataclasses.field(default="databricks", init=False, repr=False, compare=False) # type: ignore[misc] credentials: DatabricksCredentials = None + staging_credentials_name: Optional[str] = None + "If set, credentials with given name will be used in copy command" + is_staging_external_location: bool = False + """If true, the temporary credentials are not propagated to the COPY command""" def __str__(self) -> str: """Return displayable destination location""" diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 38412b2608..614e6e97c5 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -1,4 +1,4 @@ -from typing import ClassVar, Dict, Optional, Sequence, Tuple, List, Any, Iterable, Type, cast +from typing import Optional, Sequence, List, cast from urllib.parse import urlparse, urlunparse from dlt import config @@ -6,20 +6,17 @@ from dlt.common.destination.reference import ( HasFollowupJobs, FollowupJobRequest, - TLoadJobState, RunnableLoadJob, - CredentialsConfiguration, SupportsStagingDestination, LoadJob, ) from dlt.common.configuration.specs import ( AwsCredentialsWithoutDefaults, - AzureCredentials, AzureCredentialsWithoutDefaults, ) from dlt.common.exceptions import TerminalValueError from dlt.common.storages.file_storage import FileStorage -from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns +from dlt.common.schema import TColumnSchema, Schema from dlt.common.schema.typing import TTableSchema, TColumnType, TSchemaTables, TTableFormat from dlt.common.schema.utils import table_schema_has_type from dlt.common.storages import FilesystemConfiguration, fsspec_from_config @@ -35,6 +32,9 @@ from dlt.destinations.type_mapping import TypeMapper +AZURE_BLOB_STORAGE_PROTOCOLS = ["az", "abfss", "abfs"] + + class DatabricksTypeMapper(TypeMapper): sct_to_unbound_dbt = { "complex": "STRING", # Databricks supports complex types like ARRAY @@ -137,41 +137,51 @@ def run(self) -> None: if bucket_path: bucket_url = urlparse(bucket_path) bucket_scheme = bucket_url.scheme - # referencing an staged files via a bucket URL requires explicit AWS credentials - if bucket_scheme == "s3" and isinstance( - staging_credentials, AwsCredentialsWithoutDefaults - ): - s3_creds = staging_credentials.to_session_credentials() - credentials_clause = f"""WITH(CREDENTIAL( - AWS_ACCESS_KEY='{s3_creds["aws_access_key_id"]}', - AWS_SECRET_KEY='{s3_creds["aws_secret_access_key"]}', - - AWS_SESSION_TOKEN='{s3_creds["aws_session_token"]}' - )) - """ - from_clause = f"FROM '{bucket_path}'" - elif bucket_scheme in ["az", "abfs"] and isinstance( - staging_credentials, AzureCredentialsWithoutDefaults - ): - # Explicit azure credentials are needed to load from bucket without a named stage - credentials_clause = f"""WITH(CREDENTIAL(AZURE_SAS_TOKEN='{staging_credentials.azure_storage_sas_token}'))""" - # Converts an az:/// to abfss://@.dfs.core.windows.net/ - # as required by snowflake - _path = bucket_url.path - bucket_path = urlunparse( - bucket_url._replace( - scheme="abfss", - netloc=f"{bucket_url.netloc}@{staging_credentials.azure_storage_account_name}.dfs.core.windows.net", - path=_path, - ) - ) - from_clause = f"FROM '{bucket_path}'" - else: + + if bucket_scheme not in AZURE_BLOB_STORAGE_PROTOCOLS + ["s3"]: raise LoadJobTerminalException( self._file_path, f"Databricks cannot load data from staging bucket {bucket_path}. Only s3 and" " azure buckets are supported", ) + + if self._job_client.config.is_staging_external_location: + # just skip the credentials clause for external location + # https://docs.databricks.com/en/sql/language-manual/sql-ref-external-locations.html#external-location + pass + elif self._job_client.config.staging_credentials_name: + # add named credentials + credentials_clause = ( + f"WITH(CREDENTIAL {self._job_client.config.staging_credentials_name} )" + ) + else: + # referencing an staged files via a bucket URL requires explicit AWS credentials + if bucket_scheme == "s3": + assert isinstance(staging_credentials, AwsCredentialsWithoutDefaults) + s3_creds = staging_credentials.to_session_credentials() + credentials_clause = f"""WITH(CREDENTIAL( + AWS_ACCESS_KEY='{s3_creds["aws_access_key_id"]}', + AWS_SECRET_KEY='{s3_creds["aws_secret_access_key"]}', + + AWS_SESSION_TOKEN='{s3_creds["aws_session_token"]}' + )) + """ + elif bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS: + assert isinstance(staging_credentials, AzureCredentialsWithoutDefaults) + # Explicit azure credentials are needed to load from bucket without a named stage + credentials_clause = f"""WITH(CREDENTIAL(AZURE_SAS_TOKEN='{staging_credentials.azure_storage_sas_token}'))""" + bucket_path = self.ensure_databricks_abfss_url( + bucket_path, staging_credentials.azure_storage_account_name + ) + + if bucket_scheme in AZURE_BLOB_STORAGE_PROTOCOLS: + assert isinstance(staging_credentials, AzureCredentialsWithoutDefaults) + bucket_path = self.ensure_databricks_abfss_url( + bucket_path, staging_credentials.azure_storage_account_name + ) + + # always add FROM clause + from_clause = f"FROM '{bucket_path}'" else: raise LoadJobTerminalException( self._file_path, @@ -231,6 +241,34 @@ def run(self) -> None: """ self._sql_client.execute_sql(statement) + @staticmethod + def ensure_databricks_abfss_url( + bucket_path: str, azure_storage_account_name: str = None + ) -> str: + bucket_url = urlparse(bucket_path) + # Converts an az:/// to abfss://@.dfs.core.windows.net/ + if bucket_url.username: + # has the right form, ensure abfss schema + return urlunparse(bucket_url._replace(scheme="abfss")) + + if not azure_storage_account_name: + raise TerminalValueError( + f"Could not convert azure blob storage url {bucket_path} into form required by" + " Databricks" + " (abfss://@.dfs.core.windows.net/)" + " because storage account name is not known. Please use Databricks abfss://" + " canonical url as bucket_url in staging credentials" + ) + # as required by databricks + _path = bucket_url.path + return urlunparse( + bucket_url._replace( + scheme="abfss", + netloc=f"{bucket_url.netloc}@{azure_storage_account_name}.dfs.core.windows.net", + path=_path, + ) + ) + class DatabricksMergeJob(SqlMergeFollowupJob): @classmethod diff --git a/dlt/destinations/impl/databricks/factory.py b/dlt/destinations/impl/databricks/factory.py index 409d3bc4be..6108b69da9 100644 --- a/dlt/destinations/impl/databricks/factory.py +++ b/dlt/destinations/impl/databricks/factory.py @@ -54,6 +54,8 @@ def client_class(self) -> t.Type["DatabricksClient"]: def __init__( self, credentials: t.Union[DatabricksCredentials, t.Dict[str, t.Any], str] = None, + is_staging_external_location: t.Optional[bool] = False, + staging_credentials_name: t.Optional[str] = None, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, **kwargs: t.Any, @@ -65,10 +67,14 @@ def __init__( Args: credentials: Credentials to connect to the databricks database. Can be an instance of `DatabricksCredentials` or a connection string in the format `databricks://user:password@host:port/database` + is_staging_external_location: If true, the temporary credentials are not propagated to the COPY command + staging_credentials_name: If set, credentials with given name will be used in copy command **kwargs: Additional arguments passed to the destination config """ super().__init__( credentials=credentials, + is_staging_external_location=is_staging_external_location, + staging_credentials_name=staging_credentials_name, destination_name=destination_name, environment=environment, **kwargs, diff --git a/dlt/destinations/impl/dummy/dummy.py b/dlt/destinations/impl/dummy/dummy.py index feb09369dc..fc87faaf5a 100644 --- a/dlt/destinations/impl/dummy/dummy.py +++ b/dlt/destinations/impl/dummy/dummy.py @@ -90,9 +90,9 @@ def run(self) -> None: def metrics(self) -> Optional[LoadJobMetrics]: m = super().metrics() - # add remote uri if there's followup job + # add remote url if there's followup job if self.config.create_followup_jobs: - m = m._replace(remote_uri=self._file_name) + m = m._replace(remote_url=self._file_name) return m diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index 62263a10b9..ac5ffb9ef3 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -56,7 +56,7 @@ def __init__( self._job_client: FilesystemClient = None def run(self) -> None: - self.__is_local_filesystem = self._job_client.config.protocol == "file" + self.__is_local_filesystem = self._job_client.config.is_local_filesystem # We would like to avoid failing for local filesystem where # deeply nested directory will not exist before writing a file. # It `auto_mkdir` is disabled by default in fsspec so we made some @@ -88,13 +88,13 @@ def make_remote_path(self) -> str: path_utils.normalize_path_sep(pathlib, destination_file_name), ) - def make_remote_uri(self) -> str: - """Returns path on a remote filesystem as a full uri including scheme.""" - return self._job_client.make_remote_uri(self.make_remote_path()) + def make_remote_url(self) -> str: + """Returns path on a remote filesystem as a full url including scheme.""" + return self._job_client.make_remote_url(self.make_remote_path()) def metrics(self) -> Optional[LoadJobMetrics]: m = super().metrics() - return m._replace(remote_uri=self.make_remote_uri()) + return m._replace(remote_url=self.make_remote_url()) class DeltaLoadFilesystemJob(FilesystemLoadJob): @@ -112,7 +112,7 @@ def make_remote_path(self) -> str: return self._job_client.get_table_dir(self.load_table_name) def run(self) -> None: - logger.info(f"Will copy file(s) {self.file_paths} to delta table {self.make_remote_uri()}") + logger.info(f"Will copy file(s) {self.file_paths} to delta table {self.make_remote_url()}") from dlt.common.libs.deltalake import write_delta_table, merge_delta_table @@ -133,7 +133,7 @@ def run(self) -> None: else: write_delta_table( table_or_uri=( - self.make_remote_uri() if self._delta_table is None else self._delta_table + self.make_remote_url() if self._delta_table is None else self._delta_table ), data=arrow_rbr, write_disposition=self._load_table["write_disposition"], @@ -151,7 +151,7 @@ def _storage_options(self) -> Dict[str, str]: def _delta_table(self) -> Optional["DeltaTable"]: # type: ignore[name-defined] # noqa: F821 from dlt.common.libs.deltalake import try_get_deltatable - return try_get_deltatable(self.make_remote_uri(), storage_options=self._storage_options) + return try_get_deltatable(self.make_remote_url(), storage_options=self._storage_options) @property def _partition_columns(self) -> List[str]: @@ -166,7 +166,7 @@ def _create_or_evolve_delta_table(self) -> None: if self._delta_table is None: DeltaTable.create( - table_uri=self.make_remote_uri(), + table_uri=self.make_remote_url(), schema=ensure_delta_compatible_arrow_schema(self.arrow_ds.schema), mode="overwrite", partition_by=self._partition_columns, @@ -185,7 +185,7 @@ def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJobRe elif final_state == "completed": ref_job = ReferenceFollowupJobRequest( original_file_name=self.file_name(), - remote_paths=[self._job_client.make_remote_uri(self.make_remote_path())], + remote_paths=[self._job_client.make_remote_url(self.make_remote_path())], ) jobs.append(ref_job) return jobs @@ -208,7 +208,7 @@ def __init__( ) -> None: super().__init__(schema, config, capabilities) self.fs_client, fs_path = fsspec_from_config(config) - self.is_local_filesystem = config.protocol == "file" + self.is_local_filesystem = config.is_local_filesystem self.bucket_path = ( config.make_local_path(config.bucket_url) if self.is_local_filesystem else fs_path ) @@ -319,7 +319,7 @@ def get_table_dir(self, table_name: str, remote: bool = False) -> str: table_prefix = self.get_table_prefix(table_name) table_dir: str = self.pathlib.dirname(table_prefix) if remote: - table_dir = self.make_remote_uri(table_dir) + table_dir = self.make_remote_url(table_dir) return table_dir def get_table_prefix(self, table_name: str) -> str: @@ -353,7 +353,7 @@ def list_files_with_prefixes(self, table_dir: str, prefixes: List[str]) -> List[ # we fallback to our own glob implementation that is tested to return consistent results for # filesystems we support. we were not able to use `find` or `walk` because they were selecting # files wrongly (on azure walk on path1/path2/ would also select files from path1/path2_v2/ but returning wrong dirs) - for details in glob_files(self.fs_client, self.make_remote_uri(table_dir), "**"): + for details in glob_files(self.fs_client, self.make_remote_url(table_dir), "**"): file = details["file_name"] filepath = self.pathlib.join(table_dir, details["relative_path"]) # skip INIT files @@ -388,12 +388,12 @@ def create_load_job( cls = FilesystemLoadJobWithFollowup if self.config.as_staging else FilesystemLoadJob return cls(file_path) - def make_remote_uri(self, remote_path: str) -> str: + def make_remote_url(self, remote_path: str) -> str: """Returns uri to the remote filesystem to which copy the file""" if self.is_local_filesystem: - return self.config.make_file_uri(remote_path) + return self.config.make_file_url(remote_path) else: - return f"{self.config.protocol}://{remote_path}" + return self.config.make_url(remote_path) def __enter__(self) -> "FilesystemClient": return self diff --git a/docs/website/docs/dlt-ecosystem/destinations/databricks.md b/docs/website/docs/dlt-ecosystem/destinations/databricks.md index 6cd5767dcb..ddb82c95b2 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/databricks.md +++ b/docs/website/docs/dlt-ecosystem/destinations/databricks.md @@ -117,6 +117,8 @@ access_token = "MY_ACCESS_TOKEN" catalog = "my_catalog" ``` +See [staging support](#staging-support) for authentication options when `dlt` copies files from buckets. + ## Write disposition All write dispositions are supported @@ -166,6 +168,11 @@ pipeline = dlt.pipeline( Refer to the [Azure Blob Storage filesystem documentation](./filesystem.md#azure-blob-storage) for details on connecting your Azure Blob Storage container with the bucket_url and credentials. +Databricks requires that you use ABFS urls in following format: +**abfss://container_name@storage_account_name.dfs.core.windows.net/path** + +`dlt` is able to adapt the other representation (ie **az://container-name/path**') still we recommend that you use the correct form. + Example to set up Databricks with Azure as a staging destination: ```py @@ -175,10 +182,34 @@ Example to set up Databricks with Azure as a staging destination: pipeline = dlt.pipeline( pipeline_name='chess_pipeline', destination='databricks', - staging=dlt.destinations.filesystem('az://your-container-name'), # add this to activate the staging location + staging=dlt.destinations.filesystem('abfss://dlt-ci-data@dltdata.dfs.core.windows.net'), # add this to activate the staging location dataset_name='player_data' ) + ``` + +### Use external locations and stored credentials +`dlt` forwards bucket credentials to `COPY INTO` SQL command by default. You may prefer to use [external locations or stored credentials instead](https://docs.databricks.com/en/sql/language-manual/sql-ref-external-locations.html#external-location) that are stored on the Databricks side. + +If you set up external location for your staging path, you can tell `dlt` to use it: +```toml +[destination.databricks] +is_staging_external_location=true +``` + +If you set up Databricks credential named ie. **credential_x**, you can tell `dlt` to use it: +```toml +[destination.databricks] +staging_credentials_name="credential_x" +``` + +Both options are available from code: +```py +import dlt + +bricks = dlt.destinations.databricks(staging_credentials_name="credential_x") +``` + ### dbt support This destination [integrates with dbt](../transformations/dbt/dbt.md) via [dbt-databricks](https://github.com/databricks/dbt-databricks) diff --git a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md index d08578c5a2..57e6db311d 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md +++ b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md @@ -176,7 +176,7 @@ Note that we ignore missing columns `ERROR_ON_COLUMN_COUNT_MISMATCH = FALSE` and Snowflake supports the following [column hints](https://dlthub.com/docs/general-usage/schema#tables-and-columns): * `cluster` - creates a cluster column(s). Many columns per table are supported and only when a new table is created. -### Table and column identifiers +## Table and column identifiers Snowflake supports both case sensitive and case insensitive identifiers. All unquoted and uppercase identifiers resolve case-insensitively in SQL statements. Case insensitive [naming conventions](../../general-usage/naming-convention.md#case-sensitive-and-insensitive-destinations) like the default **snake_case** will generate case insensitive identifiers. Case sensitive (like **sql_cs_v1**) will generate case sensitive identifiers that must be quoted in SQL statements. diff --git a/docs/website/docs/running-in-production/running.md b/docs/website/docs/running-in-production/running.md index 3b5762612c..cc089a1393 100644 --- a/docs/website/docs/running-in-production/running.md +++ b/docs/website/docs/running-in-production/running.md @@ -271,7 +271,7 @@ load_info.raise_on_failed_jobs() ``` You may also abort the load package with `LoadClientJobFailed` (terminal exception) on a first -failed job. Such package is immediately moved to completed but its load id is not added to the +failed job. Such package is will be completed but its load id is not added to the `_dlt_loads` table. All the jobs that were running in parallel are completed before raising. The dlt state, if present, will not be visible to `dlt`. Here's example `config.toml` to enable this option: @@ -282,6 +282,20 @@ load.workers=1 load.raise_on_failed_jobs=true ``` +:::caution +Note that certain write dispositions will irreversibly modify your data +1. `replace` write disposition with the default `truncate-and-insert` [strategy](../general-usage/full-loading.md) will truncate tables before loading. +2. `merge` write disposition will merge staging dataset tables into the destination dataset. This will happen only when all data for this table (and nested tables) got loaded. + +Here's what you can do to deal with partially loaded packages: +1. Retry the load step in case of transient errors +2. Use replace strategy with staging dataset so replace happens only when data for the table (and all nested tables) was fully loaded and is atomic operation (if possible) +3. Use only "append" write disposition. When your load package fails you are able to use `_dlt_load_id` to remove all unprocessed data. +4. Use "staging append" (`merge` disposition without primary key and merge key defined). + +::: + + ### What `run` does inside Before adding retry to pipeline steps, note how `run` method actually works: diff --git a/tests/.dlt/config.toml b/tests/.dlt/config.toml index ba86edf417..292175569b 100644 --- a/tests/.dlt/config.toml +++ b/tests/.dlt/config.toml @@ -6,7 +6,8 @@ bucket_url_gs="gs://ci-test-bucket" bucket_url_s3="s3://dlt-ci-test-bucket" bucket_url_file="_storage" bucket_url_az="az://dlt-ci-test-bucket" +bucket_url_abfss="abfss://dlt-ci-test-bucket@dltdata.dfs.core.windows.net" bucket_url_r2="s3://dlt-ci-test-bucket" # use "/" as root path bucket_url_gdrive="gdrive://15eC3e5MNew2XAIefWNlG8VlEa0ISnnaG" -memory="memory://m" \ No newline at end of file +memory="memory:///m" \ No newline at end of file diff --git a/tests/common/cases/normalizers/sql_upper.py b/tests/common/cases/normalizers/sql_upper.py index f2175f06ad..eb88775f95 100644 --- a/tests/common/cases/normalizers/sql_upper.py +++ b/tests/common/cases/normalizers/sql_upper.py @@ -1,5 +1,3 @@ -from typing import Any, Sequence - from dlt.common.normalizers.naming.naming import NamingConvention as BaseNamingConvention diff --git a/tests/common/storages/test_local_filesystem.py b/tests/common/storages/test_local_filesystem.py index 14e3cc23d4..1bfe6c0b5b 100644 --- a/tests/common/storages/test_local_filesystem.py +++ b/tests/common/storages/test_local_filesystem.py @@ -45,7 +45,7 @@ ) def test_local_path_win_configuration(bucket_url: str, file_url: str) -> None: assert FilesystemConfiguration.is_local_path(bucket_url) is True - assert FilesystemConfiguration.make_file_uri(bucket_url) == file_url + assert FilesystemConfiguration.make_file_url(bucket_url) == file_url c = resolve_configuration(FilesystemConfiguration(bucket_url)) assert c.protocol == "file" @@ -66,7 +66,7 @@ def test_local_path_win_configuration(bucket_url: str, file_url: str) -> None: def test_local_user_win_path_configuration(bucket_url: str) -> None: file_url = "file:///" + pathlib.Path(bucket_url).expanduser().as_posix().lstrip("/") assert FilesystemConfiguration.is_local_path(bucket_url) is True - assert FilesystemConfiguration.make_file_uri(bucket_url) == file_url + assert FilesystemConfiguration.make_file_url(bucket_url) == file_url c = resolve_configuration(FilesystemConfiguration(bucket_url)) assert c.protocol == "file" @@ -99,7 +99,7 @@ def test_file_win_configuration() -> None: ) def test_file_posix_configuration(bucket_url: str, file_url: str) -> None: assert FilesystemConfiguration.is_local_path(bucket_url) is True - assert FilesystemConfiguration.make_file_uri(bucket_url) == file_url + assert FilesystemConfiguration.make_file_url(bucket_url) == file_url c = resolve_configuration(FilesystemConfiguration(bucket_url)) assert c.protocol == "file" @@ -117,7 +117,7 @@ def test_file_posix_configuration(bucket_url: str, file_url: str) -> None: def test_local_user_posix_path_configuration(bucket_url: str) -> None: file_url = "file:///" + pathlib.Path(bucket_url).expanduser().as_posix().lstrip("/") assert FilesystemConfiguration.is_local_path(bucket_url) is True - assert FilesystemConfiguration.make_file_uri(bucket_url) == file_url + assert FilesystemConfiguration.make_file_url(bucket_url) == file_url c = resolve_configuration(FilesystemConfiguration(bucket_url)) assert c.protocol == "file" @@ -166,7 +166,7 @@ def test_file_filesystem_configuration( assert FilesystemConfiguration.make_local_path(bucket_url) == str( pathlib.Path(local_path).resolve() ) - assert FilesystemConfiguration.make_file_uri(local_path) == norm_bucket_url + assert FilesystemConfiguration.make_file_url(local_path) == norm_bucket_url if local_path == "": with pytest.raises(ConfigurationValueError): diff --git a/tests/destinations/test_destination_name_and_config.py b/tests/destinations/test_destination_name_and_config.py index 11de706722..1e432a7803 100644 --- a/tests/destinations/test_destination_name_and_config.py +++ b/tests/destinations/test_destination_name_and_config.py @@ -60,7 +60,7 @@ def test_set_name_and_environment() -> None: def test_preserve_destination_instance() -> None: dummy1 = dummy(destination_name="dummy1", environment="dev/null/1") filesystem1 = filesystem( - FilesystemConfiguration.make_file_uri(TEST_STORAGE_ROOT), + FilesystemConfiguration.make_file_url(TEST_STORAGE_ROOT), destination_name="local_fs", environment="devel", ) @@ -210,7 +210,7 @@ def test_destination_config_in_name(environment: DictStrStr) -> None: with pytest.raises(ConfigFieldMissingException): p.destination_client() - environment["DESTINATION__FILESYSTEM-PROD__BUCKET_URL"] = FilesystemConfiguration.make_file_uri( + environment["DESTINATION__FILESYSTEM-PROD__BUCKET_URL"] = FilesystemConfiguration.make_file_url( "_storage" ) assert p._fs_client().dataset_path.endswith(p.dataset_name) diff --git a/tests/load/databricks/test_databricks_configuration.py b/tests/load/databricks/test_databricks_configuration.py index f6a06180c9..bb989a887c 100644 --- a/tests/load/databricks/test_databricks_configuration.py +++ b/tests/load/databricks/test_databricks_configuration.py @@ -3,9 +3,12 @@ pytest.importorskip("databricks") +from dlt.common.exceptions import TerminalValueError +from dlt.destinations.impl.databricks.databricks import DatabricksLoadJob +from dlt.common.configuration import resolve_configuration +from dlt.destinations import databricks from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration -from dlt.common.configuration import resolve_configuration # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -34,3 +37,48 @@ def test_databricks_credentials_to_connector_params(): assert params["extra_a"] == "a" assert params["extra_b"] == "b" assert params["_socket_timeout"] == credentials.socket_timeout + + +def test_databricks_configuration() -> None: + bricks = databricks() + config = bricks.configuration(None, accept_partial=True) + assert config.is_staging_external_location is False + assert config.staging_credentials_name is None + + os.environ["IS_STAGING_EXTERNAL_LOCATION"] = "true" + os.environ["STAGING_CREDENTIALS_NAME"] = "credential" + config = bricks.configuration(None, accept_partial=True) + assert config.is_staging_external_location is True + assert config.staging_credentials_name == "credential" + + # explicit params + bricks = databricks(is_staging_external_location=None, staging_credentials_name="credential2") + config = bricks.configuration(None, accept_partial=True) + assert config.staging_credentials_name == "credential2" + assert config.is_staging_external_location is None + + +def test_databricks_abfss_converter() -> None: + with pytest.raises(TerminalValueError): + DatabricksLoadJob.ensure_databricks_abfss_url("az://dlt-ci-test-bucket") + + abfss_url = DatabricksLoadJob.ensure_databricks_abfss_url( + "az://dlt-ci-test-bucket", "my_account" + ) + assert abfss_url == "abfss://dlt-ci-test-bucket@my_account.dfs.core.windows.net" + + abfss_url = DatabricksLoadJob.ensure_databricks_abfss_url( + "az://dlt-ci-test-bucket/path/to/file.parquet", "my_account" + ) + assert ( + abfss_url + == "abfss://dlt-ci-test-bucket@my_account.dfs.core.windows.net/path/to/file.parquet" + ) + + abfss_url = DatabricksLoadJob.ensure_databricks_abfss_url( + "az://dlt-ci-test-bucket@my_account.dfs.core.windows.net/path/to/file.parquet" + ) + assert ( + abfss_url + == "abfss://dlt-ci-test-bucket@my_account.dfs.core.windows.net/path/to/file.parquet" + ) diff --git a/tests/load/filesystem/test_filesystem_common.py b/tests/load/filesystem/test_filesystem_common.py index 3cad7dda2c..29ca1a2b57 100644 --- a/tests/load/filesystem/test_filesystem_common.py +++ b/tests/load/filesystem/test_filesystem_common.py @@ -3,8 +3,8 @@ from typing import Tuple, Union, Dict from urllib.parse import urlparse - -from fsspec import AbstractFileSystem +from fsspec import AbstractFileSystem, get_filesystem_class, register_implementation +from fsspec.core import filesystem as fs_filesystem import pytest from tenacity import retry, stop_after_attempt, wait_fixed @@ -15,6 +15,7 @@ from dlt.common.configuration.inject import with_config from dlt.common.configuration.specs import AnyAzureCredentials from dlt.common.storages import fsspec_from_config, FilesystemConfiguration +from dlt.common.storages.configuration import make_fsspec_url from dlt.common.storages.fsspec_filesystem import MTIME_DISPATCH, glob_files from dlt.common.utils import custom_environ, uniq_id from dlt.destinations import filesystem @@ -22,11 +23,12 @@ FilesystemDestinationClientConfiguration, ) from dlt.destinations.impl.filesystem.typing import TExtraPlaceholders + +from tests.common.configuration.utils import environment from tests.common.storages.utils import TEST_SAMPLE_FILES, assert_sample_files -from tests.load.utils import ALL_FILESYSTEM_DRIVERS, AWS_BUCKET +from tests.load.utils import ALL_FILESYSTEM_DRIVERS, AWS_BUCKET, WITH_GDRIVE_BUCKETS from tests.utils import autouse_test_storage -from .utils import self_signed_cert -from tests.common.configuration.utils import environment +from tests.load.filesystem.utils import self_signed_cert # mark all tests as essential, do not remove @@ -53,6 +55,24 @@ def test_filesystem_configuration() -> None: } +@pytest.mark.parametrize("bucket_url", WITH_GDRIVE_BUCKETS) +def test_remote_url(bucket_url: str) -> None: + # make absolute urls out of paths + scheme = urlparse(bucket_url).scheme + if not scheme: + scheme = "file" + bucket_url = FilesystemConfiguration.make_file_url(bucket_url) + if scheme == "gdrive": + from dlt.common.storages.fsspecs.google_drive import GoogleDriveFileSystem + + register_implementation("gdrive", GoogleDriveFileSystem, "GoogleDriveFileSystem") + + fs_class = get_filesystem_class(scheme) + fs_path = fs_class._strip_protocol(bucket_url) + # reconstitute url + assert make_fsspec_url(scheme, fs_path, bucket_url) == bucket_url + + def test_filesystem_instance(with_gdrive_buckets_env: str) -> None: @retry(stop=stop_after_attempt(10), wait=wait_fixed(1), reraise=True) def check_file_exists(filedir_: str, file_url_: str): @@ -72,10 +92,8 @@ def check_file_changed(file_url_: str): bucket_url = os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] config = get_config() # we do not add protocol to bucket_url (we need relative path) - assert bucket_url.startswith(config.protocol) or config.protocol == "file" + assert bucket_url.startswith(config.protocol) or config.is_local_filesystem filesystem, url = fsspec_from_config(config) - if config.protocol != "file": - assert bucket_url.endswith(url) # do a few file ops now = pendulum.now() filename = f"filesystem_common_{uniq_id()}" @@ -113,7 +131,9 @@ def test_glob_overlapping_path_files(with_gdrive_buckets_env: str) -> None: # "standard_source/sample" overlaps with a real existing "standard_source/samples". walk operation on azure # will return all files from "standard_source/samples" and report the wrong "standard_source/sample" path to the user # here we test we do not have this problem with out glob - bucket_url, _, filesystem = glob_test_setup(bucket_url, "standard_source/sample") + bucket_url, config, filesystem = glob_test_setup(bucket_url, "standard_source/sample") + if config.protocol in ["file"]: + pytest.skip(f"{config.protocol} not supported in this test") # use glob to get data all_file_items = list(glob_files(filesystem, bucket_url)) assert len(all_file_items) == 0 @@ -272,18 +292,18 @@ def glob_test_setup( config = get_config() # enable caches config.read_only = True - if config.protocol in ["file"]: - pytest.skip(f"{config.protocol} not supported in this test") # may contain query string - bucket_url_parsed = urlparse(bucket_url) - bucket_url = bucket_url_parsed._replace( - path=posixpath.join(bucket_url_parsed.path, glob_folder) - ).geturl() - filesystem, _ = fsspec_from_config(config) + filesystem, fs_path = fsspec_from_config(config) + bucket_url = make_fsspec_url(config.protocol, posixpath.join(fs_path, glob_folder), bucket_url) if config.protocol == "memory": - mem_path = os.path.join("m", "standard_source") + mem_path = os.path.join("/m", "standard_source") if not filesystem.isdir(mem_path): filesystem.mkdirs(mem_path) filesystem.upload(TEST_SAMPLE_FILES, mem_path, recursive=True) + if config.protocol == "file": + file_path = os.path.join("_storage", "standard_source") + if not filesystem.isdir(file_path): + filesystem.mkdirs(file_path) + filesystem.upload(TEST_SAMPLE_FILES, file_path, recursive=True) return bucket_url, config, filesystem diff --git a/tests/load/pipeline/test_databricks_pipeline.py b/tests/load/pipeline/test_databricks_pipeline.py new file mode 100644 index 0000000000..5f8641f9fa --- /dev/null +++ b/tests/load/pipeline/test_databricks_pipeline.py @@ -0,0 +1,85 @@ +import pytest +import os + +from dlt.common.utils import uniq_id +from tests.load.utils import DestinationTestConfiguration, destinations_configs, AZ_BUCKET +from tests.pipeline.utils import assert_load_info + + +# mark all tests as essential, do not remove +pytestmark = pytest.mark.essential + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + default_sql_configs=True, bucket_subset=(AZ_BUCKET), subset=("databricks",) + ), + ids=lambda x: x.name, +) +def test_databricks_external_location(destination_config: DestinationTestConfiguration) -> None: + # do not interfere with state + os.environ["RESTORE_FROM_DESTINATION"] = "False" + dataset_name = "test_databricks_external_location" + uniq_id() + + from dlt.destinations import databricks, filesystem + from dlt.destinations.impl.databricks.databricks import DatabricksLoadJob + + abfss_bucket_url = DatabricksLoadJob.ensure_databricks_abfss_url(AZ_BUCKET, "dltdata") + stage = filesystem(abfss_bucket_url) + + # should load abfss formatted url just fine + bricks = databricks(is_staging_external_location=False) + pipeline = destination_config.setup_pipeline( + "test_databricks_external_location", + dataset_name=dataset_name, + destination=bricks, + staging=stage, + ) + info = pipeline.run([1, 2, 3], table_name="digits") + assert_load_info(info) + # get metrics + metrics = info.metrics[info.loads_ids[0]][0] + remote_url = list(metrics["job_metrics"].values())[0].remote_url + # abfss form was preserved + assert remote_url.startswith(abfss_bucket_url) + + # should fail on internal config error as external location is not configured + bricks = databricks(is_staging_external_location=True) + pipeline = destination_config.setup_pipeline( + "test_databricks_external_location", + dataset_name=dataset_name, + destination=bricks, + staging=stage, + ) + info = pipeline.run([1, 2, 3], table_name="digits") + assert info.has_failed_jobs is True + assert ( + "Invalid configuration value detected" + in pipeline.list_failed_jobs_in_package(info.loads_ids[0])[0].failed_message + ) + + # should fail on non existing stored credentials + bricks = databricks(is_staging_external_location=False, staging_credentials_name="CREDENTIAL_X") + pipeline = destination_config.setup_pipeline( + "test_databricks_external_location", + dataset_name=dataset_name, + destination=bricks, + staging=stage, + ) + info = pipeline.run([1, 2, 3], table_name="digits") + assert info.has_failed_jobs is True + assert ( + "credential_x" in pipeline.list_failed_jobs_in_package(info.loads_ids[0])[0].failed_message + ) + + # should fail on non existing stored credentials + # auto stage with regular az:// used + pipeline = destination_config.setup_pipeline( + "test_databricks_external_location", dataset_name=dataset_name, destination=bricks + ) + info = pipeline.run([1, 2, 3], table_name="digits") + assert info.has_failed_jobs is True + assert ( + "credential_x" in pipeline.list_failed_jobs_in_package(info.loads_ids[0])[0].failed_message + ) diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index d88eba7c06..bc6cbd9848 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -300,16 +300,16 @@ def data_types(): assert len(rows) == 10 assert_all_data_types_row(rows[0], schema=column_schemas) - # make sure remote_uri is in metrics + # make sure remote_url is in metrics metrics = info.metrics[info.loads_ids[0]][0] - # TODO: only final copy job has remote_uri. not the initial (empty) job for particular files - # we could implement an empty job for delta that generates correct remote_uri - remote_uri = list(metrics["job_metrics"].values())[-1].remote_uri - assert remote_uri.endswith("data_types") - bucket_uri = destination_config.bucket_url - if FilesystemConfiguration.is_local_path(bucket_uri): - bucket_uri = FilesystemConfiguration.make_file_uri(bucket_uri) - assert remote_uri.startswith(bucket_uri) + # TODO: only final copy job has remote_url. not the initial (empty) job for particular files + # we could implement an empty job for delta that generates correct remote_url + remote_url = list(metrics["job_metrics"].values())[-1].remote_url + assert remote_url.endswith("data_types") + bucket_url = destination_config.bucket_url + if FilesystemConfiguration.is_local_path(bucket_url): + bucket_url = FilesystemConfiguration.make_file_url(bucket_url) + assert remote_url.startswith(bucket_url) # another run should append rows to the table info = pipeline.run(data_types()) diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index f216fa3c05..42dee5fc8f 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -57,17 +57,17 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: info = pipeline.run(github(), loader_file_format=destination_config.file_format) assert_load_info(info) - # checks if remote_uri is set correctly on copy jobs + # checks if remote_url is set correctly on copy jobs metrics = info.metrics[info.loads_ids[0]][0] for job_metrics in metrics["job_metrics"].values(): - remote_uri = job_metrics.remote_uri + remote_url = job_metrics.remote_url job_ext = os.path.splitext(job_metrics.job_id)[1] if job_ext not in (".reference", ".sql"): - assert remote_uri.endswith(job_ext) + assert remote_url.endswith(job_ext) bucket_uri = destination_config.bucket_url if FilesystemConfiguration.is_local_path(bucket_uri): - bucket_uri = FilesystemConfiguration.make_file_uri(bucket_uri) - assert remote_uri.startswith(bucket_uri) + bucket_uri = FilesystemConfiguration.make_file_url(bucket_uri) + assert remote_url.startswith(bucket_uri) package_info = pipeline.get_load_package_info(info.loads_ids[0]) assert package_info.state == "loaded" diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index 59b7acac15..72c5772668 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -1012,17 +1012,17 @@ def assert_complete_job( if state == "failed_jobs" else "completed" ) - remote_uri = job_metrics.remote_uri + remote_url = job_metrics.remote_url if load.initial_client_config.create_followup_jobs: # type: ignore - assert remote_uri.endswith(job.file_name()) + assert remote_url.endswith(job.file_name()) elif load.is_staging_destination_job(job.file_name()): # staging destination should contain reference to remote filesystem assert ( - FilesystemConfiguration.make_file_uri(REMOTE_FILESYSTEM) - in remote_uri + FilesystemConfiguration.make_file_url(REMOTE_FILESYSTEM) + in remote_url ) else: - assert remote_uri is None + assert remote_url is None else: assert job_metrics is None diff --git a/tests/load/utils.py b/tests/load/utils.py index 086109de8b..15b1e1575e 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -70,6 +70,7 @@ AWS_BUCKET = dlt.config.get("tests.bucket_url_s3", str) GCS_BUCKET = dlt.config.get("tests.bucket_url_gs", str) AZ_BUCKET = dlt.config.get("tests.bucket_url_az", str) +ABFS_BUCKET = dlt.config.get("tests.bucket_url_abfss", str) GDRIVE_BUCKET = dlt.config.get("tests.bucket_url_gdrive", str) FILE_BUCKET = dlt.config.get("tests.bucket_url_file", str) R2_BUCKET = dlt.config.get("tests.bucket_url_r2", str) @@ -79,6 +80,7 @@ "s3", "gs", "az", + "abfss", "gdrive", "file", "memory", @@ -86,7 +88,15 @@ ] # Filter out buckets not in all filesystem drivers -WITH_GDRIVE_BUCKETS = [GCS_BUCKET, AWS_BUCKET, FILE_BUCKET, MEMORY_BUCKET, AZ_BUCKET, GDRIVE_BUCKET] +WITH_GDRIVE_BUCKETS = [ + GCS_BUCKET, + AWS_BUCKET, + FILE_BUCKET, + MEMORY_BUCKET, + ABFS_BUCKET, + AZ_BUCKET, + GDRIVE_BUCKET, +] WITH_GDRIVE_BUCKETS = [ bucket for bucket in WITH_GDRIVE_BUCKETS diff --git a/tests/pipeline/cases/contracts/trace.schema.yaml b/tests/pipeline/cases/contracts/trace.schema.yaml index 89831977c0..c324818338 100644 --- a/tests/pipeline/cases/contracts/trace.schema.yaml +++ b/tests/pipeline/cases/contracts/trace.schema.yaml @@ -562,7 +562,7 @@ tables: finished_at: data_type: timestamp nullable: true - remote_uri: + remote_url: data_type: text nullable: true parent: trace__steps diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index b6a7feffc1..027a2b4e72 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -2600,6 +2600,20 @@ def ids(_id=dlt.sources.incremental("_id", initial_value=2)): assert pipeline.last_trace.last_normalize_info.row_counts["_ids"] == 2 +def test_dlt_columns_nested_table_collisions() -> None: + # we generate all identifiers in upper case to test for a bug where dlt columns for nested tables were hardcoded to + # small caps. they got normalized to upper case after the first run and then added again as small caps + # generating duplicate columns and raising collision exception as duckdb is ci destination + duck = duckdb(naming_convention="tests.common.cases.normalizers.sql_upper") + pipeline = dlt.pipeline("test_dlt_columns_child_table_collisions", destination=duck) + customers = [ + {"id": 1, "name": "dave", "orders": [1, 2, 3]}, + ] + assert_load_info(pipeline.run(customers, table_name="CUSTOMERS")) + # this one would fail without bugfix + assert_load_info(pipeline.run(customers, table_name="CUSTOMERS")) + + def test_access_pipeline_in_resource() -> None: pipeline = dlt.pipeline("test_access_pipeline_in_resource", destination="duckdb") diff --git a/tests/pipeline/test_pipeline_trace.py b/tests/pipeline/test_pipeline_trace.py index 4e52d2aa29..d2bb035a17 100644 --- a/tests/pipeline/test_pipeline_trace.py +++ b/tests/pipeline/test_pipeline_trace.py @@ -315,7 +315,7 @@ def data(): return data() - # create pipeline with staging to get remote_uri in load step job_metrics + # create pipeline with staging to get remote_url in load step job_metrics dummy_dest = dummy(completed_prob=1.0) pipeline = dlt.pipeline( pipeline_name="test_trace_schema", From 63f89542678c7af51089f94365aa6834ccca90e7 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 28 Aug 2024 13:20:16 +0200 Subject: [PATCH 5/7] bumps dlt version to 0.5.4 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 74161f5ccc..d32285572f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "dlt" -version = "0.5.4a0" +version = "0.5.4" description = "dlt is an open-source python-first scalable data loading library that does not require any backend to run." authors = ["dltHub Inc. "] maintainers = [ "Marcin Rudolf ", "Adrian Brudaru ", "Anton Burnashev ", "David Scharf " ] From b48c7c3e7db9fb4ff321b668b9b22553b7882b31 Mon Sep 17 00:00:00 2001 From: rudolfix Date: Wed, 28 Aug 2024 19:11:56 +0200 Subject: [PATCH 6/7] runs staging tests on athena (#1764) * always truncates staging tables on athena + replace without iceberg * adds athena staging configs to all staging configs * updates athena tests for staging destination --- dlt/common/destination/reference.py | 11 +++++ dlt/destinations/impl/athena/athena.py | 2 +- tests/load/pipeline/test_stage_loading.py | 23 ++++++++++- tests/load/utils.py | 49 ++++++++++++++--------- 4 files changed, 62 insertions(+), 23 deletions(-) diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 0944b03bea..e7bba266df 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -586,10 +586,21 @@ class SupportsStagingDestination(ABC): def should_load_data_to_staging_dataset_on_staging_destination( self, table: TTableSchema ) -> bool: + """If set to True, and staging destination is configured, the data will be loaded to staging dataset on staging destination + instead of a regular dataset on staging destination. Currently it is used by Athena Iceberg which uses staging dataset + on staging destination to copy data to iceberg tables stored on regular dataset on staging destination. + The default is to load data to regular dataset on staging destination from where warehouses like Snowflake (that have their + own storage) will copy data. + """ return False @abstractmethod def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + """If set to True, data in `table` will be truncated on staging destination (regular dataset). This is the default behavior which + can be changed with a config flag. + For Athena + Iceberg this setting is always False - Athena uses regular dataset to store Iceberg tables and we avoid touching it. + For Athena we truncate those tables only on "replace" write disposition. + """ pass diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index b3b2fbcf0f..a5a8ae2562 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -530,7 +530,7 @@ def should_truncate_table_before_load_on_staging_destination(self, table: TTable if table["write_disposition"] == "replace" and not self._is_iceberg_table( self.prepare_load_table(table["name"]) ): - return self.config.truncate_tables_on_staging_destination_before_load + return True return False def should_load_data_to_staging_dataset_on_staging_destination( diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index 42dee5fc8f..3bfa050fd7 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -218,7 +218,18 @@ def test_truncate_staging_dataset(destination_config: DestinationTestConfigurati # check there are two staging files _, staging_client = pipeline._get_destination_clients(pipeline.default_schema) with staging_client: - assert len(staging_client.list_table_files(table_name)) == 2 # type: ignore[attr-defined] + # except Athena + Iceberg which does not store tables in staging dataset + if ( + destination_config.destination == "athena" + and destination_config.table_format == "iceberg" + ): + table_count = 0 + # but keeps them in staging dataset on staging destination - but only the last one + with staging_client.with_staging_dataset(): # type: ignore[attr-defined] + assert len(staging_client.list_table_files(table_name)) == 1 # type: ignore[attr-defined] + else: + table_count = 2 + assert len(staging_client.list_table_files(table_name)) == table_count # type: ignore[attr-defined] # load the data with truncating, so only new file is on the staging pipeline.destination.config_params["truncate_tables_on_staging_destination_before_load"] = True @@ -231,7 +242,15 @@ def test_truncate_staging_dataset(destination_config: DestinationTestConfigurati # check there is only one staging file _, staging_client = pipeline._get_destination_clients(pipeline.default_schema) with staging_client: - assert len(staging_client.list_table_files(table_name)) == 1 # type: ignore[attr-defined] + # except for Athena which does not delete staging destination tables + if destination_config.destination == "athena": + if destination_config.table_format == "iceberg": + table_count = 0 + else: + table_count = 3 + else: + table_count = 1 + assert len(staging_client.list_table_files(table_name)) == table_count # type: ignore[attr-defined] @pytest.mark.parametrize( diff --git a/tests/load/utils.py b/tests/load/utils.py index 15b1e1575e..5427904d52 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -257,6 +257,27 @@ def destinations_configs( # build destination configs destination_configs: List[DestinationTestConfiguration] = [] + # default sql configs that are also default staging configs + default_sql_configs_with_staging = [ + # Athena needs filesystem staging, which will be automatically set; we have to supply a bucket url though. + DestinationTestConfiguration( + destination="athena", + file_format="parquet", + supports_merge=False, + bucket_url=AWS_BUCKET, + ), + DestinationTestConfiguration( + destination="athena", + file_format="parquet", + bucket_url=AWS_BUCKET, + force_iceberg=True, + supports_merge=True, + supports_dbt=False, + table_format="iceberg", + extra_info="iceberg", + ), + ] + # default non staging sql based configs, one per destination if default_sql_configs: destination_configs += [ @@ -268,26 +289,10 @@ def destinations_configs( DestinationTestConfiguration(destination="duckdb", file_format="parquet"), DestinationTestConfiguration(destination="motherduck", file_format="insert_values"), ] - # Athena needs filesystem staging, which will be automatically set; we have to supply a bucket url though. - destination_configs += [ - DestinationTestConfiguration( - destination="athena", - file_format="parquet", - supports_merge=False, - bucket_url=AWS_BUCKET, - ) - ] - destination_configs += [ - DestinationTestConfiguration( - destination="athena", - file_format="parquet", - bucket_url=AWS_BUCKET, - force_iceberg=True, - supports_merge=True, - supports_dbt=False, - extra_info="iceberg", - ) - ] + + # add Athena staging configs + destination_configs += default_sql_configs_with_staging + destination_configs += [ DestinationTestConfiguration( destination="clickhouse", file_format="jsonl", supports_dbt=False @@ -332,6 +337,10 @@ def destinations_configs( DestinationTestConfiguration(destination="qdrant", extra_info="server"), ] + if (default_sql_configs or all_staging_configs) and not default_sql_configs: + # athena default configs not added yet + destination_configs += default_sql_configs_with_staging + if default_staging_configs or all_staging_configs: destination_configs += [ DestinationTestConfiguration( From e9c9ecfa8a644fdb516dd74aabca3bf75bafb154 Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Wed, 28 Aug 2024 21:45:16 +0200 Subject: [PATCH 7/7] fixes staging tests for athena --- tests/load/pipeline/test_stage_loading.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index 3bfa050fd7..6c4f6dfec8 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -74,8 +74,14 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: assert len(package_info.jobs["failed_jobs"]) == 0 # we have 4 parquet and 4 reference jobs plus one merge job - num_jobs = 4 + 4 + 1 if destination_config.supports_merge else 4 + 4 - assert len(package_info.jobs["completed_jobs"]) == num_jobs + num_jobs = 4 + 4 + num_sql_jobs = 0 + if destination_config.supports_merge: + num_sql_jobs += 1 + # sql job is used to copy parquet to Athena Iceberg table (_dlt_pipeline_state) + if destination_config.destination == "athena" and destination_config.table_format == "iceberg": + num_sql_jobs += 1 + assert len(package_info.jobs["completed_jobs"]) == num_jobs + num_sql_jobs assert ( len( [ @@ -110,7 +116,7 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: if x.job_file_info.file_format == "sql" ] ) - == 1 + == num_sql_jobs ) initial_counts = load_table_counts(