diff --git a/dlt/sources/sql_database/README.md b/dlt/sources/sql_database/README.md new file mode 100644 index 0000000000..dfa4b5e161 --- /dev/null +++ b/dlt/sources/sql_database/README.md @@ -0,0 +1,205 @@ +# SQL Database +SQL database, or Structured Query Language database, are a type of database management system (DBMS) that stores and manages data in a structured format. The SQL Database `dlt` is a verified source and pipeline example that makes it easy to load data from your SQL database to a destination of your choice. It offers flexibility in terms of loading either the entire database or specific tables to the target. + +## Initialize the pipeline with SQL Database verified source +```bash +dlt init sql_database bigquery +``` +Here, we chose BigQuery as the destination. Alternatively, you can also choose redshift, duckdb, or any of the otherĀ [destinations.](https://dlthub.com/docs/dlt-ecosystem/destinations/) + +## Setup verified source + +To setup the SQL Database Verified Source read the [full documentation here.](https://dlthub.com/docs/dlt-ecosystem/verified-sources/sql_database) + +## Add credentials +1. Open `.dlt/secrets.toml`. +2. In order to continue, we will use the supplied connection URL to establish credentials. The connection URL is associated with a public database and looks like this: + ```bash + connection_url = "mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam" + ``` + Here's what the `secrets.toml` looks like: + ```toml + # Put your secret values and credentials here. do not share this file and do not upload it to github. + # We will set up creds with the following connection URL, which is a public database + + # The credentials are as follows + drivername = "mysql+pymysql" # Driver name for the database + database = "Rfam # Database name + username = "rfamro" # username associated with the database + host = "mysql-rfam-public.ebi.ac.uk" # host address + port = "4497 # port required for connection + ``` +3. Enter credentials for your chosen destination as per the [docs.](https://dlthub.com/docs/dlt-ecosystem/destinations/) + +## Running the pipeline example + +1. Install the required dependencies by running the following command: + ```bash + pip install -r requirements.txt + ``` + +2. Now you can build the verified source by using the command: + ```bash + python3 sql_database_pipeline.py + ``` + +3. To ensure that everything loads as expected, use the command: + ```bash + dlt pipeline show + ``` + + For example, the pipeline_name for the above pipeline example is `rfam`, you can use any custom name instead. + + +## Pick the right table backend +Table backends convert stream of rows from database tables into batches in various formats. The default backend **sqlalchemy** is following standard `dlt` behavior of +extracting and normalizing Python dictionaries. We recommend it for smaller tables, initial development work and when minimal dependencies or pure Python environment is required. It is also the slowest. +Database tables are structured data and other backends speed up dealing with such data significantly. The **pyarrow** will convert rows into `arrow` tables, has +good performance, preserves exact database types and we recommend it for large tables. + +### **sqlalchemy** backend + +**sqlalchemy** (the default) yields table data as list of Python dictionaries. This data goes through regular extract +and normalize steps and does not require additional dependencies to be installed. It is the most robust (works with any destination, correctly represents data types) but also the slowest. You can use `detect_precision_hints` to pass exact database types to `dlt` schema. + +### **pyarrow** backend + +**pyarrow** yields data as Arrow tables. It uses **SqlAlchemy** to read rows in batches but then immediately converts them into `ndarray`, transposes it and uses to set columns in an arrow table. This backend always fully +reflects the database table and preserves original types ie. **decimal** / **numeric** will be extracted without loss of precision. If the destination loads parquet files, this backend will skip `dlt` normalizer and you can gain two orders of magnitude (20x - 30x) speed increase. + +Note that if **pandas** is installed, we'll use it to convert SqlAlchemy tuples into **ndarray** as it seems to be 20-30% faster than using **numpy** directly. + +```py +import sqlalchemy as sa +pipeline = dlt.pipeline( + pipeline_name="rfam_cx", destination="postgres", dataset_name="rfam_data_arrow" +) + +def _double_as_decimal_adapter(table: sa.Table) -> None: + """Return double as double, not decimals, this is mysql thing""" + for column in table.columns.values(): + if isinstance(column.type, sa.Double): + column.type.asdecimal = False + +sql_alchemy_source = sql_database( + "mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam?&binary_prefix=true", + backend="pyarrow", + table_adapter_callback=_double_as_decimal_adapter +).with_resources("family", "genome") + +info = pipeline.run(sql_alchemy_source) +print(info) +``` + +### **pandas** backend + +**pandas** backend yield data as data frames using the `pandas.io.sql` module. `dlt` use **pyarrow** dtypes by default as they generate more stable typing. + +With default settings, several database types will be coerced to dtypes in yielded data frame: +* **decimal** are mapped to doubles so it is possible to lose precision. +* **date** and **time** are mapped to strings +* all types are nullable. + +Note: `dlt` will still use the reflected source database types to create destination tables. It is up to the destination to reconcile / parse +type differences. Most of the destinations will be able to parse date/time strings and convert doubles into decimals (Please note that you' still lose precision on decimals with default settings.). **However we strongly suggest +not to use pandas backend if your source tables contain date, time or decimal columns** + + +Example: Use `backend_kwargs` to pass [backend-specific settings](https://pandas.pydata.org/docs/reference/api/pandas.read_sql_table.html) ie. `coerce_float`. Internally dlt uses `pandas.io.sql._wrap_result` to generate panda frames. + +```py +import sqlalchemy as sa +pipeline = dlt.pipeline( + pipeline_name="rfam_cx", destination="postgres", dataset_name="rfam_data_pandas_2" +) + +def _double_as_decimal_adapter(table: sa.Table) -> None: + """Emits decimals instead of floats.""" + for column in table.columns.values(): + if isinstance(column.type, sa.Float): + column.type.asdecimal = True + +sql_alchemy_source = sql_database( + "mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam?&binary_prefix=true", + backend="pandas", + table_adapter_callback=_double_as_decimal_adapter, + chunk_size=100000, + # set coerce_float to False to represent them as string + backend_kwargs={"coerce_float": False, "dtype_backend": "numpy_nullable"}, +).with_resources("family", "genome") + +info = pipeline.run(sql_alchemy_source) +print(info) +``` + +### **connectorx** backend +[connectorx](https://sfu-db.github.io/connector-x/intro.html) backend completely skips **sqlalchemy** when reading table rows, in favor of doing that in rust. This is claimed to be significantly faster than any other method (confirmed only on postgres - see next chapter). With the default settings it will emit **pyarrow** tables, but you can configure it via **backend_kwargs**. + +There are certain limitations when using this backend: +* it will ignore `chunk_size`. **connectorx** cannot yield data in batches. +* in many cases it requires a connection string that differs from **sqlalchemy** connection string. Use `conn` argument in **backend_kwargs** to set it up. +* it will convert **decimals** to **doubles** so you'll will lose precision. +* nullability of the columns is ignored (always true) +* it uses different database type mappings for each database type. [check here for more details](https://sfu-db.github.io/connector-x/databases.html) +* JSON fields (at least those coming from postgres) are double wrapped in strings. Here's a transform to be added with `add_map` that will unwrap it: + +```py +from sources.sql_database.helpers import unwrap_json_connector_x +``` + +Note: dlt will still use the reflected source database types to create destination tables. It is up to the destination to reconcile / parse type differences. Please note that you' still lose precision on decimals with default settings. + +```py +"""Uses unsw_flow dataset (~2mln rows, 25+ columns) to test connectorx speed""" +import os +from dlt.destinations import filesystem + +unsw_table = sql_table( + "postgresql://loader:loader@localhost:5432/dlt_data", + "unsw_flow_7", + "speed_test", + # this is ignored by connectorx + chunk_size=100000, + backend="connectorx", + # keep source data types + detect_precision_hints=True, + # just to demonstrate how to setup a separate connection string for connectorx + backend_kwargs={"conn": "postgresql://loader:loader@localhost:5432/dlt_data"} +) + +pipeline = dlt.pipeline( + pipeline_name="unsw_download", + destination=filesystem(os.path.abspath("../_storage/unsw")), + progress="log", + full_refresh=True, +) + +info = pipeline.run( + unsw_table, + dataset_name="speed_test", + table_name="unsw_flow", + loader_file_format="parquet", +) +print(info) +``` +With dataset above and local postgres instance, connectorx is 2x faster than pyarrow backend. + +## Notes on source databases + +### Oracle +1. When using **oracledb** dialect in thin mode we are getting protocol errors. Use thick mode or **cx_oracle** (old) client. +2. Mind that **sqlalchemy** translates Oracle identifiers into lower case! Keep the default `dlt` naming convention (`snake_case`) when loading data. We'll support more naming conventions soon. +3. Connectorx is for some reason slower for Oracle than `pyarrow` backend. + +### DB2 +1. Mind that **sqlalchemy** translates DB2 identifiers into lower case! Keep the default `dlt` naming convention (`snake_case`) when loading data. We'll support more naming conventions soon. +2. DB2 `DOUBLE` type is mapped to `Numeric` SqlAlchemy type with default precision, still `float` python types are returned. That requires `dlt` to perform additional casts. The cost of the cast however is minuscule compared to the cost of reading rows from database + +### MySQL +1. SqlAlchemy dialect converts doubles to decimals, we disable that behavior via table adapter in our demo pipeline + +### Postgres / MSSQL +No issues found. Postgres is the only backend where we observed 2x speedup with connector x. On other db systems it performs same as `pyarrrow` backend or slower. + +## Learn more +šŸ’” To explore additional customizations for this pipeline, we recommend referring to the official DLT SQL Database verified documentation. It provides comprehensive information and guidance on how to further customize and tailor the pipeline to suit your specific needs. You can find the DLT SQL Database documentation in [Setup Guide: SQL Database.](https://dlthub.com/docs/dlt-ecosystem/verified-sources/sql_database) diff --git a/dlt/sources/sql_database/__init__.py b/dlt/sources/sql_database/__init__.py new file mode 100644 index 0000000000..75172b5bd9 --- /dev/null +++ b/dlt/sources/sql_database/__init__.py @@ -0,0 +1,213 @@ +"""Source that loads tables form any SQLAlchemy supported database, supports batching requests and incremental loads.""" + +from typing import Callable, Dict, List, Optional, Union, Iterable, Any +from sqlalchemy import MetaData, Table +from sqlalchemy.engine import Engine + +import dlt +from dlt.sources import DltResource + + +from dlt.sources.credentials import ConnectionStringCredentials +from dlt.common.configuration.specs.config_section_context import ConfigSectionContext + +from .helpers import ( + table_rows, + engine_from_credentials, + TableBackend, + SqlDatabaseTableConfiguration, + SqlTableResourceConfiguration, + _detect_precision_hints_deprecated, + TQueryAdapter, +) +from .schema_types import ( + default_table_adapter, + table_to_columns, + get_primary_key, + ReflectionLevel, + TTypeAdapter, +) + + +@dlt.source +def sql_database( + credentials: Union[ConnectionStringCredentials, Engine, str] = dlt.secrets.value, + schema: Optional[str] = dlt.config.value, + metadata: Optional[MetaData] = None, + table_names: Optional[List[str]] = dlt.config.value, + chunk_size: int = 50000, + backend: TableBackend = "sqlalchemy", + detect_precision_hints: Optional[bool] = False, + reflection_level: Optional[ReflectionLevel] = "full", + defer_table_reflect: Optional[bool] = None, + table_adapter_callback: Callable[[Table], None] = None, + backend_kwargs: Dict[str, Any] = None, + include_views: bool = False, + type_adapter_callback: Optional[TTypeAdapter] = None, + query_adapter_callback: Optional[TQueryAdapter] = None, +) -> Iterable[DltResource]: + """ + A dlt source which loads data from an SQL database using SQLAlchemy. + Resources are automatically created for each table in the schema or from the given list of tables. + + Args: + credentials (Union[ConnectionStringCredentials, Engine, str]): Database credentials or an `sqlalchemy.Engine` instance. + schema (Optional[str]): Name of the database schema to load (if different from default). + metadata (Optional[MetaData]): Optional `sqlalchemy.MetaData` instance. `schema` argument is ignored when this is used. + table_names (Optional[List[str]]): A list of table names to load. By default, all tables in the schema are loaded. + chunk_size (int): Number of rows yielded in one batch. SQL Alchemy will create additional internal rows buffer twice the chunk size. + backend (TableBackend): Type of backend to generate table data. One of: "sqlalchemy", "pyarrow", "pandas" and "connectorx". + "sqlalchemy" yields batches as lists of Python dictionaries, "pyarrow" and "connectorx" yield batches as arrow tables, "pandas" yields panda frames. + "sqlalchemy" is the default and does not require additional dependencies, "pyarrow" creates stable destination schemas with correct data types, + "connectorx" is typically the fastest but ignores the "chunk_size" so you must deal with large tables yourself. + detect_precision_hints (bool): Deprecated. Use `reflection_level`. Set column precision and scale hints for supported data types in the target schema based on the columns in the source tables. + This is disabled by default. + reflection_level: (ReflectionLevel): Specifies how much information should be reflected from the source database schema. + "minimal": Only table names, nullability and primary keys are reflected. Data types are inferred from the data. + "full": Data types will be reflected on top of "minimal". `dlt` will coerce the data into reflected types if necessary. This is the default option. + "full_with_precision": Sets precision and scale on supported data types (ie. decimal, text, binary). Creates big and regular integer types. + defer_table_reflect (bool): Will connect and reflect table schema only when yielding data. Requires table_names to be explicitly passed. + Enable this option when running on Airflow. Available on dlt 0.4.4 and later. + table_adapter_callback: (Callable): Receives each reflected table. May be used to modify the list of columns that will be selected. + backend_kwargs (**kwargs): kwargs passed to table backend ie. "conn" is used to pass specialized connection string to connectorx. + include_views (bool): Reflect views as well as tables. Note view names included in `table_names` are always included regardless of this setting. + type_adapter_callback(Optional[Callable]): Callable to override type inference when reflecting columns. + Argument is a single sqlalchemy data type (`TypeEngine` instance) and it should return another sqlalchemy data type, or `None` (type will be inferred from data) + query_adapter_callback(Optional[Callable[Select, Table], Select]): Callable to override the SELECT query used to fetch data from the table. + The callback receives the sqlalchemy `Select` and corresponding `Table` objects and should return the modified `Select`. + + Returns: + Iterable[DltResource]: A list of DLT resources for each table to be loaded. + """ + # detect precision hints is deprecated + _detect_precision_hints_deprecated(detect_precision_hints) + + if detect_precision_hints: + reflection_level = "full_with_precision" + else: + reflection_level = reflection_level or "minimal" + + # set up alchemy engine + engine = engine_from_credentials(credentials) + engine.execution_options(stream_results=True, max_row_buffer=2 * chunk_size) + metadata = metadata or MetaData(schema=schema) + + # use provided tables or all tables + if table_names: + tables = [ + Table(name, metadata, autoload_with=None if defer_table_reflect else engine) + for name in table_names + ] + else: + if defer_table_reflect: + raise ValueError("You must pass table names to defer table reflection") + metadata.reflect(bind=engine, views=include_views) + tables = list(metadata.tables.values()) + + for table in tables: + yield sql_table( + credentials=credentials, + table=table.name, + schema=table.schema, + metadata=metadata, + chunk_size=chunk_size, + backend=backend, + reflection_level=reflection_level, + defer_table_reflect=defer_table_reflect, + table_adapter_callback=table_adapter_callback, + backend_kwargs=backend_kwargs, + type_adapter_callback=type_adapter_callback, + query_adapter_callback=query_adapter_callback, + ) + + +@dlt.resource(name=lambda args: args["table"], standalone=True, spec=SqlTableResourceConfiguration) +def sql_table( + credentials: Union[ConnectionStringCredentials, Engine, str] = dlt.secrets.value, + table: str = dlt.config.value, + schema: Optional[str] = dlt.config.value, + metadata: Optional[MetaData] = None, + incremental: Optional[dlt.sources.incremental[Any]] = None, + chunk_size: int = 50000, + backend: TableBackend = "sqlalchemy", + detect_precision_hints: Optional[bool] = None, + reflection_level: Optional[ReflectionLevel] = "full", + defer_table_reflect: Optional[bool] = None, + table_adapter_callback: Callable[[Table], None] = None, + backend_kwargs: Dict[str, Any] = None, + type_adapter_callback: Optional[TTypeAdapter] = None, + included_columns: Optional[List[str]] = None, + query_adapter_callback: Optional[TQueryAdapter] = None, +) -> DltResource: + """ + A dlt resource which loads data from an SQL database table using SQLAlchemy. + + Args: + credentials (Union[ConnectionStringCredentials, Engine, str]): Database credentials or an `Engine` instance representing the database connection. + table (str): Name of the table or view to load. + schema (Optional[str]): Optional name of the schema the table belongs to. + metadata (Optional[MetaData]): Optional `sqlalchemy.MetaData` instance. If provided, the `schema` argument is ignored. + incremental (Optional[dlt.sources.incremental[Any]]): Option to enable incremental loading for the table. + E.g., `incremental=dlt.sources.incremental('updated_at', pendulum.parse('2022-01-01T00:00:00Z'))` + chunk_size (int): Number of rows yielded in one batch. SQL Alchemy will create additional internal rows buffer twice the chunk size. + backend (TableBackend): Type of backend to generate table data. One of: "sqlalchemy", "pyarrow", "pandas" and "connectorx". + "sqlalchemy" yields batches as lists of Python dictionaries, "pyarrow" and "connectorx" yield batches as arrow tables, "pandas" yields panda frames. + "sqlalchemy" is the default and does not require additional dependencies, "pyarrow" creates stable destination schemas with correct data types, + "connectorx" is typically the fastest but ignores the "chunk_size" so you must deal with large tables yourself. + reflection_level: (ReflectionLevel): Specifies how much information should be reflected from the source database schema. + "minimal": Only table names, nullability and primary keys are reflected. Data types are inferred from the data. + "full": Data types will be reflected on top of "minimal". `dlt` will coerce the data into reflected types if necessary. This is the default option. + "full_with_precision": Sets precision and scale on supported data types (ie. decimal, text, binary). Creates big and regular integer types. + detect_precision_hints (bool): Deprecated. Use `reflection_level`. Set column precision and scale hints for supported data types in the target schema based on the columns in the source tables. + This is disabled by default. + defer_table_reflect (bool): Will connect and reflect table schema only when yielding data. Enable this option when running on Airflow. Available + on dlt 0.4.4 and later + table_adapter_callback: (Callable): Receives each reflected table. May be used to modify the list of columns that will be selected. + backend_kwargs (**kwargs): kwargs passed to table backend ie. "conn" is used to pass specialized connection string to connectorx. + type_adapter_callback(Optional[Callable]): Callable to override type inference when reflecting columns. + Argument is a single sqlalchemy data type (`TypeEngine` instance) and it should return another sqlalchemy data type, or `None` (type will be inferred from data) + included_columns (Optional[List[str]): List of column names to select from the table. If not provided, all columns are loaded. + query_adapter_callback(Optional[Callable[Select, Table], Select]): Callable to override the SELECT query used to fetch data from the table. + The callback receives the sqlalchemy `Select` and corresponding `Table` objects and should return the modified `Select`. + + Returns: + DltResource: The dlt resource for loading data from the SQL database table. + """ + _detect_precision_hints_deprecated(detect_precision_hints) + + if detect_precision_hints: + reflection_level = "full_with_precision" + else: + reflection_level = reflection_level or "minimal" + + engine = engine_from_credentials(credentials, may_dispose_after_use=True) + engine.execution_options(stream_results=True, max_row_buffer=2 * chunk_size) + metadata = metadata or MetaData(schema=schema) + + table_obj = metadata.tables.get("table") or Table( + table, metadata, autoload_with=None if defer_table_reflect else engine + ) + if not defer_table_reflect: + default_table_adapter(table_obj, included_columns) + if table_adapter_callback: + table_adapter_callback(table_obj) + + return dlt.resource( + table_rows, + name=table_obj.name, + primary_key=get_primary_key(table_obj), + columns=table_to_columns(table_obj, reflection_level, type_adapter_callback), + )( + engine, + table_obj, + chunk_size, + backend, + incremental=incremental, + reflection_level=reflection_level, + defer_table_reflect=defer_table_reflect, + table_adapter_callback=table_adapter_callback, + backend_kwargs=backend_kwargs, + type_adapter_callback=type_adapter_callback, + included_columns=included_columns, + query_adapter_callback=query_adapter_callback, + ) diff --git a/dlt/sources/sql_database/arrow_helpers.py b/dlt/sources/sql_database/arrow_helpers.py new file mode 100644 index 0000000000..898d8c3280 --- /dev/null +++ b/dlt/sources/sql_database/arrow_helpers.py @@ -0,0 +1,150 @@ +from typing import Any, Sequence, Optional + +from dlt.common.schema.typing import TTableSchemaColumns +from dlt.common import logger, json +from dlt.common.configuration import with_config +from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.json import custom_encode, map_nested_in_place + +from .schema_types import RowAny + + +@with_config +def columns_to_arrow( + columns_schema: TTableSchemaColumns, + caps: DestinationCapabilitiesContext = None, + tz: str = "UTC", +) -> Any: + """Converts `column_schema` to arrow schema using `caps` and `tz`. `caps` are injected from the container - which + is always the case if run within the pipeline. This will generate arrow schema compatible with the destination. + Otherwise generic capabilities are used + """ + from dlt.common.libs.pyarrow import pyarrow as pa, get_py_arrow_datatype + from dlt.common.destination.capabilities import DestinationCapabilitiesContext + + return pa.schema( + [ + pa.field( + name, + get_py_arrow_datatype( + schema_item, + caps or DestinationCapabilitiesContext.generic_capabilities(), + tz, + ), + nullable=schema_item.get("nullable", True), + ) + for name, schema_item in columns_schema.items() + if schema_item.get("data_type") is not None + ] + ) + + +def row_tuples_to_arrow(rows: Sequence[RowAny], columns: TTableSchemaColumns, tz: str) -> Any: + """Converts the rows to an arrow table using the columns schema. + Columns missing `data_type` will be inferred from the row data. + Columns with object types not supported by arrow are excluded from the resulting table. + """ + from dlt.common.libs.pyarrow import pyarrow as pa + import numpy as np + + try: + from pandas._libs import lib + + pivoted_rows = lib.to_object_array_tuples(rows).T + except ImportError: + logger.info( + "Pandas not installed, reverting to numpy.asarray to create a table which is slower" + ) + pivoted_rows = np.asarray(rows, dtype="object", order="k").T # type: ignore[call-overload] + + columnar = { + col: dat.ravel() for col, dat in zip(columns, np.vsplit(pivoted_rows, len(columns))) + } + columnar_known_types = { + col["name"]: columnar[col["name"]] + for col in columns.values() + if col.get("data_type") is not None + } + columnar_unknown_types = { + col["name"]: columnar[col["name"]] + for col in columns.values() + if col.get("data_type") is None + } + + arrow_schema = columns_to_arrow(columns, tz=tz) + + for idx in range(0, len(arrow_schema.names)): + field = arrow_schema.field(idx) + py_type = type(rows[0][idx]) + # cast double / float ndarrays to decimals if type mismatch, looks like decimals and floats are often mixed up in dialects + if pa.types.is_decimal(field.type) and issubclass(py_type, (str, float)): + logger.warning( + f"Field {field.name} was reflected as decimal type, but rows contains" + f" {py_type.__name__}. Additional cast is required which may slow down arrow table" + " generation." + ) + float_array = pa.array(columnar_known_types[field.name], type=pa.float64()) + columnar_known_types[field.name] = float_array.cast(field.type, safe=False) + if issubclass(py_type, (dict, list)): + logger.warning( + f"Field {field.name} was reflected as JSON type and needs to be serialized back to" + " string to be placed in arrow table. This will slow data extraction down. You" + " should cast JSON field to STRING in your database system ie. by creating and" + " extracting an SQL VIEW that selects with cast." + ) + json_str_array = pa.array( + [None if s is None else json.dumps(s) for s in columnar_known_types[field.name]] + ) + columnar_known_types[field.name] = json_str_array + + # If there are unknown type columns, first create a table to infer their types + if columnar_unknown_types: + new_schema_fields = [] + for key in list(columnar_unknown_types): + arrow_col: Optional[pa.Array] = None + try: + arrow_col = pa.array(columnar_unknown_types[key]) + if pa.types.is_null(arrow_col.type): + logger.warning( + f"Column {key} contains only NULL values and data type could not be" + " inferred. This column is removed from a arrow table" + ) + continue + + except pa.ArrowInvalid as e: + # Try coercing types not supported by arrow to a json friendly format + # E.g. dataclasses -> dict, UUID -> str + try: + arrow_col = pa.array( + map_nested_in_place(custom_encode, list(columnar_unknown_types[key])) + ) + logger.warning( + f"Column {key} contains a data type which is not supported by pyarrow and" + f" got converted into {arrow_col.type}. This slows down arrow table" + " generation." + ) + except (pa.ArrowInvalid, TypeError): + logger.warning( + f"Column {key} contains a data type which is not supported by pyarrow. This" + f" column will be ignored. Error: {e}" + ) + if arrow_col is not None: + columnar_known_types[key] = arrow_col + new_schema_fields.append( + pa.field( + key, + arrow_col.type, + nullable=columns[key]["nullable"], + ) + ) + + # New schema + column_order = {name: idx for idx, name in enumerate(columns)} + arrow_schema = pa.schema( + sorted( + list(arrow_schema) + new_schema_fields, + key=lambda x: column_order[x.name], + ) + ) + + return pa.Table.from_pydict(columnar_known_types, schema=arrow_schema) diff --git a/dlt/sources/sql_database/helpers.py b/dlt/sources/sql_database/helpers.py new file mode 100644 index 0000000000..f9a8470e9b --- /dev/null +++ b/dlt/sources/sql_database/helpers.py @@ -0,0 +1,313 @@ +"""SQL database source helpers""" + +import warnings +from typing import ( + Callable, + Any, + Dict, + List, + Literal, + Optional, + Iterator, + Union, +) +import operator + +import dlt +from dlt.common.configuration.specs import BaseConfiguration, configspec +from dlt.common.exceptions import MissingDependencyException +from dlt.common.schema import TTableSchemaColumns +from dlt.common.typing import TDataItem, TSortOrder + +from dlt.sources.credentials import ConnectionStringCredentials + +from .arrow_helpers import row_tuples_to_arrow +from .schema_types import ( + default_table_adapter, + table_to_columns, + get_primary_key, + Table, + SelectAny, + ReflectionLevel, + TTypeAdapter, +) + +from sqlalchemy import create_engine +from sqlalchemy.engine import Engine +from sqlalchemy.exc import CompileError + + +TableBackend = Literal["sqlalchemy", "pyarrow", "pandas", "connectorx"] +TQueryAdapter = Callable[[SelectAny, Table], SelectAny] + + +class TableLoader: + def __init__( + self, + engine: Engine, + backend: TableBackend, + table: Table, + columns: TTableSchemaColumns, + chunk_size: int = 1000, + incremental: Optional[dlt.sources.incremental[Any]] = None, + query_adapter_callback: Optional[TQueryAdapter] = None, + ) -> None: + self.engine = engine + self.backend = backend + self.table = table + self.columns = columns + self.chunk_size = chunk_size + self.query_adapter_callback = query_adapter_callback + self.incremental = incremental + if incremental: + try: + self.cursor_column = table.c[incremental.cursor_path] + except KeyError as e: + raise KeyError( + f"Cursor column '{incremental.cursor_path}' does not exist in table" + f" '{table.name}'" + ) from e + self.last_value = incremental.last_value + self.end_value = incremental.end_value + self.row_order: TSortOrder = self.incremental.row_order + else: + self.cursor_column = None + self.last_value = None + self.end_value = None + self.row_order = None + + def _make_query(self) -> SelectAny: + table = self.table + query = table.select() + if not self.incremental: + return query # type: ignore[no-any-return] + last_value_func = self.incremental.last_value_func + + # generate where + if last_value_func is max: # Query ordered and filtered according to last_value function + filter_op = operator.ge + filter_op_end = operator.lt + elif last_value_func is min: + filter_op = operator.le + filter_op_end = operator.gt + else: # Custom last_value, load everything and let incremental handle filtering + return query # type: ignore[no-any-return] + + if self.last_value is not None: + query = query.where(filter_op(self.cursor_column, self.last_value)) + if self.end_value is not None: + query = query.where(filter_op_end(self.cursor_column, self.end_value)) + + # generate order by from declared row order + order_by = None + if (self.row_order == "asc" and last_value_func is max) or ( + self.row_order == "desc" and last_value_func is min + ): + order_by = self.cursor_column.asc() + elif (self.row_order == "asc" and last_value_func is min) or ( + self.row_order == "desc" and last_value_func is max + ): + order_by = self.cursor_column.desc() + if order_by is not None: + query = query.order_by(order_by) + + return query # type: ignore[no-any-return] + + def make_query(self) -> SelectAny: + if self.query_adapter_callback: + return self.query_adapter_callback(self._make_query(), self.table) + return self._make_query() + + def load_rows(self, backend_kwargs: Dict[str, Any] = None) -> Iterator[TDataItem]: + # make copy of kwargs + backend_kwargs = dict(backend_kwargs or {}) + query = self.make_query() + if self.backend == "connectorx": + yield from self._load_rows_connectorx(query, backend_kwargs) + else: + yield from self._load_rows(query, backend_kwargs) + + def _load_rows(self, query: SelectAny, backend_kwargs: Dict[str, Any]) -> TDataItem: + with self.engine.connect() as conn: + result = conn.execution_options(yield_per=self.chunk_size).execute(query) + # NOTE: cursor returns not normalized column names! may be quite useful in case of Oracle dialect + # that normalizes columns + # columns = [c[0] for c in result.cursor.description] + columns = list(result.keys()) + for partition in result.partitions(size=self.chunk_size): + if self.backend == "sqlalchemy": + yield [dict(row._mapping) for row in partition] + elif self.backend == "pandas": + from dlt.common.libs.pandas_sql import _wrap_result + + df = _wrap_result( + partition, + columns, + **{"dtype_backend": "pyarrow", **backend_kwargs}, + ) + yield df + elif self.backend == "pyarrow": + yield row_tuples_to_arrow( + partition, self.columns, tz=backend_kwargs.get("tz", "UTC") + ) + + def _load_rows_connectorx( + self, query: SelectAny, backend_kwargs: Dict[str, Any] + ) -> Iterator[TDataItem]: + try: + import connectorx as cx + except ImportError: + raise MissingDependencyException("Connector X table backend", ["connectorx"]) + + # default settings + backend_kwargs = { + "return_type": "arrow2", + "protocol": "binary", + **backend_kwargs, + } + conn = backend_kwargs.pop( + "conn", + self.engine.url._replace( + drivername=self.engine.url.get_backend_name() + ).render_as_string(hide_password=False), + ) + try: + query_str = str(query.compile(self.engine, compile_kwargs={"literal_binds": True})) + except CompileError as ex: + raise NotImplementedError( + f"Query for table {self.table.name} could not be compiled to string to execute it" + " on ConnectorX. If you are on SQLAlchemy 1.4.x the causing exception is due to" + f" literals that cannot be rendered, upgrade to 2.x: {str(ex)}" + ) from ex + df = cx.read_sql(conn, query_str, **backend_kwargs) + yield df + + +def table_rows( + engine: Engine, + table: Table, + chunk_size: int, + backend: TableBackend, + incremental: Optional[dlt.sources.incremental[Any]] = None, + defer_table_reflect: bool = False, + table_adapter_callback: Callable[[Table], None] = None, + reflection_level: ReflectionLevel = "minimal", + backend_kwargs: Dict[str, Any] = None, + type_adapter_callback: Optional[TTypeAdapter] = None, + included_columns: Optional[List[str]] = None, + query_adapter_callback: Optional[TQueryAdapter] = None, +) -> Iterator[TDataItem]: + columns: TTableSchemaColumns = None + if defer_table_reflect: + table = Table(table.name, table.metadata, autoload_with=engine, extend_existing=True) # type: ignore[attr-defined] + default_table_adapter(table, included_columns) + if table_adapter_callback: + table_adapter_callback(table) + columns = table_to_columns(table, reflection_level, type_adapter_callback) + + # set the primary_key in the incremental + if incremental and incremental.primary_key is None: + primary_key = get_primary_key(table) + if primary_key is not None: + incremental.primary_key = primary_key + + # yield empty record to set hints + yield dlt.mark.with_hints( + [], + dlt.mark.make_hints( + primary_key=get_primary_key(table), + columns=columns, + ), + ) + else: + # table was already reflected + columns = table_to_columns(table, reflection_level, type_adapter_callback) + + loader = TableLoader( + engine, + backend, + table, + columns, + incremental=incremental, + chunk_size=chunk_size, + query_adapter_callback=query_adapter_callback, + ) + try: + yield from loader.load_rows(backend_kwargs) + finally: + # dispose the engine if created for this particular table + # NOTE: database wide engines are not disposed, not externally provided + if getattr(engine, "may_dispose_after_use", False): + engine.dispose() + + +def engine_from_credentials( + credentials: Union[ConnectionStringCredentials, Engine, str], + may_dispose_after_use: bool = False, + **backend_kwargs: Any, +) -> Engine: + if isinstance(credentials, Engine): + return credentials + if isinstance(credentials, ConnectionStringCredentials): + credentials = credentials.to_native_representation() + engine = create_engine(credentials, **backend_kwargs) + setattr(engine, "may_dispose_after_use", may_dispose_after_use) # noqa + return engine # type: ignore[no-any-return] + + +def unwrap_json_connector_x(field: str) -> TDataItem: + """Creates a transform function to be added with `add_map` that will unwrap JSON columns + ingested via connectorx. Such columns are additionally quoted and translate SQL NULL to json "null" + """ + import pyarrow.compute as pc + import pyarrow as pa + + def _unwrap(table: TDataItem) -> TDataItem: + col_index = table.column_names.index(field) + # remove quotes + column = pc.replace_substring_regex(table[field], '"(.*)"', "\\1") + # convert json null to null + column = pc.replace_with_mask( + column, + pc.equal(column, "null").combine_chunks(), + pa.scalar(None, pa.large_string()), + ) + return table.set_column(col_index, table.schema.field(col_index), column) + + return _unwrap + + +def _detect_precision_hints_deprecated(value: Optional[bool]) -> None: + if value is None: + return + + msg = ( + "`detect_precision_hints` argument is deprecated and will be removed in a future release. " + ) + if value: + msg += "Use `reflection_level='full_with_precision'` which has the same effect instead." + + warnings.warn( + msg, + DeprecationWarning, + ) + + +@configspec +class SqlDatabaseTableConfiguration(BaseConfiguration): + incremental: Optional[dlt.sources.incremental] = None # type: ignore[type-arg] + included_columns: Optional[List[str]] = None + + +@configspec +class SqlTableResourceConfiguration(BaseConfiguration): + credentials: Union[ConnectionStringCredentials, Engine, str] = None + table: str = None + schema: Optional[str] = None + incremental: Optional[dlt.sources.incremental] = None # type: ignore[type-arg] + chunk_size: int = 50000 + backend: TableBackend = "sqlalchemy" + detect_precision_hints: Optional[bool] = None + defer_table_reflect: Optional[bool] = False + reflection_level: Optional[ReflectionLevel] = "full" + included_columns: Optional[List[str]] = None diff --git a/dlt/sources/sql_database/schema_types.py b/dlt/sources/sql_database/schema_types.py new file mode 100644 index 0000000000..7a6e0a3daa --- /dev/null +++ b/dlt/sources/sql_database/schema_types.py @@ -0,0 +1,157 @@ +from typing import ( + Optional, + Any, + Type, + TYPE_CHECKING, + Literal, + List, + Callable, + Union, +) +from typing_extensions import TypeAlias +from sqlalchemy import Table, Column +from sqlalchemy.engine import Row +from sqlalchemy.sql import sqltypes, Select +from sqlalchemy.sql.sqltypes import TypeEngine + +from dlt.common import logger +from dlt.common.schema.typing import TColumnSchema, TTableSchemaColumns + +ReflectionLevel = Literal["minimal", "full", "full_with_precision"] + + +# optionally create generics with any so they can be imported by dlt importer +if TYPE_CHECKING: + SelectAny: TypeAlias = Select[Any] # type: ignore[type-arg] + ColumnAny: TypeAlias = Column[Any] # type: ignore[type-arg] + RowAny: TypeAlias = Row[Any] # type: ignore[type-arg] + TypeEngineAny = TypeEngine[Any] # type: ignore[type-arg] +else: + SelectAny: TypeAlias = Type[Any] + ColumnAny: TypeAlias = Type[Any] + RowAny: TypeAlias = Type[Any] + TypeEngineAny = Type[Any] + + +TTypeAdapter = Callable[[TypeEngineAny], Optional[Union[TypeEngineAny, Type[TypeEngineAny]]]] + + +def default_table_adapter(table: Table, included_columns: Optional[List[str]]) -> None: + """Default table adapter being always called before custom one""" + if included_columns is not None: + # Delete columns not included in the load + for col in list(table._columns): # type: ignore[attr-defined] + if col.name not in included_columns: + table._columns.remove(col) # type: ignore[attr-defined] + for col in table._columns: # type: ignore[attr-defined] + sql_t = col.type + # if isinstance(sql_t, sqltypes.Uuid): # in sqlalchemy 2.0 uuid type is available + # emit uuids as string by default + sql_t.as_uuid = False + + +def sqla_col_to_column_schema( + sql_col: ColumnAny, + reflection_level: ReflectionLevel, + type_adapter_callback: Optional[TTypeAdapter] = None, +) -> Optional[TColumnSchema]: + """Infer dlt schema column type from an sqlalchemy type. + + If `add_precision` is set, precision and scale is inferred from that types that support it, + such as numeric, varchar, int, bigint. Numeric (decimal) types have always precision added. + """ + col: TColumnSchema = { + "name": sql_col.name, + "nullable": sql_col.nullable, + } + if reflection_level == "minimal": + return col + + sql_t = sql_col.type + + if type_adapter_callback: + sql_t = type_adapter_callback(sql_t) + # Check if sqla type class rather than instance is returned + if sql_t is not None and isinstance(sql_t, type): + sql_t = sql_t() + + if sql_t is None: + # Column ignored by callback + return col + + add_precision = reflection_level == "full_with_precision" + + # if isinstance(sql_t, sqltypes.Uuid): + # # we represent UUID as text by default, see default_table_adapter + # col["data_type"] = "text" + if isinstance(sql_t, sqltypes.Numeric): + # check for Numeric type first and integer later, some numeric types (ie. Oracle) + # derive from both + # all Numeric types that are returned as floats will assume "double" type + # and returned as decimals will assume "decimal" type + if sql_t.asdecimal is False: + col["data_type"] = "double" + else: + col["data_type"] = "decimal" + if sql_t.precision is not None: + col["precision"] = sql_t.precision + # must have a precision for any meaningful scale + if sql_t.scale is not None: + col["scale"] = sql_t.scale + elif sql_t.decimal_return_scale is not None: + col["scale"] = sql_t.decimal_return_scale + elif isinstance(sql_t, sqltypes.SmallInteger): + col["data_type"] = "bigint" + if add_precision: + col["precision"] = 32 + elif isinstance(sql_t, sqltypes.Integer): + col["data_type"] = "bigint" + elif isinstance(sql_t, sqltypes.String): + col["data_type"] = "text" + if add_precision and sql_t.length: + col["precision"] = sql_t.length + elif isinstance(sql_t, sqltypes._Binary): + col["data_type"] = "binary" + if add_precision and sql_t.length: + col["precision"] = sql_t.length + elif isinstance(sql_t, sqltypes.DateTime): + col["data_type"] = "timestamp" + elif isinstance(sql_t, sqltypes.Date): + col["data_type"] = "date" + elif isinstance(sql_t, sqltypes.Time): + col["data_type"] = "time" + elif isinstance(sql_t, sqltypes.JSON): + col["data_type"] = "complex" + elif isinstance(sql_t, sqltypes.Boolean): + col["data_type"] = "bool" + else: + logger.warning( + f"A column with name {sql_col.name} contains unknown data type {sql_t} which cannot be" + " mapped to `dlt` data type. When using sqlalchemy backend such data will be passed to" + " the normalizer. In case of `pyarrow` and `pandas` backend, data types are detected" + " from numpy ndarrays. In case of other backends, the behavior is backend-specific." + ) + + return {key: value for key, value in col.items() if value is not None} # type: ignore[return-value] + + +def get_primary_key(table: Table) -> Optional[List[str]]: + """Create primary key or return None if no key defined""" + primary_key = [c.name for c in table.primary_key] + return primary_key if len(primary_key) > 0 else None + + +def table_to_columns( + table: Table, + reflection_level: ReflectionLevel = "full", + type_conversion_fallback: Optional[TTypeAdapter] = None, +) -> TTableSchemaColumns: + """Convert an sqlalchemy table to a dlt table schema.""" + return { + col["name"]: col + for col in ( + sqla_col_to_column_schema(c, reflection_level, type_conversion_fallback) + for c in table.columns + ) + if col is not None + } diff --git a/poetry.lock b/poetry.lock index bcff46e77a..8338a9dba3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5197,6 +5197,17 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "mimesis" +version = "7.1.0" +description = "Mimesis: Fake Data Generator." +optional = false +python-versions = ">=3.8,<4.0" +files = [ + {file = "mimesis-7.1.0-py3-none-any.whl", hash = "sha256:da65bea6d6d5d5d87d5c008e6b23ef5f96a49cce436d9f8708dabb5152da0290"}, + {file = "mimesis-7.1.0.tar.gz", hash = "sha256:c83b55d35536d7e9b9700a596b7ccfb639a740e3e1fb5e08062e8ab2a67dcb37"}, +] + [[package]] name = "minimal-snowplow-tracker" version = "0.0.2" @@ -9684,4 +9695,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "65dce0b89ae712f21b53d826bc90d29e53fd9d093628d4bbb8b7ab2dd9b7528a" +content-hash = "388e7501b69b8468ae030474f97c468a645452877a11f87c3d5226cb318bec7c" diff --git a/pyproject.toml b/pyproject.toml index 2e7c7a971f..ab4f839271 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -156,6 +156,12 @@ pyjwt = "^2.8.0" pytest-mock = "^3.14.0" types-regex = "^2024.5.15.20240519" flake8-print = "^5.0.0" +mimesis = "^7.0.0" + +[tool.poetry.group.sql_database.dependencies] +sqlalchemy = ">=1.4" +pymysql = "^1.0.3" +connectorx = ">=0.3.1" [tool.poetry.group.pipeline] optional = true diff --git a/tests/.example.env b/tests/.example.env index 50eee33bd5..175544218c 100644 --- a/tests/.example.env +++ b/tests/.example.env @@ -19,6 +19,6 @@ DESTINATION__REDSHIFT__CREDENTIALS__USERNAME=loader DESTINATION__REDSHIFT__CREDENTIALS__HOST=3.73.90.3 DESTINATION__REDSHIFT__CREDENTIALS__PASSWORD=set-me-up -DESTINATION__POSTGRES__CREDENTIALS=postgres://loader:loader@localhost:5432/dlt_data +DESTINATION__POSTGRES__CREDENTIALS=postgresql://loader:loader@localhost:5432/dlt_data DESTINATION__DUCKDB__CREDENTIALS=duckdb:///_storage/test_quack.duckdb -RUNTIME__SENTRY_DSN=https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752 \ No newline at end of file +RUNTIME__SENTRY_DSN=https://6f6f7b6f8e0f458a89be4187603b55fe@o1061158.ingest.sentry.io/4504819859914752 diff --git a/tests/load/sources/sql_database/__init__.py b/tests/load/sources/sql_database/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/load/sources/sql_database/conftest.py b/tests/load/sources/sql_database/conftest.py new file mode 100644 index 0000000000..1372663663 --- /dev/null +++ b/tests/load/sources/sql_database/conftest.py @@ -0,0 +1 @@ +from tests.sources.sql_database.conftest import * # noqa: F403 diff --git a/tests/load/sources/sql_database/test_sql_database.py b/tests/load/sources/sql_database/test_sql_database.py new file mode 100644 index 0000000000..303030cf82 --- /dev/null +++ b/tests/load/sources/sql_database/test_sql_database.py @@ -0,0 +1,331 @@ +import os +from typing import Any, List + +import humanize +import pytest + +import dlt +from dlt.sources import DltResource +from dlt.sources.credentials import ConnectionStringCredentials +from dlt.sources.sql_database import TableBackend, sql_database, sql_table +from tests.load.utils import ( + DestinationTestConfiguration, + destinations_configs, +) +from tests.pipeline.utils import ( + assert_load_info, + load_table_counts, +) +from tests.sources.sql_database.sql_source import SQLAlchemySourceDB +from tests.sources.sql_database.test_helpers import mock_json_column +from tests.sources.sql_database.test_sql_database_source import ( + assert_row_counts, + convert_time_to_us, + default_test_callback, +) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("backend", ["sqlalchemy", "pandas", "pyarrow", "connectorx"]) +def test_load_sql_schema_loads_all_tables( + sql_source_db: SQLAlchemySourceDB, + destination_config: DestinationTestConfiguration, + backend: TableBackend, + request: Any, +) -> None: + pipeline = destination_config.setup_pipeline(request.node.name, dev_mode=True) + + source = sql_database( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + backend=backend, + reflection_level="minimal", + type_adapter_callback=default_test_callback(destination_config.destination, backend), + ) + + if destination_config.destination == "bigquery" and backend == "connectorx": + # connectorx generates nanoseconds time which bigquery cannot load + source.has_precision.add_map(convert_time_to_us) + source.has_precision_nullable.add_map(convert_time_to_us) + + if backend != "sqlalchemy": + # always use mock json + source.has_precision.add_map(mock_json_column("json_col")) + source.has_precision_nullable.add_map(mock_json_column("json_col")) + + assert "chat_message_view" not in source.resources # Views are not reflected by default + + load_info = pipeline.run(source) + print(humanize.precisedelta(pipeline.last_trace.finished_at - pipeline.last_trace.started_at)) + assert_load_info(load_info) + + assert_row_counts(pipeline, sql_source_db) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("backend", ["sqlalchemy", "pandas", "pyarrow", "connectorx"]) +def test_load_sql_schema_loads_all_tables_parallel( + sql_source_db: SQLAlchemySourceDB, + destination_config: DestinationTestConfiguration, + backend: TableBackend, + request: Any, +) -> None: + pipeline = destination_config.setup_pipeline(request.node.name, dev_mode=True) + source = sql_database( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + backend=backend, + reflection_level="minimal", + type_adapter_callback=default_test_callback(destination_config.destination, backend), + ).parallelize() + + if destination_config.destination == "bigquery" and backend == "connectorx": + # connectorx generates nanoseconds time which bigquery cannot load + source.has_precision.add_map(convert_time_to_us) + source.has_precision_nullable.add_map(convert_time_to_us) + + if backend != "sqlalchemy": + # always use mock json + source.has_precision.add_map(mock_json_column("json_col")) + source.has_precision_nullable.add_map(mock_json_column("json_col")) + + load_info = pipeline.run(source) + print(humanize.precisedelta(pipeline.last_trace.finished_at - pipeline.last_trace.started_at)) + assert_load_info(load_info) + + assert_row_counts(pipeline, sql_source_db) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("backend", ["sqlalchemy", "pandas", "pyarrow", "connectorx"]) +def test_load_sql_table_names( + sql_source_db: SQLAlchemySourceDB, + destination_config: DestinationTestConfiguration, + backend: TableBackend, + request: Any, +) -> None: + pipeline = destination_config.setup_pipeline(request.node.name, dev_mode=True) + tables = ["chat_channel", "chat_message"] + load_info = pipeline.run( + sql_database( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table_names=tables, + reflection_level="minimal", + backend=backend, + ) + ) + assert_load_info(load_info) + + assert_row_counts(pipeline, sql_source_db, tables) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("backend", ["sqlalchemy", "pandas", "pyarrow", "connectorx"]) +def test_load_sql_table_incremental( + sql_source_db: SQLAlchemySourceDB, + destination_config: DestinationTestConfiguration, + backend: TableBackend, + request: Any, +) -> None: + """Run pipeline twice. Insert more rows after first run + and ensure only those rows are stored after the second run. + """ + os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = "updated_at" + + pipeline = destination_config.setup_pipeline(request.node.name, dev_mode=True) + tables = ["chat_message"] + + def make_source(): + return sql_database( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table_names=tables, + reflection_level="minimal", + backend=backend, + ) + + load_info = pipeline.run(make_source()) + assert_load_info(load_info) + sql_source_db.fake_messages(n=100) + load_info = pipeline.run(make_source()) + assert_load_info(load_info) + + assert_row_counts(pipeline, sql_source_db, tables) + + +@pytest.mark.skip(reason="Skipping this test temporarily") +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("backend", ["sqlalchemy", "pandas", "pyarrow", "connectorx"]) +def test_load_mysql_data_load( + destination_config: DestinationTestConfiguration, backend: TableBackend, request: Any +) -> None: + # reflect a database + credentials = ConnectionStringCredentials( + "mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam" + ) + database = sql_database(credentials) + assert "family" in database.resources + + if backend == "connectorx": + # connector-x has different connection string format + backend_kwargs = {"conn": "mysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam"} + else: + backend_kwargs = {} + + # no longer needed: asdecimal used to infer decimal or not + # def _double_as_decimal_adapter(table: sa.Table) -> sa.Table: + # for column in table.columns.values(): + # if isinstance(column.type, sa.Double): + # column.type.asdecimal = False + + # load a single table + family_table = sql_table( + credentials="mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam", + table="family", + backend=backend, + reflection_level="minimal", + backend_kwargs=backend_kwargs, + # table_adapter_callback=_double_as_decimal_adapter, + ) + + pipeline = destination_config.setup_pipeline(request.node.name, dev_mode=True) + load_info = pipeline.run(family_table, write_disposition="merge") + assert_load_info(load_info) + counts_1 = load_table_counts(pipeline, "family") + + # load again also with merge + family_table = sql_table( + credentials="mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam", + table="family", + backend=backend, + reflection_level="minimal", + # we also try to remove dialect automatically + backend_kwargs={}, + # table_adapter_callback=_double_as_decimal_adapter, + ) + load_info = pipeline.run(family_table, write_disposition="merge") + assert_load_info(load_info) + counts_2 = load_table_counts(pipeline, "family") + # no duplicates + assert counts_1 == counts_2 + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("backend", ["sqlalchemy", "pandas", "pyarrow", "connectorx"]) +def test_load_sql_table_resource_loads_data( + sql_source_db: SQLAlchemySourceDB, + destination_config: DestinationTestConfiguration, + backend: TableBackend, + request: Any, +) -> None: + @dlt.source + def sql_table_source() -> List[DltResource]: + return [ + sql_table( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table="chat_message", + reflection_level="minimal", + backend=backend, + ) + ] + + pipeline = destination_config.setup_pipeline(request.node.name, dev_mode=True) + load_info = pipeline.run(sql_table_source()) + assert_load_info(load_info) + + assert_row_counts(pipeline, sql_source_db, ["chat_message"]) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_load_sql_table_resource_incremental( + sql_source_db: SQLAlchemySourceDB, + destination_config: DestinationTestConfiguration, + backend: TableBackend, + request: Any, +) -> None: + @dlt.source + def sql_table_source() -> List[DltResource]: + return [ + sql_table( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table="chat_message", + incremental=dlt.sources.incremental("updated_at"), + reflection_level="minimal", + backend=backend, + ) + ] + + pipeline = destination_config.setup_pipeline(request.node.name, dev_mode=True) + load_info = pipeline.run(sql_table_source()) + assert_load_info(load_info) + sql_source_db.fake_messages(n=100) + load_info = pipeline.run(sql_table_source()) + assert_load_info(load_info) + + assert_row_counts(pipeline, sql_source_db, ["chat_message"]) + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True), + ids=lambda x: x.name, +) +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_load_sql_table_resource_incremental_initial_value( + sql_source_db: SQLAlchemySourceDB, + destination_config: DestinationTestConfiguration, + backend: TableBackend, + request: Any, +) -> None: + @dlt.source + def sql_table_source() -> List[DltResource]: + return [ + sql_table( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table="chat_message", + incremental=dlt.sources.incremental( + "updated_at", + sql_source_db.table_infos["chat_message"]["created_at"].start_value, + ), + reflection_level="minimal", + backend=backend, + ) + ] + + pipeline = destination_config.setup_pipeline(request.node.name, dev_mode=True) + load_info = pipeline.run(sql_table_source()) + assert_load_info(load_info) + assert_row_counts(pipeline, sql_source_db, ["chat_message"]) diff --git a/tests/pipeline/utils.py b/tests/pipeline/utils.py index dfb5f3f82d..1523ace9e5 100644 --- a/tests/pipeline/utils.py +++ b/tests/pipeline/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Callable, Sequence +from typing import Any, Dict, List, Set, Callable, Sequence import pytest import random from os import environ @@ -6,16 +6,16 @@ import dlt from dlt.common import json, sleep -from dlt.common.destination.exceptions import DestinationUndefinedEntity +from dlt.common.data_types import py_type_to_sc_type from dlt.common.pipeline import LoadInfo from dlt.common.schema.utils import get_table_format from dlt.common.typing import DictStrAny from dlt.destinations.impl.filesystem.filesystem import FilesystemClient from dlt.destinations.fs_client import FSClientBase -from dlt.pipeline.exceptions import SqlClientNotAvailable -from dlt.common.storages import FileStorage from dlt.destinations.exceptions import DatabaseUndefinedRelation +from dlt.common.schema.typing import TTableSchema + PIPELINE_TEST_CASES_PATH = "./tests/pipeline/cases/" @@ -420,3 +420,66 @@ def assert_query_data( # the second is load id if info: assert row[1] in info.loads_ids + + +def assert_schema_on_data( + table_schema: TTableSchema, + rows: List[Dict[str, Any]], + requires_nulls: bool, + check_complex: bool, +) -> None: + """Asserts that `rows` conform to `table_schema`. Fields and their order must conform to columns. Null values and + python data types are checked. + """ + table_columns = table_schema["columns"] + columns_with_nulls: Set[str] = set() + for row in rows: + # check columns + assert set(table_schema["columns"].keys()) == set(row.keys()) + # check column order + assert list(table_schema["columns"].keys()) == list(row.keys()) + # check data types + for key, value in row.items(): + if value is None: + assert table_columns[key][ + "nullable" + ], f"column {key} must be nullable: value is None" + # next value. we cannot validate data type + columns_with_nulls.add(key) + continue + expected_dt = table_columns[key]["data_type"] + # allow complex strings + if expected_dt == "complex": + if check_complex: + # NOTE: we expect a dict or a list here. simple types of null will fail the test + value = json.loads(value) + else: + # skip checking complex types + continue + actual_dt = py_type_to_sc_type(type(value)) + assert actual_dt == expected_dt + + if requires_nulls: + # make sure that all nullable columns in table received nulls + assert ( + set(col["name"] for col in table_columns.values() if col["nullable"]) + == columns_with_nulls + ), "Some columns didn't receive NULLs which is required" + + +def load_table_distinct_counts( + p: dlt.Pipeline, distinct_column: str, *table_names: str +) -> DictStrAny: + """Returns counts of distinct values for column `distinct_column` for `table_names` as dict""" + with p.sql_client() as c: + query = "\nUNION ALL\n".join( + [ + f"SELECT '{name}' as name, COUNT(DISTINCT {distinct_column}) as c FROM" + f" {c.make_qualified_table_name(name)}" + for name in table_names + ] + ) + + with c.execute_query(query) as cur: + rows = list(cur.fetchall()) + return {r[0]: r[1] for r in rows} diff --git a/tests/sources/sql_database/__init__.py b/tests/sources/sql_database/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/sources/sql_database/conftest.py b/tests/sources/sql_database/conftest.py new file mode 100644 index 0000000000..d107216f1c --- /dev/null +++ b/tests/sources/sql_database/conftest.py @@ -0,0 +1,36 @@ +from typing import Iterator + +import pytest + +import dlt +from dlt.sources.credentials import ConnectionStringCredentials +from tests.sources.sql_database.sql_source import SQLAlchemySourceDB + + +def _create_db(**kwargs) -> Iterator[SQLAlchemySourceDB]: + # TODO: parametrize the fixture so it takes the credentials for all destinations + credentials = dlt.secrets.get( + "destination.postgres.credentials", expected_type=ConnectionStringCredentials + ) + + db = SQLAlchemySourceDB(credentials, **kwargs) + db.create_schema() + try: + db.create_tables() + db.insert_data() + yield db + finally: + db.drop_schema() + + +@pytest.fixture(scope="package") +def sql_source_db(request: pytest.FixtureRequest) -> Iterator[SQLAlchemySourceDB]: + # Without unsupported types so we can test full schema load with connector-x + yield from _create_db(with_unsupported_types=False) + + +@pytest.fixture(scope="package") +def sql_source_db_unsupported_types( + request: pytest.FixtureRequest, +) -> Iterator[SQLAlchemySourceDB]: + yield from _create_db(with_unsupported_types=True) diff --git a/tests/sources/sql_database/sql_source.py b/tests/sources/sql_database/sql_source.py new file mode 100644 index 0000000000..2fb1fc3489 --- /dev/null +++ b/tests/sources/sql_database/sql_source.py @@ -0,0 +1,369 @@ +import random +from copy import deepcopy +from typing import Dict, List, TypedDict +from uuid import uuid4 + +import mimesis +from sqlalchemy import ( + ARRAY, + BigInteger, + Boolean, + Column, + Date, + DateTime, + Float, + ForeignKey, + Integer, + MetaData, + Numeric, + SmallInteger, + String, + Table, + Text, + Time, + create_engine, + func, + text, +) +from sqlalchemy import ( + schema as sqla_schema, +) + +# Uuid, # requires sqlalchemy 2.0. Use String(length=36) for lower versions +from sqlalchemy.dialects.postgresql import DATERANGE, JSONB + +from dlt.common.pendulum import pendulum, timedelta +from dlt.common.utils import chunks, uniq_id +from dlt.sources.credentials import ConnectionStringCredentials + + +class SQLAlchemySourceDB: + def __init__( + self, + credentials: ConnectionStringCredentials, + schema: str = None, + with_unsupported_types: bool = False, + ) -> None: + self.credentials = credentials + self.database_url = credentials.to_native_representation() + self.schema = schema or "my_dlt_source" + uniq_id() + self.engine = create_engine(self.database_url) + self.metadata = MetaData(schema=self.schema) + self.table_infos: Dict[str, TableInfo] = {} + self.with_unsupported_types = with_unsupported_types + + def create_schema(self) -> None: + with self.engine.begin() as conn: + conn.execute(sqla_schema.CreateSchema(self.schema, if_not_exists=True)) + + def drop_schema(self) -> None: + with self.engine.begin() as conn: + conn.execute(sqla_schema.DropSchema(self.schema, cascade=True, if_exists=True)) + + def get_table(self, name: str) -> Table: + return self.metadata.tables[f"{self.schema}.{name}"] + + def create_tables(self) -> None: + Table( + "app_user", + self.metadata, + Column("id", Integer(), primary_key=True, autoincrement=True), + Column("email", Text(), nullable=False, unique=True), + Column("display_name", Text(), nullable=False), + Column( + "created_at", + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + ), + Column( + "updated_at", + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + ), + ) + Table( + "chat_channel", + self.metadata, + Column("id", Integer(), primary_key=True, autoincrement=True), + Column( + "created_at", + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + ), + Column("name", Text(), nullable=False), + Column("active", Boolean(), nullable=False, server_default=text("true")), + Column( + "updated_at", + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + ), + ) + Table( + "chat_message", + self.metadata, + Column("id", Integer(), primary_key=True, autoincrement=True), + Column( + "created_at", + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + ), + Column("content", Text(), nullable=False), + Column( + "user_id", + Integer(), + ForeignKey("app_user.id"), + nullable=False, + index=True, + ), + Column( + "channel_id", + Integer(), + ForeignKey("chat_channel.id"), + nullable=False, + index=True, + ), + Column( + "updated_at", + DateTime(timezone=True), + nullable=False, + server_default=func.now(), + ), + ) + Table( + "has_composite_key", + self.metadata, + Column("a", Integer(), primary_key=True), + Column("b", Integer(), primary_key=True), + Column("c", Integer(), primary_key=True), + ) + + def _make_precision_table(table_name: str, nullable: bool) -> None: + Table( + table_name, + self.metadata, + Column("int_col", Integer(), nullable=nullable), + Column("bigint_col", BigInteger(), nullable=nullable), + Column("smallint_col", SmallInteger(), nullable=nullable), + Column("numeric_col", Numeric(precision=10, scale=2), nullable=nullable), + Column("numeric_default_col", Numeric(), nullable=nullable), + Column("string_col", String(length=10), nullable=nullable), + Column("string_default_col", String(), nullable=nullable), + Column("datetime_tz_col", DateTime(timezone=True), nullable=nullable), + Column("datetime_ntz_col", DateTime(timezone=False), nullable=nullable), + Column("date_col", Date, nullable=nullable), + Column("time_col", Time, nullable=nullable), + Column("float_col", Float, nullable=nullable), + Column("json_col", JSONB, nullable=nullable), + Column("bool_col", Boolean, nullable=nullable), + Column("uuid_col", String(length=36), nullable=nullable), + ) + + _make_precision_table("has_precision", False) + _make_precision_table("has_precision_nullable", True) + + if self.with_unsupported_types: + Table( + "has_unsupported_types", + self.metadata, + Column("unsupported_daterange_1", DATERANGE, nullable=False), + Column("supported_text", Text, nullable=False), + Column("supported_int", Integer, nullable=False), + Column("unsupported_array_1", ARRAY(Integer), nullable=False), + Column("supported_datetime", DateTime(timezone=True), nullable=False), + ) + + self.metadata.create_all(bind=self.engine) + + # Create a view + q = f""" + CREATE VIEW {self.schema}.chat_message_view AS + SELECT + cm.id, + cm.content, + cm.created_at as _created_at, + cm.updated_at as _updated_at, + au.email as user_email, + au.display_name as user_display_name, + cc.name as channel_name, + CAST(NULL as TIMESTAMP) as _null_ts + FROM {self.schema}.chat_message cm + JOIN {self.schema}.app_user au ON cm.user_id = au.id + JOIN {self.schema}.chat_channel cc ON cm.channel_id = cc.id + """ + with self.engine.begin() as conn: + conn.execute(text(q)) + + def _fake_users(self, n: int = 8594) -> List[int]: + person = mimesis.Person() + user_ids: List[int] = [] + table = self.metadata.tables[f"{self.schema}.app_user"] + info = self.table_infos.setdefault( + "app_user", + dict(row_count=0, ids=[], created_at=IncrementingDate(), is_view=False), + ) + dt = info["created_at"] + for chunk in chunks(range(n), 5000): + rows = [ + dict( + email=person.email(unique=True), + display_name=person.name(), + created_at=next(dt), + updated_at=next(dt), + ) + for i in chunk + ] + with self.engine.begin() as conn: + result = conn.execute(table.insert().values(rows).returning(table.c.id)) + user_ids.extend(result.scalars()) + info["row_count"] += n + info["ids"] += user_ids + return user_ids + + def _fake_channels(self, n: int = 500) -> List[int]: + _text = mimesis.Text() + dev = mimesis.Development() + table = self.metadata.tables[f"{self.schema}.chat_channel"] + channel_ids: List[int] = [] + info = self.table_infos.setdefault( + "chat_channel", + dict(row_count=0, ids=[], created_at=IncrementingDate(), is_view=False), + ) + dt = info["created_at"] + for chunk in chunks(range(n), 5000): + rows = [ + dict( + name=" ".join(_text.words()), + active=dev.boolean(), + created_at=next(dt), + updated_at=next(dt), + ) + for i in chunk + ] + with self.engine.begin() as conn: + result = conn.execute(table.insert().values(rows).returning(table.c.id)) + channel_ids.extend(result.scalars()) + info["row_count"] += n + info["ids"] += channel_ids + return channel_ids + + def fake_messages(self, n: int = 9402) -> List[int]: + user_ids = self.table_infos["app_user"]["ids"] + channel_ids = self.table_infos["chat_channel"]["ids"] + _text = mimesis.Text() + choice = mimesis.Choice() + table = self.metadata.tables[f"{self.schema}.chat_message"] + message_ids: List[int] = [] + info = self.table_infos.setdefault( + "chat_message", + dict(row_count=0, ids=[], created_at=IncrementingDate(), is_view=False), + ) + dt = info["created_at"] + for chunk in chunks(range(n), 5000): + rows = [ + dict( + content=_text.random.choice(_text.extract(["questions"])), + user_id=choice(user_ids), + channel_id=choice(channel_ids), + created_at=next(dt), + updated_at=next(dt), + ) + for i in chunk + ] + with self.engine.begin() as conn: + result = conn.execute(table.insert().values(rows).returning(table.c.id)) + message_ids.extend(result.scalars()) + info["row_count"] += len(message_ids) + info["ids"].extend(message_ids) + # View is the same number of rows as the table + view_info = deepcopy(info) + view_info["is_view"] = True + view_info = self.table_infos.setdefault("chat_message_view", view_info) + view_info["row_count"] = info["row_count"] + view_info["ids"] = info["ids"] + return message_ids + + def _fake_precision_data(self, table_name: str, n: int = 100, null_n: int = 0) -> None: + table = self.metadata.tables[f"{self.schema}.{table_name}"] + self.table_infos.setdefault(table_name, dict(row_count=n + null_n, is_view=False)) # type: ignore[call-overload] + + rows = [ + dict( + int_col=random.randrange(-2147483648, 2147483647), + bigint_col=random.randrange(-9223372036854775808, 9223372036854775807), + smallint_col=random.randrange(-32768, 32767), + numeric_col=random.randrange(-9999999999, 9999999999) / 100, + numeric_default_col=random.randrange(-9999999999, 9999999999) / 100, + string_col=mimesis.Text().word()[:10], + string_default_col=mimesis.Text().word(), + datetime_tz_col=mimesis.Datetime().datetime(timezone="UTC"), + datetime_ntz_col=mimesis.Datetime().datetime(), # no timezone + date_col=mimesis.Datetime().date(), + time_col=mimesis.Datetime().time(), + float_col=random.random(), + json_col={"data": [1, 2, 3]}, + bool_col=random.randint(0, 1) == 1, + uuid_col=uuid4(), + ) + for _ in range(n + null_n) + ] + for row in rows[n:]: + # all fields to None + for field in row: + row[field] = None + with self.engine.begin() as conn: + conn.execute(table.insert().values(rows)) + + def _fake_chat_data(self, n: int = 9402) -> None: + self._fake_users() + self._fake_channels() + self.fake_messages() + + def _fake_unsupported_data(self, n: int = 100) -> None: + table = self.metadata.tables[f"{self.schema}.has_unsupported_types"] + self.table_infos.setdefault("has_unsupported_types", dict(row_count=n, is_view=False)) # type: ignore[call-overload] + + rows = [ + dict( + unsupported_daterange_1="[2020-01-01, 2020-09-01)", + supported_text=mimesis.Text().word(), + supported_int=random.randint(0, 100), + unsupported_array_1=[1, 2, 3], + supported_datetime=mimesis.Datetime().datetime(timezone="UTC"), + ) + for _ in range(n) + ] + with self.engine.begin() as conn: + conn.execute(table.insert().values(rows)) + + def insert_data(self) -> None: + self._fake_chat_data() + self._fake_precision_data("has_precision") + self._fake_precision_data("has_precision_nullable", null_n=10) + if self.with_unsupported_types: + self._fake_unsupported_data() + + +class IncrementingDate: + def __init__(self, start_value: pendulum.DateTime = None) -> None: + self.started = False + self.start_value = start_value or pendulum.now() + self.current_value = self.start_value + + def __next__(self) -> pendulum.DateTime: + if not self.started: + self.started = True + return self.current_value + self.current_value += timedelta(seconds=random.randrange(0, 120)) + return self.current_value + + +class TableInfo(TypedDict): + row_count: int + ids: List[int] + created_at: IncrementingDate + is_view: bool diff --git a/tests/sources/sql_database/test_arrow_helpers.py b/tests/sources/sql_database/test_arrow_helpers.py new file mode 100644 index 0000000000..8328bed89b --- /dev/null +++ b/tests/sources/sql_database/test_arrow_helpers.py @@ -0,0 +1,114 @@ +from datetime import date, datetime, timezone # noqa: I251 +from uuid import uuid4 + +import pyarrow as pa +import pytest + +from dlt.sources.sql_database.arrow_helpers import row_tuples_to_arrow + + +@pytest.mark.parametrize("all_unknown", [True, False]) +def test_row_tuples_to_arrow_unknown_types(all_unknown: bool) -> None: + """Test inferring data types with pyarrow""" + + rows = [ + ( + 1, + "a", + 1.1, + True, + date.today(), + uuid4(), + datetime.now(timezone.utc), + [1, 2, 3], + ), + ( + 2, + "b", + 2.2, + False, + date.today(), + uuid4(), + datetime.now(timezone.utc), + [4, 5, 6], + ), + ( + 3, + "c", + 3.3, + True, + date.today(), + uuid4(), + datetime.now(timezone.utc), + [7, 8, 9], + ), + ] + + # Some columns don't specify data type and should be inferred + columns = { + "int_col": {"name": "int_col", "data_type": "bigint", "nullable": False}, + "str_col": {"name": "str_col", "data_type": "text", "nullable": False}, + "float_col": {"name": "float_col", "nullable": False}, + "bool_col": {"name": "bool_col", "data_type": "bool", "nullable": False}, + "date_col": {"name": "date_col", "nullable": False}, + "uuid_col": {"name": "uuid_col", "nullable": False}, + "datetime_col": { + "name": "datetime_col", + "data_type": "timestamp", + "nullable": False, + }, + "array_col": {"name": "array_col", "nullable": False}, + } + + if all_unknown: + for col in columns.values(): + col.pop("data_type", None) + + # Call the function + result = row_tuples_to_arrow(rows, columns, tz="UTC") # type: ignore[arg-type] + + # Result is arrow table containing all columns in original order with correct types + assert result.num_columns == len(columns) + result_col_names = [f.name for f in result.schema] + expected_names = list(columns) + assert result_col_names == expected_names + + assert pa.types.is_int64(result[0].type) + assert pa.types.is_string(result[1].type) + assert pa.types.is_float64(result[2].type) + assert pa.types.is_boolean(result[3].type) + assert pa.types.is_date(result[4].type) + assert pa.types.is_string(result[5].type) + assert pa.types.is_timestamp(result[6].type) + assert pa.types.is_list(result[7].type) + + +pytest.importorskip("sqlalchemy", minversion="2.0") + + +def test_row_tuples_to_arrow_detects_range_type() -> None: + from sqlalchemy.dialects.postgresql import Range # type: ignore[attr-defined] + + # Applies to NUMRANGE, DATERANGE, etc sql types. Sqlalchemy returns a Range dataclass + IntRange = Range + + rows = [ + (IntRange(1, 10),), + (IntRange(2, 20),), + (IntRange(3, 30),), + ] + result = row_tuples_to_arrow( + rows=rows, # type: ignore[arg-type] + columns={"range_col": {"name": "range_col", "nullable": False}}, + tz="UTC", + ) + assert result.num_columns == 1 + assert pa.types.is_struct(result[0].type) + + # Check range has all fields + range_type = result[0].type + range_fields = {f.name: f for f in range_type} + assert pa.types.is_int64(range_fields["lower"].type) + assert pa.types.is_int64(range_fields["upper"].type) + assert pa.types.is_boolean(range_fields["empty"].type) + assert pa.types.is_string(range_fields["bounds"].type) diff --git a/tests/sources/sql_database/test_helpers.py b/tests/sources/sql_database/test_helpers.py new file mode 100644 index 0000000000..a32c6c91cd --- /dev/null +++ b/tests/sources/sql_database/test_helpers.py @@ -0,0 +1,168 @@ +import pytest + +import dlt +from dlt.common.typing import TDataItem + +from dlt.sources.sql_database.helpers import TableLoader, TableBackend +from dlt.sources.sql_database.schema_types import table_to_columns + +from tests.sources.sql_database.sql_source import SQLAlchemySourceDB + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_cursor_or_unique_column_not_in_table( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend +) -> None: + table = sql_source_db.get_table("chat_message") + + with pytest.raises(KeyError): + TableLoader( + sql_source_db.engine, + backend, + table, + table_to_columns(table), + incremental=dlt.sources.incremental("not_a_column"), + ) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_make_query_incremental_max( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend +) -> None: + """Verify query is generated according to incremental settings""" + + class MockIncremental: + last_value = dlt.common.pendulum.now() + last_value_func = max + cursor_path = "created_at" + row_order = "asc" + end_value = None + + table = sql_source_db.get_table("chat_message") + loader = TableLoader( + sql_source_db.engine, + backend, + table, + table_to_columns(table), + incremental=MockIncremental(), # type: ignore[arg-type] + ) + + query = loader.make_query() + expected = ( + table.select() + .order_by(table.c.created_at.asc()) + .where(table.c.created_at >= MockIncremental.last_value) + ) + + assert query.compare(expected) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_make_query_incremental_min( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend +) -> None: + class MockIncremental: + last_value = dlt.common.pendulum.now() + last_value_func = min + cursor_path = "created_at" + row_order = "desc" + end_value = None + + table = sql_source_db.get_table("chat_message") + loader = TableLoader( + sql_source_db.engine, + backend, + table, + table_to_columns(table), + incremental=MockIncremental(), # type: ignore[arg-type] + ) + + query = loader.make_query() + expected = ( + table.select() + .order_by(table.c.created_at.asc()) # `min` func swaps order + .where(table.c.created_at <= MockIncremental.last_value) + ) + + assert query.compare(expected) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_make_query_incremental_end_value( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend +) -> None: + now = dlt.common.pendulum.now() + + class MockIncremental: + last_value = now + last_value_func = min + cursor_path = "created_at" + end_value = now.add(hours=1) + row_order = None + + table = sql_source_db.get_table("chat_message") + loader = TableLoader( + sql_source_db.engine, + backend, + table, + table_to_columns(table), + incremental=MockIncremental(), # type: ignore[arg-type] + ) + + query = loader.make_query() + expected = ( + table.select() + .where(table.c.created_at <= MockIncremental.last_value) + .where(table.c.created_at > MockIncremental.end_value) + ) + + assert query.compare(expected) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_make_query_incremental_any_fun( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend +) -> None: + class MockIncremental: + last_value = dlt.common.pendulum.now() + last_value_func = lambda x: x[-1] + cursor_path = "created_at" + row_order = "asc" + end_value = dlt.common.pendulum.now() + + table = sql_source_db.get_table("chat_message") + loader = TableLoader( + sql_source_db.engine, + backend, + table, + table_to_columns(table), + incremental=MockIncremental(), # type: ignore[arg-type] + ) + + query = loader.make_query() + expected = table.select() + + assert query.compare(expected) + + +def mock_json_column(field: str) -> TDataItem: + """""" + import pyarrow as pa + import pandas as pd + + json_mock_str = '{"data": [1, 2, 3]}' + + def _unwrap(table: TDataItem) -> TDataItem: + if isinstance(table, pd.DataFrame): + table[field] = [None if s is None else json_mock_str for s in table[field]] + return table + else: + col_index = table.column_names.index(field) + json_str_array = pa.array([None if s is None else json_mock_str for s in table[field]]) + return table.set_column( + col_index, + pa.field(field, pa.string(), nullable=table.schema.field(field).nullable), + json_str_array, + ) + + return _unwrap diff --git a/tests/sources/sql_database/test_sql_database_source.py b/tests/sources/sql_database/test_sql_database_source.py new file mode 100644 index 0000000000..e26114f848 --- /dev/null +++ b/tests/sources/sql_database/test_sql_database_source.py @@ -0,0 +1,1191 @@ +import os +import re +from copy import deepcopy +from datetime import datetime # noqa: I251 +from typing import Any, Callable, cast, List, Optional, Set + +import pytest +import sqlalchemy as sa + +import dlt +from dlt.common import json +from dlt.common.configuration.exceptions import ConfigFieldMissingException +from dlt.common.schema.typing import TColumnSchema, TSortOrder, TTableSchemaColumns +from dlt.common.utils import uniq_id +from dlt.extract.exceptions import ResourceExtractionError +from dlt.sources import DltResource +from dlt.sources.sql_database import ( + ReflectionLevel, + TableBackend, + sql_database, + sql_table, +) +from dlt.sources.sql_database.helpers import unwrap_json_connector_x +from tests.pipeline.utils import ( + assert_load_info, + assert_schema_on_data, + load_tables_to_dicts, +) +from tests.sources.sql_database.sql_source import SQLAlchemySourceDB +from tests.sources.sql_database.test_helpers import mock_json_column +from tests.utils import data_item_length + + +@pytest.fixture(autouse=True) +def dispose_engines(): + yield + import gc + + # will collect and dispose all hanging engines + gc.collect() + + +@pytest.fixture(autouse=True) +def reset_os_environ(): + # Save the current state of os.environ + original_environ = deepcopy(os.environ) + yield + # Restore the original state of os.environ + os.environ.clear() + os.environ.update(original_environ) + + +def make_pipeline(destination_name: str) -> dlt.Pipeline: + return dlt.pipeline( + pipeline_name="sql_database", + destination=destination_name, + dataset_name="test_sql_pipeline_" + uniq_id(), + full_refresh=False, + ) + + +def convert_json_to_text(t): + if isinstance(t, sa.JSON): + return sa.Text + return t + + +def default_test_callback( + destination_name: str, backend: TableBackend +) -> Optional[Callable[[sa.types.TypeEngine], sa.types.TypeEngine]]: + if backend == "pyarrow" and destination_name == "bigquery": + return convert_json_to_text + return None + + +def convert_time_to_us(table): + """map transform converting time column to microseconds (ie. from nanoseconds)""" + import pyarrow as pa + from pyarrow import compute as pc + + time_ns_column = table["time_col"] + time_us_column = pc.cast(time_ns_column, pa.time64("us"), safe=False) + new_table = table.set_column( + table.column_names.index("time_col"), + "time_col", + time_us_column, + ) + return new_table + + +def test_pass_engine_credentials(sql_source_db: SQLAlchemySourceDB) -> None: + # verify database + database = sql_database( + sql_source_db.engine, schema=sql_source_db.schema, table_names=["chat_message"] + ) + assert len(list(database)) == sql_source_db.table_infos["chat_message"]["row_count"] + + # verify table + table = sql_table(sql_source_db.engine, table="chat_message", schema=sql_source_db.schema) + assert len(list(table)) == sql_source_db.table_infos["chat_message"]["row_count"] + + +def test_named_sql_table_config(sql_source_db: SQLAlchemySourceDB) -> None: + # set the credentials per table name + os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__CREDENTIALS"] = ( + sql_source_db.engine.url.render_as_string(False) + ) + table = sql_table(table="chat_message", schema=sql_source_db.schema) + assert table.name == "chat_message" + assert len(list(table)) == sql_source_db.table_infos["chat_message"]["row_count"] + + with pytest.raises(ConfigFieldMissingException): + sql_table(table="has_composite_key", schema=sql_source_db.schema) + + # set backend + os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__BACKEND"] = "pandas" + table = sql_table(table="chat_message", schema=sql_source_db.schema) + # just one frame here + assert len(list(table)) == 1 + + os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__CHUNK_SIZE"] = "1000" + table = sql_table(table="chat_message", schema=sql_source_db.schema) + # now 10 frames with chunk size of 1000 + assert len(list(table)) == 10 + + # make it fail on cursor + os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = "updated_at_x" + table = sql_table(table="chat_message", schema=sql_source_db.schema) + with pytest.raises(ResourceExtractionError) as ext_ex: + len(list(table)) + assert "'updated_at_x'" in str(ext_ex.value) + + +def test_general_sql_database_config(sql_source_db: SQLAlchemySourceDB) -> None: + # set the credentials per table name + os.environ["SOURCES__SQL_DATABASE__CREDENTIALS"] = sql_source_db.engine.url.render_as_string( + False + ) + # applies to both sql table and sql database + table = sql_table(table="chat_message", schema=sql_source_db.schema) + assert len(list(table)) == sql_source_db.table_infos["chat_message"]["row_count"] + database = sql_database(schema=sql_source_db.schema).with_resources("chat_message") + assert len(list(database)) == sql_source_db.table_infos["chat_message"]["row_count"] + + # set backend + os.environ["SOURCES__SQL_DATABASE__BACKEND"] = "pandas" + table = sql_table(table="chat_message", schema=sql_source_db.schema) + # just one frame here + assert len(list(table)) == 1 + database = sql_database(schema=sql_source_db.schema).with_resources("chat_message") + assert len(list(database)) == 1 + + os.environ["SOURCES__SQL_DATABASE__CHUNK_SIZE"] = "1000" + table = sql_table(table="chat_message", schema=sql_source_db.schema) + # now 10 frames with chunk size of 1000 + assert len(list(table)) == 10 + database = sql_database(schema=sql_source_db.schema).with_resources("chat_message") + assert len(list(database)) == 10 + + # make it fail on cursor + os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCREMENTAL__CURSOR_PATH"] = "updated_at_x" + table = sql_table(table="chat_message", schema=sql_source_db.schema) + with pytest.raises(ResourceExtractionError) as ext_ex: + len(list(table)) + assert "'updated_at_x'" in str(ext_ex.value) + with pytest.raises(ResourceExtractionError) as ext_ex: + list(sql_database(schema=sql_source_db.schema).with_resources("chat_message")) + # other resources will be loaded, incremental is selective + assert len(list(sql_database(schema=sql_source_db.schema).with_resources("app_user"))) > 0 + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pandas", "pyarrow", "connectorx"]) +@pytest.mark.parametrize("row_order", ["asc", "desc", None]) +@pytest.mark.parametrize("last_value_func", [min, max, lambda x: max(x)]) +def test_load_sql_table_resource_incremental_end_value( + sql_source_db: SQLAlchemySourceDB, + backend: TableBackend, + row_order: TSortOrder, + last_value_func: Any, +) -> None: + start_id = sql_source_db.table_infos["chat_message"]["ids"][0] + end_id = sql_source_db.table_infos["chat_message"]["ids"][-1] // 2 + + if last_value_func is min: + start_id, end_id = end_id, start_id + + @dlt.source + def sql_table_source() -> List[DltResource]: + return [ + sql_table( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table="chat_message", + backend=backend, + incremental=dlt.sources.incremental( + "id", + initial_value=start_id, + end_value=end_id, + row_order=row_order, + last_value_func=last_value_func, + ), + ) + ] + + try: + rows = list(sql_table_source()) + except Exception as exc: + if isinstance(exc.__context__, NotImplementedError): + pytest.skip("Test skipped due to: " + str(exc.__context__)) + raise + # half of the records loaded -1 record. end values is non inclusive + assert data_item_length(rows) == abs(end_id - start_id) + # check first and last id to see if order was applied + if backend == "sqlalchemy": + if row_order == "asc" and last_value_func is max: + assert rows[0]["id"] == start_id + assert rows[-1]["id"] == end_id - 1 # non inclusive + if row_order == "desc" and last_value_func is max: + assert rows[0]["id"] == end_id - 1 # non inclusive + assert rows[-1]["id"] == start_id + if row_order == "asc" and last_value_func is min: + assert rows[0]["id"] == start_id + assert ( + rows[-1]["id"] == end_id + 1 + ) # non inclusive, but + 1 because last value func is min + if row_order == "desc" and last_value_func is min: + assert ( + rows[0]["id"] == end_id + 1 + ) # non inclusive, but + 1 because last value func is min + assert rows[-1]["id"] == start_id + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize("defer_table_reflect", (False, True)) +def test_load_sql_table_resource_select_columns( + sql_source_db: SQLAlchemySourceDB, defer_table_reflect: bool, backend: TableBackend +) -> None: + # get chat messages with content column removed + chat_messages = sql_table( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table="chat_message", + defer_table_reflect=defer_table_reflect, + table_adapter_callback=lambda table: table._columns.remove(table.columns["content"]), # type: ignore[attr-defined] + backend=backend, + ) + pipeline = make_pipeline("duckdb") + load_info = pipeline.run(chat_messages) + assert_load_info(load_info) + assert_row_counts(pipeline, sql_source_db, ["chat_message"]) + assert "content" not in pipeline.default_schema.tables["chat_message"]["columns"] + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize("defer_table_reflect", (False, True)) +def test_load_sql_table_source_select_columns( + sql_source_db: SQLAlchemySourceDB, defer_table_reflect: bool, backend: TableBackend +) -> None: + mod_tables: Set[str] = set() + + def adapt(table) -> None: + mod_tables.add(table) + if table.name == "chat_message": + table._columns.remove(table.columns["content"]) + + # get chat messages with content column removed + all_tables = sql_database( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + defer_table_reflect=defer_table_reflect, + table_names=(list(sql_source_db.table_infos.keys()) if defer_table_reflect else None), + table_adapter_callback=adapt, + backend=backend, + ) + pipeline = make_pipeline("duckdb") + load_info = pipeline.run(all_tables) + assert_load_info(load_info) + assert_row_counts(pipeline, sql_source_db) + assert "content" not in pipeline.default_schema.tables["chat_message"]["columns"] + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize("reflection_level", ["full", "full_with_precision"]) +@pytest.mark.parametrize("with_defer", [True, False]) +def test_extract_without_pipeline( + sql_source_db: SQLAlchemySourceDB, + backend: TableBackend, + reflection_level: ReflectionLevel, + with_defer: bool, +) -> None: + # make sure that we can evaluate tables without pipeline + source = sql_database( + credentials=sql_source_db.credentials, + table_names=["has_precision", "app_user", "chat_message", "chat_channel"], + schema=sql_source_db.schema, + reflection_level=reflection_level, + defer_table_reflect=with_defer, + backend=backend, + ) + assert len(list(source)) > 0 + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize("reflection_level", ["minimal", "full", "full_with_precision"]) +@pytest.mark.parametrize("with_defer", [False, True]) +@pytest.mark.parametrize("standalone_resource", [True, False]) +def test_reflection_levels( + sql_source_db: SQLAlchemySourceDB, + backend: TableBackend, + reflection_level: ReflectionLevel, + with_defer: bool, + standalone_resource: bool, +) -> None: + """Test all reflection, correct schema is inferred""" + + def prepare_source(): + if standalone_resource: + + @dlt.source + def dummy_source(): + yield sql_table( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table="has_precision", + backend=backend, + defer_table_reflect=with_defer, + reflection_level=reflection_level, + ) + yield sql_table( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table="app_user", + backend=backend, + defer_table_reflect=with_defer, + reflection_level=reflection_level, + ) + + return dummy_source() + + return sql_database( + credentials=sql_source_db.credentials, + table_names=["has_precision", "app_user"], + schema=sql_source_db.schema, + reflection_level=reflection_level, + defer_table_reflect=with_defer, + backend=backend, + ) + + source = prepare_source() + + pipeline = make_pipeline("duckdb") + pipeline.extract(source) + + schema = pipeline.default_schema + assert "has_precision" in schema.tables + + col_names = [col["name"] for col in schema.tables["has_precision"]["columns"].values()] + expected_col_names = [col["name"] for col in PRECISION_COLUMNS] + + assert col_names == expected_col_names + + # Pk col is always reflected + pk_col = schema.tables["app_user"]["columns"]["id"] + assert pk_col["primary_key"] is True + + if reflection_level == "minimal": + resource_cols = source.resources["has_precision"].compute_table_schema()["columns"] + schema_cols = pipeline.default_schema.tables["has_precision"]["columns"] + # We should have all column names on resource hints after extract but no data type or precision + for col, schema_col in zip(resource_cols.values(), schema_cols.values()): + assert col.get("data_type") is None + assert col.get("precision") is None + assert col.get("scale") is None + if backend == "sqlalchemy": # Data types are inferred from pandas/arrow during extract + assert schema_col.get("data_type") is None + + pipeline.normalize() + # Check with/out precision after normalize + schema_cols = pipeline.default_schema.tables["has_precision"]["columns"] + if reflection_level == "full": + # Columns have data type set + assert_no_precision_columns(schema_cols, backend, False) + + elif reflection_level == "full_with_precision": + # Columns have data type and precision scale set + assert_precision_columns(schema_cols, backend, False) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize("standalone_resource", [True, False]) +def test_type_adapter_callback( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend, standalone_resource: bool +) -> None: + def conversion_callback(t): + if isinstance(t, sa.JSON): + return sa.Text + elif isinstance(t, sa.Double): # type: ignore[attr-defined] + return sa.BIGINT + return t + + common_kwargs = dict( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + backend=backend, + type_adapter_callback=conversion_callback, + reflection_level="full", + ) + + if standalone_resource: + source = sql_table( + table="has_precision", + **common_kwargs, # type: ignore[arg-type] + ) + else: + source = sql_database( # type: ignore[assignment] + table_names=["has_precision"], + **common_kwargs, # type: ignore[arg-type] + ) + + pipeline = make_pipeline("duckdb") + pipeline.extract(source) + + schema = pipeline.default_schema + table = schema.tables["has_precision"] + assert table["columns"]["json_col"]["data_type"] == "text" + assert table["columns"]["float_col"]["data_type"] == "bigint" + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize( + "table_name,nullable", (("has_precision", False), ("has_precision_nullable", True)) +) +def test_all_types_with_precision_hints( + sql_source_db: SQLAlchemySourceDB, + backend: TableBackend, + table_name: str, + nullable: bool, +) -> None: + source = sql_database( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + reflection_level="full_with_precision", + backend=backend, + ) + + pipeline = make_pipeline("duckdb") + + # add JSON unwrap for connectorx + if backend == "connectorx": + source.resources[table_name].add_map(unwrap_json_connector_x("json_col")) + pipeline.extract(source) + pipeline.normalize(loader_file_format="parquet") + info = pipeline.load() + assert_load_info(info) + + schema = pipeline.default_schema + table = schema.tables[table_name] + assert_precision_columns(table["columns"], backend, nullable) + assert_schema_on_data( + table, + load_tables_to_dicts(pipeline, table_name)[table_name], + nullable, + backend in ["sqlalchemy", "pyarrow"], + ) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize( + "table_name,nullable", (("has_precision", False), ("has_precision_nullable", True)) +) +def test_all_types_no_precision_hints( + sql_source_db: SQLAlchemySourceDB, + backend: TableBackend, + table_name: str, + nullable: bool, +) -> None: + source = sql_database( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + reflection_level="full", + backend=backend, + ) + + pipeline = make_pipeline("duckdb") + + # add JSON unwrap for connectorx + if backend == "connectorx": + source.resources[table_name].add_map(unwrap_json_connector_x("json_col")) + pipeline.extract(source) + pipeline.normalize(loader_file_format="parquet") + pipeline.load().raise_on_failed_jobs() + + schema = pipeline.default_schema + # print(pipeline.default_schema.to_pretty_yaml()) + table = schema.tables[table_name] + assert_no_precision_columns(table["columns"], backend, nullable) + assert_schema_on_data( + table, + load_tables_to_dicts(pipeline, table_name)[table_name], + nullable, + backend in ["sqlalchemy", "pyarrow"], + ) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_incremental_composite_primary_key_from_table( + sql_source_db: SQLAlchemySourceDB, + backend: TableBackend, +) -> None: + resource = sql_table( + credentials=sql_source_db.credentials, + table="has_composite_key", + schema=sql_source_db.schema, + backend=backend, + ) + + assert resource.incremental.primary_key == ["a", "b", "c"] + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize("upfront_incremental", (True, False)) +def test_set_primary_key_deferred_incremental( + sql_source_db: SQLAlchemySourceDB, + upfront_incremental: bool, + backend: TableBackend, +) -> None: + # this tests dynamically adds primary key to resource and as consequence to incremental + updated_at = dlt.sources.incremental("updated_at") # type: ignore[var-annotated] + resource = sql_table( + credentials=sql_source_db.credentials, + table="chat_message", + schema=sql_source_db.schema, + defer_table_reflect=True, + incremental=updated_at if upfront_incremental else None, + backend=backend, + ) + + resource.apply_hints(incremental=None if upfront_incremental else updated_at) + + # nothing set for deferred reflect + assert resource.incremental.primary_key is None + + def _assert_incremental(item): + # for all the items, all keys must be present + _r = dlt.current.source().resources[dlt.current.resource_name()] + # assert _r.incremental._incremental is updated_at + if len(item) == 0: + # not yet propagated + assert _r.incremental.primary_key is None + else: + assert _r.incremental.primary_key == ["id"] + assert _r.incremental._incremental.primary_key == ["id"] + assert _r.incremental._incremental._transformers["json"].primary_key == ["id"] + assert _r.incremental._incremental._transformers["arrow"].primary_key == ["id"] + return item + + pipeline = make_pipeline("duckdb") + # must evaluate resource for primary key to be set + pipeline.extract(resource.add_step(_assert_incremental)) # type: ignore[arg-type] + + assert resource.incremental.primary_key == ["id"] + assert resource.incremental._incremental.primary_key == ["id"] + assert resource.incremental._incremental._transformers["json"].primary_key == ["id"] + assert resource.incremental._incremental._transformers["arrow"].primary_key == ["id"] + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_deferred_reflect_in_source( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend +) -> None: + source = sql_database( + credentials=sql_source_db.credentials, + table_names=["has_precision", "chat_message"], + schema=sql_source_db.schema, + reflection_level="full_with_precision", + defer_table_reflect=True, + backend=backend, + ) + # mock the right json values for backends not supporting it + if backend in ("connectorx", "pandas"): + source.resources["has_precision"].add_map(mock_json_column("json_col")) + + # no columns in both tables + assert source.has_precision.columns == {} + assert source.chat_message.columns == {} + + pipeline = make_pipeline("duckdb") + pipeline.extract(source) + # use insert values to convert parquet into INSERT + pipeline.normalize(loader_file_format="insert_values") + pipeline.load().raise_on_failed_jobs() + precision_table = pipeline.default_schema.get_table("has_precision") + assert_precision_columns( + precision_table["columns"], + backend, + nullable=False, + ) + assert_schema_on_data( + precision_table, + load_tables_to_dicts(pipeline, "has_precision")["has_precision"], + True, + backend in ["sqlalchemy", "pyarrow"], + ) + assert len(source.chat_message.columns) > 0 # type: ignore[arg-type] + assert source.chat_message.compute_table_schema()["columns"]["id"]["primary_key"] is True + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_deferred_reflect_no_source_connect(backend: TableBackend) -> None: + source = sql_database( + credentials="mysql+pymysql://test@test/test", + table_names=["has_precision", "chat_message"], + schema="schema", + reflection_level="full_with_precision", + defer_table_reflect=True, + backend=backend, + ) + + # no columns in both tables + assert source.has_precision.columns == {} + assert source.chat_message.columns == {} + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_deferred_reflect_in_resource( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend +) -> None: + table = sql_table( + credentials=sql_source_db.credentials, + table="has_precision", + schema=sql_source_db.schema, + reflection_level="full_with_precision", + defer_table_reflect=True, + backend=backend, + ) + # mock the right json values for backends not supporting it + if backend in ("connectorx", "pandas"): + table.add_map(mock_json_column("json_col")) + + # no columns in both tables + assert table.columns == {} + + pipeline = make_pipeline("duckdb") + pipeline.extract(table) + # use insert values to convert parquet into INSERT + pipeline.normalize(loader_file_format="insert_values") + pipeline.load().raise_on_failed_jobs() + precision_table = pipeline.default_schema.get_table("has_precision") + assert_precision_columns( + precision_table["columns"], + backend, + nullable=False, + ) + assert_schema_on_data( + precision_table, + load_tables_to_dicts(pipeline, "has_precision")["has_precision"], + True, + backend in ["sqlalchemy", "pyarrow"], + ) + + +@pytest.mark.parametrize("backend", ["pyarrow", "pandas", "connectorx"]) +def test_destination_caps_context(sql_source_db: SQLAlchemySourceDB, backend: TableBackend) -> None: + # use athena with timestamp precision == 3 + table = sql_table( + credentials=sql_source_db.credentials, + table="has_precision", + schema=sql_source_db.schema, + reflection_level="full_with_precision", + defer_table_reflect=True, + backend=backend, + ) + + # no columns in both tables + assert table.columns == {} + + pipeline = make_pipeline("athena") + pipeline.extract(table) + pipeline.normalize() + # timestamps are milliseconds + columns = pipeline.default_schema.get_table("has_precision")["columns"] + assert columns["datetime_tz_col"]["precision"] == columns["datetime_ntz_col"]["precision"] == 3 + # prevent drop + pipeline.destination = None + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_sql_table_from_view(sql_source_db: SQLAlchemySourceDB, backend: TableBackend) -> None: + """View can be extract by sql_table without any reflect flags""" + table = sql_table( + credentials=sql_source_db.credentials, + table="chat_message_view", + schema=sql_source_db.schema, + backend=backend, + # use minimal level so we infer types from DATA + reflection_level="minimal", + incremental=dlt.sources.incremental("_created_at"), + ) + + pipeline = make_pipeline("duckdb") + info = pipeline.run(table) + assert_load_info(info) + + assert_row_counts(pipeline, sql_source_db, ["chat_message_view"]) + assert "content" in pipeline.default_schema.tables["chat_message_view"]["columns"] + assert "_created_at" in pipeline.default_schema.tables["chat_message_view"]["columns"] + db_data = load_tables_to_dicts(pipeline, "chat_message_view")["chat_message_view"] + assert "content" in db_data[0] + assert "_created_at" in db_data[0] + # make sure that all NULLs is not present + assert "_null_ts" in pipeline.default_schema.tables["chat_message_view"]["columns"] + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_sql_database_include_views( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend +) -> None: + """include_view flag reflects and extracts views as tables""" + source = sql_database( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + include_views=True, + backend=backend, + ) + + pipeline = make_pipeline("duckdb") + pipeline.run(source) + + assert_row_counts(pipeline, sql_source_db, include_views=True) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +def test_sql_database_include_view_in_table_names( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend +) -> None: + """Passing a view explicitly in table_names should reflect it, regardless of include_views flag""" + source = sql_database( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table_names=["app_user", "chat_message_view"], + include_views=False, + backend=backend, + ) + + pipeline = make_pipeline("duckdb") + pipeline.run(source) + + assert_row_counts(pipeline, sql_source_db, ["app_user", "chat_message_view"]) + + +@pytest.mark.parametrize("backend", ["pyarrow", "pandas", "sqlalchemy"]) +@pytest.mark.parametrize("standalone_resource", [True, False]) +@pytest.mark.parametrize("reflection_level", ["minimal", "full", "full_with_precision"]) +@pytest.mark.parametrize("type_adapter", [True, False]) +def test_infer_unsupported_types( + sql_source_db_unsupported_types: SQLAlchemySourceDB, + backend: TableBackend, + reflection_level: ReflectionLevel, + standalone_resource: bool, + type_adapter: bool, +) -> None: + def type_adapter_callback(t): + if isinstance(t, sa.ARRAY): + return sa.JSON + return t + + if backend == "pyarrow" and type_adapter: + pytest.skip("Arrow does not support type adapter for arrays") + + common_kwargs = dict( + credentials=sql_source_db_unsupported_types.credentials, + schema=sql_source_db_unsupported_types.schema, + reflection_level=reflection_level, + backend=backend, + type_adapter_callback=type_adapter_callback if type_adapter else None, + ) + if standalone_resource: + + @dlt.source + def dummy_source(): + yield sql_table( + **common_kwargs, # type: ignore[arg-type] + table="has_unsupported_types", + ) + + source = dummy_source() + source.max_table_nesting = 0 + else: + source = sql_database( + **common_kwargs, # type: ignore[arg-type] + table_names=["has_unsupported_types"], + ) + source.max_table_nesting = 0 + + pipeline = make_pipeline("duckdb") + pipeline.extract(source) + + columns = pipeline.default_schema.tables["has_unsupported_types"]["columns"] + + # unsupported columns have unknown data type here + assert "unsupported_daterange_1" in columns + + # Arrow and pandas infer types in extract + if backend == "pyarrow": + assert columns["unsupported_daterange_1"]["data_type"] == "complex" + elif backend == "pandas": + assert columns["unsupported_daterange_1"]["data_type"] == "text" + else: + assert "data_type" not in columns["unsupported_daterange_1"] + + pipeline.normalize() + pipeline.load() + + assert_row_counts(pipeline, sql_source_db_unsupported_types, ["has_unsupported_types"]) + + schema = pipeline.default_schema + assert "has_unsupported_types" in schema.tables + columns = schema.tables["has_unsupported_types"]["columns"] + + rows = load_tables_to_dicts(pipeline, "has_unsupported_types")["has_unsupported_types"] + + if backend == "pyarrow": + # TODO: duckdb writes structs as strings (not json encoded) to json columns + # Just check that it has a value + assert rows[0]["unsupported_daterange_1"] + + assert isinstance(json.loads(rows[0]["unsupported_array_1"]), list) + assert columns["unsupported_array_1"]["data_type"] == "complex" + # Other columns are loaded + assert isinstance(rows[0]["supported_text"], str) + assert isinstance(rows[0]["supported_datetime"], datetime) + assert isinstance(rows[0]["supported_int"], int) + elif backend == "sqlalchemy": + # sqla value is a dataclass and is inferred as complex + assert columns["unsupported_daterange_1"]["data_type"] == "complex" + + assert columns["unsupported_array_1"]["data_type"] == "complex" + + value = rows[0]["unsupported_daterange_1"] + assert set(json.loads(value).keys()) == {"lower", "upper", "bounds", "empty"} + elif backend == "pandas": + # pandas parses it as string + assert columns["unsupported_daterange_1"]["data_type"] == "text" + # Regex that matches daterange [2021-01-01, 2021-01-02) + assert re.match( + r"\[\d{4}-\d{2}-\d{2},\d{4}-\d{2}-\d{2}\)", + rows[0]["unsupported_daterange_1"], + ) + + if type_adapter and reflection_level != "minimal": + assert columns["unsupported_array_1"]["data_type"] == "complex" + + assert isinstance(json.loads(rows[0]["unsupported_array_1"]), list) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize("defer_table_reflect", (False, True)) +def test_sql_database_included_columns( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend, defer_table_reflect: bool +) -> None: + # include only some columns from the table + os.environ["SOURCES__SQL_DATABASE__CHAT_MESSAGE__INCLUDED_COLUMNS"] = json.dumps( + ["id", "created_at"] + ) + + source = sql_database( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table_names=["chat_message"], + reflection_level="full", + defer_table_reflect=defer_table_reflect, + backend=backend, + ) + + pipeline = make_pipeline("duckdb") + pipeline.run(source) + + schema = pipeline.default_schema + schema_cols = set( + col + for col in schema.get_table_columns("chat_message", include_incomplete=True) + if not col.startswith("_dlt_") + ) + assert schema_cols == {"id", "created_at"} + + assert_row_counts(pipeline, sql_source_db, ["chat_message"]) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize("defer_table_reflect", (False, True)) +def test_sql_table_included_columns( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend, defer_table_reflect: bool +) -> None: + source = sql_table( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + table="chat_message", + reflection_level="full", + defer_table_reflect=defer_table_reflect, + backend=backend, + included_columns=["id", "created_at"], + ) + + pipeline = make_pipeline("duckdb") + pipeline.run(source) + + schema = pipeline.default_schema + schema_cols = set( + col + for col in schema.get_table_columns("chat_message", include_incomplete=True) + if not col.startswith("_dlt_") + ) + assert schema_cols == {"id", "created_at"} + + assert_row_counts(pipeline, sql_source_db, ["chat_message"]) + + +@pytest.mark.parametrize("backend", ["sqlalchemy", "pyarrow", "pandas", "connectorx"]) +@pytest.mark.parametrize("standalone_resource", [True, False]) +def test_query_adapter_callback( + sql_source_db: SQLAlchemySourceDB, backend: TableBackend, standalone_resource: bool +) -> None: + def query_adapter_callback(query, table): + if table.name == "chat_channel": + # Only select active channels + return query.where(table.c.active.is_(True)) + # Use the original query for other tables + return query + + common_kwargs = dict( + credentials=sql_source_db.credentials, + schema=sql_source_db.schema, + reflection_level="full", + backend=backend, + query_adapter_callback=query_adapter_callback, + ) + + if standalone_resource: + + @dlt.source + def dummy_source(): + yield sql_table( + **common_kwargs, # type: ignore[arg-type] + table="chat_channel", + ) + + yield sql_table( + **common_kwargs, # type: ignore[arg-type] + table="chat_message", + ) + + source = dummy_source() + else: + source = sql_database( + **common_kwargs, # type: ignore[arg-type] + table_names=["chat_message", "chat_channel"], + ) + + pipeline = make_pipeline("duckdb") + pipeline.extract(source) + + pipeline.normalize() + pipeline.load().raise_on_failed_jobs() + + channel_rows = load_tables_to_dicts(pipeline, "chat_channel")["chat_channel"] + assert channel_rows and all(row["active"] for row in channel_rows) + + # unfiltred table loads all rows + assert_row_counts(pipeline, sql_source_db, ["chat_message"]) + + +def assert_row_counts( + pipeline: dlt.Pipeline, + sql_source_db: SQLAlchemySourceDB, + tables: Optional[List[str]] = None, + include_views: bool = False, +) -> None: + with pipeline.sql_client() as c: + if not tables: + tables = [ + tbl_name + for tbl_name, info in sql_source_db.table_infos.items() + if include_views or not info["is_view"] + ] + for table in tables: + info = sql_source_db.table_infos[table] + with c.execute_query(f"SELECT count(*) FROM {table}") as cur: + row = cur.fetchone() + assert row[0] == info["row_count"] + + +def assert_precision_columns( + columns: TTableSchemaColumns, backend: TableBackend, nullable: bool +) -> None: + actual = list(columns.values()) + expected = NULL_PRECISION_COLUMNS if nullable else NOT_NULL_PRECISION_COLUMNS + # always has nullability set and always has hints + expected = cast(List[TColumnSchema], deepcopy(expected)) + if backend == "sqlalchemy": + expected = remove_timestamp_precision(expected) + actual = remove_dlt_columns(actual) + if backend == "pyarrow": + expected = add_default_decimal_precision(expected) + if backend == "pandas": + expected = remove_timestamp_precision(expected, with_timestamps=False) + if backend == "connectorx": + # connector x emits 32 precision which gets merged with sql alchemy schema + del columns["int_col"]["precision"] + assert actual == expected + + +def assert_no_precision_columns( + columns: TTableSchemaColumns, backend: TableBackend, nullable: bool +) -> None: + actual = list(columns.values()) + + # we always infer and emit nullability + expected = cast( + List[TColumnSchema], + deepcopy(NULL_NO_PRECISION_COLUMNS if nullable else NOT_NULL_NO_PRECISION_COLUMNS), + ) + if backend == "pyarrow": + expected = cast( + List[TColumnSchema], + deepcopy(NULL_PRECISION_COLUMNS if nullable else NOT_NULL_PRECISION_COLUMNS), + ) + # always has nullability set and always has hints + # default precision is not set + expected = remove_default_precision(expected) + expected = add_default_decimal_precision(expected) + elif backend == "sqlalchemy": + # no precision, no nullability, all hints inferred + # remove dlt columns + actual = remove_dlt_columns(actual) + elif backend == "pandas": + # no precision, no nullability, all hints inferred + # pandas destroys decimals + expected = convert_non_pandas_types(expected) + elif backend == "connectorx": + expected = cast( + List[TColumnSchema], + deepcopy(NULL_PRECISION_COLUMNS if nullable else NOT_NULL_PRECISION_COLUMNS), + ) + expected = convert_connectorx_types(expected) + + assert actual == expected + + +def convert_non_pandas_types(columns: List[TColumnSchema]) -> List[TColumnSchema]: + for column in columns: + if column["data_type"] == "timestamp": + column["precision"] = 6 + return columns + + +def remove_dlt_columns(columns: List[TColumnSchema]) -> List[TColumnSchema]: + return [col for col in columns if not col["name"].startswith("_dlt")] + + +def remove_default_precision(columns: List[TColumnSchema]) -> List[TColumnSchema]: + for column in columns: + if column["data_type"] == "bigint" and column.get("precision") == 32: + del column["precision"] + if column["data_type"] == "text" and column.get("precision"): + del column["precision"] + return columns + + +def remove_timestamp_precision( + columns: List[TColumnSchema], with_timestamps: bool = True +) -> List[TColumnSchema]: + for column in columns: + if column["data_type"] == "timestamp" and column["precision"] == 6 and with_timestamps: + del column["precision"] + if column["data_type"] == "time" and column["precision"] == 6: + del column["precision"] + return columns + + +def convert_connectorx_types(columns: List[TColumnSchema]) -> List[TColumnSchema]: + """connector x converts decimals to double, otherwise tries to keep data types and precision + nullability is not kept, string precision is not kept + """ + for column in columns: + if column["data_type"] == "bigint": + if column["name"] == "int_col": + column["precision"] = 32 # only int and bigint in connectorx + if column["data_type"] == "text" and column.get("precision"): + del column["precision"] + return columns + + +def add_default_decimal_precision(columns: List[TColumnSchema]) -> List[TColumnSchema]: + for column in columns: + if column["data_type"] == "decimal" and not column.get("precision"): + column["precision"] = 38 + column["scale"] = 9 + return columns + + +PRECISION_COLUMNS: List[TColumnSchema] = [ + { + "data_type": "bigint", + "name": "int_col", + }, + { + "data_type": "bigint", + "name": "bigint_col", + }, + { + "data_type": "bigint", + "precision": 32, + "name": "smallint_col", + }, + { + "data_type": "decimal", + "precision": 10, + "scale": 2, + "name": "numeric_col", + }, + { + "data_type": "decimal", + "name": "numeric_default_col", + }, + { + "data_type": "text", + "precision": 10, + "name": "string_col", + }, + { + "data_type": "text", + "name": "string_default_col", + }, + { + "data_type": "timestamp", + "precision": 6, + "name": "datetime_tz_col", + }, + { + "data_type": "timestamp", + "precision": 6, + "name": "datetime_ntz_col", + }, + { + "data_type": "date", + "name": "date_col", + }, + { + "data_type": "time", + "name": "time_col", + "precision": 6, + }, + { + "data_type": "double", + "name": "float_col", + }, + { + "data_type": "complex", + "name": "json_col", + }, + { + "data_type": "bool", + "name": "bool_col", + }, + { + "data_type": "text", + "name": "uuid_col", + }, +] + +NOT_NULL_PRECISION_COLUMNS = [{"nullable": False, **column} for column in PRECISION_COLUMNS] +NULL_PRECISION_COLUMNS: List[TColumnSchema] = [ + {"nullable": True, **column} for column in PRECISION_COLUMNS +] + +# but keep decimal precision +NO_PRECISION_COLUMNS: List[TColumnSchema] = [ + ( + {"name": column["name"], "data_type": column["data_type"]} # type: ignore[misc] + if column["data_type"] != "decimal" + else dict(column) + ) + for column in PRECISION_COLUMNS +] + +NOT_NULL_NO_PRECISION_COLUMNS: List[TColumnSchema] = [ + {"nullable": False, **column} for column in NO_PRECISION_COLUMNS +] +NULL_NO_PRECISION_COLUMNS: List[TColumnSchema] = [ + {"nullable": True, **column} for column in NO_PRECISION_COLUMNS +]