diff --git a/dlt/destinations/impl/sqlalchemy/db_api_client.py b/dlt/destinations/impl/sqlalchemy/db_api_client.py index 0969fcb628..2255bc740a 100644 --- a/dlt/destinations/impl/sqlalchemy/db_api_client.py +++ b/dlt/destinations/impl/sqlalchemy/db_api_client.py @@ -114,7 +114,8 @@ def __init__( self.external_engine = True else: self.engine = sa.create_engine( - credentials.to_url().render_as_string(hide_password=False), **(engine_args or {}) + credentials.to_url().render_as_string(hide_password=False), + **(engine_args or {}), ) self._current_connection: Optional[Connection] = None @@ -198,7 +199,7 @@ def _sqlite_create_dataset(self, dataset_name: str) -> None: ) statement = "ATTACH DATABASE :fn AS :name" - self.execute(statement, fn=new_db_fn, name=dataset_name) + self.execute_sql(statement, fn=new_db_fn, name=dataset_name) def _sqlite_drop_dataset(self, dataset_name: str) -> None: """Drop a dataset in sqlite by detaching the database file @@ -208,10 +209,10 @@ def _sqlite_drop_dataset(self, dataset_name: str) -> None: rows = self.execute_sql("PRAGMA database_list") dbs = {row[1]: row[2] for row in rows} # db_name: filename if dataset_name not in dbs: - return + raise DatabaseUndefinedRelation(f"Database {dataset_name} does not exist") statement = "DETACH DATABASE :name" - self.execute(statement, name=dataset_name) + self.execute_sql(statement, name=dataset_name) fn = dbs[dataset_name] if not fn: # It's a memory database, nothing to do @@ -230,11 +231,7 @@ def drop_dataset(self) -> None: try: self.execute_sql(sa.schema.DropSchema(self.dataset_name, cascade=True)) except DatabaseTransientException as e: - if isinstance(e.__cause__, sa.exc.ProgrammingError): - # May not support CASCADE - self.execute_sql(sa.schema.DropSchema(self.dataset_name)) - else: - raise + self.execute_sql(sa.schema.DropSchema(self.dataset_name)) def truncate_tables(self, *tables: str) -> None: # TODO: alchemy doesn't have a construct for TRUNCATE TABLE diff --git a/tests/cases.py b/tests/cases.py index 796db036a8..121d4af6e5 100644 --- a/tests/cases.py +++ b/tests/cases.py @@ -216,19 +216,22 @@ def assert_all_data_types_row( expected_rows = {key: value for key, value in expected_row.items() if key in schema} # prepare date to be compared: convert into pendulum instance, adjust microsecond precision if "col4" in expected_rows: - parsed_date = pendulum.instance(db_mapping["col4"]) + parsed_date = ensure_pendulum_datetime((db_mapping["col4"])) db_mapping["col4"] = reduce_pendulum_datetime_precision(parsed_date, timestamp_precision) expected_rows["col4"] = reduce_pendulum_datetime_precision( ensure_pendulum_datetime(expected_rows["col4"]), # type: ignore[arg-type] timestamp_precision, ) if "col4_precision" in expected_rows: - parsed_date = pendulum.instance(db_mapping["col4_precision"]) + parsed_date = ensure_pendulum_datetime((db_mapping["col4_precision"])) db_mapping["col4_precision"] = reduce_pendulum_datetime_precision(parsed_date, 3) expected_rows["col4_precision"] = reduce_pendulum_datetime_precision( ensure_pendulum_datetime(expected_rows["col4_precision"]), 3 # type: ignore[arg-type] ) + if "col10" in expected_rows: + db_mapping["col10"] = ensure_pendulum_date(db_mapping["col10"]) + if "col11" in expected_rows: expected_rows["col11"] = reduce_pendulum_datetime_precision( ensure_pendulum_time(expected_rows["col11"]), timestamp_precision # type: ignore[arg-type] diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index 2af15c3558..59d0367a0e 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -54,10 +54,12 @@ def test_load_arrow_item( ) include_decimal = not ( - destination_config.destination_type == "databricks" and destination_config.file_format == "jsonl" + destination_config.destination_type == "databricks" + and destination_config.file_format == "jsonl" ) include_date = not ( - destination_config.destination_type == "databricks" and destination_config.file_format == "jsonl" + destination_config.destination_type == "databricks" + and destination_config.file_format == "jsonl" ) item, records, _ = arrow_table_all_data_types( @@ -77,7 +79,9 @@ def some_data(): # use csv for postgres to get native arrow processing file_format = ( - destination_config.file_format if destination_config.destination_type != "postgres" else "csv" + destination_config.file_format + if destination_config.destination_type != "postgres" + else "csv" ) load_info = pipeline.run(some_data(), loader_file_format=file_format) diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index a8eb613839..7f7830670f 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -925,7 +925,9 @@ def some_source(): parse_complex_strings=destination_config.destination_type in ["snowflake", "bigquery", "redshift"], allow_string_binary=destination_config.destination_type == "clickhouse", - timestamp_precision=3 if destination_config.destination_type in ("athena", "dremio") else 6, + timestamp_precision=( + 3 if destination_config.destination_type in ("athena", "dremio") else 6 + ), ) diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index d8dda9e487..70c72c6aea 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -279,7 +279,10 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non ) and destination_config.file_format in ("parquet", "jsonl"): # Redshift copy doesn't support TIME column exclude_types.append("time") - if destination_config.destination_type == "synapse" and destination_config.file_format == "parquet": + if ( + destination_config.destination_type == "synapse" + and destination_config.file_format == "parquet" + ): # TIME columns are not supported for staged parquet loads into Synapse exclude_types.append("time") if destination_config.destination_type in ( @@ -291,7 +294,10 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non ): # Redshift can't load fixed width binary columns from parquet exclude_columns.append("col7_precision") - if destination_config.destination_type == "databricks" and destination_config.file_format == "jsonl": + if ( + destination_config.destination_type == "databricks" + and destination_config.file_format == "jsonl" + ): exclude_types.extend(["decimal", "binary", "wei", "complex", "date"]) exclude_columns.append("col1_precision") diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 63137f9756..370d55a099 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -27,7 +27,12 @@ ) from dlt.destinations.job_client_impl import SqlJobClientBase -from dlt.common.destination.reference import StateInfo, WithStagingDataset +from dlt.common.destination.reference import ( + StateInfo, + WithStagingDataset, + DestinationClientConfiguration, +) +from dlt.common.time import ensure_pendulum_datetime from tests.cases import table_update_and_row, assert_all_data_types_row from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage @@ -202,7 +207,7 @@ def test_complete_load(naming: str, client: SqlJobClientBase) -> None: assert load_rows[0][2] == 0 import datetime # noqa: I251 - assert type(load_rows[0][3]) is datetime.datetime + assert isinstance(ensure_pendulum_datetime(load_rows[0][3]), datetime.datetime) assert load_rows[0][4] == client.schema.version_hash # make sure that hash in loads exists in schema versions table versions_table = client.sql_client.make_qualified_table_name(version_table_name) @@ -571,7 +576,7 @@ def test_load_with_all_types( if not client.capabilities.preferred_loader_file_format: pytest.skip("preferred loader file format not set, destination will only work with staging") table_name = "event_test_table" + uniq_id() - column_schemas, data_row = get_columns_and_row_all_types(client.config.destination_type) + column_schemas, data_row = get_columns_and_row_all_types(client.config) # we should have identical content with all disposition types partial = client.schema.update_table( @@ -648,7 +653,7 @@ def test_write_dispositions( os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy table_name = "event_test_table" + uniq_id() - column_schemas, data_row = get_columns_and_row_all_types(client.config.destination_type) + column_schemas, data_row = get_columns_and_row_all_types(client.config) client.schema.update_table( new_table(table_name, write_disposition=write_disposition, columns=column_schemas.values()) ) @@ -807,7 +812,8 @@ def test_get_stored_state( os.environ["SCHEMA__NAMING"] = naming_convention with cm_yield_client_with_storage( - destination_config.destination_factory(), default_config_values={"default_schema_name": None} + destination_config.destination_factory(), + default_config_values={"default_schema_name": None}, ) as client: # event schema with event table if not client.capabilities.preferred_loader_file_format: @@ -871,7 +877,8 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None: assert len(db_rows) == expected_rows with cm_yield_client_with_storage( - destination_config.destination_factory(), default_config_values={"default_schema_name": None} + destination_config.destination_factory(), + default_config_values={"default_schema_name": None}, ) as client: # event schema with event table if not client.capabilities.preferred_loader_file_format: @@ -967,11 +974,16 @@ def normalize_rows(rows: List[Dict[str, Any]], naming: NamingConvention) -> None row[naming.normalize_identifier(k)] = row.pop(k) -def get_columns_and_row_all_types(destination_type: str): +def get_columns_and_row_all_types(destination_config: DestinationClientConfiguration): + exclude_types = [] + if destination_config.destination_type in ["databricks", "clickhouse", "motherduck"]: + exclude_types.append("time") + if destination_config.destination_name == "sqlalchemy_sqlite": + exclude_types.extend(["decimal", "wei"]) return table_update_and_row( # TIME + parquet is actually a duckdb problem: https://github.com/duckdb/duckdb/pull/13283 - exclude_types=( - ["time"] if destination_type in ["databricks", "clickhouse", "motherduck"] else None + exclude_types=exclude_types, # type: ignore[arg-type] + exclude_columns=( + ["col4_precision"] if destination_config.destination_type in ["motherduck"] else None ), - exclude_columns=["col4_precision"] if destination_type in ["motherduck"] else None, ) diff --git a/tests/load/test_sql_client.py b/tests/load/test_sql_client.py index 9743eb8ec6..2b1097ac82 100644 --- a/tests/load/test_sql_client.py +++ b/tests/load/test_sql_client.py @@ -149,6 +149,7 @@ def test_has_dataset(naming: str, client: SqlJobClientBase) -> None: ) def test_create_drop_dataset(naming: str, client: SqlJobClientBase) -> None: # client.sql_client.create_dataset() + # Dataset is already create in fixture, so next time it fails with pytest.raises(DatabaseException): client.sql_client.create_dataset() client.sql_client.drop_dataset() @@ -212,7 +213,7 @@ def test_execute_sql(client: SqlJobClientBase) -> None: assert len(rows) == 1 # print(rows) assert rows[0][0] == "event" - assert isinstance(rows[0][1], datetime.datetime) + assert isinstance(ensure_pendulum_datetime(rows[0][1]), datetime.datetime) assert rows[0][0] == "event" # print(rows[0][1]) # print(type(rows[0][1])) diff --git a/tests/load/utils.py b/tests/load/utils.py index f5732ac5aa..db309a4df6 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -309,8 +309,14 @@ def destinations_configs( destination_type="sqlalchemy", supports_merge=False, supports_dbt=False, - destination_name="mysql_driver", - ) + destination_name="sqlalchemy_mysql", + ), + DestinationTestConfiguration( + destination_type="sqlalchemy", + supports_merge=False, + supports_dbt=False, + destination_name="sqlalchemy_sqlite", + ), ] destination_configs += [