Skip to content

Commit

Permalink
Some job client/sql client tests running on sqlite
Browse files Browse the repository at this point in the history
  • Loading branch information
steinitzu committed Aug 31, 2024
1 parent 0d9c75a commit 1e7fa6a
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 30 deletions.
15 changes: 6 additions & 9 deletions dlt/destinations/impl/sqlalchemy/db_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ def __init__(
self.external_engine = True
else:
self.engine = sa.create_engine(
credentials.to_url().render_as_string(hide_password=False), **(engine_args or {})
credentials.to_url().render_as_string(hide_password=False),
**(engine_args or {}),
)

self._current_connection: Optional[Connection] = None
Expand Down Expand Up @@ -198,7 +199,7 @@ def _sqlite_create_dataset(self, dataset_name: str) -> None:
)

statement = "ATTACH DATABASE :fn AS :name"
self.execute(statement, fn=new_db_fn, name=dataset_name)
self.execute_sql(statement, fn=new_db_fn, name=dataset_name)

def _sqlite_drop_dataset(self, dataset_name: str) -> None:
"""Drop a dataset in sqlite by detaching the database file
Expand All @@ -208,10 +209,10 @@ def _sqlite_drop_dataset(self, dataset_name: str) -> None:
rows = self.execute_sql("PRAGMA database_list")
dbs = {row[1]: row[2] for row in rows} # db_name: filename
if dataset_name not in dbs:
return
raise DatabaseUndefinedRelation(f"Database {dataset_name} does not exist")

statement = "DETACH DATABASE :name"
self.execute(statement, name=dataset_name)
self.execute_sql(statement, name=dataset_name)

fn = dbs[dataset_name]
if not fn: # It's a memory database, nothing to do
Expand All @@ -230,11 +231,7 @@ def drop_dataset(self) -> None:
try:
self.execute_sql(sa.schema.DropSchema(self.dataset_name, cascade=True))
except DatabaseTransientException as e:
if isinstance(e.__cause__, sa.exc.ProgrammingError):
# May not support CASCADE
self.execute_sql(sa.schema.DropSchema(self.dataset_name))
else:
raise
self.execute_sql(sa.schema.DropSchema(self.dataset_name))

def truncate_tables(self, *tables: str) -> None:
# TODO: alchemy doesn't have a construct for TRUNCATE TABLE
Expand Down
7 changes: 5 additions & 2 deletions tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,19 +216,22 @@ def assert_all_data_types_row(
expected_rows = {key: value for key, value in expected_row.items() if key in schema}
# prepare date to be compared: convert into pendulum instance, adjust microsecond precision
if "col4" in expected_rows:
parsed_date = pendulum.instance(db_mapping["col4"])
parsed_date = ensure_pendulum_datetime((db_mapping["col4"]))
db_mapping["col4"] = reduce_pendulum_datetime_precision(parsed_date, timestamp_precision)
expected_rows["col4"] = reduce_pendulum_datetime_precision(
ensure_pendulum_datetime(expected_rows["col4"]), # type: ignore[arg-type]
timestamp_precision,
)
if "col4_precision" in expected_rows:
parsed_date = pendulum.instance(db_mapping["col4_precision"])
parsed_date = ensure_pendulum_datetime((db_mapping["col4_precision"]))
db_mapping["col4_precision"] = reduce_pendulum_datetime_precision(parsed_date, 3)
expected_rows["col4_precision"] = reduce_pendulum_datetime_precision(
ensure_pendulum_datetime(expected_rows["col4_precision"]), 3 # type: ignore[arg-type]
)

if "col10" in expected_rows:
db_mapping["col10"] = ensure_pendulum_date(db_mapping["col10"])

if "col11" in expected_rows:
expected_rows["col11"] = reduce_pendulum_datetime_precision(
ensure_pendulum_time(expected_rows["col11"]), timestamp_precision # type: ignore[arg-type]
Expand Down
10 changes: 7 additions & 3 deletions tests/load/pipeline/test_arrow_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ def test_load_arrow_item(
)

include_decimal = not (
destination_config.destination_type == "databricks" and destination_config.file_format == "jsonl"
destination_config.destination_type == "databricks"
and destination_config.file_format == "jsonl"
)
include_date = not (
destination_config.destination_type == "databricks" and destination_config.file_format == "jsonl"
destination_config.destination_type == "databricks"
and destination_config.file_format == "jsonl"
)

item, records, _ = arrow_table_all_data_types(
Expand All @@ -77,7 +79,9 @@ def some_data():

# use csv for postgres to get native arrow processing
file_format = (
destination_config.file_format if destination_config.destination_type != "postgres" else "csv"
destination_config.file_format
if destination_config.destination_type != "postgres"
else "csv"
)

load_info = pipeline.run(some_data(), loader_file_format=file_format)
Expand Down
4 changes: 3 additions & 1 deletion tests/load/pipeline/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,9 @@ def some_source():
parse_complex_strings=destination_config.destination_type
in ["snowflake", "bigquery", "redshift"],
allow_string_binary=destination_config.destination_type == "clickhouse",
timestamp_precision=3 if destination_config.destination_type in ("athena", "dremio") else 6,
timestamp_precision=(
3 if destination_config.destination_type in ("athena", "dremio") else 6
),
)


Expand Down
10 changes: 8 additions & 2 deletions tests/load/pipeline/test_stage_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,10 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non
) and destination_config.file_format in ("parquet", "jsonl"):
# Redshift copy doesn't support TIME column
exclude_types.append("time")
if destination_config.destination_type == "synapse" and destination_config.file_format == "parquet":
if (
destination_config.destination_type == "synapse"
and destination_config.file_format == "parquet"
):
# TIME columns are not supported for staged parquet loads into Synapse
exclude_types.append("time")
if destination_config.destination_type in (
Expand All @@ -291,7 +294,10 @@ def test_all_data_types(destination_config: DestinationTestConfiguration) -> Non
):
# Redshift can't load fixed width binary columns from parquet
exclude_columns.append("col7_precision")
if destination_config.destination_type == "databricks" and destination_config.file_format == "jsonl":
if (
destination_config.destination_type == "databricks"
and destination_config.file_format == "jsonl"
):
exclude_types.extend(["decimal", "binary", "wei", "complex", "date"])
exclude_columns.append("col1_precision")

Expand Down
32 changes: 22 additions & 10 deletions tests/load/test_job_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@
)

from dlt.destinations.job_client_impl import SqlJobClientBase
from dlt.common.destination.reference import StateInfo, WithStagingDataset
from dlt.common.destination.reference import (
StateInfo,
WithStagingDataset,
DestinationClientConfiguration,
)
from dlt.common.time import ensure_pendulum_datetime

from tests.cases import table_update_and_row, assert_all_data_types_row
from tests.utils import TEST_STORAGE_ROOT, autouse_test_storage
Expand Down Expand Up @@ -202,7 +207,7 @@ def test_complete_load(naming: str, client: SqlJobClientBase) -> None:
assert load_rows[0][2] == 0
import datetime # noqa: I251

assert type(load_rows[0][3]) is datetime.datetime
assert isinstance(ensure_pendulum_datetime(load_rows[0][3]), datetime.datetime)
assert load_rows[0][4] == client.schema.version_hash
# make sure that hash in loads exists in schema versions table
versions_table = client.sql_client.make_qualified_table_name(version_table_name)
Expand Down Expand Up @@ -571,7 +576,7 @@ def test_load_with_all_types(
if not client.capabilities.preferred_loader_file_format:
pytest.skip("preferred loader file format not set, destination will only work with staging")
table_name = "event_test_table" + uniq_id()
column_schemas, data_row = get_columns_and_row_all_types(client.config.destination_type)
column_schemas, data_row = get_columns_and_row_all_types(client.config)

# we should have identical content with all disposition types
partial = client.schema.update_table(
Expand Down Expand Up @@ -648,7 +653,7 @@ def test_write_dispositions(
os.environ["DESTINATION__REPLACE_STRATEGY"] = replace_strategy

table_name = "event_test_table" + uniq_id()
column_schemas, data_row = get_columns_and_row_all_types(client.config.destination_type)
column_schemas, data_row = get_columns_and_row_all_types(client.config)
client.schema.update_table(
new_table(table_name, write_disposition=write_disposition, columns=column_schemas.values())
)
Expand Down Expand Up @@ -807,7 +812,8 @@ def test_get_stored_state(
os.environ["SCHEMA__NAMING"] = naming_convention

with cm_yield_client_with_storage(
destination_config.destination_factory(), default_config_values={"default_schema_name": None}
destination_config.destination_factory(),
default_config_values={"default_schema_name": None},
) as client:
# event schema with event table
if not client.capabilities.preferred_loader_file_format:
Expand Down Expand Up @@ -871,7 +877,8 @@ def _load_something(_client: SqlJobClientBase, expected_rows: int) -> None:
assert len(db_rows) == expected_rows

with cm_yield_client_with_storage(
destination_config.destination_factory(), default_config_values={"default_schema_name": None}
destination_config.destination_factory(),
default_config_values={"default_schema_name": None},
) as client:
# event schema with event table
if not client.capabilities.preferred_loader_file_format:
Expand Down Expand Up @@ -967,11 +974,16 @@ def normalize_rows(rows: List[Dict[str, Any]], naming: NamingConvention) -> None
row[naming.normalize_identifier(k)] = row.pop(k)


def get_columns_and_row_all_types(destination_type: str):
def get_columns_and_row_all_types(destination_config: DestinationClientConfiguration):
exclude_types = []
if destination_config.destination_type in ["databricks", "clickhouse", "motherduck"]:
exclude_types.append("time")
if destination_config.destination_name == "sqlalchemy_sqlite":
exclude_types.extend(["decimal", "wei"])
return table_update_and_row(
# TIME + parquet is actually a duckdb problem: https://github.com/duckdb/duckdb/pull/13283
exclude_types=(
["time"] if destination_type in ["databricks", "clickhouse", "motherduck"] else None
exclude_types=exclude_types, # type: ignore[arg-type]
exclude_columns=(
["col4_precision"] if destination_config.destination_type in ["motherduck"] else None
),
exclude_columns=["col4_precision"] if destination_type in ["motherduck"] else None,
)
3 changes: 2 additions & 1 deletion tests/load/test_sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def test_has_dataset(naming: str, client: SqlJobClientBase) -> None:
)
def test_create_drop_dataset(naming: str, client: SqlJobClientBase) -> None:
# client.sql_client.create_dataset()
# Dataset is already create in fixture, so next time it fails
with pytest.raises(DatabaseException):
client.sql_client.create_dataset()
client.sql_client.drop_dataset()
Expand Down Expand Up @@ -212,7 +213,7 @@ def test_execute_sql(client: SqlJobClientBase) -> None:
assert len(rows) == 1
# print(rows)
assert rows[0][0] == "event"
assert isinstance(rows[0][1], datetime.datetime)
assert isinstance(ensure_pendulum_datetime(rows[0][1]), datetime.datetime)
assert rows[0][0] == "event"
# print(rows[0][1])
# print(type(rows[0][1]))
Expand Down
10 changes: 8 additions & 2 deletions tests/load/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,14 @@ def destinations_configs(
destination_type="sqlalchemy",
supports_merge=False,
supports_dbt=False,
destination_name="mysql_driver",
)
destination_name="sqlalchemy_mysql",
),
DestinationTestConfiguration(
destination_type="sqlalchemy",
supports_merge=False,
supports_dbt=False,
destination_name="sqlalchemy_sqlite",
),
]

destination_configs += [
Expand Down

0 comments on commit 1e7fa6a

Please sign in to comment.