From dad2a08d0d8591eced9c8a446a435eda6c7e01c9 Mon Sep 17 00:00:00 2001 From: rudolfix Date: Mon, 9 Sep 2024 12:56:37 +0200 Subject: [PATCH] prepares for nested references (#1774) * adds methods to detect nested and root tables via parent hint * skips linking in relational when no parent hint, removes linking skip for primary keys * moves schema config and normalizer importers to schema module, braks cyclic deps with dest capabilities * adds table_format override to pipeline run * resolves merge strategy using adapter, uses default for a destination if strategy not explicit * removes force_iceberg flag from athena, requires explicit table_format * adds PreparedTableSchema to indicate TTableSchemas that are prepared for loading, makes verify_schema explicit method to be called by load, simplifies methods to prepare tables * applies table and file format to run methods in all pipeline tests * shortens temp table names in sql jobs * adds filesystem to drop command tests * fixes tests * adds method to update table from diff into extract * athena iceberg does not create dlt pipeline state as iceberg by default * other test fixes * deprecates force_icebergs, adds hive table format to opt out * merges column props and hints, categorizes column props * moves type mappers into destination capabilities * fixes tests * fixes cap data types verification errors not being raised * adds missing deps * fixes more tests * allows precision and scale to be 0 * fixes more tests * corrects connectorx for 3.12 --- .github/workflows/test_common.yml | 11 +- .github/workflows/test_destination_athena.yml | 2 +- .github/workflows/test_local_sources.yml | 6 +- Makefile | 2 +- dlt/common/configuration/container.py | 8 +- dlt/common/data_writers/exceptions.py | 1 + dlt/common/data_writers/writers.py | 7 +- dlt/common/destination/__init__.py | 6 +- dlt/common/destination/capabilities.py | 115 ++++- dlt/common/destination/exceptions.py | 33 +- dlt/common/destination/reference.py | 109 ++-- dlt/common/destination/typing.py | 8 + dlt/common/destination/utils.py | 145 +++++- dlt/common/libs/pyarrow.py | 20 +- dlt/common/normalizers/json/relational.py | 198 ++++---- dlt/common/normalizers/utils.py | 183 +------ .../{normalizers => schema}/configuration.py | 2 +- dlt/common/schema/exceptions.py | 4 +- dlt/common/schema/migrations.py | 9 +- dlt/common/schema/normalizers.py | 186 +++++++ dlt/common/schema/schema.py | 62 ++- dlt/common/schema/typing.py | 91 +++- dlt/common/schema/utils.py | 147 +++--- dlt/common/storages/load_storage.py | 4 +- dlt/common/warnings.py | 12 +- dlt/destinations/impl/athena/athena.py | 167 ++---- .../impl/athena/athena_adapter.py | 6 +- dlt/destinations/impl/athena/configuration.py | 19 +- dlt/destinations/impl/athena/factory.py | 99 +++- dlt/destinations/impl/bigquery/bigquery.py | 114 +---- .../impl/bigquery/bigquery_adapter.py | 6 + dlt/destinations/impl/bigquery/factory.py | 75 +++ .../impl/clickhouse/clickhouse.py | 95 +--- dlt/destinations/impl/clickhouse/factory.py | 72 +++ .../impl/databricks/databricks.py | 116 +---- dlt/destinations/impl/databricks/factory.py | 91 ++++ .../impl/destination/destination.py | 15 +- dlt/destinations/impl/dremio/dremio.py | 76 +-- dlt/destinations/impl/dremio/factory.py | 68 +++ dlt/destinations/impl/duckdb/duck.py | 133 +---- dlt/destinations/impl/duckdb/factory.py | 114 ++++- dlt/destinations/impl/dummy/dummy.py | 13 +- dlt/destinations/impl/filesystem/factory.py | 24 +- .../impl/filesystem/filesystem.py | 68 ++- dlt/destinations/impl/lancedb/factory.py | 11 + .../impl/lancedb/lancedb_client.py | 99 +--- dlt/destinations/impl/lancedb/schema.py | 9 +- dlt/destinations/impl/lancedb/type_mapper.py | 85 ++++ dlt/destinations/impl/motherduck/factory.py | 2 + dlt/destinations/impl/mssql/factory.py | 68 +++ dlt/destinations/impl/mssql/mssql.py | 88 +--- dlt/destinations/impl/postgres/factory.py | 100 ++++ dlt/destinations/impl/postgres/postgres.py | 115 +---- .../impl/qdrant/qdrant_adapter.py | 2 +- .../impl/qdrant/qdrant_job_client.py | 14 +- dlt/destinations/impl/redshift/factory.py | 94 ++++ dlt/destinations/impl/redshift/redshift.py | 103 +--- dlt/destinations/impl/snowflake/factory.py | 79 +++ dlt/destinations/impl/snowflake/snowflake.py | 101 +--- dlt/destinations/impl/synapse/factory.py | 24 +- dlt/destinations/impl/synapse/synapse.py | 50 +- dlt/destinations/impl/weaviate/factory.py | 29 ++ .../impl/weaviate/weaviate_client.py | 46 +- dlt/destinations/insert_job_client.py | 18 +- dlt/destinations/job_client_impl.py | 75 +-- dlt/destinations/job_impl.py | 3 - dlt/destinations/sql_jobs.py | 101 ++-- dlt/destinations/type_mapping.py | 33 +- dlt/destinations/utils.py | 27 +- dlt/extract/decorators.py | 2 +- dlt/extract/extract.py | 22 +- dlt/extract/extractors.py | 9 +- dlt/extract/hints.py | 109 ++-- dlt/load/load.py | 43 +- dlt/load/utils.py | 30 +- dlt/normalize/items_normalizers.py | 3 +- dlt/normalize/normalize.py | 10 +- dlt/normalize/schema.py | 20 - dlt/normalize/validate.py | 43 ++ dlt/normalize/worker.py | 18 +- dlt/pipeline/__init__.py | 22 +- dlt/pipeline/pipeline.py | 48 +- dlt/pipeline/warnings.py | 1 - .../docs/dlt-ecosystem/destinations/athena.md | 7 - poetry.lock | 31 +- pyproject.toml | 10 + .../cases/schemas/eth/ethereum_schema_v9.yml | 9 +- tests/common/destination/__init__.py | 0 .../test_destination_capabilities.py | 224 ++++++++ .../test_reference.py} | 0 .../normalizers/test_json_relational.py | 93 ++-- .../test_import_normalizers.py | 26 +- tests/common/schema/test_merges.py | 16 +- tests/common/schema/test_schema.py | 2 +- tests/common/storages/test_schema_storage.py | 2 +- tests/common/utils.py | 2 +- .../cases/eth_source/ethereum.schema.yaml | 30 +- tests/libs/pyarrow/test_pyarrow_normalizer.py | 2 +- tests/libs/test_parquet_writer.py | 6 +- .../athena_iceberg/test_athena_iceberg.py | 75 ++- tests/load/bigquery/test_bigquery_client.py | 41 +- .../bigquery/test_bigquery_table_builder.py | 3 +- .../load/duckdb/test_duckdb_table_builder.py | 24 +- tests/load/lancedb/__init__.py | 2 + tests/load/lancedb/test_pipeline.py | 3 +- tests/load/pipeline/test_arrow_loading.py | 18 +- tests/load/pipeline/test_athena.py | 30 +- tests/load/pipeline/test_bigquery.py | 2 +- tests/load/pipeline/test_clickhouse.py | 4 +- .../load/pipeline/test_databricks_pipeline.py | 6 +- tests/load/pipeline/test_dbt_helper.py | 4 +- tests/load/pipeline/test_dremio.py | 2 +- tests/load/pipeline/test_drop.py | 108 +++- tests/load/pipeline/test_duckdb.py | 18 +- .../load/pipeline/test_filesystem_pipeline.py | 22 +- tests/load/pipeline/test_merge_disposition.py | 195 +++---- tests/load/pipeline/test_pipelines.py | 478 ++++-------------- tests/load/pipeline/test_postgres.py | 258 +++++++++- tests/load/pipeline/test_redshift.py | 2 +- tests/load/pipeline/test_refresh_modes.py | 58 ++- .../load/pipeline/test_replace_disposition.py | 30 +- tests/load/pipeline/test_restore_state.py | 62 ++- tests/load/pipeline/test_scd2.py | 70 +-- .../load/pipeline/test_snowflake_pipeline.py | 85 +++- tests/load/pipeline/test_stage_loading.py | 35 +- .../test_write_disposition_changes.py | 16 +- .../sql_database/test_sql_database_source.py | 28 +- .../synapse/test_synapse_table_builder.py | 2 +- tests/load/test_dummy_client.py | 30 +- tests/load/test_job_client.py | 6 +- tests/load/utils.py | 45 +- tests/load/weaviate/test_weaviate_client.py | 2 + tests/normalize/test_max_nesting.py | 6 +- tests/pipeline/test_dlt_versions.py | 2 +- tests/pipeline/test_pipeline.py | 2 +- tests/pipeline/utils.py | 3 +- .../test_rest_api_pipeline_template.py | 5 +- 137 files changed, 3946 insertions(+), 2794 deletions(-) create mode 100644 dlt/common/destination/typing.py rename dlt/common/{normalizers => schema}/configuration.py (91%) create mode 100644 dlt/common/schema/normalizers.py create mode 100644 dlt/destinations/impl/lancedb/type_mapper.py delete mode 100644 dlt/normalize/schema.py create mode 100644 dlt/normalize/validate.py create mode 100644 tests/common/destination/__init__.py create mode 100644 tests/common/destination/test_destination_capabilities.py rename tests/common/{test_destination.py => destination/test_reference.py} (100%) rename tests/common/{normalizers => schema}/test_import_normalizers.py (97%) diff --git a/.github/workflows/test_common.yml b/.github/workflows/test_common.yml index 6efa7ffc4c..674b38a776 100644 --- a/.github/workflows/test_common.yml +++ b/.github/workflows/test_common.yml @@ -17,7 +17,7 @@ env: # we need the secrets only for the rest_api_pipeline tests which are in tests/sources # so we inject them only at the end - DLT_SECRETS_TOML: ${{ secrets.DLT_SECRETS_TOML }} + SOURCES__GITHUB__ACCESS_TOKEN: ${{ secrets.GITHUB_TOKEN }} jobs: get_docs_changes: @@ -126,15 +126,8 @@ jobs: name: Run pipeline tests with pyarrow but no pandas installed Windows shell: cmd - - name: create secrets.toml for examples - run: pwd && echo "$DLT_SECRETS_TOML" > tests/.dlt/secrets.toml - - name: Install pipeline and sources dependencies - run: poetry install --no-interaction -E duckdb -E cli -E parquet --with sentry-sdk --with pipeline -E deltalake -E sql_database - - # TODO: this is needed for the filesystem tests, not sure if this should be in an extra? - - name: Install openpyxl for excel tests - run: poetry run pip install openpyxl + run: poetry install --no-interaction -E duckdb -E cli -E parquet -E deltalake -E sql_database --with sentry-sdk,pipeline,sources - run: | poetry run pytest tests/extract tests/pipeline tests/libs tests/cli/common tests/destinations tests/sources diff --git a/.github/workflows/test_destination_athena.yml b/.github/workflows/test_destination_athena.yml index 70a79cd218..a03c17d342 100644 --- a/.github/workflows/test_destination_athena.yml +++ b/.github/workflows/test_destination_athena.yml @@ -22,7 +22,7 @@ env: RUNTIME__DLTHUB_TELEMETRY_ENDPOINT: ${{ secrets.RUNTIME__DLTHUB_TELEMETRY_ENDPOINT }} ACTIVE_DESTINATIONS: "[\"athena\"]" ALL_FILESYSTEM_DRIVERS: "[\"memory\"]" - EXCLUDED_DESTINATION_CONFIGURATIONS: "[\"athena-parquet-staging-iceberg\", \"athena-parquet-no-staging-iceberg\"]" + EXCLUDED_DESTINATION_CONFIGURATIONS: "[\"athena-parquet-iceberg-no-staging-iceberg\", \"athena-parquet-iceberg-staging-iceberg\"]" jobs: get_docs_changes: diff --git a/.github/workflows/test_local_sources.yml b/.github/workflows/test_local_sources.yml index 0178f59322..3d9e7b29a5 100644 --- a/.github/workflows/test_local_sources.yml +++ b/.github/workflows/test_local_sources.yml @@ -81,12 +81,12 @@ jobs: path: .venv key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}-local-sources - # TODO: which deps should we enable? + # TODO: which deps should we enable? - name: Install dependencies - run: poetry install --no-interaction -E postgres -E duckdb -E parquet -E filesystem -E cli -E sql_database --with sentry-sdk --with pipeline + run: poetry install --no-interaction -E postgres -E duckdb -E parquet -E filesystem -E cli -E sql_database --with sentry-sdk,pipeline,sources # run sources tests in load against configured destinations - - run: poetry run pytest tests/load/sources + - run: poetry run pytest tests/load/sources name: Run tests Linux env: DESTINATION__POSTGRES__CREDENTIALS: postgresql://loader:loader@localhost:5432/dlt_data diff --git a/Makefile b/Makefile index f47047a3fe..3878dddd15 100644 --- a/Makefile +++ b/Makefile @@ -44,7 +44,7 @@ has-poetry: poetry --version dev: has-poetry - poetry install --all-extras --with airflow,docs,providers,pipeline,sentry-sdk,dbt + poetry install --all-extras --with docs,providers,pipeline,sources,sentry-sdk lint: ./tools/check-package.sh diff --git a/dlt/common/configuration/container.py b/dlt/common/configuration/container.py index 84d6194966..d6b67b6e62 100644 --- a/dlt/common/configuration/container.py +++ b/dlt/common/configuration/container.py @@ -1,7 +1,7 @@ from contextlib import contextmanager, nullcontext, AbstractContextManager import re import threading -from typing import ClassVar, Dict, Iterator, Tuple, Type, TypeVar, Any +from typing import ClassVar, Dict, Iterator, Optional, Tuple, Type, TypeVar, Any from dlt.common.configuration.specs.base_configuration import ContainerInjectableContext from dlt.common.configuration.exceptions import ( @@ -171,6 +171,12 @@ def injectable_context( # value was modified in the meantime and not restored raise ContainerInjectableContextMangled(spec, context[spec], config) + def get(self, spec: Type[TConfiguration]) -> Optional[TConfiguration]: + try: + return self[spec] + except KeyError: + return None + @staticmethod def thread_pool_prefix() -> str: """Creates a container friendly pool prefix that contains starting thread id. Container implementation will automatically use it diff --git a/dlt/common/data_writers/exceptions.py b/dlt/common/data_writers/exceptions.py index 3b11ed70fc..cc63fdf9a8 100644 --- a/dlt/common/data_writers/exceptions.py +++ b/dlt/common/data_writers/exceptions.py @@ -1,4 +1,5 @@ from typing import NamedTuple, Sequence + from dlt.common.destination import TLoaderFileFormat from dlt.common.exceptions import DltException diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 4311fb270e..22df7ecea4 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -32,10 +32,11 @@ from dlt.common.destination import ( DestinationCapabilitiesContext, TLoaderFileFormat, - ALL_SUPPORTED_FILE_FORMATS, + LOADER_FILE_FORMATS, ) from dlt.common.metrics import DataWriterMetrics from dlt.common.schema.typing import TTableSchemaColumns +from dlt.common.schema.utils import is_nullable_column from dlt.common.typing import StrAny, TDataItem @@ -115,7 +116,7 @@ def item_format_from_file_extension(cls, extension: str) -> TDataItemFormat: elif extension == "parquet": return "arrow" # those files may be imported by normalizer as is - elif extension in ALL_SUPPORTED_FILE_FORMATS: + elif extension in LOADER_FILE_FORMATS: return "file" else: raise ValueError(f"Cannot figure out data item format for extension {extension}") @@ -331,7 +332,7 @@ def write_header(self, columns_schema: TTableSchemaColumns) -> None: self._caps, self.timestamp_timezone, ), - nullable=schema_item.get("nullable", True), + nullable=is_nullable_column(schema_item), ) for name, schema_item in columns_schema.items() ] diff --git a/dlt/common/destination/__init__.py b/dlt/common/destination/__init__.py index b7b98416a6..2f50b3e3d2 100644 --- a/dlt/common/destination/__init__.py +++ b/dlt/common/destination/__init__.py @@ -2,15 +2,17 @@ DestinationCapabilitiesContext, merge_caps_file_formats, TLoaderFileFormat, - ALL_SUPPORTED_FILE_FORMATS, + LOADER_FILE_FORMATS, ) from dlt.common.destination.reference import TDestinationReferenceArg, Destination, TDestination +from dlt.common.destination.typing import PreparedTableSchema __all__ = [ "DestinationCapabilitiesContext", "merge_caps_file_formats", "TLoaderFileFormat", - "ALL_SUPPORTED_FILE_FORMATS", + "LOADER_FILE_FORMATS", + "PreparedTableSchema", "TDestinationReferenceArg", "Destination", "TDestination", diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index 52e7d74833..eed1d6189e 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -1,15 +1,20 @@ +from abc import ABC, abstractmethod from typing import ( Any, Callable, ClassVar, + Iterable, Literal, Optional, Sequence, Tuple, Set, Protocol, + Type, get_args, ) +from dlt.common.data_types import TDataType +from dlt.common.exceptions import TerminalValueError from dlt.common.normalizers.typing import TNamingConventionReferenceArg from dlt.common.typing import TLoaderFileFormat from dlt.common.configuration.utils import serialize_value @@ -20,36 +25,109 @@ DestinationLoadingViaStagingNotSupported, DestinationLoadingWithoutStagingNotSupported, ) -from dlt.common.normalizers.naming import NamingConvention +from dlt.common.destination.typing import PreparedTableSchema from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.schema.typing import ( + TColumnSchema, + TColumnType, + TTableSchema, + TLoaderMergeStrategy, + TTableFormat, +) from dlt.common.wei import EVM_DECIMAL_PRECISION TLoaderParallelismStrategy = Literal["parallel", "table-sequential", "sequential"] -ALL_SUPPORTED_FILE_FORMATS: Set[TLoaderFileFormat] = set(get_args(TLoaderFileFormat)) +LOADER_FILE_FORMATS: Set[TLoaderFileFormat] = set(get_args(TLoaderFileFormat)) -class LoaderFileFormatAdapter(Protocol): - """Callback protocol for `loader_file_format_adapter` capability.""" +class LoaderFileFormatSelector(Protocol): + """Selects preferred and supported file formats for a given table schema""" + @staticmethod def __call__( - self, preferred_loader_file_format: TLoaderFileFormat, supported_loader_file_formats: Sequence[TLoaderFileFormat], /, *, - table_schema: "TTableSchema", # type: ignore[name-defined] # noqa: F821 + table_schema: TTableSchema, ) -> Tuple[TLoaderFileFormat, Sequence[TLoaderFileFormat]]: ... +class MergeStrategySelector(Protocol): + """Selects right set of merge strategies for a given table schema""" + + @staticmethod + def __call__( + supported_merge_strategies: Sequence[TLoaderMergeStrategy], + /, + *, + table_schema: TTableSchema, + ) -> Sequence["TLoaderMergeStrategy"]: ... + + +class DataTypeMapper(ABC): + def __init__(self, capabilities: "DestinationCapabilitiesContext") -> None: + """Maps dlt data types into destination data types""" + self.capabilities = capabilities + + @abstractmethod + def to_destination_type(self, column: TColumnSchema, table: PreparedTableSchema) -> str: + """Gets destination data type for a particular `column` in prepared `table`""" + pass + + @abstractmethod + def from_destination_type( + self, db_type: str, precision: Optional[int], scale: Optional[int] + ) -> TColumnType: + """Gets column type from db type""" + pass + + @abstractmethod + def ensure_supported_type( + self, + column: TColumnSchema, + table: PreparedTableSchema, + loader_file_format: TLoaderFileFormat, + ) -> None: + """Makes sure that dlt type in `column` in prepared `table` is supported by the destination for a given file format""" + pass + + +class UnsupportedTypeMapper(DataTypeMapper): + """Type Mapper that can't map any type""" + + def to_destination_type(self, column: TColumnSchema, table: PreparedTableSchema) -> str: + raise NotImplementedError("No types are supported, use real type mapper") + + def from_destination_type( + self, db_type: str, precision: Optional[int], scale: Optional[int] + ) -> TColumnType: + raise NotImplementedError("No types are supported, use real type mapper") + + def ensure_supported_type( + self, + column: TColumnSchema, + table: PreparedTableSchema, + loader_file_format: TLoaderFileFormat, + ) -> None: + raise TerminalValueError( + "No types are supported, use real type mapper", column["data_type"] + ) + + @configspec class DestinationCapabilitiesContext(ContainerInjectableContext): """Injectable destination capabilities required for many Pipeline stages ie. normalize""" + # do not allow to create default value, destination caps must be always explicitly inserted into container + can_create_default: ClassVar[bool] = False + preferred_loader_file_format: TLoaderFileFormat = None supported_loader_file_formats: Sequence[TLoaderFileFormat] = None - loader_file_format_adapter: LoaderFileFormatAdapter = None + loader_file_format_selector: LoaderFileFormatSelector = None """Callable that adapts `preferred_loader_file_format` and `supported_loader_file_formats` at runtime.""" - supported_table_formats: Sequence["TTableFormat"] = None # type: ignore[name-defined] # noqa: F821 + supported_table_formats: Sequence[TTableFormat] = None + type_mapper: Optional[Type[DataTypeMapper]] = None recommended_file_size: Optional[int] = None """Recommended file size in bytes when writing extract/load files""" preferred_staging_file_format: Optional[TLoaderFileFormat] = None @@ -89,14 +167,12 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): max_table_nesting: Optional[int] = None """Allows a destination to overwrite max_table_nesting from source""" - supported_merge_strategies: Sequence["TLoaderMergeStrategy"] = None # type: ignore[name-defined] # noqa: F821 + supported_merge_strategies: Sequence[TLoaderMergeStrategy] = None + merge_strategies_selector: MergeStrategySelector = None # TODO: also add `supported_replace_strategies` capability - # do not allow to create default value, destination caps must be always explicitly inserted into container - can_create_default: ClassVar[bool] = False - max_parallel_load_jobs: Optional[int] = None - """The destination can set the maxium amount of parallel load jobs being executed""" + """The destination can set the maximum amount of parallel load jobs being executed""" loader_parallelism_strategy: Optional[TLoaderParallelismStrategy] = None """The destination can override the parallelism strategy""" @@ -109,16 +185,17 @@ def generates_case_sensitive_identifiers(self) -> bool: def generic_capabilities( preferred_loader_file_format: TLoaderFileFormat = None, naming_convention: TNamingConventionReferenceArg = None, - loader_file_format_adapter: LoaderFileFormatAdapter = None, - supported_table_formats: Sequence["TTableFormat"] = None, # type: ignore[name-defined] # noqa: F821 - supported_merge_strategies: Sequence["TLoaderMergeStrategy"] = None, # type: ignore[name-defined] # noqa: F821 + loader_file_format_selector: LoaderFileFormatSelector = None, + supported_table_formats: Sequence[TTableFormat] = None, + supported_merge_strategies: Sequence[TLoaderMergeStrategy] = None, + merge_strategies_selector: MergeStrategySelector = None, ) -> "DestinationCapabilitiesContext": from dlt.common.data_writers.escape import format_datetime_literal caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = preferred_loader_file_format caps.supported_loader_file_formats = ["jsonl", "insert_values", "parquet", "csv"] - caps.loader_file_format_adapter = loader_file_format_adapter + caps.loader_file_format_selector = loader_file_format_selector caps.preferred_staging_file_format = None caps.supported_staging_file_formats = [] caps.naming_convention = naming_convention or caps.naming_convention @@ -140,8 +217,12 @@ def generic_capabilities( caps.supports_transactions = True caps.supports_multiple_statements = True caps.supported_merge_strategies = supported_merge_strategies or [] + caps.merge_strategies_selector = merge_strategies_selector return caps + def get_type_mapper(self) -> DataTypeMapper: + return self.type_mapper(self) + def merge_caps_file_formats( destination: str, diff --git a/dlt/common/destination/exceptions.py b/dlt/common/destination/exceptions.py index 49c9b822e3..50796998ad 100644 --- a/dlt/common/destination/exceptions.py +++ b/dlt/common/destination/exceptions.py @@ -1,4 +1,4 @@ -from typing import Any, Iterable, List +from typing import Any, Iterable, List, Sequence from dlt.common.exceptions import DltException, TerminalException, TransientException @@ -102,6 +102,37 @@ def __init__( ) +class UnsupportedDataType(DestinationTerminalException): + def __init__( + self, + destination_type: str, + table_name: str, + column: str, + data_type: str, + file_format: str, + available_in_formats: Sequence[str], + more_info: str, + ) -> None: + self.destination_type = destination_type + self.table_name = table_name + self.column = column + self.data_type = data_type + self.file_format = file_format + self.available_in_formats = available_in_formats + self.more_info = more_info + msg = ( + f"Destination {destination_type} cannot load data type '{data_type}' from" + f" '{file_format}' files. The affected table is '{table_name}' column '{column}'." + ) + if available_in_formats: + msg += f" Note: '{data_type}' can be loaded from {available_in_formats} formats(s)." + else: + msg += f" None of available file formats support '{data_type}' for this destination." + if more_info: + msg += " More info: " + more_info + super().__init__(msg) + + class DestinationHasFailedJobs(DestinationTerminalException): def __init__(self, destination_name: str, load_id: str, failed_jobs: List[Any]) -> None: self.destination_name = destination_name diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index e7bba266df..ef294d4298 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -21,22 +21,18 @@ ) from typing_extensions import Annotated import datetime # noqa: 251 -from copy import deepcopy import inspect from dlt.common import logger, pendulum from dlt.common.configuration.specs.base_configuration import extract_inner_hint -from dlt.common.destination.utils import verify_schema_capabilities +from dlt.common.destination.typing import PreparedTableSchema +from dlt.common.destination.utils import verify_schema_capabilities, verify_supported_data_types from dlt.common.exceptions import TerminalValueError from dlt.common.metrics import LoadJobMetrics from dlt.common.normalizers.naming import NamingConvention -from dlt.common.schema import Schema, TTableSchema, TSchemaTables -from dlt.common.schema.utils import ( - get_file_format, - get_write_disposition, - get_table_format, - get_merge_strategy, -) +from dlt.common.schema import Schema, TSchemaTables +from dlt.common.schema.typing import _TTableSchemaBase, TWriteDisposition +from dlt.common.schema.utils import fill_hints_from_parent_and_clone_table from dlt.common.configuration import configspec, resolve_configuration, known_sections, NotResolved from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration from dlt.common.destination.capabilities import DestinationCapabilitiesContext @@ -257,7 +253,7 @@ class DestinationClientStagingConfiguration(DestinationClientDwhConfiguration): Also supports datasets and can act as standalone destination. """ - as_staging: bool = False + as_staging_destination: bool = False bucket_url: str = None # layout of the destination files layout: str = DEFAULT_FILE_LAYOUT @@ -347,11 +343,11 @@ def __init__(self, file_path: str) -> None: # variables needed by most jobs, set by the loader in set_run_vars self._schema: Schema = None - self._load_table: TTableSchema = None + self._load_table: PreparedTableSchema = None self._load_id: str = None self._job_client: "JobClientBase" = None - def set_run_vars(self, load_id: str, schema: Schema, load_table: TTableSchema) -> None: + def set_run_vars(self, load_id: str, schema: Schema, load_table: PreparedTableSchema) -> None: """ called by the loader right before the job is run """ @@ -457,6 +453,38 @@ def drop_storage(self) -> None: """Brings storage back into not initialized state. Typically data in storage is destroyed.""" pass + def verify_schema( + self, only_tables: Iterable[str] = None, new_jobs: Iterable[ParsedLoadJobFileName] = None + ) -> List[PreparedTableSchema]: + """Verifies schema before loading, returns a list of verified loaded tables.""" + if exceptions := verify_schema_capabilities( + self.schema, + self.capabilities, + self.config.destination_type, + warnings=False, + ): + for exception in exceptions: + logger.error(str(exception)) + raise exceptions[0] + + prepared_tables = [ + self.prepare_load_table(table_name) + for table_name in set( + list(only_tables or []) + self.schema.data_table_names(seen_data_only=True) + ) + ] + if exceptions := verify_supported_data_types( + prepared_tables, + new_jobs, + self.capabilities, + self.config.destination_type, + warnings=False, + ): + for exception in exceptions: + logger.error(str(exception)) + raise exceptions[0] + return prepared_tables + def update_stored_schema( self, only_tables: Iterable[str] = None, @@ -473,7 +501,6 @@ def update_stored_schema( Returns: Optional[TSchemaTables]: Returns an update that was applied at the destination. """ - self._verify_schema() # make sure that schema being saved was not modified from the moment it was loaded from storage version_hash = self.schema.version_hash if self.schema.is_modified: @@ -482,11 +509,19 @@ def update_stored_schema( ) return expected_update + def prepare_load_table(self, table_name: str) -> PreparedTableSchema: + """Prepares a table schema to be loaded by filling missing hints and doing other modifications requires by given destination.""" + try: + return fill_hints_from_parent_and_clone_table(self.schema.tables, self.schema.tables[table_name]) # type: ignore[return-value] + + except KeyError: + raise UnknownTableException(self.schema.name, table_name) + @abstractmethod def create_load_job( - self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: - """Creates a load job for a particular `table` with content in `file_path`""" + """Creates a load job for a particular `table` with content in `file_path`. Table is already prepared to be loaded.""" pass def prepare_load_job_execution( # noqa: B027, optional override @@ -495,15 +530,15 @@ def prepare_load_job_execution( # noqa: B027, optional override """Prepare the connected job client for the execution of a load job (used for query tags in sql clients)""" pass - def should_truncate_table_before_load(self, table: TTableSchema) -> bool: - return table["write_disposition"] == "replace" + def should_truncate_table_before_load(self, table_name: str) -> bool: + return self.prepare_load_table(table_name)["write_disposition"] == "replace" def create_table_chain_completed_followup_jobs( self, - table_chain: Sequence[TTableSchema], + table_chain: Sequence[PreparedTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJobRequest]: - """Creates a list of followup jobs that should be executed after a table chain is completed""" + """Creates a list of followup jobs that should be executed after a table chain is completed. Tables are already prepared to be loaded.""" return [] @abstractmethod @@ -521,34 +556,6 @@ def __exit__( ) -> None: pass - def _verify_schema(self) -> None: - """Verifies schema before loading""" - if exceptions := verify_schema_capabilities( - self.schema, self.capabilities, self.config.destination_type, warnings=False - ): - for exception in exceptions: - logger.error(str(exception)) - raise exceptions[0] - - def prepare_load_table( - self, table_name: str, prepare_for_staging: bool = False - ) -> TTableSchema: - try: - # make a copy of the schema so modifications do not affect the original document - table = deepcopy(self.schema.tables[table_name]) - # add write disposition if not specified - in child tables - if "write_disposition" not in table: - table["write_disposition"] = get_write_disposition(self.schema.tables, table_name) - if "x-merge-strategy" not in table: - table["x-merge-strategy"] = get_merge_strategy(self.schema.tables, table_name) # type: ignore[typeddict-unknown-key] - if "table_format" not in table: - table["table_format"] = get_table_format(self.schema.tables, table_name) - if "file_format" not in table: - table["file_format"] = get_file_format(self.schema.tables, table_name) - return table - except KeyError: - raise UnknownTableException(self.schema.name, table_name) - class WithStateSync(ABC): @abstractmethod @@ -571,7 +578,7 @@ class WithStagingDataset(ABC): """Adds capability to use staging dataset and request it from the loader""" @abstractmethod - def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: + def should_load_data_to_staging_dataset(self, table_name: str) -> bool: return False @abstractmethod @@ -583,9 +590,7 @@ def with_staging_dataset(self) -> ContextManager["JobClientBase"]: class SupportsStagingDestination(ABC): """Adds capability to support a staging destination for the load""" - def should_load_data_to_staging_dataset_on_staging_destination( - self, table: TTableSchema - ) -> bool: + def should_load_data_to_staging_dataset_on_staging_destination(self, table_name: str) -> bool: """If set to True, and staging destination is configured, the data will be loaded to staging dataset on staging destination instead of a regular dataset on staging destination. Currently it is used by Athena Iceberg which uses staging dataset on staging destination to copy data to iceberg tables stored on regular dataset on staging destination. @@ -595,7 +600,7 @@ def should_load_data_to_staging_dataset_on_staging_destination( return False @abstractmethod - def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: """If set to True, data in `table` will be truncated on staging destination (regular dataset). This is the default behavior which can be changed with a config flag. For Athena + Iceberg this setting is always False - Athena uses regular dataset to store Iceberg tables and we avoid touching it. diff --git a/dlt/common/destination/typing.py b/dlt/common/destination/typing.py new file mode 100644 index 0000000000..bdfbddaa8c --- /dev/null +++ b/dlt/common/destination/typing.py @@ -0,0 +1,8 @@ +from dlt.common.schema.typing import _TTableSchemaBase, TWriteDisposition + + +class PreparedTableSchema(_TTableSchemaBase, total=False): + """Table schema with all hints prepared to be loaded""" + + write_disposition: TWriteDisposition + _x_prepared: bool # needed for the type checker diff --git a/dlt/common/destination/utils.py b/dlt/common/destination/utils.py index 931413126c..0bad5b152e 100644 --- a/dlt/common/destination/utils.py +++ b/dlt/common/destination/utils.py @@ -1,14 +1,23 @@ -from typing import List +import contextlib +from typing import Dict, Iterable, List, Optional, Set from dlt.common import logger -from dlt.common.destination.exceptions import IdentifierTooLongException +from dlt.common.configuration.inject import with_config +from dlt.common.destination.exceptions import ( + DestinationCapabilitiesException, + IdentifierTooLongException, +) +from dlt.common.destination.typing import PreparedTableSchema +from dlt.common.destination.exceptions import UnsupportedDataType +from dlt.common.destination.capabilities import DestinationCapabilitiesContext, LOADER_FILE_FORMATS from dlt.common.schema import Schema from dlt.common.schema.exceptions import ( SchemaIdentifierNormalizationCollision, ) -from dlt.common.typing import DictStrStr - -from .capabilities import DestinationCapabilitiesContext +from dlt.common.schema.typing import TColumnType, TLoaderMergeStrategy, TSchemaTables, TTableSchema +from dlt.common.schema.utils import get_merge_strategy, is_complete_column +from dlt.common.storages import ParsedLoadJobFileName +from dlt.common.typing import ConfigValue, DictStrStr, TLoaderFileFormat def verify_schema_capabilities( @@ -17,7 +26,8 @@ def verify_schema_capabilities( destination_type: str, warnings: bool = True, ) -> List[Exception]: - """Verifies schema tables before loading against capabilities. Returns a list of exceptions representing critical problems with the schema. + """Verifies `load_tables` that have all hints filled by job client before loading against capabilities. + Returns a list of exceptions representing critical problems with the schema. It will log warnings by default. It is up to the caller to eventually raise exception * Checks all table and column name lengths against destination capabilities and raises on too long identifiers @@ -104,3 +114,126 @@ def verify_schema_capabilities( ) ) return exception_log + + +def column_type_to_str(column: TColumnType) -> str: + """Converts column type to db-like type string""" + data_type: str = column["data_type"] + precision = column.get("precision") + scale = column.get("scale") + if precision is not None and scale is not None: + data_type += f"({precision},{scale})" + elif precision is not None: + data_type += f"({precision})" + return data_type + + +def verify_supported_data_types( + prepared_tables: Iterable[PreparedTableSchema], + new_jobs: Iterable[ParsedLoadJobFileName], + capabilities: DestinationCapabilitiesContext, + destination_type: str, + warnings: bool = True, +) -> List[Exception]: + exception_log: List[Exception] = [] + # can't check types without type mapper + if capabilities.type_mapper is None or not new_jobs: + return exception_log + + type_mapper = capabilities.get_type_mapper() + + # index available file formats + table_file_formats: Dict[str, Set[TLoaderFileFormat]] = {} + for parsed_file in new_jobs: + formats = table_file_formats.setdefault(parsed_file.table_name, set()) + if parsed_file.file_format in LOADER_FILE_FORMATS: + formats.add(parsed_file.file_format) # type: ignore[arg-type] + # all file formats + all_file_formats = set(capabilities.supported_loader_file_formats or []) | set( + capabilities.supported_staging_file_formats or [] + ) + + for table in prepared_tables: + # map types + for column in table["columns"].values(): + # do not verify incomplete columns, those won't be created + if not is_complete_column(column): + continue + try: + type_mapper.to_destination_type(column, table) + except Exception as ex: + # collect mapping exceptions + exception_log.append(ex) + # ensure if types can be loaded from file formats present in jobs + for format_ in table_file_formats.get(table["name"], []): + try: + type_mapper.ensure_supported_type(column, table, format_) + except ValueError as err: + # figure out where data type is supported + available_in_formats: List[TLoaderFileFormat] = [] + for candidate_format in all_file_formats - set([format_]): + with contextlib.suppress(Exception): + type_mapper.ensure_supported_type(column, table, candidate_format) + available_in_formats.append(candidate_format) + exception_log.append( + UnsupportedDataType( + destination_type, + table["name"], + column["name"], + column_type_to_str(column), + format_, + available_in_formats, + err.args[0], + ) + ) + + return exception_log + + +@with_config +def resolve_merge_strategy( + tables: TSchemaTables, + table: TTableSchema, + destination_capabilities: Optional[DestinationCapabilitiesContext] = ConfigValue, +) -> Optional[TLoaderMergeStrategy]: + """Resolve merge strategy for a table, possibly resolving the 'x-merge-strategy from a table chain. strategies selector in `destination_capabilities` + is used if present. If `table` does not contain strategy hint, a default value will be used which is the first. + + `destination_capabilities` are injected from context if not explicitly passed. + + Returns None if table write disposition is not merge + """ + if table.get("write_disposition") == "merge": + destination_capabilities = ( + destination_capabilities or DestinationCapabilitiesContext.generic_capabilities() + ) + supported_strategies = destination_capabilities.supported_merge_strategies + table_name = table["name"] + if destination_capabilities.merge_strategies_selector: + supported_strategies = destination_capabilities.merge_strategies_selector( + supported_strategies, table_schema=table + ) + if not supported_strategies: + table_format_info = "" + if destination_capabilities.supported_table_formats: + table_format_info = ( + " or try different table format which may offer `merge`:" + f" {destination_capabilities.supported_table_formats}" + ) + logger.warning( + "Destination does not support any merge strategies and `merge` write disposition " + f" for table `{table_name}` cannot be met and will fall back to `append`. Change" + f" write disposition{table_format_info}." + ) + return None + merge_strategy = get_merge_strategy(tables, table_name) + # use first merge strategy as default + if merge_strategy is None and supported_strategies: + merge_strategy = supported_strategies[0] + if merge_strategy not in supported_strategies: + raise DestinationCapabilitiesException( + f"`{merge_strategy}` merge strategy not supported" + f" for table `{table_name}`. Available strategies: {supported_strategies}" + ) + return merge_strategy + return None diff --git a/dlt/common/libs/pyarrow.py b/dlt/common/libs/pyarrow.py index 14ca1fb46f..3f047e275a 100644 --- a/dlt/common/libs/pyarrow.py +++ b/dlt/common/libs/pyarrow.py @@ -21,6 +21,7 @@ from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.schema.typing import TColumnType +from dlt.common.schema.utils import is_nullable_column from dlt.common.typing import StrStr, TFileOrPath from dlt.common.normalizers.naming import NamingConvention @@ -56,7 +57,9 @@ def get_py_arrow_datatype( elif column_type == "timestamp": # sets timezone to None when timezone hint is false timezone = tz if column.get("timezone", True) else None - precision = column.get("precision") or caps.timestamp_precision + precision = column.get("precision") + if precision is None: + precision = caps.timestamp_precision return get_py_arrow_timestamp(precision, timezone) elif column_type == "bigint": return get_pyarrow_int(column.get("precision")) @@ -78,7 +81,10 @@ def get_py_arrow_datatype( elif column_type == "date": return pyarrow.date32() elif column_type == "time": - return get_py_arrow_time(column.get("precision") or caps.timestamp_precision) + precision = column.get("precision") + if precision is None: + precision = caps.timestamp_precision + return get_py_arrow_time(precision) else: raise ValueError(column_type) @@ -237,7 +243,7 @@ def should_normalize_arrow_schema( ) -> Tuple[bool, Mapping[str, str], Dict[str, str], Dict[str, bool], bool, TTableSchemaColumns]: rename_mapping = get_normalized_arrow_fields_mapping(schema, naming) rev_mapping = {v: k for k, v in rename_mapping.items()} - nullable_mapping = {k: v.get("nullable", True) for k, v in columns.items()} + nullable_mapping = {k: is_nullable_column(v) for k, v in columns.items()} # All fields from arrow schema that have nullable set to different value than in columns # Key is the renamed column name nullable_updates: Dict[str, bool] = {} @@ -246,8 +252,8 @@ def should_normalize_arrow_schema( if norm_name in nullable_mapping and field.nullable != nullable_mapping[norm_name]: nullable_updates[norm_name] = nullable_mapping[norm_name] - dlt_load_id_col = naming.normalize_table_identifier("_dlt_load_id") - dlt_id_col = naming.normalize_table_identifier("_dlt_id") + dlt_load_id_col = naming.normalize_identifier("_dlt_load_id") + dlt_id_col = naming.normalize_identifier("_dlt_id") dlt_columns = {dlt_load_id_col, dlt_id_col} # Do we need to add a load id column? @@ -326,7 +332,7 @@ def normalize_py_arrow_item( new_field = pyarrow.field( column_name, get_py_arrow_datatype(column, caps, "UTC"), - nullable=column.get("nullable", True), + nullable=is_nullable_column(column), ) new_fields.append(new_field) new_columns.append(pyarrow.nulls(item.num_rows, type=new_field.type)) @@ -343,7 +349,7 @@ def normalize_py_arrow_item( load_id_type = pyarrow.dictionary(pyarrow.int8(), pyarrow.string()) new_fields.append( pyarrow.field( - naming.normalize_table_identifier("_dlt_load_id"), + naming.normalize_identifier("_dlt_load_id"), load_id_type, nullable=False, ) diff --git a/dlt/common/normalizers/json/relational.py b/dlt/common/normalizers/json/relational.py index 33184640f0..b8e9fdaff3 100644 --- a/dlt/common/normalizers/json/relational.py +++ b/dlt/common/normalizers/json/relational.py @@ -1,5 +1,6 @@ from functools import lru_cache from typing import Dict, List, Mapping, Optional, Sequence, Tuple, cast, TypedDict, Any +from dlt.common.destination.utils import resolve_merge_strategy from dlt.common.json import json from dlt.common.normalizers.exceptions import InvalidJsonNormalizer from dlt.common.normalizers.typing import TJSONNormalizer, TRowIdType @@ -8,20 +9,19 @@ from dlt.common.typing import DictStrAny, TDataItem, StrAny from dlt.common.schema import Schema from dlt.common.schema.typing import ( - TLoaderMergeStrategy, TColumnSchema, TColumnName, TSimpleRegex, DLT_NAME_PREFIX, + TTableSchema, ) from dlt.common.schema.utils import ( column_name_validator, - get_validity_column_names, get_columns_names_with_prop, get_first_column_name_with_prop, - get_merge_strategy, + has_column_with_prop, + is_nested_table, ) -from dlt.common.schema.exceptions import ColumnNameConflictException from dlt.common.utils import digest128, update_dict_nested from dlt.common.normalizers.json import ( TNormalizedRowIterator, @@ -98,33 +98,6 @@ def _reset(self) -> None: # self.known_types: Dict[str, TDataType] = {} # self.primary_keys = Dict[str, ] - # for those paths the complex nested objects should be left in place - def _is_complex_type(self, table_name: str, field_name: str, _r_lvl: int) -> bool: - # turn everything at the recursion level into complex type - max_nesting = self.max_nesting - schema = self.schema - max_table_nesting = self._get_table_nesting_level(schema, table_name) - if max_table_nesting is not None: - max_nesting = max_table_nesting - - assert _r_lvl <= max_nesting - if _r_lvl == max_nesting: - return True - - # use cached value - # path = f"{table_name}▶{field_name}" - # or use definition in the schema - column: TColumnSchema = None - table = schema.tables.get(table_name) - if table: - column = table["columns"].get(field_name) - if column is None or "data_type" not in column: - data_type = schema.get_preferred_type(field_name) - else: - data_type = column["data_type"] - - return data_type == "complex" - def _flatten( self, table: str, dict_row: DictStrAny, _r_lvl: int ) -> Tuple[DictStrAny, Dict[Tuple[str, ...], Sequence[Any]]]: @@ -141,13 +114,15 @@ def norm_row_dicts(dict_row: StrAny, __r_lvl: int, path: Tuple[str, ...] = ()) - norm_k = self.EMPTY_KEY_IDENTIFIER # if norm_k != k: # print(f"{k} -> {norm_k}") - child_name = ( + nested_name = ( norm_k if path == () else schema_naming.shorten_fragments(*path, norm_k) ) # for lists and dicts we must check if type is possibly complex if isinstance(v, (dict, list)): - if not self._is_complex_type(table, child_name, __r_lvl): - # TODO: if schema contains table {table}__{child_name} then convert v into single element list + if not self._is_complex_type( + self.schema, table, nested_name, self.max_nesting, __r_lvl + ): + # TODO: if schema contains table {table}__{nested_name} then convert v into single element list if isinstance(v, dict): # flatten the dict more norm_row_dicts(v, __r_lvl + 1, path + (norm_k,)) @@ -159,7 +134,7 @@ def norm_row_dicts(dict_row: StrAny, __r_lvl: int, path: Tuple[str, ...] = ()) - # pass the complex value to out_rec_row pass - out_rec_row[child_name] = v + out_rec_row[nested_name] = v norm_row_dicts(dict_row, _r_lvl) return out_rec_row, out_rec_list @@ -179,10 +154,10 @@ def get_row_hash(row: Dict[str, Any], subset: Optional[List[str]] = None) -> str return digest128(row_str, DLT_ID_LENGTH_BYTES) @staticmethod - def _get_child_row_hash(parent_row_id: str, child_table: str, list_idx: int) -> str: - # create deterministic unique id of the child row taking into account that all lists are ordered - # and all child tables must be lists - return digest128(f"{parent_row_id}_{child_table}_{list_idx}", DLT_ID_LENGTH_BYTES) + def _get_nested_row_hash(parent_row_id: str, nested_table: str, list_idx: int) -> str: + # create deterministic unique id of the nested row taking into account that all lists are ordered + # and all nested tables must be lists + return digest128(f"{parent_row_id}_{nested_table}_{list_idx}", DLT_ID_LENGTH_BYTES) def _link_row(self, row: DictStrAny, parent_row_id: str, list_idx: int) -> DictStrAny: assert parent_row_id @@ -204,29 +179,27 @@ def _add_row_id( pos: int, _r_lvl: int, ) -> str: - primary_key = False - if _r_lvl > 0: # child table - primary_key = bool( - self.schema.filter_row_with_hint(table, "primary_key", flattened_row) - ) - row_id_type = self._get_row_id_type(self.schema, table, primary_key, _r_lvl) - - if row_id_type == "random": - row_id = generate_dlt_id() - else: - if _r_lvl == 0: # root table - if row_id_type in ("key_hash", "row_hash"): - subset = None - if row_id_type == "key_hash": - subset = self._get_primary_key(self.schema, table) - # base hash on `dict_row` instead of `flattened_row` - # so changes in child tables lead to new row id - row_id = self.get_row_hash(dict_row, subset=subset) - elif _r_lvl > 0: # child table - if row_id_type == "row_hash": - row_id = DataItemNormalizer._get_child_row_hash(parent_row_id, table, pos) - # link to parent table + if _r_lvl == 0: # root table + row_id_type = self._get_root_row_id_type(self.schema, table) + if row_id_type in ("key_hash", "row_hash"): + subset = None + if row_id_type == "key_hash": + subset = self._get_primary_key(self.schema, table) + # base hash on `dict_row` instead of `flattened_row` + # so changes in nested tables lead to new row id + row_id = self.get_row_hash(dict_row, subset=subset) + else: + row_id = generate_dlt_id() + else: # nested table + row_id_type, is_nested = self._get_nested_row_id_type(self.schema, table) + if row_id_type == "row_hash": + row_id = DataItemNormalizer._get_nested_row_hash(parent_row_id, table, pos) + # link to parent table + if is_nested: self._link_row(flattened_row, parent_row_id, pos) + else: + # do not create link if primary key was found for nested table + row_id = generate_dlt_id() flattened_row[self.c_dlt_id] = row_id return row_id @@ -236,7 +209,7 @@ def _get_propagated_values(self, table: str, row: DictStrAny, _r_lvl: int) -> St config = self.propagation_config if config: - # mapping(k:v): propagate property with name "k" as property with name "v" in child table + # mapping(k:v): propagate property with name "k" as property with name "v" in nested table mappings: Dict[TColumnName, TColumnName] = {} if _r_lvl == 0: mappings.update(config.get("root") or {}) @@ -249,7 +222,7 @@ def _get_propagated_values(self, table: str, row: DictStrAny, _r_lvl: int) -> St return extend - # generate child tables only for lists + # generate nested tables only for lists def _normalize_list( self, seq: Sequence[Any], @@ -262,8 +235,8 @@ def _normalize_list( table = self.schema.naming.shorten_fragments(*parent_path, *ident_path) for idx, v in enumerate(seq): - # yield child table row if isinstance(v, dict): + # found dict element in seq yield from self._normalize_row( v, extend, ident_path, parent_path, parent_row_id, idx, _r_lvl ) @@ -279,13 +252,11 @@ def _normalize_list( _r_lvl + 1, ) else: - # list of simple types - child_row_hash = DataItemNormalizer._get_child_row_hash(parent_row_id, table, idx) + # found non-dict in seq, so wrap it wrap_v = wrap_in_dict(self.c_value, v) - wrap_v[self.c_dlt_id] = child_row_hash - e = self._link_row(wrap_v, parent_row_id, idx) - DataItemNormalizer._extend_row(extend, e) - yield (table, self.schema.naming.shorten_fragments(*parent_path)), e + DataItemNormalizer._extend_row(extend, wrap_v) + self._add_row_id(table, wrap_v, wrap_v, parent_row_id, idx, _r_lvl) + yield (table, self.schema.naming.shorten_fragments(*parent_path)), wrap_v def _normalize_row( self, @@ -308,7 +279,7 @@ def _normalize_row( if not row_id: row_id = self._add_row_id(table, dict_row, flattened_row, parent_row_id, pos, _r_lvl) - # find fields to propagate to child tables in config + # find fields to propagate to nested tables in config extend.update(self._get_propagated_values(table, flattened_row, _r_lvl)) # yield parent table first @@ -368,7 +339,7 @@ def extend_table(self, table_name: str) -> None: Table name should be normalized. """ table = self.schema.tables.get(table_name) - if not table.get("parent") and table.get("write_disposition") == "merge": + if not is_nested_table(table) and table.get("write_disposition") == "merge": DataItemNormalizer.update_normalizer_config( self.schema, { @@ -392,11 +363,6 @@ def normalize_data_item( row = cast(DictStrAny, item) # identify load id if loaded data must be processed after loading incrementally row[self.c_dlt_load_id] = load_id - if self._get_merge_strategy(self.schema, table_name) == "scd2": - self._validate_validity_column_names( - self.schema.name, self._get_validity_column_names(self.schema, table_name), item - ) - yield from self._normalize_row( row, {}, @@ -459,18 +425,12 @@ def _normalize_prop( ) @staticmethod - @lru_cache(maxsize=None) def _get_table_nesting_level(schema: Schema, table_name: str) -> Optional[int]: table = schema.tables.get(table_name) if table: return table.get("x-normalizer", {}).get("max_nesting") # type: ignore return None - @staticmethod - @lru_cache(maxsize=None) - def _get_merge_strategy(schema: Schema, table_name: str) -> Optional[TLoaderMergeStrategy]: - return get_merge_strategy(schema.tables, table_name) - @staticmethod @lru_cache(maxsize=None) def _get_primary_key(schema: Schema, table_name: str) -> List[str]: @@ -481,16 +441,50 @@ def _get_primary_key(schema: Schema, table_name: str) -> List[str]: @staticmethod @lru_cache(maxsize=None) - def _get_validity_column_names(schema: Schema, table_name: str) -> List[Optional[str]]: - return get_validity_column_names(schema.get_table(table_name)) + def _is_complex_type( + schema: Schema, table_name: str, field_name: str, max_nesting: int, _r_lvl: int + ) -> bool: + """For those paths the complex nested objects should be left in place. + Cache perf: max_nesting < _r_lvl: ~2x faster, full check 10x faster + """ + # turn everything at the recursion level into complex type + max_table_nesting = DataItemNormalizer._get_table_nesting_level(schema, table_name) + if max_table_nesting is not None: + max_nesting = max_table_nesting + + assert _r_lvl <= max_nesting + if _r_lvl == max_nesting: + return True + + column: TColumnSchema = None + table = schema.tables.get(table_name) + if table: + column = table["columns"].get(field_name) + if column is None or "data_type" not in column: + data_type = schema.get_preferred_type(field_name) + else: + data_type = column["data_type"] + + return data_type == "complex" @staticmethod @lru_cache(maxsize=None) - def _get_row_id_type( - schema: Schema, table_name: str, primary_key: bool, _r_lvl: int - ) -> TRowIdType: - if _r_lvl == 0: # root table - merge_strategy = DataItemNormalizer._get_merge_strategy(schema, table_name) + def _get_nested_row_id_type(schema: Schema, table_name: str) -> Tuple[TRowIdType, bool]: + """Gets type of row id to be added to nested table and if linking information should be added""" + if table := schema.tables.get(table_name): + merge_strategy = resolve_merge_strategy(schema.tables, table) + if merge_strategy not in ("upsert", "scd2") and not is_nested_table(table): + return "random", False + else: + # table will be created, use standard linking + pass + return "row_hash", True + + @staticmethod + @lru_cache(maxsize=None) + def _get_root_row_id_type(schema: Schema, table_name: str) -> TRowIdType: + if table := schema.tables.get(table_name): + merge_strategy = resolve_merge_strategy(schema.tables, table) if merge_strategy == "upsert": return "key_hash" elif merge_strategy == "scd2": @@ -499,26 +493,8 @@ def _get_row_id_type( "x-row-version", include_incomplete=True, ) - if x_row_version_col == DataItemNormalizer.C_DLT_ID: + if x_row_version_col == schema.naming.normalize_identifier( + DataItemNormalizer.C_DLT_ID + ): return "row_hash" - elif _r_lvl > 0: # child table - merge_strategy = DataItemNormalizer._get_merge_strategy(schema, table_name) - if merge_strategy in ("upsert", "scd2"): - # these merge strategies rely on deterministic child row hash - return "row_hash" - if not primary_key: - return "row_hash" return "random" - - @staticmethod - def _validate_validity_column_names( - schema_name: str, validity_column_names: List[Optional[str]], item: TDataItem - ) -> None: - """Raises exception if configured validity column name appears in data item.""" - for validity_column_name in validity_column_names: - if validity_column_name in item.keys(): - raise ColumnNameConflictException( - schema_name, - "Found column in data item with same name as validity column" - f' "{validity_column_name}".', - ) diff --git a/dlt/common/normalizers/utils.py b/dlt/common/normalizers/utils.py index d852cfb7d9..c090aa1bde 100644 --- a/dlt/common/normalizers/utils.py +++ b/dlt/common/normalizers/utils.py @@ -1,188 +1,11 @@ import os -from importlib import import_module -from types import ModuleType -from typing import Any, Dict, Optional, Type, Tuple, cast, List +from typing import List -import dlt -from dlt.common import logger from dlt.common import known_env -from dlt.common.configuration.inject import with_config -from dlt.common.configuration.specs import known_sections -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.normalizers.configuration import NormalizersConfiguration -from dlt.common.normalizers.exceptions import InvalidJsonNormalizer -from dlt.common.normalizers.json import SupportsDataItemNormalizer, DataItemNormalizer -from dlt.common.normalizers.naming import NamingConvention -from dlt.common.normalizers.naming.exceptions import ( - NamingTypeNotFound, - UnknownNamingModule, - InvalidNamingType, -) -from dlt.common.normalizers.typing import ( - TJSONNormalizer, - TNormalizersConfig, - TNamingConventionReferenceArg, -) -from dlt.common.typing import is_subclass -from dlt.common.utils import get_full_class_name, uniq_id_base64, many_uniq_ids_base64 +from dlt.common.utils import uniq_id_base64, many_uniq_ids_base64 -DEFAULT_NAMING_NAMESPACE = os.environ.get( - known_env.DLT_DEFAULT_NAMING_NAMESPACE, "dlt.common.normalizers.naming" -) -DEFAULT_NAMING_MODULE = os.environ.get(known_env.DLT_DEFAULT_NAMING_MODULE, "snake_case") -DLT_ID_LENGTH_BYTES = int(os.environ.get(known_env.DLT_DLT_ID_LENGTH_BYTES, 10)) - - -def _section_for_schema(kwargs: Dict[str, Any]) -> Tuple[str, ...]: - """Uses the schema name to generate dynamic section normalizer settings""" - if schema_name := kwargs.get("schema_name"): - return (known_sections.SOURCES, schema_name) - else: - return (known_sections.SOURCES,) - - -@with_config(spec=NormalizersConfiguration, sections=_section_for_schema) # type: ignore[call-overload] -def explicit_normalizers( - naming: TNamingConventionReferenceArg = dlt.config.value, - json_normalizer: TJSONNormalizer = dlt.config.value, - allow_identifier_change_on_table_with_data: bool = None, - schema_name: Optional[str] = None, -) -> TNormalizersConfig: - """Gets explicitly configured normalizers without any defaults or capabilities injection. If `naming` - is a module or a type it will get converted into string form via import. - - If `schema_name` is present, a section ("sources", schema_name, "schema") is used to inject the config - """ - - norm_conf: TNormalizersConfig = {"names": serialize_reference(naming), "json": json_normalizer} - if allow_identifier_change_on_table_with_data is not None: - norm_conf["allow_identifier_change_on_table_with_data"] = ( - allow_identifier_change_on_table_with_data - ) - return norm_conf - - -@with_config -def import_normalizers( - explicit_normalizers: TNormalizersConfig, - default_normalizers: TNormalizersConfig = None, - destination_capabilities: DestinationCapabilitiesContext = None, -) -> Tuple[TNormalizersConfig, NamingConvention, Type[DataItemNormalizer[Any]]]: - """Imports the normalizers specified in `normalizers_config` or taken from defaults. Returns the updated config and imported modules. - - `destination_capabilities` are used to get naming convention, max length of the identifier and max nesting level. - """ - if default_normalizers is None: - default_normalizers = {} - # add defaults to normalizer_config - naming: TNamingConventionReferenceArg = explicit_normalizers.get("names") - if naming is None: - if destination_capabilities: - naming = destination_capabilities.naming_convention - if naming is None: - naming = default_normalizers.get("names") or DEFAULT_NAMING_MODULE - naming_convention = naming_from_reference(naming, destination_capabilities) - explicit_normalizers["names"] = serialize_reference(naming) - - item_normalizer = explicit_normalizers.get("json") or default_normalizers.get("json") or {} - item_normalizer.setdefault("module", "dlt.common.normalizers.json.relational") - # if max_table_nesting is set, we need to set the max_table_nesting in the json_normalizer - if destination_capabilities and destination_capabilities.max_table_nesting is not None: - # TODO: this is a hack, we need a better method to do this - from dlt.common.normalizers.json.relational import DataItemNormalizer - try: - DataItemNormalizer.ensure_this_normalizer(item_normalizer) - item_normalizer.setdefault("config", {}) - item_normalizer["config"]["max_nesting"] = destination_capabilities.max_table_nesting # type: ignore[index] - except InvalidJsonNormalizer: - # not a right normalizer - logger.warning(f"JSON Normalizer {item_normalizer} does not support max_nesting") - pass - json_module = cast(SupportsDataItemNormalizer, import_module(item_normalizer["module"])) - explicit_normalizers["json"] = item_normalizer - return ( - explicit_normalizers, - naming_convention, - json_module.DataItemNormalizer, - ) - - -def naming_from_reference( - names: TNamingConventionReferenceArg, - destination_capabilities: DestinationCapabilitiesContext = None, -) -> NamingConvention: - """Resolves naming convention from reference in `names` and applies max length from `destination_capabilities` - - Reference may be: (1) shorthand name pointing to `dlt.common.normalizers.naming` namespace - (2) a type name which is a module containing `NamingConvention` attribute (3) a type of class deriving from NamingConvention - """ - - def _import_naming(module: str) -> ModuleType: - if "." in module: - # TODO: bump schema engine version and migrate schema. also change the name in TNormalizersConfig from names to naming - if module == "dlt.common.normalizers.names.snake_case": - module = f"{DEFAULT_NAMING_NAMESPACE}.{DEFAULT_NAMING_MODULE}" - # this is full module name - naming_module = import_module(module) - else: - # from known location - try: - naming_module = import_module(f"{DEFAULT_NAMING_NAMESPACE}.{module}") - except ImportError: - # also import local module - naming_module = import_module(module) - return naming_module - - def _get_type(naming_module: ModuleType, cls: str) -> Type[NamingConvention]: - class_: Type[NamingConvention] = getattr(naming_module, cls, None) - if class_ is None: - raise NamingTypeNotFound(naming_module.__name__, cls) - if is_subclass(class_, NamingConvention): - return class_ - raise InvalidNamingType(naming_module.__name__, cls) - - if is_subclass(names, NamingConvention): - class_: Type[NamingConvention] = names # type: ignore[assignment] - elif isinstance(names, ModuleType): - class_ = _get_type(names, "NamingConvention") - elif isinstance(names, str): - try: - class_ = _get_type(_import_naming(names), "NamingConvention") - except ImportError: - parts = names.rsplit(".", 1) - # we have no more options to try - if len(parts) <= 1: - raise UnknownNamingModule(names) - try: - class_ = _get_type(_import_naming(parts[0]), parts[1]) - except UnknownNamingModule: - raise - except ImportError: - raise UnknownNamingModule(names) - else: - raise ValueError(names) - - # get max identifier length - if destination_capabilities: - max_length = min( - destination_capabilities.max_identifier_length, - destination_capabilities.max_column_identifier_length, - ) - else: - max_length = None - - return class_(max_length) - - -def serialize_reference(naming: Optional[TNamingConventionReferenceArg]) -> Optional[str]: - """Serializes generic `naming` reference to importable string.""" - if naming is None: - return naming - if isinstance(naming, str): - return naming - # import reference and use naming to get valid path to type - return get_full_class_name(naming_from_reference(naming)) +DLT_ID_LENGTH_BYTES = int(os.environ.get(known_env.DLT_DLT_ID_LENGTH_BYTES, 10)) def generate_dlt_ids(n_ids: int) -> List[str]: diff --git a/dlt/common/normalizers/configuration.py b/dlt/common/schema/configuration.py similarity index 91% rename from dlt/common/normalizers/configuration.py rename to dlt/common/schema/configuration.py index 6011ba4774..e64dd57494 100644 --- a/dlt/common/normalizers/configuration.py +++ b/dlt/common/schema/configuration.py @@ -7,7 +7,7 @@ @configspec -class NormalizersConfiguration(BaseConfiguration): +class SchemaConfiguration(BaseConfiguration): # always in section __section__: ClassVar[str] = known_sections.SCHEMA diff --git a/dlt/common/schema/exceptions.py b/dlt/common/schema/exceptions.py index 2e75b4b3a1..2b9a2d8cd1 100644 --- a/dlt/common/schema/exceptions.py +++ b/dlt/common/schema/exceptions.py @@ -73,8 +73,8 @@ def __init__(self, schema_name: str, table_name: str, prop_name: str, val1: str, self.val2 = val2 super().__init__( schema_name, - f"Cannot merge partial tables for {table_name} due to property {prop_name}: {val1} !=" - f" {val2}", + f"Cannot merge partial tables into table `{table_name}` due to property `{prop_name}`" + f' with different values: "{val1}" != "{val2}"', ) diff --git a/dlt/common/schema/migrations.py b/dlt/common/schema/migrations.py index b64714ba19..0dcfc2122a 100644 --- a/dlt/common/schema/migrations.py +++ b/dlt/common/schema/migrations.py @@ -1,7 +1,6 @@ from typing import Dict, List, cast from dlt.common.data_types import TDataType -from dlt.common.normalizers.utils import explicit_normalizers from dlt.common.typing import DictStrAny from dlt.common.schema.typing import ( LOADS_TABLE_NAME, @@ -9,11 +8,9 @@ TSimpleRegex, TStoredSchema, TTableSchemaColumns, - TColumnHint, + TColumnDefaultHint, ) from dlt.common.schema.exceptions import SchemaEngineNoUpgradePathException - -from dlt.common.normalizers.utils import import_normalizers from dlt.common.schema.utils import new_table, version_table, loads_table @@ -26,6 +23,8 @@ def migrate_schema(schema_dict: DictStrAny, from_engine: int, to_engine: int) -> schema_dict["excludes"] = [] from_engine = 2 if from_engine == 2 and to_engine > 2: + from dlt.common.schema.normalizers import import_normalizers, explicit_normalizers + # current version of the schema current = cast(TStoredSchema, schema_dict) # add default normalizers and root hash propagation @@ -35,7 +34,7 @@ def migrate_schema(schema_dict: DictStrAny, from_engine: int, to_engine: int) -> "propagation": {"root": {"_dlt_id": "_dlt_root_id"}} } # move settings, convert strings to simple regexes - d_h: Dict[TColumnHint, List[TSimpleRegex]] = schema_dict.pop("hints", {}) + d_h: Dict[TColumnDefaultHint, List[TSimpleRegex]] = schema_dict.pop("hints", {}) for h_k, h_l in d_h.items(): d_h[h_k] = list(map(lambda r: TSimpleRegex("re:" + r), h_l)) p_t: Dict[TSimpleRegex, TDataType] = schema_dict.pop("preferred_types", {}) diff --git a/dlt/common/schema/normalizers.py b/dlt/common/schema/normalizers.py new file mode 100644 index 0000000000..9b2a37e708 --- /dev/null +++ b/dlt/common/schema/normalizers.py @@ -0,0 +1,186 @@ +import os +from importlib import import_module +from types import ModuleType +from typing import Any, Dict, Optional, Type, Tuple, cast + +import dlt +from dlt.common import logger +from dlt.common import known_env +from dlt.common.configuration.inject import with_config +from dlt.common.configuration.specs import known_sections +from dlt.common.schema.configuration import SchemaConfiguration +from dlt.common.normalizers.exceptions import InvalidJsonNormalizer +from dlt.common.normalizers.json import SupportsDataItemNormalizer, DataItemNormalizer +from dlt.common.normalizers.naming import NamingConvention +from dlt.common.normalizers.naming.exceptions import ( + NamingTypeNotFound, + UnknownNamingModule, + InvalidNamingType, +) +from dlt.common.normalizers.typing import ( + TJSONNormalizer, + TNormalizersConfig, + TNamingConventionReferenceArg, +) +from dlt.common.typing import is_subclass +from dlt.common.utils import get_full_class_name + +DEFAULT_NAMING_NAMESPACE = os.environ.get( + known_env.DLT_DEFAULT_NAMING_NAMESPACE, "dlt.common.normalizers.naming" +) +DEFAULT_NAMING_MODULE = os.environ.get(known_env.DLT_DEFAULT_NAMING_MODULE, "snake_case") + + +def _section_for_schema(kwargs: Dict[str, Any]) -> Tuple[str, ...]: + """Uses the schema name to generate dynamic section normalizer settings""" + if schema_name := kwargs.get("schema_name"): + return (known_sections.SOURCES, schema_name) + else: + return (known_sections.SOURCES,) + + +@with_config(spec=SchemaConfiguration, sections=_section_for_schema) # type: ignore[call-overload] +def explicit_normalizers( + naming: TNamingConventionReferenceArg = dlt.config.value, + json_normalizer: TJSONNormalizer = dlt.config.value, + allow_identifier_change_on_table_with_data: bool = None, + schema_name: Optional[str] = None, +) -> TNormalizersConfig: + """Gets explicitly configured normalizers without any defaults or capabilities injection. If `naming` + is a module or a type it will get converted into string form via import. + + If `schema_name` is present, a section ("sources", schema_name, "schema") is used to inject the config + """ + + norm_conf: TNormalizersConfig = {"names": serialize_reference(naming), "json": json_normalizer} + if allow_identifier_change_on_table_with_data is not None: + norm_conf["allow_identifier_change_on_table_with_data"] = ( + allow_identifier_change_on_table_with_data + ) + return norm_conf + + +@with_config +def import_normalizers( + explicit_normalizers: TNormalizersConfig, + default_normalizers: TNormalizersConfig = None, +) -> Tuple[TNormalizersConfig, NamingConvention, Type[DataItemNormalizer[Any]]]: + """Imports the normalizers specified in `normalizers_config` or taken from defaults. Returns the updated config and imported modules. + + `destination_capabilities` are used to get naming convention, max length of the identifier and max nesting level. + """ + # use container to get destination capabilities, do not use config injection to resolve circular dependencies + from dlt.common.destination.capabilities import DestinationCapabilitiesContext + from dlt.common.configuration.container import Container + + destination_capabilities = Container().get(DestinationCapabilitiesContext) + if default_normalizers is None: + default_normalizers = {} + # add defaults to normalizer_config + naming: Optional[TNamingConventionReferenceArg] = explicit_normalizers.get("names") + if naming is None: + if destination_capabilities: + naming = destination_capabilities.naming_convention + if naming is None: + naming = default_normalizers.get("names") or DEFAULT_NAMING_MODULE + # get max identifier length + if destination_capabilities: + max_length = min( + destination_capabilities.max_identifier_length, + destination_capabilities.max_column_identifier_length, + ) + else: + max_length = None + naming_convention = naming_from_reference(naming, max_length) + explicit_normalizers["names"] = serialize_reference(naming) + + item_normalizer = explicit_normalizers.get("json") or default_normalizers.get("json") or {} + item_normalizer.setdefault("module", "dlt.common.normalizers.json.relational") + # if max_table_nesting is set, we need to set the max_table_nesting in the json_normalizer + if destination_capabilities and destination_capabilities.max_table_nesting is not None: + # TODO: this is a hack, we need a better method to do this + from dlt.common.normalizers.json.relational import DataItemNormalizer + + try: + DataItemNormalizer.ensure_this_normalizer(item_normalizer) + item_normalizer.setdefault("config", {}) + item_normalizer["config"]["max_nesting"] = destination_capabilities.max_table_nesting # type: ignore[index] + except InvalidJsonNormalizer: + # not a right normalizer + logger.warning(f"JSON Normalizer {item_normalizer} does not support max_nesting") + pass + json_module = cast(SupportsDataItemNormalizer, import_module(item_normalizer["module"])) + explicit_normalizers["json"] = item_normalizer + return ( + explicit_normalizers, + naming_convention, + json_module.DataItemNormalizer, + ) + + +def naming_from_reference( + names: TNamingConventionReferenceArg, + max_length: Optional[int] = None, +) -> NamingConvention: + """Resolves naming convention from reference in `names` and applies max length if specified + + Reference may be: (1) shorthand name pointing to `dlt.common.normalizers.naming` namespace + (2) a type name which is a module containing `NamingConvention` attribute (3) a type of class deriving from NamingConvention + """ + + def _import_naming(module: str) -> ModuleType: + if "." in module: + # TODO: bump schema engine version and migrate schema. also change the name in TNormalizersConfig from names to naming + if module == "dlt.common.normalizers.names.snake_case": + module = f"{DEFAULT_NAMING_NAMESPACE}.{DEFAULT_NAMING_MODULE}" + # this is full module name + naming_module = import_module(module) + else: + # from known location + try: + naming_module = import_module(f"{DEFAULT_NAMING_NAMESPACE}.{module}") + except ImportError: + # also import local module + naming_module = import_module(module) + return naming_module + + def _get_type(naming_module: ModuleType, cls: str) -> Type[NamingConvention]: + class_: Type[NamingConvention] = getattr(naming_module, cls, None) + if class_ is None: + raise NamingTypeNotFound(naming_module.__name__, cls) + if is_subclass(class_, NamingConvention): + return class_ + raise InvalidNamingType(naming_module.__name__, cls) + + if is_subclass(names, NamingConvention): + class_: Type[NamingConvention] = names # type: ignore[assignment] + elif isinstance(names, ModuleType): + class_ = _get_type(names, "NamingConvention") + elif isinstance(names, str): + try: + class_ = _get_type(_import_naming(names), "NamingConvention") + except ImportError: + parts = names.rsplit(".", 1) + # we have no more options to try + if len(parts) <= 1: + raise UnknownNamingModule(names) + try: + class_ = _get_type(_import_naming(parts[0]), parts[1]) + except UnknownNamingModule: + raise + except ImportError: + raise UnknownNamingModule(names) + else: + raise ValueError(names) + + return class_(max_length) + + +def serialize_reference(naming: Optional[TNamingConventionReferenceArg]) -> Optional[str]: + """Serializes generic `naming` reference to importable string.""" + if naming is None: + return naming + if isinstance(naming, str): + return naming + # import reference and use naming to get valid path to type + return get_full_class_name(naming_from_reference(naming)) diff --git a/dlt/common/schema/schema.py b/dlt/common/schema/schema.py index da9e581637..b0fec4a67a 100644 --- a/dlt/common/schema/schema.py +++ b/dlt/common/schema/schema.py @@ -23,12 +23,10 @@ TDataItem, ) from dlt.common.normalizers import TNormalizersConfig, NamingConvention -from dlt.common.normalizers.utils import explicit_normalizers, import_normalizers from dlt.common.normalizers.json import DataItemNormalizer, TNormalizedRowIterator from dlt.common.schema import utils from dlt.common.data_types import py_type_to_sc_type, coerce_value, TDataType from dlt.common.schema.typing import ( - COLUMN_HINTS, DLT_NAME_PREFIX, SCHEMA_ENGINE_VERSION, LOADS_TABLE_NAME, @@ -46,6 +44,7 @@ TColumnSchema, TColumnProp, TColumnHint, + TColumnDefaultHint, TTypeDetections, TSchemaContractDict, TSchemaContract, @@ -58,8 +57,9 @@ SchemaCorruptedException, TableIdentifiersFrozen, ) -from dlt.common.validation import validate_dict +from dlt.common.schema.normalizers import import_normalizers, explicit_normalizers from dlt.common.schema.exceptions import DataValidationError +from dlt.common.validation import validate_dict DEFAULT_SCHEMA_CONTRACT_MODE: TSchemaContractDict = { @@ -99,7 +99,7 @@ class Schema: # list of preferred types: map regex on columns into types _compiled_preferred_types: List[Tuple[REPattern, TDataType]] # compiled default hints - _compiled_hints: Dict[TColumnHint, Sequence[REPattern]] + _compiled_hints: Dict[TColumnDefaultHint, Sequence[REPattern]] # compiled exclude filters per table _compiled_excludes: Dict[str, Sequence[REPattern]] # compiled include filters per table @@ -387,7 +387,7 @@ def resolve_contract_settings_for_table( tables = self._schema_tables # find root table try: - table = utils.get_top_level_table(tables, table_name) + table = utils.get_root_table(tables, table_name) settings = table["schema_contract"] except KeyError: settings = self._settings.get("schema_contract", {}) @@ -396,14 +396,19 @@ def resolve_contract_settings_for_table( return Schema.expand_schema_contract_settings(settings) def update_table( - self, partial_table: TPartialTableSchema, normalize_identifiers: bool = True + self, + partial_table: TPartialTableSchema, + normalize_identifiers: bool = True, + from_diff: bool = False, ) -> TPartialTableSchema: - """Adds or merges `partial_table` into the schema. Identifiers are normalized by default""" + """Adds or merges `partial_table` into the schema. Identifiers are normalized by default. + `from_diff` + """ + parent_table_name = partial_table.get("parent") if normalize_identifiers: partial_table = utils.normalize_table_identifiers(partial_table, self.naming) table_name = partial_table["name"] - parent_table_name = partial_table.get("parent") # check if parent table present if parent_table_name is not None: if self._schema_tables.get(parent_table_name) is None: @@ -418,10 +423,14 @@ def update_table( table = self._schema_tables.get(table_name) if table is None: # add the whole new table to SchemaTables + assert not from_diff, "Cannot update the whole table from diff" self._schema_tables[table_name] = partial_table else: - # merge tables performing additional checks - partial_table = utils.merge_table(self.name, table, partial_table) + if from_diff: + partial_table = utils.merge_diff(table, partial_table) + else: + # merge tables performing additional checks + partial_table = utils.merge_table(self.name, table, partial_table) self.data_item_normalizer.extend_table(table_name) return partial_table @@ -447,7 +456,9 @@ def drop_tables( result.append(self._schema_tables.pop(table_name)) return result - def filter_row_with_hint(self, table_name: str, hint_type: TColumnHint, row: StrAny) -> StrAny: + def filter_row_with_hint( + self, table_name: str, hint_type: TColumnDefaultHint, row: StrAny + ) -> StrAny: rv_row: DictStrAny = {} column_prop: TColumnProp = utils.hint_to_column_prop(hint_type) try: @@ -459,7 +470,7 @@ def filter_row_with_hint(self, table_name: str, hint_type: TColumnHint, row: Str rv_row[column_name] = row[column_name] except KeyError: for k, v in row.items(): - if self._infer_hint(hint_type, v, k): + if self._infer_hint(hint_type, k): rv_row[k] = v # dicts are ordered and we will return the rows with hints in the same order as they appear in the columns @@ -467,7 +478,7 @@ def filter_row_with_hint(self, table_name: str, hint_type: TColumnHint, row: Str def merge_hints( self, - new_hints: Mapping[TColumnHint, Sequence[TSimpleRegex]], + new_hints: Mapping[TColumnDefaultHint, Sequence[TSimpleRegex]], normalize_identifiers: bool = True, ) -> None: """Merges existing default hints with `new_hints`. Normalizes names in column regexes if possible. Compiles setting at the end @@ -747,6 +758,7 @@ def update_normalizers(self) -> None: def will_update_normalizers(self) -> bool: """Checks if schema has any pending normalizer updates due to configuration or destination capabilities""" + # import desired modules _, to_naming, _ = import_normalizers( explicit_normalizers(schema_name=self._schema_name), self._normalizers_config @@ -765,11 +777,16 @@ def _infer_column( column_schema = TColumnSchema( name=k, data_type=data_type or self._infer_column_type(v, k), - nullable=not self._infer_hint("not_null", v, k), + nullable=not self._infer_hint("not_null", k), ) - for hint in COLUMN_HINTS: + # check other preferred hints that are available + for hint in self._compiled_hints: + # already processed + if hint == "not_null": + continue column_prop = utils.hint_to_column_prop(hint) - hint_value = self._infer_hint(hint, v, k) + hint_value = self._infer_hint(hint, k) + # set only non-default values if not utils.has_default_column_prop_value(column_prop, hint_value): column_schema[column_prop] = hint_value @@ -783,7 +800,7 @@ def _coerce_null_value( """Raises when column is explicitly not nullable""" if col_name in table_columns: existing_column = table_columns[col_name] - if not existing_column.get("nullable", True): + if not utils.is_nullable_column(existing_column): raise CannotCoerceNullException(self.name, table_name, col_name) def _coerce_non_null_value( @@ -872,7 +889,7 @@ def _infer_column_type(self, v: Any, col_name: str, skip_preferred: bool = False preferred_type = self.get_preferred_type(col_name) return preferred_type or mapped_type - def _infer_hint(self, hint_type: TColumnHint, _: Any, col_name: str) -> bool: + def _infer_hint(self, hint_type: TColumnDefaultHint, col_name: str) -> bool: if hint_type in self._compiled_hints: return any(h.search(col_name) for h in self._compiled_hints[hint_type]) else: @@ -880,7 +897,7 @@ def _infer_hint(self, hint_type: TColumnHint, _: Any, col_name: str) -> bool: def _merge_hints( self, - new_hints: Mapping[TColumnHint, Sequence[TSimpleRegex]], + new_hints: Mapping[TColumnDefaultHint, Sequence[TSimpleRegex]], normalize_identifiers: bool = True, ) -> None: """Used by `merge_hints method, does not compile settings at the end""" @@ -968,8 +985,8 @@ def _add_standard_hints(self) -> None: self._settings["detections"] = type_detections def _normalize_default_hints( - self, default_hints: Mapping[TColumnHint, Sequence[TSimpleRegex]] - ) -> Dict[TColumnHint, List[TSimpleRegex]]: + self, default_hints: Mapping[TColumnDefaultHint, Sequence[TSimpleRegex]] + ) -> Dict[TColumnDefaultHint, List[TSimpleRegex]]: """Normalizes the column names in default hints. In case of column names that are regexes, normalization is skipped""" return { hint: [utils.normalize_simple_regex_column(self.naming, regex) for regex in regexes] @@ -1116,7 +1133,6 @@ def _renormalize_schema_identifiers( def _configure_normalizers(self, explicit_normalizers: TNormalizersConfig) -> None: """Gets naming and item normalizer from schema yaml, config providers and destination capabilities and applies them to schema.""" - # import desired modules normalizers_config, to_naming, item_normalizer_class = import_normalizers( explicit_normalizers, self._normalizers_config ) @@ -1136,7 +1152,7 @@ def _reset_schema(self, name: str, normalizers: TNormalizersConfig = None) -> No self._settings: TSchemaSettings = {} self._compiled_preferred_types: List[Tuple[REPattern, TDataType]] = [] - self._compiled_hints: Dict[TColumnHint, Sequence[REPattern]] = {} + self._compiled_hints: Dict[TColumnDefaultHint, Sequence[REPattern]] = {} self._compiled_excludes: Dict[str, Sequence[REPattern]] = {} self._compiled_includes: Dict[str, Sequence[REPattern]] = {} self._type_detections: Sequence[TTypeDetections] = None diff --git a/dlt/common/schema/typing.py b/dlt/common/schema/typing.py index a81e9046a9..c238c15a54 100644 --- a/dlt/common/schema/typing.py +++ b/dlt/common/schema/typing.py @@ -4,9 +4,11 @@ Dict, List, Literal, + NamedTuple, Optional, Sequence, Set, + Tuple, Type, TypedDict, NewType, @@ -36,8 +38,14 @@ TColumnProp = Literal[ "name", + # data type "data_type", + "precision", + "scale", + "timezone", "nullable", + "variant", + # hints "partition", "cluster", "primary_key", @@ -49,10 +57,11 @@ "hard_delete", "dedup_sort", ] -"""Known properties and hints of the column""" -# TODO: merge TColumnHint with TColumnProp +"""All known properties of the column, including name, data type info and hints""" +COLUMN_PROPS: Set[TColumnProp] = set(get_args(TColumnProp)) + TColumnHint = Literal[ - "not_null", + "nullable", "partition", "cluster", "primary_key", @@ -64,9 +73,51 @@ "hard_delete", "dedup_sort", ] -"""Known hints of a column used to declare hint regexes.""" +"""Known hints of a column""" +COLUMN_HINTS: Set[TColumnHint] = set(get_args(TColumnHint)) + + +class TColumnPropInfo(NamedTuple): + name: Union[TColumnProp, str] + defaults: Tuple[Any, ...] = (None,) + is_hint: bool = False + + +_ColumnPropInfos = [ + TColumnPropInfo("name"), + TColumnPropInfo("data_type"), + TColumnPropInfo("precision"), + TColumnPropInfo("scale"), + TColumnPropInfo("timezone", (True, None)), + TColumnPropInfo("nullable", (True, None)), + TColumnPropInfo("variant", (False, None)), + TColumnPropInfo("partition", (False, None)), + TColumnPropInfo("cluster", (False, None)), + TColumnPropInfo("primary_key", (False, None)), + TColumnPropInfo("foreign_key", (False, None)), + TColumnPropInfo("sort", (False, None)), + TColumnPropInfo("unique", (False, None)), + TColumnPropInfo("merge_key", (False, None)), + TColumnPropInfo("root_key", (False, None)), + TColumnPropInfo("hard_delete", (False, None)), + TColumnPropInfo("dedup_sort", (False, None)), + # any x- hint with special settings ie. defaults + TColumnPropInfo("x-active-record-timestamp", (), is_hint=True), # no default values +] + +ColumnPropInfos: Dict[Union[TColumnProp, str], TColumnPropInfo] = { + info.name: info for info in _ColumnPropInfos +} +# verify column props and column hints infos +for hint in COLUMN_HINTS: + assert hint in COLUMN_PROPS, f"Hint {hint} must be a column prop" -TTableFormat = Literal["iceberg", "delta"] +for prop in COLUMN_PROPS: + assert prop in ColumnPropInfos, f"Column {prop} has no info, please define" + if prop in COLUMN_HINTS: + ColumnPropInfos[prop] = ColumnPropInfos[prop]._replace(is_hint=True) + +TTableFormat = Literal["iceberg", "delta", "hive"] TFileFormat = Literal[Literal["preferred"], TLoaderFileFormat] TTypeDetections = Literal[ "timestamp", "iso_timestamp", "iso_date", "large_integer", "hexbytes_to_text", "wei_to_double" @@ -75,20 +126,6 @@ TColumnNames = Union[str, Sequence[str]] """A string representing a column name or a list of""" -# COLUMN_PROPS: Set[TColumnProp] = set(get_args(TColumnProp)) -COLUMN_HINTS: Set[TColumnHint] = set( - [ - "partition", - "cluster", - "primary_key", - "foreign_key", - "sort", - "unique", - "merge_key", - "root_key", - ] -) - class TColumnType(TypedDict, total=False): data_type: Optional[TDataType] @@ -195,13 +232,9 @@ class TMergeDispositionDict(TWriteDispositionDict, total=False): TWriteDispositionConfig = Union[TWriteDisposition, TWriteDispositionDict, TMergeDispositionDict] -# TypedDict that defines properties of a table -class TTableSchema(TTableProcessingHints, total=False): - """TypedDict that defines properties of a table""" - +class _TTableSchemaBase(TTableProcessingHints, total=False): name: Optional[str] description: Optional[str] - write_disposition: Optional[TWriteDisposition] schema_contract: Optional[TSchemaContract] table_sealed: Optional[bool] parent: Optional[str] @@ -212,18 +245,26 @@ class TTableSchema(TTableProcessingHints, total=False): file_format: Optional[TFileFormat] +class TTableSchema(_TTableSchemaBase, total=False): + """TypedDict that defines properties of a table""" + + write_disposition: Optional[TWriteDisposition] + + class TPartialTableSchema(TTableSchema): pass TSchemaTables = Dict[str, TTableSchema] TSchemaUpdate = Dict[str, List[TPartialTableSchema]] +TColumnDefaultHint = Literal["not_null", TColumnHint] +"""Allows using not_null in default hints setting section""" class TSchemaSettings(TypedDict, total=False): schema_contract: Optional[TSchemaContract] detections: Optional[List[TTypeDetections]] - default_hints: Optional[Dict[TColumnHint, List[TSimpleRegex]]] + default_hints: Optional[Dict[TColumnDefaultHint, List[TSimpleRegex]]] preferred_types: Optional[Dict[TSimpleRegex, TDataType]] diff --git a/dlt/common/schema/utils.py b/dlt/common/schema/utils.py index 8b87a7e5fe..4eb147624e 100644 --- a/dlt/common/schema/utils.py +++ b/dlt/common/schema/utils.py @@ -17,12 +17,12 @@ from dlt.common.validation import TCustomValidator, validate_dict_ignoring_xkeys from dlt.common.schema import detections from dlt.common.schema.typing import ( - COLUMN_HINTS, SCHEMA_ENGINE_VERSION, LOADS_TABLE_NAME, SIMPLE_REGEX_PREFIX, VERSION_TABLE_NAME, PIPELINE_STATE_TABLE_NAME, + ColumnPropInfos, TColumnName, TFileFormat, TPartialTableSchema, @@ -36,7 +36,7 @@ TColumnSchema, TColumnProp, TTableFormat, - TColumnHint, + TColumnDefaultHint, TTableSchemaColumns, TTypeDetectionFunc, TTypeDetections, @@ -55,7 +55,6 @@ RE_NON_ALPHANUMERIC_UNDERSCORE = re.compile(r"[^a-zA-Z\d_]") DEFAULT_WRITE_DISPOSITION: TWriteDisposition = "append" -DEFAULT_MERGE_STRATEGY: TLoaderMergeStrategy = "delete-insert" def is_valid_schema_name(name: str) -> bool: @@ -67,6 +66,12 @@ def is_valid_schema_name(name: str) -> bool: ) +def is_nested_table(table: TTableSchema) -> bool: + """Checks if table is a dlt nested table: connected to parent table via row_key - parent_key reference""" + # "parent" table hint indicates NESTED table. + return bool(table.get("parent")) + + def normalize_schema_name(name: str) -> str: """Normalizes schema name by using snake case naming convention. The maximum length is 64 characters""" snake_case = SnakeCase(InvalidSchemaName.MAXIMUM_SCHEMA_NAME_LENGTH) @@ -81,12 +86,6 @@ def apply_defaults(stored_schema: TStoredSchema) -> TStoredSchema: for table_name, table in stored_schema["tables"].items(): # overwrite name table["name"] = table_name - # add default write disposition to root tables - if table.get("parent") is None: - if table.get("write_disposition") is None: - table["write_disposition"] = DEFAULT_WRITE_DISPOSITION - if table.get("resource") is None: - table["resource"] = table_name for column_name in table["columns"]: # add default hints to tables column = table["columns"][column_name] @@ -94,6 +93,12 @@ def apply_defaults(stored_schema: TStoredSchema) -> TStoredSchema: column["name"] = column_name # set column with default # table["columns"][column_name] = column + # add default write disposition to root tables + if not is_nested_table(table): + if table.get("write_disposition") is None: + table["write_disposition"] = DEFAULT_WRITE_DISPOSITION + if table.get("resource") is None: + table["resource"] = table_name return stored_schema @@ -124,15 +129,9 @@ def remove_defaults(stored_schema: TStoredSchema) -> TStoredSchema: def has_default_column_prop_value(prop: str, value: Any) -> bool: """Checks if `value` is a default for `prop`.""" # remove all boolean hints that are False, except "nullable" which is removed when it is True - # TODO: merge column props and hints - if prop in COLUMN_HINTS: - return value in (False, None) - # TODO: type all the hints including default value so those exceptions may be removed - if prop == "nullable": - return value in (True, None) - if prop == "x-active-record-timestamp": - # None is a valid value so it is not a default - return False + if prop in ColumnPropInfos: + return value in ColumnPropInfos[prop].defaults + # for any unknown hint ie. "x-" the defaults are return value in (None, False) @@ -357,14 +356,11 @@ def is_nullable_column(col: TColumnSchemaBase) -> bool: return col.get("nullable", True) -def find_incomplete_columns( - tables: List[TTableSchema], -) -> Iterable[Tuple[str, TColumnSchemaBase, bool]]: - """Yields (table_name, column, nullable) for all incomplete columns in `tables`""" - for table in tables: - for col in table["columns"].values(): - if not is_complete_column(col): - yield table["name"], col, is_nullable_column(col) +def find_incomplete_columns(table: TTableSchema) -> Iterable[Tuple[TColumnSchemaBase, bool]]: + """Yields (column, nullable) for all incomplete columns in `table`""" + for col in table["columns"].values(): + if not is_complete_column(col): + yield col, is_nullable_column(col) def compare_complete_columns(a: TColumnSchema, b: TColumnSchema) -> bool: @@ -431,6 +427,10 @@ def diff_table( * when columns with the same name have different data types * when table links to different parent tables """ + if tab_a["name"] != tab_b["name"]: + raise TablePropertiesConflictException( + schema_name, tab_a["name"], "name", tab_a["name"], tab_b["name"] + ) table_name = tab_a["name"] # check if table properties can be merged if tab_a.get("parent") != tab_b.get("parent"): @@ -476,7 +476,7 @@ def diff_table( partial_table[k] = v # type: ignore # this should not really happen - if tab_a.get("parent") is not None and (resource := tab_b.get("resource")): + if is_nested_table(tab_a) and (resource := tab_b.get("resource")): raise TablePropertiesConflictException( schema_name, table_name, "resource", resource, tab_a.get("parent") ) @@ -500,25 +500,24 @@ def merge_table( schema_name: str, table: TTableSchema, partial_table: TPartialTableSchema ) -> TPartialTableSchema: """Merges "partial_table" into "table". `table` is merged in place. Returns the diff partial table. + `table` and `partial_table` names must be identical. A table diff is generated and applied to `table` + """ + return merge_diff(table, diff_table(schema_name, table, partial_table)) + - `table` and `partial_table` names must be identical. A table diff is generated and applied to `table`: +def merge_diff(table: TTableSchema, table_diff: TPartialTableSchema) -> TPartialTableSchema: + """Merges a table diff `table_diff` into `table`. `table` is merged in place. Returns the diff. * new columns are added, updated columns are replaced from diff * incomplete columns in `table` that got completed in `partial_table` are removed to preserve order * table hints are added or replaced from diff * nothing gets deleted """ - - if table["name"] != partial_table["name"]: - raise TablePropertiesConflictException( - schema_name, table["name"], "name", table["name"], partial_table["name"] - ) - diff = diff_table(schema_name, table, partial_table) # add new columns when all checks passed - updated_columns = merge_columns(table["columns"], diff["columns"]) - table.update(diff) + updated_columns = merge_columns(table["columns"], table_diff["columns"]) + table.update(table_diff) table["columns"] = updated_columns - return diff + return table_diff def normalize_table_identifiers(table: TTableSchema, naming: NamingConvention) -> TTableSchema: @@ -584,7 +583,7 @@ def get_processing_hints(tables: TSchemaTables) -> Dict[str, List[str]]: return hints -def hint_to_column_prop(h: TColumnHint) -> TColumnProp: +def hint_to_column_prop(h: TColumnDefaultHint) -> TColumnProp: if h == "not_null": return "nullable" return h @@ -668,9 +667,8 @@ def get_inherited_table_hint( if hint: return hint - parent = table.get("parent") - if parent: - return get_inherited_table_hint(tables, parent, table_hint_name, allow_none) + if is_nested_table(table): + return get_inherited_table_hint(tables, table.get("parent"), table_hint_name, allow_none) if allow_none: return None @@ -713,13 +711,18 @@ def fill_hints_from_parent_and_clone_table( """Takes write disposition and table format from parent tables if not present""" # make a copy of the schema so modifications do not affect the original document table = deepcopy(table) - # add write disposition if not specified - in child tables + table_name = table["name"] if "write_disposition" not in table: - table["write_disposition"] = get_write_disposition(tables, table["name"]) + table["write_disposition"] = get_write_disposition(tables, table_name) if "table_format" not in table: - table["table_format"] = get_table_format(tables, table["name"]) + if table_format := get_table_format(tables, table_name): + table["table_format"] = table_format if "file_format" not in table: - table["file_format"] = get_file_format(tables, table["name"]) + if file_format := get_file_format(tables, table_name): + table["file_format"] = file_format + if "x-merge-strategy" not in table: + if strategy := get_merge_strategy(tables, table_name): + table["x-merge-strategy"] = strategy # type: ignore[typeddict-unknown-key] return table @@ -736,24 +739,27 @@ def table_schema_has_type_with_precision(table: TTableSchema, _typ: TDataType) - ) -def get_top_level_table(tables: TSchemaTables, table_name: str) -> TTableSchema: - """Finds top level (without parent) of a `table_name` following the ancestry hierarchy.""" +def get_root_table(tables: TSchemaTables, table_name: str) -> TTableSchema: + """Finds root (without parent) of a `table_name` following the nested references (row_key - parent_key).""" table = tables[table_name] - parent = table.get("parent") - if parent: - return get_top_level_table(tables, parent) + if is_nested_table(table): + return get_root_table(tables, table.get("parent")) return table -def get_child_tables(tables: TSchemaTables, table_name: str) -> List[TTableSchema]: - """Get child tables for table name and return a list of tables ordered by ancestry so the child tables are always after their parents""" +def get_nested_tables(tables: TSchemaTables, table_name: str) -> List[TTableSchema]: + """Get nested tables for table name and return a list of tables ordered by ancestry so the nested tables are always after their parents + + Note that this function follows only NESTED TABLE reference typically expressed on _dlt_parent_id (PARENT_KEY) to _dlt_id (ROW_KEY). + TABLE REFERENCES (foreign_key - primary_key) are not followed. + """ chain: List[TTableSchema] = [] def _child(t: TTableSchema) -> None: name = t["name"] chain.append(t) for candidate in tables.values(): - if candidate.get("parent") == name: + if is_nested_table(candidate) and candidate.get("parent") == name: _child(candidate) _child(tables[table_name]) @@ -771,7 +777,7 @@ def group_tables_by_resource( resource = table.get("resource") if resource and (pattern is None or pattern.match(resource)): resource_tables = result.setdefault(resource, []) - resource_tables.extend(get_child_tables(tables, table["name"])) + resource_tables.extend(get_nested_tables(tables, table["name"])) return result @@ -866,28 +872,33 @@ def new_table( "name": table_name, "columns": {} if columns is None else {c["name"]: c for c in columns}, } + + if write_disposition: + table["write_disposition"] = write_disposition + if resource: + table["resource"] = resource + if schema_contract is not None: + table["schema_contract"] = schema_contract + if table_format: + table["table_format"] = table_format + if file_format: + table["file_format"] = file_format if parent_table_name: table["parent"] = parent_table_name - assert write_disposition is None - assert resource is None - assert schema_contract is None else: - # set write disposition only for root tables - table["write_disposition"] = write_disposition or DEFAULT_WRITE_DISPOSITION - table["resource"] = resource or table_name - if schema_contract is not None: - table["schema_contract"] = schema_contract - if table_format: - table["table_format"] = table_format - if file_format: - table["file_format"] = file_format + # set only for root tables + if not write_disposition: + # set write disposition only for root tables + table["write_disposition"] = DEFAULT_WRITE_DISPOSITION + if not resource: + table["resource"] = table_name + if validate_schema: validate_dict_ignoring_xkeys( spec=TColumnSchema, doc=table["columns"], path=f"new_table/{table_name}", ) - return table @@ -916,7 +927,7 @@ def new_column( return column -def default_hints() -> Dict[TColumnHint, List[TSimpleRegex]]: +def default_hints() -> Dict[TColumnDefaultHint, List[TSimpleRegex]]: return None diff --git a/dlt/common/storages/load_storage.py b/dlt/common/storages/load_storage.py index 8ac1d74e9a..076615fa5b 100644 --- a/dlt/common/storages/load_storage.py +++ b/dlt/common/storages/load_storage.py @@ -5,7 +5,7 @@ from dlt.common.json import json from dlt.common.configuration import known_sections from dlt.common.configuration.inject import with_config -from dlt.common.destination import ALL_SUPPORTED_FILE_FORMATS, TLoaderFileFormat +from dlt.common.destination import LOADER_FILE_FORMATS, TLoaderFileFormat from dlt.common.configuration.accessors import config from dlt.common.schema import TSchemaTables from dlt.common.storages.file_storage import FileStorage @@ -46,7 +46,7 @@ class LoadStorage(VersionedStorage): LOADED_FOLDER = "loaded" # folder to keep the loads that were completely processed NEW_PACKAGES_FOLDER = "new" # folder where new packages are created - ALL_SUPPORTED_FILE_FORMATS = ALL_SUPPORTED_FILE_FORMATS + ALL_SUPPORTED_FILE_FORMATS = LOADER_FILE_FORMATS @with_config(spec=LoadStorageConfiguration, sections=(known_sections.LOAD,)) def __init__( diff --git a/dlt/common/warnings.py b/dlt/common/warnings.py index 9c62c69bf8..95d5a19f08 100644 --- a/dlt/common/warnings.py +++ b/dlt/common/warnings.py @@ -39,7 +39,8 @@ def __init__( if isinstance(expected_due, semver.VersionInfo) else semver.parse_version_info(expected_due) ) - self.expected_due = expected_due if expected_due is not None else self.since.bump_minor() + # we deprecate across major version since 1.0.0 + self.expected_due = expected_due if expected_due is not None else self.since.bump_major() def __str__(self) -> str: message = ( @@ -57,6 +58,15 @@ def __init__(self, message: str, *args: typing.Any, expected_due: VersionString ) +class Dlt100DeprecationWarning(DltDeprecationWarning): + V100 = semver.parse_version_info("1.0.0") + + def __init__(self, message: str, *args: typing.Any, expected_due: VersionString = None) -> None: + super().__init__( + message, *args, since=Dlt100DeprecationWarning.V100, expected_due=expected_due + ) + + # show dlt deprecations once warnings.simplefilter("once", DltDeprecationWarning) diff --git a/dlt/destinations/impl/athena/athena.py b/dlt/destinations/impl/athena/athena.py index c4a9bab212..04078dd510 100644 --- a/dlt/destinations/impl/athena/athena.py +++ b/dlt/destinations/impl/athena/athena.py @@ -17,7 +17,6 @@ import re from contextlib import contextmanager -from fsspec import AbstractFileSystem from pendulum.datetime import DateTime, Date from datetime import datetime # noqa: I251 @@ -33,19 +32,15 @@ ) from dlt.common import logger -from dlt.common.exceptions import TerminalValueError -from dlt.common.utils import uniq_id, without_none -from dlt.common.schema import TColumnSchema, Schema, TTableSchema +from dlt.common.utils import uniq_id +from dlt.common.schema import TColumnSchema, Schema from dlt.common.schema.typing import ( - TTableSchema, TColumnType, TTableFormat, TSortOrder, ) -from dlt.common.schema.utils import table_schema_has_type -from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import LoadJob -from dlt.common.destination.reference import FollowupJobRequest, SupportsStagingDestination +from dlt.common.destination import DestinationCapabilitiesContext, PreparedTableSchema +from dlt.common.destination.reference import FollowupJobRequest, SupportsStagingDestination, LoadJob from dlt.common.data_writers.escape import escape_hive_identifier from dlt.destinations.sql_jobs import SqlStagingCopyFollowupJob, SqlMergeFollowupJob @@ -54,7 +49,6 @@ DatabaseTerminalException, DatabaseTransientException, DatabaseUndefinedRelation, - LoadJobTerminalException, ) from dlt.destinations.sql_client import ( SqlClientBase, @@ -63,73 +57,13 @@ raise_open_connection_error, ) from dlt.destinations.typing import DBApiCursor -from dlt.destinations.job_client_impl import SqlJobClientWithStaging +from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs, FinalizedLoadJob from dlt.destinations.impl.athena.configuration import AthenaClientConfiguration -from dlt.destinations.type_mapping import TypeMapper from dlt.destinations import path_utils from dlt.destinations.impl.athena.athena_adapter import PARTITION_HINT -class AthenaTypeMapper(TypeMapper): - sct_to_unbound_dbt = { - "complex": "string", - "text": "string", - "double": "double", - "bool": "boolean", - "date": "date", - "timestamp": "timestamp", - "bigint": "bigint", - "binary": "binary", - "time": "string", - } - - sct_to_dbt = {"decimal": "decimal(%i,%i)", "wei": "decimal(%i,%i)"} - - dbt_to_sct = { - "varchar": "text", - "double": "double", - "boolean": "bool", - "date": "date", - "timestamp": "timestamp", - "bigint": "bigint", - "binary": "binary", - "varbinary": "binary", - "decimal": "decimal", - "tinyint": "bigint", - "smallint": "bigint", - "int": "bigint", - } - - def __init__(self, capabilities: DestinationCapabilitiesContext): - super().__init__(capabilities) - - def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: - precision = column.get("precision") - table_format = table.get("table_format") - if precision is None: - return "bigint" - if precision <= 8: - return "int" if table_format == "iceberg" else "tinyint" - elif precision <= 16: - return "int" if table_format == "iceberg" else "smallint" - elif precision <= 32: - return "int" - elif precision <= 64: - return "bigint" - raise TerminalValueError( - f"bigint with {precision} bits precision cannot be mapped into athena integer type" - ) - - def from_db_type( - self, db_type: str, precision: Optional[int], scale: Optional[int] - ) -> TColumnType: - for key, val in self.dbt_to_sct.items(): - if db_type.startswith(key): - return without_none(dict(data_type=val, precision=precision, scale=scale)) # type: ignore[return-value] - return dict(data_type=None) - - # add a formatter for pendulum to be used by pyathen dbapi def _format_pendulum_datetime(formatter: Formatter, escaper: Callable[[str], str], val: Any) -> Any: # copied from https://github.com/laughingman7743/PyAthena/blob/f4b21a0b0f501f5c3504698e25081f491a541d4e/pyathena/formatter.py#L114 @@ -165,7 +99,9 @@ class AthenaMergeJob(SqlMergeFollowupJob): def _new_temp_table_name(cls, name_prefix: str, sql_client: SqlClientBase[Any]) -> str: # reproducible name so we know which table to drop with sql_client.with_staging_dataset(): - return sql_client.make_qualified_table_name(name_prefix) + return sql_client.make_qualified_table_name( + cls._shorten_table_name(name_prefix, sql_client) + ) @classmethod def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: @@ -366,7 +302,7 @@ def execute_query(self, query: AnyStr, *args: Any, **kwargs: Any) -> Iterator[DB yield DBApiCursorImpl(cursor) # type: ignore -class AthenaClient(SqlJobClientWithStaging, SupportsStagingDestination): +class AthenaClient(SqlJobClientWithStagingDataset, SupportsStagingDestination): def __init__( self, schema: Schema, @@ -391,7 +327,7 @@ def __init__( super().__init__(schema, config, sql_client) self.sql_client: AthenaSQLClient = sql_client # type: ignore self.config: AthenaClientConfiguration = config - self.type_mapper = AthenaTypeMapper(self.capabilities) + self.type_mapper = self.capabilities.get_type_mapper() def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: # only truncate tables in iceberg mode @@ -401,11 +337,11 @@ def initialize_storage(self, truncate_tables: Iterable[str] = None) -> None: def _from_db_type( self, hive_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: - return self.type_mapper.from_db_type(hive_t, precision, scale) + return self.type_mapper.from_destination_type(hive_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: return ( - f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_db_type(c, table)}" + f"{self.sql_client.escape_ddl_identifier(c['name'])} {self.type_mapper.to_destination_type(c, table)}" ) def _iceberg_partition_clause(self, partition_hints: Optional[Dict[str, str]]) -> str: @@ -428,15 +364,15 @@ def _get_table_update_sql( # for the system tables we need to create empty iceberg tables to be able to run, DELETE and UPDATE queries # or if we are in iceberg mode, we create iceberg tables for all tables - table = self.prepare_load_table(table_name, self.in_staging_mode) - - is_iceberg = self._is_iceberg_table(table) or table.get("write_disposition", None) == "skip" + table = self.prepare_load_table(table_name) + # do not create iceberg tables on staging dataset + create_iceberg = self._is_iceberg_table(table, self.in_staging_dataset_mode) columns = ", ".join([self._get_column_def_sql(c, table) for c in new_columns]) # create unique tag for iceberg table so it is never recreated in the same folder # athena requires some kind of special cleaning (or that is a bug) so we cannot refresh # iceberg tables without it - location_tag = uniq_id(6) if is_iceberg else "" + location_tag = uniq_id(6) if create_iceberg else "" # this will fail if the table prefix is not properly defined table_prefix = self.table_prefix_layout.format(table_name=table_name + location_tag) location = f"{bucket}/{dataset}/{table_prefix}" @@ -447,7 +383,7 @@ def _get_table_update_sql( # alter table to add new columns at the end sql.append(f"""ALTER TABLE {qualified_table_name} ADD COLUMNS ({columns});""") else: - if is_iceberg: + if create_iceberg: partition_clause = self._iceberg_partition_clause( cast(Optional[Dict[str, str]], table.get(PARTITION_HINT)) ) @@ -469,28 +405,22 @@ def _get_table_update_sql( return sql def create_load_job( - self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: """Starts SqlLoadJob for files ending with .sql or returns None to let derived classes to handle their specific jobs""" - if table_schema_has_type(table, "time"): - raise LoadJobTerminalException( - file_path, - "Athena cannot load TIME columns from parquet tables. Please convert" - " `datetime.time` objects in your data to `str` or `datetime.datetime`.", - ) job = super().create_load_job(table, file_path, load_id, restore) if not job: job = ( FinalizedLoadJobWithFollowupJobs(file_path) - if self._is_iceberg_table(self.prepare_load_table(table["name"])) + if self._is_iceberg_table(table) else FinalizedLoadJob(file_path) ) return job def _create_append_followup_jobs( - self, table_chain: Sequence[TTableSchema] + self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: - if self._is_iceberg_table(self.prepare_load_table(table_chain[0]["name"])): + if self._is_iceberg_table(table_chain[0]): return [ SqlStagingCopyFollowupJob.from_table_chain( table_chain, self.sql_client, {"replace": False} @@ -499,9 +429,9 @@ def _create_append_followup_jobs( return super()._create_append_followup_jobs(table_chain) def _create_replace_followup_jobs( - self, table_chain: Sequence[TTableSchema] + self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: - if self._is_iceberg_table(self.prepare_load_table(table_chain[0]["name"])): + if self._is_iceberg_table(table_chain[0]): return [ SqlStagingCopyFollowupJob.from_table_chain( table_chain, self.sql_client, {"replace": True} @@ -510,46 +440,43 @@ def _create_replace_followup_jobs( return super()._create_replace_followup_jobs(table_chain) def _create_merge_followup_jobs( - self, table_chain: Sequence[TTableSchema] + self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: return [AthenaMergeJob.from_table_chain(table_chain, self.sql_client)] - def _is_iceberg_table(self, table: TTableSchema) -> bool: + def _is_iceberg_table( + self, table: PreparedTableSchema, is_staging_dataset: bool = False + ) -> bool: table_format = table.get("table_format") - return table_format == "iceberg" + # all dlt tables that are not loaded via files are iceberg tables, no matter if they are on staging or regular dataset + # all other iceberg tables are HIVE (external) tables on staging dataset + table_format_iceberg = table_format == "iceberg" or ( + self.config.force_iceberg and table_format is None + ) + return (table_format_iceberg and not is_staging_dataset) or table[ + "write_disposition" + ] == "skip" - def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: + def should_load_data_to_staging_dataset(self, table_name: str) -> bool: # all iceberg tables need staging - if self._is_iceberg_table(self.prepare_load_table(table["name"])): + table = self.prepare_load_table(table_name) + if self._is_iceberg_table(table): return True - return super().should_load_data_to_staging_dataset(table) + return super().should_load_data_to_staging_dataset(table_name) - def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: # on athena we only truncate replace tables that are not iceberg - table = self.prepare_load_table(table["name"]) - if table["write_disposition"] == "replace" and not self._is_iceberg_table( - self.prepare_load_table(table["name"]) - ): + table = self.prepare_load_table(table_name) + if table["write_disposition"] == "replace" and not self._is_iceberg_table(table): return True return False - def should_load_data_to_staging_dataset_on_staging_destination( - self, table: TTableSchema - ) -> bool: + def should_load_data_to_staging_dataset_on_staging_destination(self, table_name: str) -> bool: """iceberg table data goes into staging on staging destination""" - if self._is_iceberg_table(self.prepare_load_table(table["name"])): + table = self.prepare_load_table(table_name) + if self._is_iceberg_table(table): return True - return super().should_load_data_to_staging_dataset_on_staging_destination(table) - - def prepare_load_table( - self, table_name: str, prepare_for_staging: bool = False - ) -> TTableSchema: - table = super().prepare_load_table(table_name, prepare_for_staging) - if self.config.force_iceberg: - table["table_format"] = "iceberg" - if prepare_for_staging and table.get("table_format", None) == "iceberg": - table.pop("table_format") - return table + return super().should_load_data_to_staging_dataset_on_staging_destination(table_name) @staticmethod def is_dbapi_exception(ex: Exception) -> bool: diff --git a/dlt/destinations/impl/athena/athena_adapter.py b/dlt/destinations/impl/athena/athena_adapter.py index 50f7abc54a..426c2ca1b8 100644 --- a/dlt/destinations/impl/athena/athena_adapter.py +++ b/dlt/destinations/impl/athena/athena_adapter.py @@ -1,9 +1,5 @@ -from typing import Any, Optional, Dict, Protocol, Sequence, Union, Final +from typing import Any, Dict, Sequence, Union, Final -from dateutil import parser - -from dlt.common.pendulum import timezone -from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns, TColumnSchema from dlt.destinations.utils import get_resource_for_adapter from dlt.extract import DltResource from dlt.extract.items import TTableHintTemplate diff --git a/dlt/destinations/impl/athena/configuration.py b/dlt/destinations/impl/athena/configuration.py index 59dfeee4ec..8a0f14b4cc 100644 --- a/dlt/destinations/impl/athena/configuration.py +++ b/dlt/destinations/impl/athena/configuration.py @@ -1,9 +1,12 @@ import dataclasses from typing import ClassVar, Final, List, Optional +import warnings +from dlt.common import logger from dlt.common.configuration import configspec from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration from dlt.common.configuration.specs import AwsCredentials +from dlt.common.warnings import Dlt100DeprecationWarning @configspec @@ -13,14 +16,24 @@ class AthenaClientConfiguration(DestinationClientDwhWithStagingConfiguration): credentials: AwsCredentials = None athena_work_group: Optional[str] = None aws_data_catalog: Optional[str] = "awsdatacatalog" - supports_truncate_command: bool = False - force_iceberg: Optional[bool] = False + force_iceberg: Optional[bool] = None __config_gen_annotations__: ClassVar[List[str]] = ["athena_work_group"] + def on_resolved(self) -> None: + if self.force_iceberg is not None: + warnings.warn( + "The `force_iceberg` is deprecated.If you upgraded dlt on existing pipeline and you" + " have data already loaded, please keep this flag to make sure your data is" + " consistent.If you are creating a new dataset and no data was loaded, please set" + " `table_format='iceberg`` on your resources explicitly.", + Dlt100DeprecationWarning, + stacklevel=1, + ) + def __str__(self) -> str: """Return displayable destination location""" if self.staging_config: - return str(self.staging_config.credentials) + return f"{self.staging_config} on {self.aws_data_catalog}" else: return "[no staging set]" diff --git a/dlt/destinations/impl/athena/factory.py b/dlt/destinations/impl/athena/factory.py index 07d784ed49..d027f14dfb 100644 --- a/dlt/destinations/impl/athena/factory.py +++ b/dlt/destinations/impl/athena/factory.py @@ -1,5 +1,6 @@ import typing as t +from dlt.common.data_types.typing import TDataType from dlt.common.destination import Destination, DestinationCapabilitiesContext from dlt.common.configuration.specs import AwsCredentials from dlt.common.data_writers.escape import ( @@ -7,13 +8,101 @@ format_bigquery_datetime_literal, ) from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.destination.typing import PreparedTableSchema +from dlt.common.exceptions import TerminalValueError +from dlt.common.schema.typing import TColumnSchema, TColumnType, TLoaderMergeStrategy, TTableSchema +from dlt.common.typing import TLoaderFileFormat +from dlt.common.utils import without_none +from dlt.destinations.type_mapping import TypeMapperImpl from dlt.destinations.impl.athena.configuration import AthenaClientConfiguration if t.TYPE_CHECKING: from dlt.destinations.impl.athena.athena import AthenaClient +def athena_merge_strategies_selector( + supported_merge_strategies: t.Sequence[TLoaderMergeStrategy], + /, + *, + table_schema: TTableSchema, +) -> t.Sequence[TLoaderMergeStrategy]: + if table_schema.get("table_format") == "iceberg": + return supported_merge_strategies + else: + return [] + + +class AthenaTypeMapper(TypeMapperImpl): + sct_to_unbound_dbt = { + "complex": "string", + "text": "string", + "double": "double", + "bool": "boolean", + "date": "date", + "timestamp": "timestamp", + "bigint": "bigint", + "binary": "binary", + "time": "string", + } + + sct_to_dbt = {"decimal": "decimal(%i,%i)", "wei": "decimal(%i,%i)"} + + dbt_to_sct = { + "varchar": "text", + "double": "double", + "boolean": "bool", + "date": "date", + "timestamp": "timestamp", + "bigint": "bigint", + "binary": "binary", + "varbinary": "binary", + "decimal": "decimal", + "tinyint": "bigint", + "smallint": "bigint", + "int": "bigint", + } + + def ensure_supported_type( + self, + column: TColumnSchema, + table: PreparedTableSchema, + loader_file_format: TLoaderFileFormat, + ) -> None: + # TIME is not supported for parquet on Athena + if loader_file_format == "parquet" and column["data_type"] == "time": + raise TerminalValueError( + "Please convert `datetime.time` objects in your data to `str` or" + " `datetime.datetime`.", + "time", + ) + + def to_db_integer_type(self, column: TColumnSchema, table: PreparedTableSchema = None) -> str: + precision = column.get("precision") + table_format = table.get("table_format") + if precision is None: + return "bigint" + if precision <= 8: + return "int" if table_format == "iceberg" else "tinyint" + elif precision <= 16: + return "int" if table_format == "iceberg" else "smallint" + elif precision <= 32: + return "int" + elif precision <= 64: + return "bigint" + raise TerminalValueError( + f"bigint with {precision} bits precision cannot be mapped into athena integer type" + ) + + def from_destination_type( + self, db_type: str, precision: t.Optional[int], scale: t.Optional[int] + ) -> TColumnType: + for key, val in self.dbt_to_sct.items(): + if db_type.startswith(key): + return without_none(dict(data_type=val, precision=precision, scale=scale)) # type: ignore[return-value] + return dict(data_type=None) + + class athena(Destination[AthenaClientConfiguration, "AthenaClient"]): spec = AthenaClientConfiguration @@ -22,9 +111,11 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: # athena only supports loading from staged files on s3 for now caps.preferred_loader_file_format = None caps.supported_loader_file_formats = [] - caps.supported_table_formats = ["iceberg"] + caps.supported_table_formats = ["iceberg", "hive"] caps.preferred_staging_file_format = "parquet" - caps.supported_staging_file_formats = ["parquet", "jsonl"] + caps.supported_staging_file_formats = ["parquet"] + caps.type_mapper = AthenaTypeMapper + # athena is storing all identifiers in lower case and is case insensitive # it also uses lower case in all the queries # https://docs.aws.amazon.com/athena/latest/ug/tables-databases-columns-names.html @@ -47,6 +138,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.timestamp_precision = 3 caps.supports_truncate_command = False caps.supported_merge_strategies = ["delete-insert", "upsert", "scd2"] + caps.merge_strategies_selector = athena_merge_strategies_selector return caps @property @@ -61,7 +153,6 @@ def __init__( credentials: t.Union[AwsCredentials, t.Dict[str, t.Any], t.Any] = None, athena_work_group: t.Optional[str] = None, aws_data_catalog: t.Optional[str] = "awsdatacatalog", - force_iceberg: bool = False, destination_name: t.Optional[str] = None, environment: t.Optional[str] = None, **kwargs: t.Any, @@ -75,7 +166,6 @@ def __init__( credentials: AWS credentials to connect to the Athena database. athena_work_group: Athena work group to use aws_data_catalog: Athena data catalog to use - force_iceberg: Force iceberg tables **kwargs: Additional arguments passed to the destination config """ super().__init__( @@ -83,7 +173,6 @@ def __init__( credentials=credentials, athena_work_group=athena_work_group, aws_data_catalog=aws_data_catalog, - force_iceberg=force_iceberg, destination_name=destination_name, environment=environment, **kwargs, diff --git a/dlt/destinations/impl/bigquery/bigquery.py b/dlt/destinations/impl/bigquery/bigquery.py index 9bc555bd0d..5bc7a64e7d 100644 --- a/dlt/destinations/impl/bigquery/bigquery.py +++ b/dlt/destinations/impl/bigquery/bigquery.py @@ -13,7 +13,7 @@ from dlt.common import logger from dlt.common.runtime.signals import sleep from dlt.common.json import json -from dlt.common.destination import DestinationCapabilitiesContext +from dlt.common.destination import DestinationCapabilitiesContext, PreparedTableSchema from dlt.common.destination.reference import ( HasFollowupJobs, FollowupJobRequest, @@ -23,14 +23,11 @@ LoadJob, ) from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns -from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat +from dlt.common.schema.typing import TColumnType from dlt.common.schema.utils import get_inherited_table_hint -from dlt.common.schema.utils import table_schema_has_type -from dlt.common.storages.file_storage import FileStorage from dlt.common.storages.load_package import destination_state from dlt.common.typing import DictStrAny from dlt.destinations.job_impl import DestinationJsonlLoadJob, DestinationParquetLoadJob -from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.exceptions import ( DatabaseTransientException, DatabaseUndefinedRelation, @@ -47,64 +44,13 @@ ROUND_HALF_EVEN_HINT, ROUND_HALF_AWAY_FROM_ZERO_HINT, TABLE_EXPIRATION_HINT, + should_autodetect_schema, ) from dlt.destinations.impl.bigquery.configuration import BigQueryClientConfiguration from dlt.destinations.impl.bigquery.sql_client import BigQuerySqlClient, BQ_TERMINAL_REASONS -from dlt.destinations.job_client_impl import SqlJobClientWithStaging +from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.sql_jobs import SqlMergeFollowupJob -from dlt.destinations.type_mapping import TypeMapper -from dlt.destinations.utils import parse_db_data_type_str_with_precision - - -class BigQueryTypeMapper(TypeMapper): - sct_to_unbound_dbt = { - "complex": "JSON", - "text": "STRING", - "double": "FLOAT64", - "bool": "BOOL", - "date": "DATE", - "timestamp": "TIMESTAMP", - "bigint": "INT64", - "binary": "BYTES", - "wei": "BIGNUMERIC", # non-parametrized should hold wei values - "time": "TIME", - } - - sct_to_dbt = { - "text": "STRING(%i)", - "binary": "BYTES(%i)", - } - - dbt_to_sct = { - "STRING": "text", - "FLOAT64": "double", - "BOOL": "bool", - "DATE": "date", - "TIMESTAMP": "timestamp", - "INT64": "bigint", - "BYTES": "binary", - "NUMERIC": "decimal", - "BIGNUMERIC": "decimal", - "JSON": "complex", - "TIME": "time", - } - - def to_db_decimal_type(self, column: TColumnSchema) -> str: - # Use BigQuery's BIGNUMERIC for large precision decimals - precision, scale = self.decimal_precision(column.get("precision"), column.get("scale")) - if precision > 38 or scale > 9: - return "BIGNUMERIC(%i,%i)" % (precision, scale) - return "NUMERIC(%i,%i)" % (precision, scale) - - # noinspection PyTypeChecker,PydanticTypeChecker - def from_db_type( - self, db_type: str, precision: Optional[int], scale: Optional[int] - ) -> TColumnType: - # precision is present in the type name - if db_type == "BIGNUMERIC": - return dict(data_type="wei") - return super().from_db_type(*parse_db_data_type_str_with_precision(db_type)) class BigQueryLoadJob(RunnableLoadJob, HasFollowupJobs): @@ -212,7 +158,7 @@ def gen_key_table_clauses( return sql -class BigQueryClient(SqlJobClientWithStaging, SupportsStagingDestination): +class BigQueryClient(SqlJobClientWithStagingDataset, SupportsStagingDestination): def __init__( self, schema: Schema, @@ -232,15 +178,15 @@ def __init__( super().__init__(schema, config, sql_client) self.config: BigQueryClientConfiguration = config self.sql_client: BigQuerySqlClient = sql_client # type: ignore - self.type_mapper = BigQueryTypeMapper(self.capabilities) + self.type_mapper = self.capabilities.get_type_mapper() def _create_merge_followup_jobs( - self, table_chain: Sequence[TTableSchema] + self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: return [BigQueryMergeJob.from_table_chain(table_chain, self.sql_client)] def create_load_job( - self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: job = super().create_load_job(table, file_path, load_id) @@ -283,10 +229,10 @@ def _get_table_update_sql( ) -> List[str]: # return empty columns which will skip table CREATE or ALTER # to let BigQuery autodetect table from data - if self._should_autodetect_schema(table_name): + table = self.prepare_load_table(table_name) + if should_autodetect_schema(table): return [] - table: Optional[TTableSchema] = self.prepare_load_table(table_name) sql = super()._get_table_update_sql(table_name, new_columns, generate_alter) canonical_name = self.sql_client.make_qualified_table_name(table_name) @@ -354,17 +300,23 @@ def _get_table_update_sql( return sql - def prepare_load_table( - self, table_name: str, prepare_for_staging: bool = False - ) -> Optional[TTableSchema]: - table = super().prepare_load_table(table_name, prepare_for_staging) - if table_name in self.schema.data_table_names(): + def prepare_load_table(self, table_name: str) -> Optional[PreparedTableSchema]: + table = super().prepare_load_table(table_name) + if table_name not in self.schema.dlt_table_names(): if TABLE_DESCRIPTION_HINT not in table: table[TABLE_DESCRIPTION_HINT] = ( # type: ignore[name-defined, typeddict-unknown-key, unused-ignore] get_inherited_table_hint( self.schema.tables, table_name, TABLE_DESCRIPTION_HINT, allow_none=True ) ) + if AUTODETECT_SCHEMA_HINT not in table: + table[AUTODETECT_SCHEMA_HINT] = ( # type: ignore[typeddict-unknown-key] + get_inherited_table_hint( + self.schema.tables, table_name, AUTODETECT_SCHEMA_HINT, allow_none=True + ) + or self.config.autodetect_schema + ) + return table def get_storage_tables( @@ -417,10 +369,10 @@ def _get_info_schema_columns_query( return query, folded_table_names - def _get_column_def_sql(self, column: TColumnSchema, table: TTableSchema = None) -> str: + def _get_column_def_sql(self, column: TColumnSchema, table: PreparedTableSchema = None) -> str: name = self.sql_client.escape_column_name(column["name"]) column_def_sql = ( - f"{name} {self.type_mapper.to_db_type(column, table)} {self._gen_not_null(column.get('nullable', True))}" + f"{name} {self.type_mapper.to_destination_type(column, table)} {self._gen_not_null(column.get('nullable', True))}" ) if column.get(ROUND_HALF_EVEN_HINT, False): column_def_sql += " OPTIONS (rounding_mode='ROUND_HALF_EVEN')" @@ -428,7 +380,7 @@ def _get_column_def_sql(self, column: TColumnSchema, table: TTableSchema = None) column_def_sql += " OPTIONS (rounding_mode='ROUND_HALF_AWAY_FROM_ZERO')" return column_def_sql - def _create_load_job(self, table: TTableSchema, file_path: str) -> bigquery.LoadJob: + def _create_load_job(self, table: PreparedTableSchema, file_path: str) -> bigquery.LoadJob: # append to table for merge loads (append to stage) and regular appends. table_name = table["name"] @@ -457,19 +409,12 @@ def _create_load_job(self, table: TTableSchema, file_path: str) -> bigquery.Load ignore_unknown_values=False, max_bad_records=0, ) - if self._should_autodetect_schema(table_name): + if should_autodetect_schema(table): # allow BigQuery to infer and evolve the schema, note that dlt is not # creating such tables at all job_config.autodetect = True job_config.schema_update_options = bigquery.SchemaUpdateOption.ALLOW_FIELD_ADDITION job_config.create_disposition = bigquery.CreateDisposition.CREATE_IF_NEEDED - elif ext == "parquet" and table_schema_has_type(table, "complex"): - # if table contains complex types, we cannot load with parquet - raise LoadJobTerminalException( - file_path, - "Bigquery cannot load into JSON data type from parquet. Enable autodetect_schema in" - " config or via BigQuery adapter or use jsonl format instead.", - ) if bucket_path: return self.sql_client.native_connection.load_table_from_uri( @@ -496,14 +441,9 @@ def _retrieve_load_job(self, file_path: str) -> bigquery.LoadJob: def _from_db_type( self, bq_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: - return self.type_mapper.from_db_type(bq_t, precision, scale) - - def _should_autodetect_schema(self, table_name: str) -> bool: - return get_inherited_table_hint( - self.schema._schema_tables, table_name, AUTODETECT_SCHEMA_HINT, allow_none=True - ) or (self.config.autodetect_schema and table_name not in self.schema.dlt_table_names()) + return self.type_mapper.from_destination_type(bq_t, precision, scale) - def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: return self.config.truncate_tables_on_staging_destination_before_load diff --git a/dlt/destinations/impl/bigquery/bigquery_adapter.py b/dlt/destinations/impl/bigquery/bigquery_adapter.py index 55fe1b6b74..ce4a455da0 100644 --- a/dlt/destinations/impl/bigquery/bigquery_adapter.py +++ b/dlt/destinations/impl/bigquery/bigquery_adapter.py @@ -2,6 +2,7 @@ from dateutil import parser +from dlt.common.destination import PreparedTableSchema from dlt.common.pendulum import timezone from dlt.common.schema.typing import ( TColumnNames, @@ -174,3 +175,8 @@ def bigquery_adapter( " specified." ) return resource + + +def should_autodetect_schema(table: PreparedTableSchema) -> bool: + """Tells if schema should be auto detected for a given prepared `table`""" + return table.get(AUTODETECT_SCHEMA_HINT, False) # type: ignore[return-value] diff --git a/dlt/destinations/impl/bigquery/factory.py b/dlt/destinations/impl/bigquery/factory.py index 34dd1790ae..7a2517e400 100644 --- a/dlt/destinations/impl/bigquery/factory.py +++ b/dlt/destinations/impl/bigquery/factory.py @@ -1,17 +1,91 @@ import typing as t +from dlt.common.destination.typing import PreparedTableSchema +from dlt.common.exceptions import TerminalValueError from dlt.common.normalizers.naming import NamingConvention from dlt.common.configuration.specs import GcpServiceAccountCredentials from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.common.data_writers.escape import escape_hive_identifier, format_bigquery_datetime_literal from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.schema.typing import TColumnSchema, TColumnType +from dlt.common.typing import TLoaderFileFormat +from dlt.destinations.type_mapping import TypeMapperImpl +from dlt.destinations.impl.bigquery.bigquery_adapter import should_autodetect_schema from dlt.destinations.impl.bigquery.configuration import BigQueryClientConfiguration +from dlt.destinations.utils import parse_db_data_type_str_with_precision + if t.TYPE_CHECKING: from dlt.destinations.impl.bigquery.bigquery import BigQueryClient +class BigQueryTypeMapper(TypeMapperImpl): + sct_to_unbound_dbt = { + "complex": "JSON", + "text": "STRING", + "double": "FLOAT64", + "bool": "BOOL", + "date": "DATE", + "timestamp": "TIMESTAMP", + "bigint": "INT64", + "binary": "BYTES", + "wei": "BIGNUMERIC", # non-parametrized should hold wei values + "time": "TIME", + } + + sct_to_dbt = { + "text": "STRING(%i)", + "binary": "BYTES(%i)", + } + + dbt_to_sct = { + "STRING": "text", + "FLOAT64": "double", + "BOOL": "bool", + "DATE": "date", + "TIMESTAMP": "timestamp", + "INT64": "bigint", + "BYTES": "binary", + "NUMERIC": "decimal", + "BIGNUMERIC": "decimal", + "JSON": "complex", + "TIME": "time", + } + + def ensure_supported_type( + self, + column: TColumnSchema, + table: PreparedTableSchema, + loader_file_format: TLoaderFileFormat, + ) -> None: + # if table contains complex types, we cannot load with parquet + if ( + loader_file_format == "parquet" + and column["data_type"] == "complex" + and not should_autodetect_schema(table) + ): + raise TerminalValueError( + "Enable autodetect_schema in config or via BigQuery adapter", column["data_type"] + ) + + def to_db_decimal_type(self, column: TColumnSchema) -> str: + # Use BigQuery's BIGNUMERIC for large precision decimals + precision, scale = self.decimal_precision(column.get("precision"), column.get("scale")) + if precision > 38 or scale > 9: + return "BIGNUMERIC(%i,%i)" % (precision, scale) + return "NUMERIC(%i,%i)" % (precision, scale) + + # noinspection PyTypeChecker,PydanticTypeChecker + def from_destination_type( + self, db_type: str, precision: t.Optional[int], scale: t.Optional[int] + ) -> TColumnType: + # precision is present in the type name + if db_type == "BIGNUMERIC": + return dict(data_type="wei") + return super().from_destination_type(*parse_db_data_type_str_with_precision(db_type)) + + # noinspection PyPep8Naming class bigquery(Destination[BigQueryClientConfiguration, "BigQueryClient"]): spec = BigQueryClientConfiguration @@ -22,6 +96,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supported_loader_file_formats = ["jsonl", "parquet"] caps.preferred_staging_file_format = "parquet" caps.supported_staging_file_formats = ["parquet", "jsonl"] + caps.type_mapper = BigQueryTypeMapper # BigQuery is by default case sensitive but that cannot be turned off for a dataset # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#case_sensitivity caps.escape_identifier = escape_hive_identifier diff --git a/dlt/destinations/impl/clickhouse/clickhouse.py b/dlt/destinations/impl/clickhouse/clickhouse.py index 038735a84b..4d81b50731 100644 --- a/dlt/destinations/impl/clickhouse/clickhouse.py +++ b/dlt/destinations/impl/clickhouse/clickhouse.py @@ -16,6 +16,7 @@ ) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( + PreparedTableSchema, SupportsStagingDestination, TLoadJobState, HasFollowupJobs, @@ -26,9 +27,9 @@ from dlt.common.schema import Schema, TColumnSchema from dlt.common.schema.typing import ( TTableFormat, - TTableSchema, TColumnType, ) +from dlt.common.schema.utils import is_nullable_column from dlt.common.storages import FileStorage from dlt.destinations.exceptions import LoadJobTerminalException from dlt.destinations.impl.clickhouse.configuration import ( @@ -50,78 +51,10 @@ ) from dlt.destinations.job_client_impl import ( SqlJobClientBase, - SqlJobClientWithStaging, + SqlJobClientWithStagingDataset, ) from dlt.destinations.job_impl import ReferenceFollowupJobRequest, FinalizedLoadJobWithFollowupJobs from dlt.destinations.sql_jobs import SqlMergeFollowupJob -from dlt.destinations.type_mapping import TypeMapper - - -class ClickHouseTypeMapper(TypeMapper): - sct_to_unbound_dbt = { - "complex": "String", - "text": "String", - "double": "Float64", - "bool": "Boolean", - "date": "Date", - "timestamp": "DateTime64(6,'UTC')", - "time": "String", - "bigint": "Int64", - "binary": "String", - "wei": "Decimal", - } - - sct_to_dbt = { - "decimal": "Decimal(%i,%i)", - "wei": "Decimal(%i,%i)", - "timestamp": "DateTime64(%i,'UTC')", - } - - dbt_to_sct = { - "String": "text", - "Float64": "double", - "Bool": "bool", - "Date": "date", - "DateTime": "timestamp", - "DateTime64": "timestamp", - "Time": "timestamp", - "Int64": "bigint", - "Object('json')": "complex", - "Decimal": "decimal", - } - - def from_db_type( - self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None - ) -> TColumnType: - # Remove "Nullable" wrapper. - db_type = re.sub(r"^Nullable\((?P.+)\)$", r"\g", db_type) - - # Remove timezone details. - if db_type == "DateTime('UTC')": - db_type = "DateTime" - if datetime_match := re.match( - r"DateTime64(?:\((?P\d+)(?:,?\s*'(?PUTC)')?\))?", - db_type, - ): - if datetime_match["precision"]: - precision = int(datetime_match["precision"]) - else: - precision = None - db_type = "DateTime64" - - # Extract precision and scale, parameters and remove from string. - if decimal_match := re.match( - r"Decimal\((?P\d+)\s*(?:,\s*(?P\d+))?\)", db_type - ): - precision, scale = decimal_match.groups() # type: ignore[assignment] - precision = int(precision) - scale = int(scale) if scale else 0 - db_type = "Decimal" - - if db_type == "Decimal" and (precision, scale) == self.capabilities.wei_precision: - return cast(TColumnType, dict(data_type="wei")) - - return super().from_db_type(db_type, precision, scale) class ClickHouseLoadJob(RunnableLoadJob, HasFollowupJobs): @@ -269,7 +202,7 @@ def requires_temp_table_for_delete(cls) -> bool: return True -class ClickHouseClient(SqlJobClientWithStaging, SupportsStagingDestination): +class ClickHouseClient(SqlJobClientWithStagingDataset, SupportsStagingDestination): def __init__( self, schema: Schema, @@ -286,14 +219,14 @@ def __init__( super().__init__(schema, config, self.sql_client) self.config: ClickHouseClientConfiguration = config self.active_hints = deepcopy(HINT_TO_CLICKHOUSE_ATTR) - self.type_mapper = ClickHouseTypeMapper(self.capabilities) + self.type_mapper = self.capabilities.get_type_mapper() def _create_merge_followup_jobs( - self, table_chain: Sequence[TTableSchema] + self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: return [ClickHouseMergeJob.from_table_chain(table_chain, self.sql_client)] - def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: # Build column definition. # The primary key and sort order definition is defined outside column specification. hints_ = " ".join( @@ -307,9 +240,9 @@ def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> s # Alter table statements only accept `Nullable` modifiers. # JSON type isn't nullable in ClickHouse. type_with_nullability_modifier = ( - f"Nullable({self.type_mapper.to_db_type(c,table)})" - if c.get("nullable", True) - else self.type_mapper.to_db_type(c, table) + f"Nullable({self.type_mapper.to_destination_type(c,table)})" + if is_nullable_column(c) + else self.type_mapper.to_destination_type(c, table) ) return ( @@ -318,7 +251,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> s ) def create_load_job( - self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: return super().create_load_job(table, file_path, load_id, restore) or ClickHouseLoadJob( file_path, @@ -333,7 +266,7 @@ def _get_table_update_sql( new_columns: Sequence[TColumnSchema], generate_alter: bool, ) -> List[str]: - table: TTableSchema = self.prepare_load_table(table_name, self.in_staging_mode) + table = self.prepare_load_table(table_name) sql = SqlJobClientBase._get_table_update_sql(self, table_name, new_columns, generate_alter) if generate_alter: @@ -371,7 +304,7 @@ def _gen_not_null(v: bool) -> str: def _from_db_type( self, ch_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: - return self.type_mapper.from_db_type(ch_t, precision, scale) + return self.type_mapper.from_destination_type(ch_t, precision, scale) - def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: return self.config.truncate_tables_on_staging_destination_before_load diff --git a/dlt/destinations/impl/clickhouse/factory.py b/dlt/destinations/impl/clickhouse/factory.py index 93da6c866a..0e897dbefb 100644 --- a/dlt/destinations/impl/clickhouse/factory.py +++ b/dlt/destinations/impl/clickhouse/factory.py @@ -1,3 +1,4 @@ +import re import sys import typing as t @@ -8,6 +9,9 @@ format_clickhouse_datetime_literal, ) from dlt.common.destination import Destination, DestinationCapabilitiesContext + +from dlt.common.schema.typing import TColumnType +from dlt.destinations.type_mapping import TypeMapperImpl from dlt.destinations.impl.clickhouse.configuration import ( ClickHouseClientConfiguration, ClickHouseCredentials, @@ -19,6 +23,73 @@ from clickhouse_driver.dbapi import Connection # type: ignore[import-untyped] +class ClickHouseTypeMapper(TypeMapperImpl): + sct_to_unbound_dbt = { + "complex": "String", + "text": "String", + "double": "Float64", + "bool": "Boolean", + "date": "Date", + "timestamp": "DateTime64(6,'UTC')", + "time": "String", + "bigint": "Int64", + "binary": "String", + "wei": "Decimal", + } + + sct_to_dbt = { + "decimal": "Decimal(%i,%i)", + "wei": "Decimal(%i,%i)", + "timestamp": "DateTime64(%i,'UTC')", + } + + dbt_to_sct = { + "String": "text", + "Float64": "double", + "Bool": "bool", + "Date": "date", + "DateTime": "timestamp", + "DateTime64": "timestamp", + "Time": "timestamp", + "Int64": "bigint", + "Object('json')": "complex", + "Decimal": "decimal", + } + + def from_destination_type( + self, db_type: str, precision: t.Optional[int] = None, scale: t.Optional[int] = None + ) -> TColumnType: + # Remove "Nullable" wrapper. + db_type = re.sub(r"^Nullable\((?P.+)\)$", r"\g", db_type) + + # Remove timezone details. + if db_type == "DateTime('UTC')": + db_type = "DateTime" + if datetime_match := re.match( + r"DateTime64(?:\((?P\d+)(?:,?\s*'(?PUTC)')?\))?", + db_type, + ): + if datetime_match["precision"]: + precision = int(datetime_match["precision"]) + else: + precision = None + db_type = "DateTime64" + + # Extract precision and scale, parameters and remove from string. + if decimal_match := re.match( + r"Decimal\((?P\d+)\s*(?:,\s*(?P\d+))?\)", db_type + ): + precision, scale = decimal_match.groups() # type: ignore[assignment] + precision = int(precision) + scale = int(scale) if scale else 0 + db_type = "Decimal" + + if db_type == "Decimal" and (precision, scale) == self.capabilities.wei_precision: + return t.cast(TColumnType, dict(data_type="wei")) + + return super().from_destination_type(db_type, precision, scale) + + class clickhouse(Destination[ClickHouseClientConfiguration, "ClickHouseClient"]): spec = ClickHouseClientConfiguration @@ -28,6 +99,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supported_loader_file_formats = ["parquet", "jsonl"] caps.preferred_staging_file_format = "jsonl" caps.supported_staging_file_formats = ["parquet", "jsonl"] + caps.type_mapper = ClickHouseTypeMapper caps.format_datetime_literal = format_clickhouse_datetime_literal caps.escape_identifier = escape_clickhouse_identifier diff --git a/dlt/destinations/impl/databricks/databricks.py b/dlt/destinations/impl/databricks/databricks.py index 0c19984b4c..2cdff8a82c 100644 --- a/dlt/destinations/impl/databricks/databricks.py +++ b/dlt/destinations/impl/databricks/databricks.py @@ -6,6 +6,7 @@ from dlt.common.destination.reference import ( HasFollowupJobs, FollowupJobRequest, + PreparedTableSchema, RunnableLoadJob, SupportsStagingDestination, LoadJob, @@ -17,92 +18,20 @@ from dlt.common.exceptions import TerminalValueError from dlt.common.storages.file_storage import FileStorage from dlt.common.schema import TColumnSchema, Schema -from dlt.common.schema.typing import TTableSchema, TColumnType, TSchemaTables, TTableFormat -from dlt.common.schema.utils import table_schema_has_type +from dlt.common.schema.typing import TColumnType from dlt.common.storages import FilesystemConfiguration, fsspec_from_config from dlt.destinations.insert_job_client import InsertValuesJobClient -from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs from dlt.destinations.exceptions import LoadJobTerminalException from dlt.destinations.impl.databricks.configuration import DatabricksClientConfiguration from dlt.destinations.impl.databricks.sql_client import DatabricksSqlClient from dlt.destinations.sql_jobs import SqlMergeFollowupJob from dlt.destinations.job_impl import ReferenceFollowupJobRequest -from dlt.destinations.type_mapping import TypeMapper - AZURE_BLOB_STORAGE_PROTOCOLS = ["az", "abfss", "abfs"] -class DatabricksTypeMapper(TypeMapper): - sct_to_unbound_dbt = { - "complex": "STRING", # Databricks supports complex types like ARRAY - "text": "STRING", - "double": "DOUBLE", - "bool": "BOOLEAN", - "date": "DATE", - "timestamp": "TIMESTAMP", # TIMESTAMP for local timezone - "bigint": "BIGINT", - "binary": "BINARY", - "decimal": "DECIMAL", # DECIMAL(p,s) format - "time": "STRING", - } - - dbt_to_sct = { - "STRING": "text", - "DOUBLE": "double", - "BOOLEAN": "bool", - "DATE": "date", - "TIMESTAMP": "timestamp", - "BIGINT": "bigint", - "INT": "bigint", - "SMALLINT": "bigint", - "TINYINT": "bigint", - "BINARY": "binary", - "DECIMAL": "decimal", - } - - sct_to_dbt = { - "decimal": "DECIMAL(%i,%i)", - "wei": "DECIMAL(%i,%i)", - } - - def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: - precision = column.get("precision") - if precision is None: - return "BIGINT" - if precision <= 8: - return "TINYINT" - if precision <= 16: - return "SMALLINT" - if precision <= 32: - return "INT" - if precision <= 64: - return "BIGINT" - raise TerminalValueError( - f"bigint with {precision} bits precision cannot be mapped into databricks integer type" - ) - - def from_db_type( - self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None - ) -> TColumnType: - # precision and scale arguments here are meaningless as they're not included separately in information schema - # We use full_data_type from databricks which is either in form "typename" or "typename(precision, scale)" - type_parts = db_type.split("(") - if len(type_parts) > 1: - db_type = type_parts[0] - scale_str = type_parts[1].strip(")") - precision, scale = [int(val) for val in scale_str.split(",")] - else: - scale = precision = None - db_type = db_type.upper() - if db_type == "DECIMAL": - if (precision, scale) == self.wei_precision(): - return dict(data_type="wei", precision=precision, scale=scale) - return super().from_db_type(db_type, precision, scale) - - class DatabricksLoadJob(RunnableLoadJob, HasFollowupJobs): def __init__( self, @@ -199,31 +128,6 @@ def run(self) -> None: " compression in the data writer configuration:" " https://dlthub.com/docs/reference/performance#disabling-and-enabling-file-compression", ) - if table_schema_has_type(self._load_table, "decimal"): - raise LoadJobTerminalException( - self._file_path, - "Databricks loader cannot load DECIMAL type columns from json files. Switch to" - " parquet format to load decimals.", - ) - if table_schema_has_type(self._load_table, "binary"): - raise LoadJobTerminalException( - self._file_path, - "Databricks loader cannot load BINARY type columns from json files. Switch to" - " parquet format to load byte values.", - ) - if table_schema_has_type(self._load_table, "complex"): - raise LoadJobTerminalException( - self._file_path, - "Databricks loader cannot load complex columns (lists and dicts) from json" - " files. Switch to parquet format to load complex types.", - ) - if table_schema_has_type(self._load_table, "date"): - raise LoadJobTerminalException( - self._file_path, - "Databricks loader cannot load DATE type columns from json files. Switch to" - " parquet format to load dates.", - ) - source_format = "JSON" format_options_clause = "FORMAT_OPTIONS('inferTimestamp'='true')" # Databricks fails when trying to load empty json files, so we have to check the file size @@ -302,10 +206,10 @@ def __init__( super().__init__(schema, config, sql_client) self.config: DatabricksClientConfiguration = config self.sql_client: DatabricksSqlClient = sql_client # type: ignore[assignment] - self.type_mapper = DatabricksTypeMapper(self.capabilities) + self.type_mapper = self.capabilities.get_type_mapper() def create_load_job( - self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: job = super().create_load_job(table, file_path, load_id, restore) @@ -317,12 +221,12 @@ def create_load_job( return job def _create_merge_followup_jobs( - self, table_chain: Sequence[TTableSchema] + self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: return [DatabricksMergeJob.from_table_chain(table_chain, self.sql_client)] def _make_add_column_sql( - self, new_columns: Sequence[TColumnSchema], table: TTableSchema = None + self, new_columns: Sequence[TColumnSchema], table: PreparedTableSchema = None ) -> List[str]: # Override because databricks requires multiple columns in a single ADD COLUMN clause return [ @@ -350,12 +254,12 @@ def _get_table_update_sql( def _from_db_type( self, bq_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: - return self.type_mapper.from_db_type(bq_t, precision, scale) + return self.type_mapper.from_destination_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: name = self.sql_client.escape_column_name(c["name"]) return ( - f"{name} {self.type_mapper.to_db_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" + f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" ) def _get_storage_table_query_columns(self) -> List[str]: @@ -365,5 +269,5 @@ def _get_storage_table_query_columns(self) -> List[str]: ) return fields - def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: return self.config.truncate_tables_on_staging_destination_before_load diff --git a/dlt/destinations/impl/databricks/factory.py b/dlt/destinations/impl/databricks/factory.py index 6108b69da9..647b451161 100644 --- a/dlt/destinations/impl/databricks/factory.py +++ b/dlt/destinations/impl/databricks/factory.py @@ -1,9 +1,15 @@ import typing as t +from dlt.common.data_types.typing import TDataType from dlt.common.destination import Destination, DestinationCapabilitiesContext from dlt.common.data_writers.escape import escape_databricks_identifier, escape_databricks_literal from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.destination.typing import PreparedTableSchema +from dlt.common.exceptions import TerminalValueError +from dlt.common.schema.typing import TColumnSchema, TColumnType, TTableSchema +from dlt.common.typing import TLoaderFileFormat +from dlt.destinations.type_mapping import TypeMapperImpl from dlt.destinations.impl.databricks.configuration import ( DatabricksCredentials, DatabricksClientConfiguration, @@ -13,6 +19,89 @@ from dlt.destinations.impl.databricks.databricks import DatabricksClient +class DatabricksTypeMapper(TypeMapperImpl): + sct_to_unbound_dbt = { + "complex": "STRING", # Databricks supports complex types like ARRAY + "text": "STRING", + "double": "DOUBLE", + "bool": "BOOLEAN", + "date": "DATE", + "timestamp": "TIMESTAMP", # TIMESTAMP for local timezone + "bigint": "BIGINT", + "binary": "BINARY", + "decimal": "DECIMAL", # DECIMAL(p,s) format + "time": "STRING", + } + + dbt_to_sct = { + "STRING": "text", + "DOUBLE": "double", + "BOOLEAN": "bool", + "DATE": "date", + "TIMESTAMP": "timestamp", + "BIGINT": "bigint", + "INT": "bigint", + "SMALLINT": "bigint", + "TINYINT": "bigint", + "BINARY": "binary", + "DECIMAL": "decimal", + } + + sct_to_dbt = { + "decimal": "DECIMAL(%i,%i)", + "wei": "DECIMAL(%i,%i)", + } + + def ensure_supported_type( + self, + column: TColumnSchema, + table: PreparedTableSchema, + loader_file_format: TLoaderFileFormat, + ) -> None: + if loader_file_format == "jsonl" and column["data_type"] in { + "decimal", + "wei", + "binary", + "complex", + "date", + }: + raise TerminalValueError("", column["data_type"]) + + def to_db_integer_type(self, column: TColumnSchema, table: PreparedTableSchema = None) -> str: + precision = column.get("precision") + if precision is None: + return "BIGINT" + if precision <= 8: + return "TINYINT" + if precision <= 16: + return "SMALLINT" + if precision <= 32: + return "INT" + if precision <= 64: + return "BIGINT" + raise TerminalValueError( + f"bigint with {precision} bits precision cannot be mapped into databricks integer type" + ) + + def from_destination_type( + self, db_type: str, precision: t.Optional[int] = None, scale: t.Optional[int] = None + ) -> TColumnType: + # precision and scale arguments here are meaningless as they're not included separately in information schema + # We use full_data_type from databricks which is either in form "typename" or "typename(precision, scale)" + type_parts = db_type.split("(") + if len(type_parts) > 1: + db_type = type_parts[0] + scale_str = type_parts[1].strip(")") + precision, scale = [int(val) for val in scale_str.split(",")] + else: + scale = precision = None + db_type = db_type.upper() + if db_type == "DECIMAL": + if (precision, scale) == self.wei_precision(): + return dict(data_type="wei", precision=precision, scale=scale) + return super().from_destination_type(db_type, precision, scale) + + class databricks(Destination[DatabricksClientConfiguration, "DatabricksClient"]): spec = DatabricksClientConfiguration @@ -22,6 +111,8 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supported_loader_file_formats = [] caps.preferred_staging_file_format = "parquet" caps.supported_staging_file_formats = ["jsonl", "parquet"] + caps.supported_table_formats = ["delta"] + caps.type_mapper = DatabricksTypeMapper caps.escape_identifier = escape_databricks_identifier # databricks identifiers are case insensitive and stored in lower case # https://docs.databricks.com/en/sql/language-manual/sql-ref-identifiers.html diff --git a/dlt/destinations/impl/destination/destination.py b/dlt/destinations/impl/destination/destination.py index 0c4da81471..253fb8722f 100644 --- a/dlt/destinations/impl/destination/destination.py +++ b/dlt/destinations/impl/destination/destination.py @@ -1,14 +1,13 @@ -from copy import deepcopy from types import TracebackType from typing import ClassVar, Optional, Type, Iterable, cast, List -from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs, FinalizedLoadJob -from dlt.common.destination.reference import LoadJob +from dlt.destinations.job_impl import FinalizedLoadJob +from dlt.common.destination.reference import LoadJob, PreparedTableSchema from dlt.common.typing import AnyFun from dlt.common.storages.load_package import destination_state from dlt.common.configuration import create_resolved_partial -from dlt.common.schema import Schema, TTableSchema, TSchemaTables +from dlt.common.schema import Schema, TSchemaTables from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( JobClientBase, @@ -56,7 +55,7 @@ def update_stored_schema( return super().update_stored_schema(only_tables, expected_update) def create_load_job( - self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: # skip internal tables and remove columns from schema if so configured if self.config.skip_dlt_columns_and_tables: @@ -89,10 +88,8 @@ def create_load_job( ) return None - def prepare_load_table( - self, table_name: str, prepare_for_staging: bool = False - ) -> TTableSchema: - table = super().prepare_load_table(table_name, prepare_for_staging) + def prepare_load_table(self, table_name: str) -> PreparedTableSchema: + table = super().prepare_load_table(table_name) if self.config.skip_dlt_columns_and_tables: for column in list(table["columns"].keys()): if column.startswith(self.schema._dlt_tables_prefix): diff --git a/dlt/destinations/impl/dremio/dremio.py b/dlt/destinations/impl/dremio/dremio.py index 91dc64f113..ab23f58ab4 100644 --- a/dlt/destinations/impl/dremio/dremio.py +++ b/dlt/destinations/impl/dremio/dremio.py @@ -4,76 +4,32 @@ from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( HasFollowupJobs, + PreparedTableSchema, TLoadJobState, RunnableLoadJob, SupportsStagingDestination, FollowupJobRequest, LoadJob, ) -from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns -from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat, TColumnSchemaBase +from dlt.common.schema import TColumnSchema, Schema +from dlt.common.schema.typing import TColumnType, TTableFormat from dlt.common.storages.file_storage import FileStorage from dlt.common.utils import uniq_id from dlt.destinations.exceptions import LoadJobTerminalException from dlt.destinations.impl.dremio.configuration import DremioClientConfiguration from dlt.destinations.impl.dremio.sql_client import DremioSqlClient -from dlt.destinations.job_client_impl import SqlJobClientWithStaging -from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs +from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.sql_jobs import SqlMergeFollowupJob -from dlt.destinations.type_mapping import TypeMapper from dlt.destinations.sql_client import SqlClientBase -class DremioTypeMapper(TypeMapper): - BIGINT_PRECISION = 19 - sct_to_unbound_dbt = { - "complex": "VARCHAR", - "text": "VARCHAR", - "double": "DOUBLE", - "bool": "BOOLEAN", - "date": "DATE", - "timestamp": "TIMESTAMP", - "bigint": "BIGINT", - "binary": "VARBINARY", - "time": "TIME", - } - - sct_to_dbt = { - "decimal": "DECIMAL(%i,%i)", - "wei": "DECIMAL(%i,%i)", - } - - dbt_to_sct = { - "VARCHAR": "text", - "DOUBLE": "double", - "FLOAT": "double", - "BOOLEAN": "bool", - "DATE": "date", - "TIMESTAMP": "timestamp", - "VARBINARY": "binary", - "BINARY": "binary", - "BINARY VARYING": "binary", - "VARIANT": "complex", - "TIME": "time", - "BIGINT": "bigint", - "DECIMAL": "decimal", - } - - def from_db_type( - self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None - ) -> TColumnType: - if db_type == "DECIMAL": - if (precision, scale) == self.capabilities.wei_precision: - return dict(data_type="wei") - return dict(data_type="decimal", precision=precision, scale=scale) - return super().from_db_type(db_type, precision, scale) - - class DremioMergeJob(SqlMergeFollowupJob): @classmethod def _new_temp_table_name(cls, name_prefix: str, sql_client: SqlClientBase[Any]) -> str: - return sql_client.make_qualified_table_name(f"_temp_{name_prefix}_{uniq_id()}") + return sql_client.make_qualified_table_name( + cls._shorten_table_name(f"_temp_{name_prefix}_{uniq_id()}", sql_client) + ) @classmethod def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: @@ -134,7 +90,7 @@ def run(self) -> None: """) -class DremioClient(SqlJobClientWithStaging, SupportsStagingDestination): +class DremioClient(SqlJobClientWithStagingDataset, SupportsStagingDestination): def __init__( self, schema: Schema, @@ -150,10 +106,10 @@ def __init__( super().__init__(schema, config, sql_client) self.config: DremioClientConfiguration = config self.sql_client: DremioSqlClient = sql_client # type: ignore - self.type_mapper = DremioTypeMapper(self.capabilities) + self.type_mapper = self.capabilities.get_type_mapper() def create_load_job( - self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: job = super().create_load_job(table, file_path, load_id, restore) @@ -193,21 +149,21 @@ def _get_table_update_sql( def _from_db_type( self, bq_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: - return self.type_mapper.from_db_type(bq_t, precision, scale) + return self.type_mapper.from_destination_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: name = self.sql_client.escape_column_name(c["name"]) return ( - f"{name} {self.type_mapper.to_db_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" + f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" ) def _create_merge_followup_jobs( - self, table_chain: Sequence[TTableSchema] + self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: return [DremioMergeJob.from_table_chain(table_chain, self.sql_client)] def _make_add_column_sql( - self, new_columns: Sequence[TColumnSchema], table: TTableSchema = None + self, new_columns: Sequence[TColumnSchema], table: PreparedTableSchema = None ) -> List[str]: return [ "ADD COLUMNS (" @@ -215,5 +171,5 @@ def _make_add_column_sql( + ")" ] - def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: return self.config.truncate_tables_on_staging_destination_before_load diff --git a/dlt/destinations/impl/dremio/factory.py b/dlt/destinations/impl/dremio/factory.py index b8c7e1b746..a8604b45e4 100644 --- a/dlt/destinations/impl/dremio/factory.py +++ b/dlt/destinations/impl/dremio/factory.py @@ -4,6 +4,11 @@ from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE from dlt.common.data_writers.escape import escape_dremio_identifier +from dlt.common.destination.typing import PreparedTableSchema +from dlt.common.exceptions import TerminalValueError +from dlt.common.schema.typing import TColumnSchema, TColumnType +from dlt.common.typing import TLoaderFileFormat +from dlt.destinations.type_mapping import TypeMapperImpl from dlt.destinations.impl.dremio.configuration import ( DremioCredentials, DremioClientConfiguration, @@ -13,6 +18,68 @@ from dlt.destinations.impl.dremio.dremio import DremioClient +class DremioTypeMapper(TypeMapperImpl): + BIGINT_PRECISION = 19 + sct_to_unbound_dbt = { + "complex": "VARCHAR", + "text": "VARCHAR", + "double": "DOUBLE", + "bool": "BOOLEAN", + "date": "DATE", + "timestamp": "TIMESTAMP", + "bigint": "BIGINT", + "binary": "VARBINARY", + "time": "TIME", + } + + sct_to_dbt = { + "decimal": "DECIMAL(%i,%i)", + "wei": "DECIMAL(%i,%i)", + } + + dbt_to_sct = { + "VARCHAR": "text", + "DOUBLE": "double", + "FLOAT": "double", + "BOOLEAN": "bool", + "DATE": "date", + "TIMESTAMP": "timestamp", + "VARBINARY": "binary", + "BINARY": "binary", + "BINARY VARYING": "binary", + "VARIANT": "complex", + "TIME": "time", + "BIGINT": "bigint", + "DECIMAL": "decimal", + } + + def ensure_supported_type( + self, + column: TColumnSchema, + table: PreparedTableSchema, + loader_file_format: TLoaderFileFormat, + ) -> None: + if loader_file_format == "insert_values": + return + if loader_file_format == "parquet": + # binary not supported on parquet if precision is set + if column.get("precision") is not None and column["data_type"] == "binary": + raise TerminalValueError( + "Dremio cannot load fixed width 'binary' columns from parquet files. Switch to" + " other file format or use binary columns without precision.", + "binary", + ) + + def from_destination_type( + self, db_type: str, precision: t.Optional[int] = None, scale: t.Optional[int] = None + ) -> TColumnType: + if db_type == "DECIMAL": + if (precision, scale) == self.capabilities.wei_precision: + return dict(data_type="wei") + return dict(data_type="decimal", precision=precision, scale=scale) + return super().from_destination_type(db_type, precision, scale) + + class dremio(Destination[DremioClientConfiguration, "DremioClient"]): spec = DremioClientConfiguration @@ -23,6 +90,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.preferred_staging_file_format = "parquet" caps.supported_staging_file_formats = ["jsonl", "parquet"] caps.escape_identifier = escape_dremio_identifier + caps.type_mapper = DremioTypeMapper # all identifiers are case insensitive but are stored as is # https://docs.dremio.com/current/sonar/data-sources caps.has_case_sensitive_identifiers = False diff --git a/dlt/destinations/impl/duckdb/duck.py b/dlt/destinations/impl/duckdb/duck.py index 5fa82f4977..3bd4c83e1f 100644 --- a/dlt/destinations/impl/duckdb/duck.py +++ b/dlt/destinations/impl/duckdb/duck.py @@ -1,133 +1,26 @@ -import threading -import logging -from typing import ClassVar, Dict, Optional +from typing import Dict, Optional from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.data_types import TDataType from dlt.common.exceptions import TerminalValueError from dlt.common.schema import TColumnSchema, TColumnHint, Schema -from dlt.common.destination.reference import RunnableLoadJob, HasFollowupJobs, LoadJob -from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat +from dlt.common.destination.reference import ( + PreparedTableSchema, + RunnableLoadJob, + HasFollowupJobs, + LoadJob, +) +from dlt.common.schema.typing import TColumnType, TTableFormat from dlt.common.storages.file_storage import FileStorage -from dlt.common.utils import maybe_context from dlt.destinations.insert_job_client import InsertValuesJobClient from dlt.destinations.impl.duckdb.sql_client import DuckDbSqlClient from dlt.destinations.impl.duckdb.configuration import DuckDbClientConfiguration -from dlt.destinations.type_mapping import TypeMapper HINT_TO_POSTGRES_ATTR: Dict[TColumnHint, str] = {"unique": "UNIQUE"} -class DuckDbTypeMapper(TypeMapper): - sct_to_unbound_dbt = { - "complex": "JSON", - "text": "VARCHAR", - "double": "DOUBLE", - "bool": "BOOLEAN", - "date": "DATE", - # Duck does not allow specifying precision on timestamp with tz - "timestamp": "TIMESTAMP WITH TIME ZONE", - "bigint": "BIGINT", - "binary": "BLOB", - "time": "TIME", - } - - sct_to_dbt = { - # VARCHAR(n) is alias for VARCHAR in duckdb - # "text": "VARCHAR(%i)", - "decimal": "DECIMAL(%i,%i)", - "wei": "DECIMAL(%i,%i)", - } - - dbt_to_sct = { - "VARCHAR": "text", - "JSON": "complex", - "DOUBLE": "double", - "BOOLEAN": "bool", - "DATE": "date", - "TIMESTAMP WITH TIME ZONE": "timestamp", - "BLOB": "binary", - "DECIMAL": "decimal", - "TIME": "time", - # Int types - "TINYINT": "bigint", - "SMALLINT": "bigint", - "INTEGER": "bigint", - "BIGINT": "bigint", - "HUGEINT": "bigint", - "TIMESTAMP_S": "timestamp", - "TIMESTAMP_MS": "timestamp", - "TIMESTAMP_NS": "timestamp", - } - - def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: - precision = column.get("precision") - if precision is None: - return "BIGINT" - # Precision is number of bits - if precision <= 8: - return "TINYINT" - elif precision <= 16: - return "SMALLINT" - elif precision <= 32: - return "INTEGER" - elif precision <= 64: - return "BIGINT" - elif precision <= 128: - return "HUGEINT" - raise TerminalValueError( - f"bigint with {precision} bits precision cannot be mapped into duckdb integer type" - ) - - def to_db_datetime_type( - self, - column: TColumnSchema, - table: TTableSchema = None, - ) -> str: - column_name = column.get("name") - table_name = table.get("name") - timezone = column.get("timezone") - precision = column.get("precision") - - if timezone and precision is not None: - logging.warn( - f"DuckDB does not support both timezone and precision for column '{column_name}' in" - f" table '{table_name}'. Will default to timezone." - ) - - if timezone: - return "TIMESTAMP WITH TIME ZONE" - elif timezone is not None: # condition for when timezone is False given that none is falsy - return "TIMESTAMP" - - if precision is None or precision == 6: - return None - elif precision == 0: - return "TIMESTAMP_S" - elif precision == 3: - return "TIMESTAMP_MS" - elif precision == 9: - return "TIMESTAMP_NS" - - raise TerminalValueError( - f"DuckDB does not support precision '{precision}' for '{column_name}' in table" - f" '{table_name}'" - ) - - def from_db_type( - self, db_type: str, precision: Optional[int], scale: Optional[int] - ) -> TColumnType: - # duckdb provides the types with scale and precision - db_type = db_type.split("(")[0].upper() - if db_type == "DECIMAL": - if precision == 38 and scale == 0: - return dict(data_type="wei", precision=precision, scale=scale) - return super().from_db_type(db_type, precision, scale) - - class DuckDbCopyJob(RunnableLoadJob, HasFollowupJobs): def __init__(self, file_path: str) -> None: super().__init__(file_path) @@ -171,17 +64,17 @@ def __init__( self.config: DuckDbClientConfiguration = config self.sql_client: DuckDbSqlClient = sql_client # type: ignore self.active_hints = HINT_TO_POSTGRES_ATTR if self.config.create_indexes else {} - self.type_mapper = DuckDbTypeMapper(self.capabilities) + self.type_mapper = self.capabilities.get_type_mapper() def create_load_job( - self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: job = super().create_load_job(table, file_path, load_id, restore) if not job: job = DuckDbCopyJob(file_path) return job - def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: hints_str = " ".join( self.active_hints.get(h, "") for h in self.active_hints.keys() @@ -189,10 +82,10 @@ def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> s ) column_name = self.sql_client.escape_column_name(c["name"]) return ( - f"{column_name} {self.type_mapper.to_db_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" + f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) def _from_db_type( self, pq_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: - return self.type_mapper.from_db_type(pq_t, precision, scale) + return self.type_mapper.from_destination_type(pq_t, precision, scale) diff --git a/dlt/destinations/impl/duckdb/factory.py b/dlt/destinations/impl/duckdb/factory.py index 2c4df2cb58..ce861cc2f7 100644 --- a/dlt/destinations/impl/duckdb/factory.py +++ b/dlt/destinations/impl/duckdb/factory.py @@ -1,9 +1,13 @@ import typing as t +from dlt.common import logger from dlt.common.destination import Destination, DestinationCapabilitiesContext from dlt.common.data_writers.escape import escape_postgres_identifier, escape_duckdb_literal from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE - +from dlt.common.destination.typing import PreparedTableSchema +from dlt.common.exceptions import TerminalValueError +from dlt.common.schema.typing import TColumnSchema, TColumnType +from dlt.destinations.type_mapping import TypeMapperImpl from dlt.destinations.impl.duckdb.configuration import DuckDbCredentials, DuckDbClientConfiguration if t.TYPE_CHECKING: @@ -11,6 +15,113 @@ from dlt.destinations.impl.duckdb.duck import DuckDbClient +class DuckDbTypeMapper(TypeMapperImpl): + sct_to_unbound_dbt = { + "complex": "JSON", + "text": "VARCHAR", + "double": "DOUBLE", + "bool": "BOOLEAN", + "date": "DATE", + # Duck does not allow specifying precision on timestamp with tz + "timestamp": "TIMESTAMP WITH TIME ZONE", + "bigint": "BIGINT", + "binary": "BLOB", + "time": "TIME", + } + + sct_to_dbt = { + # VARCHAR(n) is alias for VARCHAR in duckdb + # "text": "VARCHAR(%i)", + "decimal": "DECIMAL(%i,%i)", + "wei": "DECIMAL(%i,%i)", + } + + dbt_to_sct = { + "VARCHAR": "text", + "JSON": "complex", + "DOUBLE": "double", + "BOOLEAN": "bool", + "DATE": "date", + "TIMESTAMP WITH TIME ZONE": "timestamp", + "BLOB": "binary", + "DECIMAL": "decimal", + "TIME": "time", + # Int types + "TINYINT": "bigint", + "SMALLINT": "bigint", + "INTEGER": "bigint", + "BIGINT": "bigint", + "HUGEINT": "bigint", + "TIMESTAMP_S": "timestamp", + "TIMESTAMP_MS": "timestamp", + "TIMESTAMP_NS": "timestamp", + } + + def to_db_integer_type(self, column: TColumnSchema, table: PreparedTableSchema = None) -> str: + precision = column.get("precision") + if precision is None: + return "BIGINT" + # Precision is number of bits + if precision <= 8: + return "TINYINT" + elif precision <= 16: + return "SMALLINT" + elif precision <= 32: + return "INTEGER" + elif precision <= 64: + return "BIGINT" + elif precision <= 128: + return "HUGEINT" + raise TerminalValueError( + f"bigint with {precision} bits precision cannot be mapped into duckdb integer type" + ) + + def to_db_datetime_type( + self, + column: TColumnSchema, + table: PreparedTableSchema = None, + ) -> str: + column_name = column["name"] + table_name = table["name"] + timezone = column.get("timezone", True) + precision = column.get("precision") + + if timezone and precision is not None: + logger.warn( + f"DuckDB does not support both timezone and precision for column '{column_name}' in" + f" table '{table_name}'. Will default to timezone. Please set timezone to False to" + " use precision types." + ) + + if timezone: + # default timestamp mapping for timezone + return None + + if precision is None or precision == 6: + return "TIMESTAMP" + elif precision == 0: + return "TIMESTAMP_S" + elif precision == 3: + return "TIMESTAMP_MS" + elif precision == 9: + return "TIMESTAMP_NS" + + raise TerminalValueError( + f"DuckDB does not support precision '{precision}' for '{column_name}' in table" + f" '{table_name}'" + ) + + def from_destination_type( + self, db_type: str, precision: t.Optional[int], scale: t.Optional[int] + ) -> TColumnType: + # duckdb provides the types with scale and precision + db_type = db_type.split("(")[0].upper() + if db_type == "DECIMAL": + if precision == 38 and scale == 0: + return dict(data_type="wei", precision=precision, scale=scale) + return super().from_destination_type(db_type, precision, scale) + + class duckdb(Destination[DuckDbClientConfiguration, "DuckDbClient"]): spec = DuckDbClientConfiguration @@ -20,6 +131,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supported_loader_file_formats = ["insert_values", "parquet", "jsonl"] caps.preferred_staging_file_format = None caps.supported_staging_file_formats = [] + caps.type_mapper = DuckDbTypeMapper caps.escape_identifier = escape_postgres_identifier # all identifiers are case insensitive but are stored as is caps.escape_literal = escape_duckdb_literal diff --git a/dlt/destinations/impl/dummy/dummy.py b/dlt/destinations/impl/dummy/dummy.py index fc87faaf5a..72563e903d 100644 --- a/dlt/destinations/impl/dummy/dummy.py +++ b/dlt/destinations/impl/dummy/dummy.py @@ -16,7 +16,7 @@ import time from dlt.common.metrics import LoadJobMetrics from dlt.common.pendulum import pendulum -from dlt.common.schema import Schema, TTableSchema, TSchemaTables +from dlt.common.schema import Schema, TSchemaTables from dlt.common.storages import FileStorage from dlt.common.storages.load_package import LoadJobInfo from dlt.common.destination import DestinationCapabilitiesContext @@ -27,6 +27,7 @@ from dlt.common.destination.reference import ( HasFollowupJobs, FollowupJobRequest, + PreparedTableSchema, SupportsStagingDestination, TLoadJobState, RunnableLoadJob, @@ -160,7 +161,7 @@ def update_stored_schema( return applied_update def create_load_job( - self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: job_id = FileStorage.get_file_name_from_file_path(file_path) if restore and job_id not in JOBS: @@ -178,7 +179,7 @@ def create_load_job( def create_table_chain_completed_followup_jobs( self, - table_chain: Sequence[TTableSchema], + table_chain: Sequence[PreparedTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJobRequest]: """Creates a list of followup jobs that should be executed after a table chain is completed""" @@ -199,10 +200,10 @@ def create_table_chain_completed_followup_jobs( def complete_load(self, load_id: str) -> None: pass - def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: - return super().should_load_data_to_staging_dataset(table) + def should_load_data_to_staging_dataset(self, table_name: str) -> bool: + return super().should_load_data_to_staging_dataset(table_name) - def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: return self.config.truncate_tables_on_staging_destination_before_load @contextmanager diff --git a/dlt/destinations/impl/filesystem/factory.py b/dlt/destinations/impl/filesystem/factory.py index ff3c8a59e1..c5218f14a3 100644 --- a/dlt/destinations/impl/filesystem/factory.py +++ b/dlt/destinations/impl/filesystem/factory.py @@ -2,18 +2,17 @@ from dlt.common.destination import Destination, DestinationCapabilitiesContext, TLoaderFileFormat from dlt.common.destination.reference import DEFAULT_FILE_LAYOUT -from dlt.common.schema.typing import TTableSchema +from dlt.common.schema.typing import TLoaderMergeStrategy, TTableSchema from dlt.common.storages.configuration import FileSystemCredentials from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration from dlt.destinations.impl.filesystem.typing import TCurrentDateTime, TExtraPlaceholders -from dlt.common.normalizers.naming.naming import NamingConvention if t.TYPE_CHECKING: from dlt.destinations.impl.filesystem.filesystem import FilesystemClient -def loader_file_format_adapter( +def filesystem_loader_file_format_selector( preferred_loader_file_format: TLoaderFileFormat, supported_loader_file_formats: t.Sequence[TLoaderFileFormat], /, @@ -25,22 +24,33 @@ def loader_file_format_adapter( return (preferred_loader_file_format, supported_loader_file_formats) +def filesystem_merge_strategies_selector( + supported_merge_strategies: t.Sequence[TLoaderMergeStrategy], + /, + *, + table_schema: TTableSchema, +) -> t.Sequence[TLoaderMergeStrategy]: + if table_schema.get("table_format") == "delta": + return supported_merge_strategies + else: + return [] + + class filesystem(Destination[FilesystemDestinationClientConfiguration, "FilesystemClient"]): spec = FilesystemDestinationClientConfiguration def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext.generic_capabilities( preferred_loader_file_format="jsonl", - loader_file_format_adapter=loader_file_format_adapter, + loader_file_format_selector=filesystem_loader_file_format_selector, supported_table_formats=["delta"], - # TODO: make `supported_merge_strategies` depend on configured - # `table_format` (perhaps with adapter similar to how we handle - # loader file format) supported_merge_strategies=["upsert"], + merge_strategies_selector=filesystem_merge_strategies_selector, ) caps.supported_loader_file_formats = list(caps.supported_loader_file_formats) + [ "reference", ] + caps.has_case_sensitive_identifiers = True return caps @property diff --git a/dlt/destinations/impl/filesystem/filesystem.py b/dlt/destinations/impl/filesystem/filesystem.py index ac5ffb9ef3..9d2072e701 100644 --- a/dlt/destinations/impl/filesystem/filesystem.py +++ b/dlt/destinations/impl/filesystem/filesystem.py @@ -9,20 +9,24 @@ import dlt from dlt.common import logger, time, json, pendulum +from dlt.common.destination.utils import resolve_merge_strategy from dlt.common.metrics import LoadJobMetrics +from dlt.common.schema.typing import TTableSchemaColumns from dlt.common.storages.fsspec_filesystem import glob_files from dlt.common.typing import DictStrAny -from dlt.common.schema import Schema, TSchemaTables, TTableSchema +from dlt.common.schema import Schema, TSchemaTables from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages import FileStorage, fsspec_from_config from dlt.common.storages.load_package import ( LoadJobInfo, + ParsedLoadJobFileName, TPipelineStateDoc, load_package as current_load_package, ) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( FollowupJobRequest, + PreparedTableSchema, TLoadJobState, RunnableLoadJob, JobClientBase, @@ -42,6 +46,7 @@ from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration from dlt.destinations import path_utils from dlt.destinations.fs_client import FSClientBase +from dlt.destinations.utils import verify_schema_merge_disposition INIT_FILE_NAME = "init" FILENAME_SEPARATOR = "__" @@ -124,7 +129,9 @@ def run(self) -> None: with self.arrow_ds.scanner().to_reader() as arrow_rbr: # RecordBatchReader if self._load_table["write_disposition"] == "merge" and self._delta_table is not None: - assert self._load_table["x-merge-strategy"] in self._job_client.capabilities.supported_merge_strategies # type: ignore[typeddict-item] + self._load_table["x-merge-strategy"] = resolve_merge_strategy( # type: ignore[typeddict-unknown-key] + self._schema.tables, self._load_table, self._job_client.capabilities + ) merge_delta_table( table=self._delta_table, data=arrow_rbr, @@ -236,9 +243,7 @@ def dataset_path(self) -> str: def with_staging_dataset(self) -> Iterator["FilesystemClient"]: current_dataset_name = self.dataset_name try: - self.dataset_name = self.schema.naming.normalize_table_identifier( - current_dataset_name + "_staging" - ) + self.dataset_name = self.config.normalize_staging_dataset_name(self.schema) yield self finally: # restore previous dataset name @@ -265,6 +270,17 @@ def drop_tables(self, *tables: str, delete_schema: bool = True) -> None: if fileparts[0] == self.schema.name: self._delete_file(filename) + def get_storage_tables( + self, table_names: Iterable[str] + ) -> Iterable[Tuple[str, TTableSchemaColumns]]: + """Yields tables that have files in storage, does not return column schemas""" + for table_name in table_names: + if len(self.list_table_files(table_name)) > 0: + yield (table_name, {"_column": {}}) + else: + # if no columns we assume that table does not exist + yield (table_name, {}) + def truncate_tables(self, table_names: List[str]) -> None: """Truncate a set of regular tables with given `table_names`""" table_dirs = set(self.get_table_dirs(table_names)) @@ -291,6 +307,19 @@ def _delete_file(self, file_path: str) -> None: if self.fs_client.exists(file_path): raise FileExistsError(file_path) + def verify_schema( + self, only_tables: Iterable[str] = None, new_jobs: Iterable[ParsedLoadJobFileName] = None + ) -> List[PreparedTableSchema]: + loaded_tables = super().verify_schema(only_tables, new_jobs) + # TODO: finetune verify_schema_merge_disposition ie. hard deletes are not supported + if exceptions := verify_schema_merge_disposition( + self.schema, loaded_tables, self.capabilities, warnings=True + ): + for exception in exceptions: + logger.error(str(exception)) + raise exceptions[0] + return loaded_tables + def update_stored_schema( self, only_tables: Iterable[str] = None, @@ -307,13 +336,21 @@ def update_stored_schema( self.fs_client.touch(self.pathlib.join(directory, INIT_FILE_NAME)) # don't store schema when used as staging - if not self.config.as_staging: + if not self.config.as_staging_destination: self._store_current_schema() # we assume that expected_update == applied_update so table schemas in dest were not # externally changed return applied_update + def prepare_load_table(self, table_name: str) -> PreparedTableSchema: + table = super().prepare_load_table(table_name) + if self.config.as_staging_destination: + if table["write_disposition"] == "merge": + table["write_disposition"] = "append" + table.pop("table_format", None) + return table + def get_table_dir(self, table_name: str, remote: bool = False) -> str: # dlt tables do not respect layout (for now) table_prefix = self.get_table_prefix(table_name) @@ -369,12 +406,12 @@ def is_storage_initialized(self) -> bool: return self.fs_client.exists(self.pathlib.join(self.dataset_path, INIT_FILE_NAME)) # type: ignore[no-any-return] def create_load_job( - self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: # skip the state table, we create a jsonl file in the complete_load step # this does not apply to scenarios where we are using filesystem as staging # where we want to load the state the regular way - if table["name"] == self.schema.state_table_name and not self.config.as_staging: + if table["name"] == self.schema.state_table_name and not self.config.as_staging_destination: return FinalizedLoadJob(file_path) if table.get("table_format") == "delta": import dlt.common.libs.deltalake # assert dependencies are installed @@ -385,7 +422,11 @@ def create_load_job( # otherwise just continue return FinalizedLoadJobWithFollowupJobs(file_path) - cls = FilesystemLoadJobWithFollowup if self.config.as_staging else FilesystemLoadJob + cls = ( + FilesystemLoadJobWithFollowup + if self.config.as_staging_destination + else FilesystemLoadJob + ) return cls(file_path) def make_remote_url(self, remote_path: str) -> str: @@ -403,10 +444,11 @@ def __exit__( ) -> None: pass - def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: + def should_load_data_to_staging_dataset(self, table_name: str) -> bool: return False - def should_truncate_table_before_load(self, table: TTableSchema) -> bool: + def should_truncate_table_before_load(self, table_name: str) -> bool: + table = self.prepare_load_table(table_name) return ( table["write_disposition"] == "replace" and not table.get("table_format") == "delta" # Delta can do a logical replace @@ -472,7 +514,7 @@ def _get_state_file_name(self, pipeline_name: str, version_hash: str, load_id: s def _store_current_state(self, load_id: str) -> None: # don't save the state this way when used as staging - if self.config.as_staging: + if self.config.as_staging_destination: return # get state doc from current pipeline @@ -587,7 +629,7 @@ def get_stored_schema_by_hash(self, version_hash: str) -> Optional[StorageSchema def create_table_chain_completed_followup_jobs( self, - table_chain: Sequence[TTableSchema], + table_chain: Sequence[PreparedTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJobRequest]: assert completed_table_chain_jobs is not None diff --git a/dlt/destinations/impl/lancedb/factory.py b/dlt/destinations/impl/lancedb/factory.py index f2e17168b9..339453133f 100644 --- a/dlt/destinations/impl/lancedb/factory.py +++ b/dlt/destinations/impl/lancedb/factory.py @@ -1,11 +1,21 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.destination.capabilities import DataTypeMapper +from dlt.common.exceptions import MissingDependencyException from dlt.destinations.impl.lancedb.configuration import ( LanceDBCredentials, LanceDBClientConfiguration, ) +LanceDBTypeMapper: t.Type[DataTypeMapper] +try: + # lancedb type mapper cannot be used without pyarrow installed + from dlt.destinations.impl.lancedb.type_mapper import LanceDBTypeMapper +except MissingDependencyException: + # assign mock type mapper if no arrow + from dlt.common.destination.capabilities import UnsupportedTypeMapper as LanceDBTypeMapper + if t.TYPE_CHECKING: from dlt.destinations.impl.lancedb.lancedb_client import LanceDBClient @@ -18,6 +28,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = "jsonl" caps.supported_loader_file_formats = ["jsonl"] + caps.type_mapper = LanceDBTypeMapper caps.max_identifier_length = 200 caps.max_column_identifier_length = 1024 diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 8d4b6303ef..5fd3e93411 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -14,6 +14,7 @@ TYPE_CHECKING, ) +from dlt.common.destination.capabilities import DataTypeMapper import lancedb # type: ignore import pyarrow as pa from lancedb import DBConnection @@ -32,6 +33,7 @@ ) from dlt.common.destination.reference import ( JobClientBase, + PreparedTableSchema, WithStateSync, RunnableLoadJob, StorageSchemaInfo, @@ -39,10 +41,8 @@ LoadJob, ) from dlt.common.pendulum import timedelta -from dlt.common.schema import Schema, TTableSchema, TSchemaTables, TColumnSchema +from dlt.common.schema import Schema, TSchemaTables from dlt.common.schema.typing import ( - TColumnType, - TTableFormat, TTableSchemaColumns, TWriteDisposition, ) @@ -68,91 +68,15 @@ generate_uuid, set_non_standard_providers_environment_variables, ) -from dlt.destinations.type_mapping import TypeMapper if TYPE_CHECKING: NDArray = ndarray[Any, Any] else: NDArray = ndarray - -TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} -UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} EMPTY_STRING_PLACEHOLDER = "0uEoDNBpQUBwsxKbmxxB" -class LanceDBTypeMapper(TypeMapper): - sct_to_unbound_dbt = { - "text": pa.string(), - "double": pa.float64(), - "bool": pa.bool_(), - "bigint": pa.int64(), - "binary": pa.binary(), - "date": pa.date32(), - "complex": pa.string(), - } - - sct_to_dbt = {} - - dbt_to_sct = { - pa.string(): "text", - pa.float64(): "double", - pa.bool_(): "bool", - pa.int64(): "bigint", - pa.binary(): "binary", - pa.date32(): "date", - } - - def to_db_decimal_type(self, column: TColumnSchema) -> pa.Decimal128Type: - precision, scale = self.decimal_precision(column.get("precision"), column.get("scale")) - return pa.decimal128(precision, scale) - - def to_db_datetime_type( - self, - column: TColumnSchema, - table: TTableSchema = None, - ) -> pa.TimestampType: - column_name = column.get("name") - timezone = column.get("timezone") - precision = column.get("precision") - if timezone is not None or precision is not None: - logger.warning( - "LanceDB does not currently support column flags for timezone or precision." - f" These flags were used in column '{column_name}'." - ) - unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] - return pa.timestamp(unit, "UTC") - - def to_db_time_type(self, column: TColumnSchema, table: TTableSchema = None) -> pa.Time64Type: - unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] - return pa.time64(unit) - - def from_db_type( - self, - db_type: pa.DataType, - precision: Optional[int] = None, - scale: Optional[int] = None, - ) -> TColumnType: - if isinstance(db_type, pa.TimestampType): - return dict( - data_type="timestamp", - precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], - scale=scale, - ) - if isinstance(db_type, pa.Time64Type): - return dict( - data_type="time", - precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], - scale=scale, - ) - if isinstance(db_type, pa.Decimal128Type): - precision, scale = db_type.precision, db_type.scale - if (precision, scale) == self.capabilities.wei_precision: - return cast(TColumnType, dict(data_type="wei")) - return dict(data_type="decimal", precision=precision, scale=scale) - return super().from_db_type(db_type, precision, scale) - - def upload_batch( records: List[DictStrAny], /, @@ -225,7 +149,7 @@ def __init__( read_consistency_interval=timedelta(0), ) self.registry = EmbeddingFunctionRegistry.get_instance() - self.type_mapper = LanceDBTypeMapper(self.capabilities) + self.type_mapper = self.capabilities.get_type_mapper() self.sentinel_table_name = config.sentinel_table_name embedding_model_provider = self.config.embedding_model_provider @@ -371,9 +295,7 @@ def update_stored_schema( only_tables: Iterable[str] = None, expected_update: TSchemaTables = None, ) -> Optional[TSchemaTables]: - super().update_stored_schema(only_tables, expected_update) - applied_update: TSchemaTables = {} - + applied_update = super().update_stored_schema(only_tables, expected_update) try: schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash) except DestinationUndefinedEntity: @@ -384,6 +306,7 @@ def update_stored_schema( f"Schema with hash {self.schema.stored_version_hash} " "not found in the storage. upgrading" ) + # TODO: return a real updated table schema (like in SQL job client) self._execute_schema_update(only_tables) else: logger.info( @@ -391,6 +314,8 @@ def update_stored_schema( f"inserted at {schema_info.inserted_at} found " "in storage, no upgrade required" ) + # we assume that expected_update == applied_update so table schemas in dest were not + # externally changed return applied_update def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]: @@ -410,7 +335,7 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns] name = self.schema.naming.normalize_identifier(field.name) table_schema[name] = { "name": name, - **self.type_mapper.from_db_type(field.type), + **self.type_mapper.from_destination_type(field.type, None, None), } return True, table_schema @@ -683,7 +608,7 @@ def complete_load(self, load_id: str) -> None: ) def create_load_job( - self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: return LanceDBLoadJob( file_path=file_path, @@ -702,12 +627,12 @@ class LanceDBLoadJob(RunnableLoadJob): def __init__( self, file_path: str, - type_mapper: LanceDBTypeMapper, + type_mapper: DataTypeMapper, model_func: TextEmbeddingFunction, fq_table_name: str, ) -> None: super().__init__(file_path) - self._type_mapper: TypeMapper = type_mapper + self._type_mapper = type_mapper self._fq_table_name: str = fq_table_name self._model_func = model_func self._job_client: "LanceDBClient" = None diff --git a/dlt/destinations/impl/lancedb/schema.py b/dlt/destinations/impl/lancedb/schema.py index c7cceec274..27c6fb33a1 100644 --- a/dlt/destinations/impl/lancedb/schema.py +++ b/dlt/destinations/impl/lancedb/schema.py @@ -13,7 +13,8 @@ from dlt.common.schema import Schema, TColumnSchema from dlt.common.typing import DictStrAny -from dlt.destinations.type_mapping import TypeMapper + +from dlt.common.destination.capabilities import DataTypeMapper TArrowSchema: TypeAlias = pa.Schema @@ -30,17 +31,17 @@ def arrow_schema_to_dict(schema: TArrowSchema) -> DictStrAny: def make_arrow_field_schema( column_name: str, column: TColumnSchema, - type_mapper: TypeMapper, + type_mapper: DataTypeMapper, ) -> TArrowField: """Creates a PyArrow field from a dlt column schema.""" - dtype = cast(TArrowDataType, type_mapper.to_db_type(column)) + dtype = cast(TArrowDataType, type_mapper.to_destination_type(column, None)) return pa.field(column_name, dtype) def make_arrow_table_schema( table_name: str, schema: Schema, - type_mapper: TypeMapper, + type_mapper: DataTypeMapper, id_field_name: Optional[str] = None, vector_field_name: Optional[str] = None, embedding_fields: Optional[List[str]] = None, diff --git a/dlt/destinations/impl/lancedb/type_mapper.py b/dlt/destinations/impl/lancedb/type_mapper.py new file mode 100644 index 0000000000..6f6d685d8e --- /dev/null +++ b/dlt/destinations/impl/lancedb/type_mapper.py @@ -0,0 +1,85 @@ +from typing import Dict, Optional, cast +from dlt.common import logger +from dlt.common.destination.typing import PreparedTableSchema +from dlt.common.libs.pyarrow import pyarrow as pa + +from dlt.common.schema.typing import TColumnSchema, TColumnType +from dlt.destinations.type_mapping import TypeMapperImpl + +TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"} +UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()} + + +# TODO: TypeMapperImpl must be a Generic where pa.DataType will be a concrete class +class LanceDBTypeMapper(TypeMapperImpl): + sct_to_unbound_dbt = { + "text": pa.string(), + "double": pa.float64(), + "bool": pa.bool_(), + "bigint": pa.int64(), + "binary": pa.binary(), + "date": pa.date32(), + "complex": pa.string(), + } + + sct_to_dbt = {} + + dbt_to_sct = { + pa.string(): "text", + pa.float64(): "double", + pa.bool_(): "bool", + pa.int64(): "bigint", + pa.binary(): "binary", + pa.date32(): "date", + } + + def to_db_decimal_type(self, column: TColumnSchema) -> pa.Decimal128Type: + precision, scale = self.decimal_precision(column.get("precision"), column.get("scale")) + return pa.decimal128(precision, scale) + + def to_db_datetime_type( + self, + column: TColumnSchema, + table: PreparedTableSchema = None, + ) -> pa.TimestampType: + column_name = column.get("name") + timezone = column.get("timezone") + precision = column.get("precision") + if timezone is not None or precision is not None: + logger.warning( + "LanceDB does not currently support column flags for timezone or precision." + f" These flags were used in column '{column_name}'." + ) + unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] + return pa.timestamp(unit, "UTC") + + def to_db_time_type( + self, column: TColumnSchema, table: PreparedTableSchema = None + ) -> pa.Time64Type: + unit: str = TIMESTAMP_PRECISION_TO_UNIT[self.capabilities.timestamp_precision] + return pa.time64(unit) + + def from_destination_type( + self, + db_type: pa.DataType, + precision: Optional[int] = None, + scale: Optional[int] = None, + ) -> TColumnType: + if isinstance(db_type, pa.TimestampType): + return dict( + data_type="timestamp", + precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], + scale=scale, + ) + if isinstance(db_type, pa.Time64Type): + return dict( + data_type="time", + precision=UNIT_TO_TIMESTAMP_PRECISION[db_type.unit], + scale=scale, + ) + if isinstance(db_type, pa.Decimal128Type): + precision, scale = db_type.precision, db_type.scale + if (precision, scale) == self.capabilities.wei_precision: + return cast(TColumnType, dict(data_type="wei")) + return dict(data_type="decimal", precision=precision, scale=scale) + return super().from_destination_type(db_type, precision, scale) diff --git a/dlt/destinations/impl/motherduck/factory.py b/dlt/destinations/impl/motherduck/factory.py index 0f4218f7cb..ac5dc70b57 100644 --- a/dlt/destinations/impl/motherduck/factory.py +++ b/dlt/destinations/impl/motherduck/factory.py @@ -4,6 +4,7 @@ from dlt.common.data_writers.escape import escape_postgres_identifier, escape_duckdb_literal from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.destinations.impl.duckdb.factory import DuckDbTypeMapper from dlt.destinations.impl.motherduck.configuration import ( MotherDuckCredentials, MotherDuckClientConfiguration, @@ -21,6 +22,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = "parquet" caps.supported_loader_file_formats = ["parquet", "insert_values", "jsonl"] + caps.type_mapper = DuckDbTypeMapper caps.escape_identifier = escape_postgres_identifier # all identifiers are case insensitive but are stored as is caps.escape_literal = escape_duckdb_literal diff --git a/dlt/destinations/impl/mssql/factory.py b/dlt/destinations/impl/mssql/factory.py index f1a8bb136a..c2767cedb9 100644 --- a/dlt/destinations/impl/mssql/factory.py +++ b/dlt/destinations/impl/mssql/factory.py @@ -1,16 +1,83 @@ import typing as t from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.destination.typing import PreparedTableSchema +from dlt.common.exceptions import TerminalValueError from dlt.common.normalizers.naming.naming import NamingConvention from dlt.common.data_writers.escape import escape_postgres_identifier, escape_mssql_literal from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.schema.typing import TColumnSchema, TColumnType +from dlt.destinations.type_mapping import TypeMapperImpl from dlt.destinations.impl.mssql.configuration import MsSqlCredentials, MsSqlClientConfiguration if t.TYPE_CHECKING: from dlt.destinations.impl.mssql.mssql import MsSqlJobClient +class MsSqlTypeMapper(TypeMapperImpl): + sct_to_unbound_dbt = { + "complex": "nvarchar(max)", + "text": "nvarchar(max)", + "double": "float", + "bool": "bit", + "bigint": "bigint", + "binary": "varbinary(max)", + "date": "date", + "timestamp": "datetimeoffset", + "time": "time", + } + + sct_to_dbt = { + "complex": "nvarchar(%i)", + "text": "nvarchar(%i)", + "timestamp": "datetimeoffset(%i)", + "binary": "varbinary(%i)", + "decimal": "decimal(%i,%i)", + "time": "time(%i)", + "wei": "decimal(%i,%i)", + } + + dbt_to_sct = { + "nvarchar": "text", + "float": "double", + "bit": "bool", + "datetimeoffset": "timestamp", + "date": "date", + "bigint": "bigint", + "varbinary": "binary", + "decimal": "decimal", + "time": "time", + "tinyint": "bigint", + "smallint": "bigint", + "int": "bigint", + } + + def to_db_integer_type(self, column: TColumnSchema, table: PreparedTableSchema = None) -> str: + precision = column.get("precision") + if precision is None: + return "bigint" + if precision <= 8: + return "tinyint" + if precision <= 16: + return "smallint" + if precision <= 32: + return "int" + elif precision <= 64: + return "bigint" + raise TerminalValueError( + f"bigint with {precision} bits precision cannot be mapped into mssql integer type" + ) + + def from_destination_type( + self, db_type: str, precision: t.Optional[int], scale: t.Optional[int] + ) -> TColumnType: + if db_type == "decimal": + if (precision, scale) == self.capabilities.wei_precision: + return dict(data_type="wei") + return super().from_destination_type(db_type, precision, scale) + + class mssql(Destination[MsSqlClientConfiguration, "MsSqlJobClient"]): spec = MsSqlClientConfiguration @@ -20,6 +87,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supported_loader_file_formats = ["insert_values"] caps.preferred_staging_file_format = None caps.supported_staging_file_formats = [] + caps.type_mapper = MsSqlTypeMapper # mssql is by default case insensitive and stores identifiers as is # case sensitivity can be changed by database collation so we allow to reconfigure # capabilities in the mssql factory diff --git a/dlt/destinations/impl/mssql/mssql.py b/dlt/destinations/impl/mssql/mssql.py index a7e796b2d8..9eabfcf392 100644 --- a/dlt/destinations/impl/mssql/mssql.py +++ b/dlt/destinations/impl/mssql/mssql.py @@ -1,10 +1,9 @@ from typing import Dict, Optional, Sequence, List, Any -from dlt.common.exceptions import TerminalValueError -from dlt.common.destination.reference import FollowupJobRequest +from dlt.common.destination.reference import FollowupJobRequest, PreparedTableSchema from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.schema import TColumnSchema, TColumnHint, Schema -from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat +from dlt.common.schema.typing import TColumnType from dlt.destinations.sql_jobs import SqlStagingCopyFollowupJob, SqlMergeFollowupJob, SqlJobParams @@ -13,7 +12,6 @@ from dlt.destinations.impl.mssql.sql_client import PyOdbcMsSqlClient from dlt.destinations.impl.mssql.configuration import MsSqlClientConfiguration from dlt.destinations.sql_client import SqlClientBase -from dlt.destinations.type_mapping import TypeMapper HINT_TO_MSSQL_ATTR: Dict[TColumnHint, str] = {"unique": "UNIQUE"} @@ -21,74 +19,11 @@ VARBINARY_MAX_N: int = 8000 -class MsSqlTypeMapper(TypeMapper): - sct_to_unbound_dbt = { - "complex": "nvarchar(max)", - "text": "nvarchar(max)", - "double": "float", - "bool": "bit", - "bigint": "bigint", - "binary": "varbinary(max)", - "date": "date", - "timestamp": "datetimeoffset", - "time": "time", - } - - sct_to_dbt = { - "complex": "nvarchar(%i)", - "text": "nvarchar(%i)", - "timestamp": "datetimeoffset(%i)", - "binary": "varbinary(%i)", - "decimal": "decimal(%i,%i)", - "time": "time(%i)", - "wei": "decimal(%i,%i)", - } - - dbt_to_sct = { - "nvarchar": "text", - "float": "double", - "bit": "bool", - "datetimeoffset": "timestamp", - "date": "date", - "bigint": "bigint", - "varbinary": "binary", - "decimal": "decimal", - "time": "time", - "tinyint": "bigint", - "smallint": "bigint", - "int": "bigint", - } - - def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: - precision = column.get("precision") - if precision is None: - return "bigint" - if precision <= 8: - return "tinyint" - if precision <= 16: - return "smallint" - if precision <= 32: - return "int" - elif precision <= 64: - return "bigint" - raise TerminalValueError( - f"bigint with {precision} bits precision cannot be mapped into mssql integer type" - ) - - def from_db_type( - self, db_type: str, precision: Optional[int], scale: Optional[int] - ) -> TColumnType: - if db_type == "decimal": - if (precision, scale) == self.capabilities.wei_precision: - return dict(data_type="wei") - return super().from_db_type(db_type, precision, scale) - - class MsSqlStagingCopyJob(SqlStagingCopyFollowupJob): @classmethod def generate_sql( cls, - table_chain: Sequence[TTableSchema], + table_chain: Sequence[PreparedTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None, ) -> List[str]: @@ -136,8 +71,7 @@ def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: @classmethod def _new_temp_table_name(cls, name_prefix: str, sql_client: SqlClientBase[Any]) -> str: - name = SqlMergeFollowupJob._new_temp_table_name(name_prefix, sql_client) - return "#" + name + return SqlMergeFollowupJob._new_temp_table_name("#" + name_prefix, sql_client) class MsSqlJobClient(InsertValuesJobClient): @@ -157,26 +91,26 @@ def __init__( self.config: MsSqlClientConfiguration = config self.sql_client = sql_client self.active_hints = HINT_TO_MSSQL_ATTR if self.config.create_indexes else {} - self.type_mapper = MsSqlTypeMapper(self.capabilities) + self.type_mapper = capabilities.get_type_mapper() def _create_merge_followup_jobs( - self, table_chain: Sequence[TTableSchema] + self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: return [MsSqlMergeJob.from_table_chain(table_chain, self.sql_client)] def _make_add_column_sql( - self, new_columns: Sequence[TColumnSchema], table: TTableSchema = None + self, new_columns: Sequence[TColumnSchema], table: PreparedTableSchema = None ) -> List[str]: # Override because mssql requires multiple columns in a single ADD COLUMN clause return ["ADD \n" + ",\n".join(self._get_column_def_sql(c, table) for c in new_columns)] - def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: sc_type = c["data_type"] if sc_type == "text" and c.get("unique"): # MSSQL does not allow index on large TEXT columns db_type = "nvarchar(%i)" % (c.get("precision") or 900) else: - db_type = self.type_mapper.to_db_type(c, table) + db_type = self.type_mapper.to_destination_type(c, table) hints_str = " ".join( self.active_hints.get(h, "") @@ -187,7 +121,7 @@ def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> s return f"{column_name} {db_type} {hints_str} {self._gen_not_null(c.get('nullable', True))}" def _create_replace_followup_jobs( - self, table_chain: Sequence[TTableSchema] + self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: if self.config.replace_strategy == "staging-optimized": return [MsSqlStagingCopyJob.from_table_chain(table_chain, self.sql_client)] @@ -196,4 +130,4 @@ def _create_replace_followup_jobs( def _from_db_type( self, pq_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: - return self.type_mapper.from_db_type(pq_t, precision, scale) + return self.type_mapper.from_destination_type(pq_t, precision, scale) diff --git a/dlt/destinations/impl/postgres/factory.py b/dlt/destinations/impl/postgres/factory.py index e14aa61465..b6a95e902e 100644 --- a/dlt/destinations/impl/postgres/factory.py +++ b/dlt/destinations/impl/postgres/factory.py @@ -4,8 +4,12 @@ from dlt.common.destination import Destination, DestinationCapabilitiesContext from dlt.common.data_writers.escape import escape_postgres_identifier, escape_postgres_literal from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.destination.typing import PreparedTableSchema +from dlt.common.exceptions import TerminalValueError +from dlt.common.schema.typing import TColumnSchema, TColumnType from dlt.common.wei import EVM_DECIMAL_PRECISION +from dlt.destinations.type_mapping import TypeMapperImpl from dlt.destinations.impl.postgres.configuration import ( PostgresCredentials, PostgresClientConfiguration, @@ -15,6 +19,101 @@ from dlt.destinations.impl.postgres.postgres import PostgresClient +class PostgresTypeMapper(TypeMapperImpl): + sct_to_unbound_dbt = { + "complex": "jsonb", + "text": "varchar", + "double": "double precision", + "bool": "boolean", + "date": "date", + "bigint": "bigint", + "binary": "bytea", + "timestamp": "timestamp with time zone", + "time": "time without time zone", + } + + sct_to_dbt = { + "text": "varchar(%i)", + "timestamp": "timestamp (%i) with time zone", + "decimal": "numeric(%i,%i)", + "time": "time (%i) without time zone", + "wei": "numeric(%i,%i)", + } + + dbt_to_sct = { + "varchar": "text", + "jsonb": "complex", + "double precision": "double", + "boolean": "bool", + "timestamp with time zone": "timestamp", + "timestamp without time zone": "timestamp", + "date": "date", + "bigint": "bigint", + "bytea": "binary", + "numeric": "decimal", + "time without time zone": "time", + "character varying": "text", + "smallint": "bigint", + "integer": "bigint", + } + + def to_db_integer_type(self, column: TColumnSchema, table: PreparedTableSchema = None) -> str: + precision = column.get("precision") + if precision is None: + return "bigint" + # Precision is number of bits + if precision <= 16: + return "smallint" + elif precision <= 32: + return "integer" + elif precision <= 64: + return "bigint" + raise TerminalValueError( + f"bigint with {precision} bits precision cannot be mapped into postgres integer type" + ) + + def to_db_datetime_type( + self, + column: TColumnSchema, + table: PreparedTableSchema = None, + ) -> str: + column_name = column.get("name") + table_name = table.get("name") + timezone = column.get("timezone") + precision = column.get("precision") + + if timezone is None and precision is None: + return None + + timestamp = "timestamp" + + # append precision if specified and valid + if precision is not None: + if 0 <= precision <= 6: + timestamp += f" ({precision})" + else: + raise TerminalValueError( + f"Postgres does not support precision '{precision}' for '{column_name}' in" + f" table '{table_name}'" + ) + + # append timezone part + if timezone is None or timezone: # timezone True and None + timestamp += " with time zone" + else: # timezone is explicitly False + timestamp += " without time zone" + + return timestamp + + def from_destination_type( + self, db_type: str, precision: t.Optional[int] = None, scale: t.Optional[int] = None + ) -> TColumnType: + if db_type == "numeric": + if (precision, scale) == self.capabilities.wei_precision: + return dict(data_type="wei") + return super().from_destination_type(db_type, precision, scale) + + class postgres(Destination[PostgresClientConfiguration, "PostgresClient"]): spec = PostgresClientConfiguration @@ -25,6 +124,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supported_loader_file_formats = ["insert_values", "csv"] caps.preferred_staging_file_format = None caps.supported_staging_file_formats = [] + caps.type_mapper = PostgresTypeMapper caps.escape_identifier = escape_postgres_identifier # postgres has case sensitive identifiers but by default # it folds them to lower case which makes them case insensitive diff --git a/dlt/destinations/impl/postgres/postgres.py b/dlt/destinations/impl/postgres/postgres.py index 5777e46c90..682f70da04 100644 --- a/dlt/destinations/impl/postgres/postgres.py +++ b/dlt/destinations/impl/postgres/postgres.py @@ -8,6 +8,7 @@ ) from dlt.common.destination.reference import ( HasFollowupJobs, + PreparedTableSchema, RunnableLoadJob, FollowupJobRequest, LoadJob, @@ -16,7 +17,8 @@ from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.exceptions import TerminalValueError from dlt.common.schema import TColumnSchema, TColumnHint, Schema -from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat +from dlt.common.schema.typing import TColumnType, TTableFormat +from dlt.common.schema.utils import is_nullable_column from dlt.common.storages.file_storage import FileStorage from dlt.destinations.sql_jobs import SqlStagingCopyFollowupJob, SqlJobParams @@ -24,110 +26,15 @@ from dlt.destinations.impl.postgres.sql_client import Psycopg2SqlClient from dlt.destinations.impl.postgres.configuration import PostgresClientConfiguration from dlt.destinations.sql_client import SqlClientBase -from dlt.destinations.type_mapping import TypeMapper HINT_TO_POSTGRES_ATTR: Dict[TColumnHint, str] = {"unique": "UNIQUE"} -class PostgresTypeMapper(TypeMapper): - sct_to_unbound_dbt = { - "complex": "jsonb", - "text": "varchar", - "double": "double precision", - "bool": "boolean", - "date": "date", - "bigint": "bigint", - "binary": "bytea", - "timestamp": "timestamp with time zone", - "time": "time without time zone", - } - - sct_to_dbt = { - "text": "varchar(%i)", - "timestamp": "timestamp (%i) with time zone", - "decimal": "numeric(%i,%i)", - "time": "time (%i) without time zone", - "wei": "numeric(%i,%i)", - } - - dbt_to_sct = { - "varchar": "text", - "jsonb": "complex", - "double precision": "double", - "boolean": "bool", - "timestamp with time zone": "timestamp", - "date": "date", - "bigint": "bigint", - "bytea": "binary", - "numeric": "decimal", - "time without time zone": "time", - "character varying": "text", - "smallint": "bigint", - "integer": "bigint", - } - - def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: - precision = column.get("precision") - if precision is None: - return "bigint" - # Precision is number of bits - if precision <= 16: - return "smallint" - elif precision <= 32: - return "integer" - elif precision <= 64: - return "bigint" - raise TerminalValueError( - f"bigint with {precision} bits precision cannot be mapped into postgres integer type" - ) - - def to_db_datetime_type( - self, - column: TColumnSchema, - table: TTableSchema = None, - ) -> str: - column_name = column.get("name") - table_name = table.get("name") - timezone = column.get("timezone") - precision = column.get("precision") - - if timezone is None and precision is None: - return None - - timestamp = "timestamp" - - # append precision if specified and valid - if precision is not None: - if 0 <= precision <= 6: - timestamp += f" ({precision})" - else: - raise TerminalValueError( - f"Postgres does not support precision '{precision}' for '{column_name}' in" - f" table '{table_name}'" - ) - - # append timezone part - if timezone is None or timezone: # timezone True and None - timestamp += " with time zone" - else: # timezone is explicitly False - timestamp += " without time zone" - - return timestamp - - def from_db_type( - self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None - ) -> TColumnType: - if db_type == "numeric": - if (precision, scale) == self.capabilities.wei_precision: - return dict(data_type="wei") - return super().from_db_type(db_type, precision, scale) - - class PostgresStagingCopyJob(SqlStagingCopyFollowupJob): @classmethod def generate_sql( cls, - table_chain: Sequence[TTableSchema], + table_chain: Sequence[PreparedTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None, ) -> List[str]: @@ -202,7 +109,7 @@ def run(self) -> None: for col in self._job_client.schema.get_table_columns(table_name).values(): norm_col = sql_client.escape_column_name(col["name"], escape=True) split_columns.append(norm_col) - if norm_col in split_headers and col.get("nullable", True): + if norm_col in split_headers and is_nullable_column(col): split_null_headers.append(norm_col) split_unknown_headers = set(split_headers).difference(split_columns) if split_unknown_headers: @@ -255,17 +162,17 @@ def __init__( self.config: PostgresClientConfiguration = config self.sql_client: Psycopg2SqlClient = sql_client self.active_hints = HINT_TO_POSTGRES_ATTR if self.config.create_indexes else {} - self.type_mapper = PostgresTypeMapper(self.capabilities) + self.type_mapper = self.capabilities.get_type_mapper() def create_load_job( - self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: job = super().create_load_job(table, file_path, load_id, restore) if not job and file_path.endswith("csv"): job = PostgresCsvCopyJob(file_path) return job - def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: hints_str = " ".join( self.active_hints.get(h, "") for h in self.active_hints.keys() @@ -273,11 +180,11 @@ def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> s ) column_name = self.sql_client.escape_column_name(c["name"]) return ( - f"{column_name} {self.type_mapper.to_db_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" + f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) def _create_replace_followup_jobs( - self, table_chain: Sequence[TTableSchema] + self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: if self.config.replace_strategy == "staging-optimized": return [PostgresStagingCopyJob.from_table_chain(table_chain, self.sql_client)] @@ -286,4 +193,4 @@ def _create_replace_followup_jobs( def _from_db_type( self, pq_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: - return self.type_mapper.from_db_type(pq_t, precision, scale) + return self.type_mapper.from_destination_type(pq_t, precision, scale) diff --git a/dlt/destinations/impl/qdrant/qdrant_adapter.py b/dlt/destinations/impl/qdrant/qdrant_adapter.py index e39d3e3644..abe301fff0 100644 --- a/dlt/destinations/impl/qdrant/qdrant_adapter.py +++ b/dlt/destinations/impl/qdrant/qdrant_adapter.py @@ -1,7 +1,7 @@ from typing import Any from dlt.common.schema.typing import TColumnNames, TTableSchemaColumns -from dlt.extract import DltResource, resource as make_resource +from dlt.extract import DltResource from dlt.destinations.utils import get_resource_for_adapter VECTORIZE_HINT = "x-qdrant-embed" diff --git a/dlt/destinations/impl/qdrant/qdrant_job_client.py b/dlt/destinations/impl/qdrant/qdrant_job_client.py index 65019c6626..bfe35d4081 100644 --- a/dlt/destinations/impl/qdrant/qdrant_job_client.py +++ b/dlt/destinations/impl/qdrant/qdrant_job_client.py @@ -5,7 +5,7 @@ from dlt.common import logger from dlt.common.json import json from dlt.common.pendulum import pendulum -from dlt.common.schema import Schema, TTableSchema, TSchemaTables +from dlt.common.schema import Schema, TSchemaTables from dlt.common.schema.utils import ( get_columns_names_with_prop, loads_table, @@ -14,6 +14,7 @@ ) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( + PreparedTableSchema, TLoadJobState, RunnableLoadJob, JobClientBase, @@ -99,11 +100,11 @@ def _get_embedding_doc(self, data: Dict[str, Any], embedding_fields: List[str]) doc = "\n".join(str(data[key]) for key in embedding_fields) return doc - def _list_unique_identifiers(self, table_schema: TTableSchema) -> Sequence[str]: + def _list_unique_identifiers(self, table_schema: PreparedTableSchema) -> Sequence[str]: """Returns a list of unique identifiers for a table. Args: - table_schema (TTableSchema): a dlt table schema. + table_schema (PreparedTableSchema): a dlt table schema. Returns: Sequence[str]: A list of unique column identifiers. @@ -291,8 +292,7 @@ def update_stored_schema( only_tables: Iterable[str] = None, expected_update: TSchemaTables = None, ) -> Optional[TSchemaTables]: - super().update_stored_schema(only_tables, expected_update) - applied_update: TSchemaTables = {} + applied_update = super().update_stored_schema(only_tables, expected_update) schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash) if schema_info is None: logger.info( @@ -306,6 +306,8 @@ def update_stored_schema( f"inserted at {schema_info.inserted_at} found " "in storage, no upgrade required" ) + # we assume that expected_update == applied_update so table schemas in dest were not + # externally changed return applied_update def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: @@ -440,7 +442,7 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI raise def create_load_job( - self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: return QDrantLoadJob( file_path, diff --git a/dlt/destinations/impl/redshift/factory.py b/dlt/destinations/impl/redshift/factory.py index ef1ee6b754..a96bd03e63 100644 --- a/dlt/destinations/impl/redshift/factory.py +++ b/dlt/destinations/impl/redshift/factory.py @@ -1,10 +1,17 @@ import typing as t +from dlt.common.data_types.typing import TDataType from dlt.common.destination import Destination, DestinationCapabilitiesContext from dlt.common.data_writers.escape import escape_redshift_identifier, escape_redshift_literal from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.destination.typing import PreparedTableSchema +from dlt.common.exceptions import TerminalValueError from dlt.common.normalizers.naming import NamingConvention +from dlt.common.schema.typing import TColumnSchema, TColumnType +from dlt.common.typing import TLoaderFileFormat + +from dlt.destinations.type_mapping import TypeMapperImpl from dlt.destinations.impl.redshift.configuration import ( RedshiftCredentials, RedshiftClientConfiguration, @@ -14,6 +21,92 @@ from dlt.destinations.impl.redshift.redshift import RedshiftClient +class RedshiftTypeMapper(TypeMapperImpl): + sct_to_unbound_dbt = { + "complex": "super", + "text": "varchar(max)", + "double": "double precision", + "bool": "boolean", + "date": "date", + "timestamp": "timestamp with time zone", + "bigint": "bigint", + "binary": "varbinary", + "time": "time without time zone", + } + + sct_to_dbt = { + "decimal": "numeric(%i,%i)", + "wei": "numeric(%i,%i)", + "text": "varchar(%i)", + "binary": "varbinary(%i)", + } + + dbt_to_sct = { + "super": "complex", + "varchar(max)": "text", + "double precision": "double", + "boolean": "bool", + "date": "date", + "timestamp with time zone": "timestamp", + "bigint": "bigint", + "binary varying": "binary", + "numeric": "decimal", + "time without time zone": "time", + "varchar": "text", + "smallint": "bigint", + "integer": "bigint", + } + + def ensure_supported_type( + self, + column: TColumnSchema, + table: PreparedTableSchema, + loader_file_format: TLoaderFileFormat, + ) -> None: + if loader_file_format == "insert_values": + return + # time not supported on staging file formats + if column["data_type"] == "time": + raise TerminalValueError( + "Please convert `datetime.time` objects in your data to `str` or" + " `datetime.datetime`.", + "time", + ) + if loader_file_format == "jsonl": + if column["data_type"] == "binary": + raise TerminalValueError("", "binary") + if loader_file_format == "parquet": + # binary not supported on parquet if precision is set + if column.get("precision") and column["data_type"] == "binary": + raise TerminalValueError( + "Redshift cannot load fixed width VARBYTE columns from parquet files. Switch" + " to other file format or use binary columns without precision.", + "binary", + ) + + def to_db_integer_type(self, column: TColumnSchema, table: PreparedTableSchema = None) -> str: + precision = column.get("precision") + if precision is None: + return "bigint" + if precision <= 16: + return "smallint" + elif precision <= 32: + return "integer" + elif precision <= 64: + return "bigint" + raise TerminalValueError( + f"bigint with {precision} bits precision cannot be mapped into postgres integer type" + ) + + def from_destination_type( + self, db_type: str, precision: t.Optional[int], scale: t.Optional[int] + ) -> TColumnType: + if db_type == "numeric": + if (precision, scale) == self.capabilities.wei_precision: + return dict(data_type="wei") + return super().from_destination_type(db_type, precision, scale) + + class redshift(Destination[RedshiftClientConfiguration, "RedshiftClient"]): spec = RedshiftClientConfiguration @@ -23,6 +116,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supported_loader_file_formats = ["insert_values"] caps.preferred_staging_file_format = "jsonl" caps.supported_staging_file_formats = ["jsonl", "parquet"] + caps.type_mapper = RedshiftTypeMapper # redshift is case insensitive and will lower case identifiers when stored # you can enable case sensitivity https://docs.aws.amazon.com/redshift/latest/dg/r_enable_case_sensitive_identifier.html # then redshift behaves like postgres diff --git a/dlt/destinations/impl/redshift/redshift.py b/dlt/destinations/impl/redshift/redshift.py index 9bba60af07..ed81b02ab4 100644 --- a/dlt/destinations/impl/redshift/redshift.py +++ b/dlt/destinations/impl/redshift/redshift.py @@ -16,26 +16,23 @@ from dlt.common.destination.reference import ( FollowupJobRequest, CredentialsConfiguration, + PreparedTableSchema, SupportsStagingDestination, LoadJob, ) -from dlt.common.data_types import TDataType from dlt.common.destination.capabilities import DestinationCapabilitiesContext from dlt.common.schema import TColumnSchema, TColumnHint, Schema -from dlt.common.exceptions import TerminalValueError -from dlt.common.schema.utils import table_schema_has_type, table_schema_has_type_with_precision -from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat, TTableSchemaColumns +from dlt.common.schema.utils import table_schema_has_type +from dlt.common.schema.typing import TColumnType from dlt.common.configuration.specs import AwsCredentialsWithoutDefaults from dlt.destinations.insert_job_client import InsertValuesJobClient from dlt.destinations.sql_jobs import SqlMergeFollowupJob -from dlt.destinations.exceptions import DatabaseTerminalException, LoadJobTerminalException +from dlt.destinations.exceptions import DatabaseTerminalException from dlt.destinations.job_client_impl import CopyRemoteFileLoadJob from dlt.destinations.impl.postgres.sql_client import Psycopg2SqlClient from dlt.destinations.impl.redshift.configuration import RedshiftClientConfiguration from dlt.destinations.job_impl import ReferenceFollowupJobRequest -from dlt.destinations.sql_client import SqlClientBase -from dlt.destinations.type_mapping import TypeMapper HINT_TO_REDSHIFT_ATTR: Dict[TColumnHint, str] = { @@ -46,65 +43,6 @@ } -class RedshiftTypeMapper(TypeMapper): - sct_to_unbound_dbt = { - "complex": "super", - "text": "varchar(max)", - "double": "double precision", - "bool": "boolean", - "date": "date", - "timestamp": "timestamp with time zone", - "bigint": "bigint", - "binary": "varbinary", - "time": "time without time zone", - } - - sct_to_dbt = { - "decimal": "numeric(%i,%i)", - "wei": "numeric(%i,%i)", - "text": "varchar(%i)", - "binary": "varbinary(%i)", - } - - dbt_to_sct = { - "super": "complex", - "varchar(max)": "text", - "double precision": "double", - "boolean": "bool", - "date": "date", - "timestamp with time zone": "timestamp", - "bigint": "bigint", - "binary varying": "binary", - "numeric": "decimal", - "time without time zone": "time", - "varchar": "text", - "smallint": "bigint", - "integer": "bigint", - } - - def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: - precision = column.get("precision") - if precision is None: - return "bigint" - if precision <= 16: - return "smallint" - elif precision <= 32: - return "integer" - elif precision <= 64: - return "bigint" - raise TerminalValueError( - f"bigint with {precision} bits precision cannot be mapped into postgres integer type" - ) - - def from_db_type( - self, db_type: str, precision: Optional[int], scale: Optional[int] - ) -> TColumnType: - if db_type == "numeric": - if (precision, scale) == self.capabilities.wei_precision: - return dict(data_type="wei") - return super().from_db_type(db_type, precision, scale) - - class RedshiftSqlClient(Psycopg2SqlClient): @staticmethod def _maybe_make_terminal_exception_from_data_error( @@ -152,30 +90,11 @@ def run(self) -> None: file_type = "" dateformat = "" compression = "" - if table_schema_has_type(self._load_table, "time"): - raise LoadJobTerminalException( - self.file_name(), - f"Redshift cannot load TIME columns from {ext} files. Switch to direct INSERT file" - " format or convert `datetime.time` objects in your data to `str` or" - " `datetime.datetime`", - ) if ext == "jsonl": - if table_schema_has_type(self._load_table, "binary"): - raise LoadJobTerminalException( - self.file_name(), - "Redshift cannot load VARBYTE columns from json files. Switch to parquet to" - " load binaries.", - ) file_type = "FORMAT AS JSON 'auto'" dateformat = "dateformat 'auto' timeformat 'auto'" compression = "GZIP" elif ext == "parquet": - if table_schema_has_type_with_precision(self._load_table, "binary"): - raise LoadJobTerminalException( - self.file_name(), - f"Redshift cannot load fixed width VARBYTE columns from {ext} files. Switch to" - " direct INSERT file format or use binary columns without precision.", - ) file_type = "PARQUET" # if table contains complex types then SUPER field will be used. # https://docs.aws.amazon.com/redshift/latest/dg/ingest-super.html @@ -235,14 +154,14 @@ def __init__( super().__init__(schema, config, sql_client) self.sql_client = sql_client self.config: RedshiftClientConfiguration = config - self.type_mapper = RedshiftTypeMapper(self.capabilities) + self.type_mapper = self.capabilities.get_type_mapper() def _create_merge_followup_jobs( - self, table_chain: Sequence[TTableSchema] + self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: return [RedshiftMergeJob.from_table_chain(table_chain, self.sql_client)] - def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: hints_str = " ".join( HINT_TO_REDSHIFT_ATTR.get(h, "") for h in HINT_TO_REDSHIFT_ATTR.keys() @@ -250,11 +169,11 @@ def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> s ) column_name = self.sql_client.escape_column_name(c["name"]) return ( - f"{column_name} {self.type_mapper.to_db_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" + f"{column_name} {self.type_mapper.to_destination_type(c,table)} {hints_str} {self._gen_not_null(c.get('nullable', True))}" ) def create_load_job( - self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: """Starts SqlLoadJob for files ending with .sql or returns None to let derived classes to handle their specific jobs""" job = super().create_load_job(table, file_path, load_id, restore) @@ -272,7 +191,7 @@ def create_load_job( def _from_db_type( self, pq_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: - return self.type_mapper.from_db_type(pq_t, precision, scale) + return self.type_mapper.from_destination_type(pq_t, precision, scale) - def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: return self.config.truncate_tables_on_staging_destination_before_load diff --git a/dlt/destinations/impl/snowflake/factory.py b/dlt/destinations/impl/snowflake/factory.py index c5fbd8600b..0013bfb5e2 100644 --- a/dlt/destinations/impl/snowflake/factory.py +++ b/dlt/destinations/impl/snowflake/factory.py @@ -4,7 +4,11 @@ from dlt.common.destination import Destination, DestinationCapabilitiesContext from dlt.common.data_writers.escape import escape_snowflake_identifier from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.destination.typing import PreparedTableSchema +from dlt.common.exceptions import TerminalValueError +from dlt.common.schema.typing import TColumnSchema, TColumnType +from dlt.destinations.type_mapping import TypeMapperImpl from dlt.destinations.impl.snowflake.configuration import ( SnowflakeCredentials, SnowflakeClientConfiguration, @@ -14,6 +18,80 @@ from dlt.destinations.impl.snowflake.snowflake import SnowflakeClient +class SnowflakeTypeMapper(TypeMapperImpl): + BIGINT_PRECISION = 19 + sct_to_unbound_dbt = { + "complex": "VARIANT", + "text": "VARCHAR", + "double": "FLOAT", + "bool": "BOOLEAN", + "date": "DATE", + "timestamp": "TIMESTAMP_TZ", + "bigint": f"NUMBER({BIGINT_PRECISION},0)", # Snowflake has no integer types + "binary": "BINARY", + "time": "TIME", + } + + sct_to_dbt = { + "text": "VARCHAR(%i)", + "timestamp": "TIMESTAMP_TZ(%i)", + "decimal": "NUMBER(%i,%i)", + "time": "TIME(%i)", + "wei": "NUMBER(%i,%i)", + } + + dbt_to_sct = { + "VARCHAR": "text", + "FLOAT": "double", + "BOOLEAN": "bool", + "DATE": "date", + "TIMESTAMP_TZ": "timestamp", + "BINARY": "binary", + "VARIANT": "complex", + "TIME": "time", + } + + def from_destination_type( + self, db_type: str, precision: t.Optional[int] = None, scale: t.Optional[int] = None + ) -> TColumnType: + if db_type == "NUMBER": + if precision == self.BIGINT_PRECISION and scale == 0: + return dict(data_type="bigint") + elif (precision, scale) == self.capabilities.wei_precision: + return dict(data_type="wei") + return dict(data_type="decimal", precision=precision, scale=scale) + if db_type == "TIMESTAMP_NTZ": + return dict(data_type="timestamp", precision=precision, scale=scale, timezone=False) + return super().from_destination_type(db_type, precision, scale) + + def to_db_datetime_type( + self, + column: TColumnSchema, + table: PreparedTableSchema = None, + ) -> str: + timezone = column.get("timezone", True) + precision = column.get("precision") + + if timezone and precision is None: + return None + + timestamp = "TIMESTAMP_TZ" if timezone else "TIMESTAMP_NTZ" + + # append precision if specified and valid + if precision is not None: + if 0 <= precision <= 9: + timestamp += f"({precision})" + else: + column_name = column["name"] + table_name = table["name"] + raise TerminalValueError( + f"Snowflake does not support precision '{precision}' for '{column_name}' in" + f" table '{table_name}'" + ) + + return timestamp + + class snowflake(Destination[SnowflakeClientConfiguration, "SnowflakeClient"]): spec = SnowflakeClientConfiguration @@ -23,6 +101,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supported_loader_file_formats = ["jsonl", "parquet", "csv"] caps.preferred_staging_file_format = "jsonl" caps.supported_staging_file_formats = ["jsonl", "parquet", "csv"] + caps.type_mapper = SnowflakeTypeMapper # snowflake is case sensitive but all unquoted identifiers are upper cased # so upper case identifiers are considered case insensitive caps.escape_identifier = escape_snowflake_identifier diff --git a/dlt/destinations/impl/snowflake/snowflake.py b/dlt/destinations/impl/snowflake/snowflake.py index 247b3233d0..41a8384754 100644 --- a/dlt/destinations/impl/snowflake/snowflake.py +++ b/dlt/destinations/impl/snowflake/snowflake.py @@ -6,6 +6,7 @@ from dlt.common.destination.reference import ( HasFollowupJobs, LoadJob, + PreparedTableSchema, RunnableLoadJob, CredentialsConfiguration, SupportsStagingDestination, @@ -16,96 +17,18 @@ ) from dlt.common.storages.configuration import FilesystemConfiguration from dlt.common.storages.file_storage import FileStorage -from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns -from dlt.common.schema.typing import TTableSchema, TColumnType, TTableFormat +from dlt.common.schema import TColumnSchema, Schema +from dlt.common.schema.typing import TColumnType from dlt.common.exceptions import TerminalValueError -from dlt.common.storages.load_package import ParsedLoadJobFileName from dlt.common.typing import TLoaderFileFormat -from dlt.destinations.job_client_impl import SqlJobClientWithStaging -from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs +from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset from dlt.destinations.exceptions import LoadJobTerminalException from dlt.destinations.impl.snowflake.configuration import SnowflakeClientConfiguration from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient from dlt.destinations.impl.snowflake.sql_client import SnowflakeSqlClient from dlt.destinations.job_impl import ReferenceFollowupJobRequest -from dlt.destinations.type_mapping import TypeMapper - - -class SnowflakeTypeMapper(TypeMapper): - BIGINT_PRECISION = 19 - sct_to_unbound_dbt = { - "complex": "VARIANT", - "text": "VARCHAR", - "double": "FLOAT", - "bool": "BOOLEAN", - "date": "DATE", - "timestamp": "TIMESTAMP_TZ", - "bigint": f"NUMBER({BIGINT_PRECISION},0)", # Snowflake has no integer types - "binary": "BINARY", - "time": "TIME", - } - - sct_to_dbt = { - "text": "VARCHAR(%i)", - "timestamp": "TIMESTAMP_TZ(%i)", - "decimal": "NUMBER(%i,%i)", - "time": "TIME(%i)", - "wei": "NUMBER(%i,%i)", - } - - dbt_to_sct = { - "VARCHAR": "text", - "FLOAT": "double", - "BOOLEAN": "bool", - "DATE": "date", - "TIMESTAMP_TZ": "timestamp", - "BINARY": "binary", - "VARIANT": "complex", - "TIME": "time", - } - - def from_db_type( - self, db_type: str, precision: Optional[int] = None, scale: Optional[int] = None - ) -> TColumnType: - if db_type == "NUMBER": - if precision == self.BIGINT_PRECISION and scale == 0: - return dict(data_type="bigint") - elif (precision, scale) == self.capabilities.wei_precision: - return dict(data_type="wei") - return dict(data_type="decimal", precision=precision, scale=scale) - return super().from_db_type(db_type, precision, scale) - - def to_db_datetime_type( - self, - column: TColumnSchema, - table: TTableSchema = None, - ) -> str: - column_name = column.get("name") - table_name = table.get("name") - timezone = column.get("timezone") - precision = column.get("precision") - - if timezone is None and precision is None: - return None - - timestamp = "TIMESTAMP_TZ" - - if timezone is not None and not timezone: # explicitaly handles timezone False - timestamp = "TIMESTAMP_NTZ" - - # append precision if specified and valid - if precision is not None: - if 0 <= precision <= 9: - timestamp += f"({precision})" - else: - raise TerminalValueError( - f"Snowflake does not support precision '{precision}' for '{column_name}' in" - f" table '{table_name}'" - ) - - return timestamp class SnowflakeLoadJob(RunnableLoadJob, HasFollowupJobs): @@ -282,7 +205,7 @@ def gen_copy_sql( """ -class SnowflakeClient(SqlJobClientWithStaging, SupportsStagingDestination): +class SnowflakeClient(SqlJobClientWithStagingDataset, SupportsStagingDestination): def __init__( self, schema: Schema, @@ -299,10 +222,10 @@ def __init__( super().__init__(schema, config, sql_client) self.config: SnowflakeClientConfiguration = config self.sql_client: SnowflakeSqlClient = sql_client # type: ignore - self.type_mapper = SnowflakeTypeMapper(self.capabilities) + self.type_mapper = self.capabilities.get_type_mapper() def create_load_job( - self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: job = super().create_load_job(table, file_path, load_id, restore) @@ -319,7 +242,7 @@ def create_load_job( return job def _make_add_column_sql( - self, new_columns: Sequence[TColumnSchema], table: TTableSchema = None + self, new_columns: Sequence[TColumnSchema], table: PreparedTableSchema = None ) -> List[str]: # Override because snowflake requires multiple columns in a single ADD COLUMN clause return [ @@ -347,13 +270,13 @@ def _get_table_update_sql( def _from_db_type( self, bq_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: - return self.type_mapper.from_db_type(bq_t, precision, scale) + return self.type_mapper.from_destination_type(bq_t, precision, scale) - def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: name = self.sql_client.escape_column_name(c["name"]) return ( - f"{name} {self.type_mapper.to_db_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" + f"{name} {self.type_mapper.to_destination_type(c,table)} {self._gen_not_null(c.get('nullable', True))}" ) - def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: return self.config.truncate_tables_on_staging_destination_before_load diff --git a/dlt/destinations/impl/synapse/factory.py b/dlt/destinations/impl/synapse/factory.py index d5a0281bec..f035f2f713 100644 --- a/dlt/destinations/impl/synapse/factory.py +++ b/dlt/destinations/impl/synapse/factory.py @@ -1,10 +1,15 @@ import typing as t -from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.common.data_types.typing import TDataType +from dlt.common.destination import Destination, DestinationCapabilitiesContext, PreparedTableSchema +from dlt.common.exceptions import TerminalValueError from dlt.common.normalizers.naming import NamingConvention from dlt.common.data_writers.escape import escape_postgres_identifier, escape_mssql_literal from dlt.common.arithmetics import DEFAULT_NUMERIC_PRECISION, DEFAULT_NUMERIC_SCALE +from dlt.common.schema.typing import TColumnSchema +from dlt.common.typing import TLoaderFileFormat +from dlt.destinations.impl.mssql.factory import MsSqlTypeMapper from dlt.destinations.impl.synapse.configuration import ( SynapseCredentials, SynapseClientConfiguration, @@ -15,6 +20,22 @@ from dlt.destinations.impl.synapse.synapse import SynapseClient +class SynapseTypeMapper(MsSqlTypeMapper): + def ensure_supported_type( + self, + column: TColumnSchema, + table: PreparedTableSchema, + loader_file_format: TLoaderFileFormat, + ) -> None: + # TIME is not supported for parquet + if loader_file_format == "parquet" and column["data_type"] == "time": + raise TerminalValueError( + "Please convert `datetime.time` objects in your data to `str` or" + " `datetime.datetime`.", + "time", + ) + + class synapse(Destination[SynapseClientConfiguration, "SynapseClient"]): spec = SynapseClientConfiguration @@ -30,6 +51,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps.supported_loader_file_formats = ["insert_values"] caps.preferred_staging_file_format = "parquet" caps.supported_staging_file_formats = ["parquet"] + caps.type_mapper = SynapseTypeMapper caps.insert_values_writer_type = "select_union" # https://stackoverflow.com/a/77014299 diff --git a/dlt/destinations/impl/synapse/synapse.py b/dlt/destinations/impl/synapse/synapse.py index 750a4895f0..15c979bafa 100644 --- a/dlt/destinations/impl/synapse/synapse.py +++ b/dlt/destinations/impl/synapse/synapse.py @@ -5,9 +5,14 @@ from urllib.parse import urlparse, urlunparse from dlt.common.destination import DestinationCapabilitiesContext -from dlt.common.destination.reference import SupportsStagingDestination, FollowupJobRequest, LoadJob +from dlt.common.destination.reference import ( + PreparedTableSchema, + SupportsStagingDestination, + FollowupJobRequest, + LoadJob, +) -from dlt.common.schema import TTableSchema, TColumnSchema, Schema, TColumnHint +from dlt.common.schema import TColumnSchema, Schema, TColumnHint from dlt.common.schema.utils import ( table_schema_has_type, get_inherited_table_hint, @@ -19,16 +24,15 @@ AzureServicePrincipalCredentialsWithoutDefaults, ) +from dlt.destinations.impl.mssql.factory import MsSqlTypeMapper from dlt.destinations.job_impl import ReferenceFollowupJobRequest from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.job_client_impl import ( SqlJobClientBase, CopyRemoteFileLoadJob, ) -from dlt.destinations.exceptions import LoadJobTerminalException from dlt.destinations.impl.mssql.mssql import ( - MsSqlTypeMapper, MsSqlJobClient, VARCHAR_MAX_N, VARBINARY_MAX_N, @@ -76,10 +80,16 @@ def __init__( def _get_table_update_sql( self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool ) -> List[str]: - table = self.prepare_load_table(table_name, staging=self.in_staging_mode) + table = self.prepare_load_table(table_name) + if self.in_staging_dataset_mode and self.config.replace_strategy == "insert-from-staging": + # Staging tables should always be heap tables, because "when you are + # temporarily landing data in dedicated SQL pool, you may find that + # using a heap table makes the overall process faster." + table[TABLE_INDEX_TYPE_HINT] = "heap" # type: ignore[typeddict-unknown-key] + table_index_type = cast(TTableIndexType, table.get(TABLE_INDEX_TYPE_HINT)) - if self.in_staging_mode: - final_table = self.prepare_load_table(table_name, staging=False) + if self.in_staging_dataset_mode: + final_table = self.prepare_load_table(table_name) final_table_index_type = cast(TTableIndexType, final_table.get(TABLE_INDEX_TYPE_HINT)) else: final_table_index_type = table_index_type @@ -130,18 +140,13 @@ def _get_columstore_valid_column(self, c: TColumnSchema) -> TColumnSchema: return c def _create_replace_followup_jobs( - self, table_chain: Sequence[TTableSchema] + self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: return SqlJobClientBase._create_replace_followup_jobs(self, table_chain) - def prepare_load_table(self, table_name: str, staging: bool = False) -> TTableSchema: - table = super().prepare_load_table(table_name, staging) - if staging and self.config.replace_strategy == "insert-from-staging": - # Staging tables should always be heap tables, because "when you are - # temporarily landing data in dedicated SQL pool, you may find that - # using a heap table makes the overall process faster." - table[TABLE_INDEX_TYPE_HINT] = "heap" # type: ignore[typeddict-unknown-key] - elif table_name in self.schema.dlt_table_names(): + def prepare_load_table(self, table_name: str) -> PreparedTableSchema: + table = super().prepare_load_table(table_name) + if table_name in self.schema.dlt_table_names(): # dlt tables should always be heap tables, because "for small lookup # tables, less than 60 million rows, consider using HEAP or clustered # index for faster query performance." @@ -159,7 +164,7 @@ def prepare_load_table(self, table_name: str, staging: bool = False) -> TTableSc return table def create_load_job( - self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: job = super().create_load_job(table, file_path, load_id, restore) if not job: @@ -173,7 +178,7 @@ def create_load_job( ) return job - def should_truncate_table_before_load_on_staging_destination(self, table: TTableSchema) -> bool: + def should_truncate_table_before_load_on_staging_destination(self, table_name: str) -> bool: return self.config.truncate_tables_on_staging_destination_before_load @@ -194,15 +199,6 @@ def run(self) -> None: # get format ext = os.path.splitext(self._bucket_path)[1][1:] if ext == "parquet": - if table_schema_has_type(self._load_table, "time"): - # Synapse interprets Parquet TIME columns as bigint, resulting in - # an incompatibility error. - raise LoadJobTerminalException( - self.file_name(), - "Synapse cannot load TIME columns from Parquet files. Switch to direct INSERT" - " file format or convert `datetime.time` objects in your data to `str` or" - " `datetime.datetime`", - ) file_type = "PARQUET" # dlt-generated DDL statements will still create the table, but diff --git a/dlt/destinations/impl/weaviate/factory.py b/dlt/destinations/impl/weaviate/factory.py index 3d78c9582a..1b6e90466d 100644 --- a/dlt/destinations/impl/weaviate/factory.py +++ b/dlt/destinations/impl/weaviate/factory.py @@ -2,6 +2,7 @@ from dlt.common.destination import Destination, DestinationCapabilitiesContext +from dlt.destinations.type_mapping import TypeMapperImpl from dlt.destinations.impl.weaviate.configuration import ( WeaviateCredentials, WeaviateClientConfiguration, @@ -11,6 +12,33 @@ from dlt.destinations.impl.weaviate.weaviate_client import WeaviateClient +class WeaviateTypeMapper(TypeMapperImpl): + sct_to_unbound_dbt = { + "text": "text", + "double": "number", + "bool": "boolean", + "timestamp": "date", + "date": "date", + "time": "text", + "bigint": "int", + "binary": "blob", + "decimal": "text", + "wei": "number", + "complex": "text", + } + + sct_to_dbt = {} + + dbt_to_sct = { + "text": "text", + "number": "double", + "boolean": "bool", + "date": "timestamp", + "int": "bigint", + "blob": "binary", + } + + class weaviate(Destination[WeaviateClientConfiguration, "WeaviateClient"]): spec = WeaviateClientConfiguration @@ -18,6 +46,7 @@ def _raw_capabilities(self) -> DestinationCapabilitiesContext: caps = DestinationCapabilitiesContext() caps.preferred_loader_file_format = "jsonl" caps.supported_loader_file_formats = ["jsonl"] + caps.type_mapper = WeaviateTypeMapper # weaviate names are case sensitive following GraphQL naming convention # https://weaviate.io/developers/weaviate/config-refs/schema caps.has_case_sensitive_identifiers = False diff --git a/dlt/destinations/impl/weaviate/weaviate_client.py b/dlt/destinations/impl/weaviate/weaviate_client.py index b8bf3d62c6..2f0112f90d 100644 --- a/dlt/destinations/impl/weaviate/weaviate_client.py +++ b/dlt/destinations/impl/weaviate/weaviate_client.py @@ -29,7 +29,7 @@ from dlt.common.pendulum import pendulum from dlt.common.typing import StrAny, TFun from dlt.common.time import ensure_pendulum_datetime -from dlt.common.schema import Schema, TTableSchema, TSchemaTables, TTableSchemaColumns +from dlt.common.schema import Schema, TSchemaTables, TTableSchemaColumns from dlt.common.schema.typing import TColumnSchema, TColumnType from dlt.common.schema.utils import ( get_columns_names_with_prop, @@ -39,6 +39,7 @@ ) from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.destination.reference import ( + PreparedTableSchema, TLoadJobState, RunnableLoadJob, JobClientBase, @@ -48,11 +49,9 @@ from dlt.common.storages import FileStorage from dlt.destinations.impl.weaviate.weaviate_adapter import VECTORIZE_HINT, TOKENIZATION_HINT -from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs from dlt.destinations.job_client_impl import StorageSchemaInfo, StateInfo from dlt.destinations.impl.weaviate.configuration import WeaviateClientConfiguration from dlt.destinations.impl.weaviate.exceptions import PropertyNameConflict, WeaviateGrpcError -from dlt.destinations.type_mapping import TypeMapper from dlt.destinations.utils import get_pipeline_state_query_columns @@ -64,33 +63,6 @@ } -class WeaviateTypeMapper(TypeMapper): - sct_to_unbound_dbt = { - "text": "text", - "double": "number", - "bool": "boolean", - "timestamp": "date", - "date": "date", - "time": "text", - "bigint": "int", - "binary": "blob", - "decimal": "text", - "wei": "number", - "complex": "text", - } - - sct_to_dbt = {} - - dbt_to_sct = { - "text": "text", - "number": "double", - "boolean": "bool", - "date": "timestamp", - "int": "bigint", - "blob": "binary", - } - - def wrap_weaviate_error(f: TFun) -> TFun: @wraps(f) def _wrap(self: JobClientBase, *args: Any, **kwargs: Any) -> Any: @@ -218,7 +190,7 @@ def check_batch_result(results: List[StrAny]) -> None: batch.add_data_object(data, self._class_name, uuid=uuid) - def list_unique_identifiers(self, table_schema: TTableSchema) -> Sequence[str]: + def list_unique_identifiers(self, table_schema: PreparedTableSchema) -> Sequence[str]: if table_schema.get("write_disposition") == "merge": primary_keys = get_columns_names_with_prop(table_schema, "primary_key") if primary_keys: @@ -259,7 +231,7 @@ def __init__( "vectorizer": config.vectorizer, "moduleConfig": config.module_config, } - self.type_mapper = WeaviateTypeMapper(self.capabilities) + self.type_mapper = self.capabilities.get_type_mapper() @property def dataset_name(self) -> str: @@ -435,9 +407,8 @@ def update_stored_schema( only_tables: Iterable[str] = None, expected_update: TSchemaTables = None, ) -> Optional[TSchemaTables]: - super().update_stored_schema(only_tables, expected_update) + applied_update = super().update_stored_schema(only_tables, expected_update) # Retrieve the schema from Weaviate - applied_update: TSchemaTables = {} try: schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash) except DestinationUndefinedEntity: @@ -447,6 +418,7 @@ def update_stored_schema( f"Schema with hash {self.schema.stored_version_hash} " "not found in the storage. upgrading" ) + # TODO: return a real updated table schema (like in SQL job client) self._execute_schema_update(only_tables) else: logger.info( @@ -670,12 +642,12 @@ def _make_property_schema( return { "name": column_name, - "dataType": [self.type_mapper.to_db_type(column)], + "dataType": [self.type_mapper.to_destination_type(column, None)], **extra_kv, } def create_load_job( - self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: return LoadWeaviateJob( file_path, @@ -727,4 +699,4 @@ def _update_schema_in_storage(self, schema: Schema) -> None: def _from_db_type( self, wt_t: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: - return self.type_mapper.from_db_type(wt_t, precision, scale) + return self.type_mapper.from_destination_type(wt_t, precision, scale) diff --git a/dlt/destinations/insert_job_client.py b/dlt/destinations/insert_job_client.py index 6ccc65705b..aa608ca2ad 100644 --- a/dlt/destinations/insert_job_client.py +++ b/dlt/destinations/insert_job_client.py @@ -1,15 +1,15 @@ -import os -import abc from typing import Any, Iterator, List -from dlt.common.destination.reference import RunnableLoadJob, HasFollowupJobs, LoadJob -from dlt.common.schema.typing import TTableSchema +from dlt.common.destination.reference import ( + PreparedTableSchema, + RunnableLoadJob, + HasFollowupJobs, + LoadJob, +) from dlt.common.storages import FileStorage from dlt.common.utils import chunks -from dlt.destinations.sql_client import SqlClientBase -from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs -from dlt.destinations.job_client_impl import SqlJobClientWithStaging, SqlJobClientBase +from dlt.destinations.job_client_impl import SqlJobClientWithStagingDataset, SqlJobClientBase class InsertValuesLoadJob(RunnableLoadJob, HasFollowupJobs): @@ -96,9 +96,9 @@ def _insert(self, qualified_table_name: str, file_path: str) -> Iterator[List[st yield insert_sql -class InsertValuesJobClient(SqlJobClientWithStaging): +class InsertValuesJobClient(SqlJobClientWithStagingDataset): def create_load_job( - self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: job = super().create_load_job(table, file_path, load_id, restore) if not job: diff --git a/dlt/destinations/job_client_impl.py b/dlt/destinations/job_client_impl.py index 3026baf753..c395b41a1f 100644 --- a/dlt/destinations/job_client_impl.py +++ b/dlt/destinations/job_client_impl.py @@ -24,19 +24,20 @@ COLUMN_HINTS, TColumnType, TColumnSchemaBase, - TTableSchema, TTableFormat, ) from dlt.common.schema.utils import ( get_inherited_table_hint, + has_default_column_prop_value, loads_table, normalize_table_identifiers, version_table, ) from dlt.common.storages import FileStorage -from dlt.common.storages.load_package import LoadJobInfo +from dlt.common.storages.load_package import LoadJobInfo, ParsedLoadJobFileName from dlt.common.schema import TColumnSchema, Schema, TTableSchemaColumns, TSchemaTables from dlt.common.destination.reference import ( + PreparedTableSchema, StateInfo, StorageSchemaInfo, WithStateSync, @@ -61,7 +62,7 @@ from dlt.destinations.utils import ( get_pipeline_state_query_columns, info_schema_null_to_bool, - verify_sql_job_client_schema, + verify_schema_merge_disposition, ) # this should suffice for now @@ -208,24 +209,25 @@ def maybe_ddl_transaction(self) -> Iterator[None]: else: yield - def should_truncate_table_before_load(self, table: TTableSchema) -> bool: + def should_truncate_table_before_load(self, table_name: str) -> bool: + table = self.prepare_load_table(table_name) return ( table["write_disposition"] == "replace" and self.config.replace_strategy == "truncate-and-insert" ) def _create_append_followup_jobs( - self, table_chain: Sequence[TTableSchema] + self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: return [] def _create_merge_followup_jobs( - self, table_chain: Sequence[TTableSchema] + self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: return [SqlMergeFollowupJob.from_table_chain(table_chain, self.sql_client)] def _create_replace_followup_jobs( - self, table_chain: Sequence[TTableSchema] + self, table_chain: Sequence[PreparedTableSchema] ) -> List[FollowupJobRequest]: jobs: List[FollowupJobRequest] = [] if self.config.replace_strategy in ["insert-from-staging", "staging-optimized"]: @@ -238,7 +240,7 @@ def _create_replace_followup_jobs( def create_table_chain_completed_followup_jobs( self, - table_chain: Sequence[TTableSchema], + table_chain: Sequence[PreparedTableSchema], completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None, ) -> List[FollowupJobRequest]: """Creates a list of followup jobs for merge write disposition and staging replace strategies""" @@ -255,7 +257,7 @@ def create_table_chain_completed_followup_jobs( return jobs def create_load_job( - self, table: TTableSchema, file_path: str, load_id: str, restore: bool = False + self, table: PreparedTableSchema, file_path: str, load_id: str, restore: bool = False ) -> LoadJob: """Starts SqlLoadJob for files ending with .sql or returns None to let derived classes to handle their specific jobs""" if SqlLoadJob.is_sql_job(file_path): @@ -499,11 +501,11 @@ def _build_schema_update_sql( ): # this will skip incomplete columns new_columns = self._create_table_update(table_name, storage_columns) + generate_alter = len(storage_columns) > 0 if len(new_columns) > 0: # build and add sql to execute - sql_statements = self._get_table_update_sql( - table_name, new_columns, len(storage_columns) > 0 - ) + self._check_table_update_hints(table_name, new_columns, generate_alter) + sql_statements = self._get_table_update_sql(table_name, new_columns, generate_alter) for sql in sql_statements: if not sql.endswith(";"): sql += ";" @@ -517,12 +519,12 @@ def _build_schema_update_sql( return sql_updates, schema_update def _make_add_column_sql( - self, new_columns: Sequence[TColumnSchema], table: TTableSchema = None + self, new_columns: Sequence[TColumnSchema], table: PreparedTableSchema = None ) -> List[str]: """Make one or more ADD COLUMN sql clauses to be joined in ALTER TABLE statement(s)""" return [f"ADD COLUMN {self._get_column_def_sql(c, table)}" for c in new_columns] - def _make_create_table(self, qualified_name: str, table: TTableSchema) -> str: + def _make_create_table(self, qualified_name: str, table: PreparedTableSchema) -> str: not_exists_clause = " " if ( table["name"] in self.schema.dlt_table_names() @@ -555,38 +557,41 @@ def _get_table_update_sql( sql_result.extend( [sql_base + col_statement for col_statement in add_column_statements] ) + return sql_result + def _check_table_update_hints( + self, table_name: str, new_columns: Sequence[TColumnSchema], generate_alter: bool + ) -> None: # scan columns to get hints if generate_alter: # no hints may be specified on added columns for hint in COLUMN_HINTS: - if any(c.get(hint, False) is True for c in new_columns): + if any(not has_default_column_prop_value(hint, c.get(hint)) for c in new_columns): hint_columns = [ self.sql_client.escape_column_name(c["name"]) for c in new_columns if c.get(hint, False) ] - if hint == "not_null": + if hint == "null": logger.warning( f"Column(s) {hint_columns} with NOT NULL are being added to existing" - f" table {qualified_name}. If there's data in the table the operation" + f" table {table_name}. If there's data in the table the operation" " will fail." ) else: logger.warning( f"Column(s) {hint_columns} with hint {hint} are being added to existing" - f" table {qualified_name}. Several hint types may not be added to" + f" table {table_name}. Several hint types may not be added to" " existing tables." ) - return sql_result @abstractmethod - def _get_column_def_sql(self, c: TColumnSchema, table: TTableSchema = None) -> str: + def _get_column_def_sql(self, c: TColumnSchema, table: PreparedTableSchema = None) -> str: pass @staticmethod - def _gen_not_null(v: bool) -> str: - return "NOT NULL" if not v else "" + def _gen_not_null(nullable: bool) -> str: + return "NOT NULL" if not nullable else "" def _create_table_update( self, table_name: str, storage_columns: TTableSchemaColumns @@ -657,17 +662,22 @@ def _commit_schema_update(self, schema: Schema, schema_str: str) -> None: schema_str, ) - def _verify_schema(self) -> None: - super()._verify_schema() - if exceptions := verify_sql_job_client_schema(self.schema, warnings=True): + def verify_schema( + self, only_tables: Iterable[str] = None, new_jobs: Iterable[ParsedLoadJobFileName] = None + ) -> List[PreparedTableSchema]: + loaded_tables = super().verify_schema(only_tables, new_jobs) + if exceptions := verify_schema_merge_disposition( + self.schema, loaded_tables, self.capabilities, warnings=True + ): for exception in exceptions: logger.error(str(exception)) raise exceptions[0] + return loaded_tables def prepare_load_job_execution(self, job: RunnableLoadJob) -> None: self._set_query_tags_for_job(load_id=job._load_id, table=job._load_table) - def _set_query_tags_for_job(self, load_id: str, table: TTableSchema) -> None: + def _set_query_tags_for_job(self, load_id: str, table: PreparedTableSchema) -> None: """Sets query tags in sql_client for a job in package `load_id`, starting for a particular `table`""" from dlt.common.pipeline import current_pipeline @@ -678,7 +688,7 @@ def _set_query_tags_for_job(self, load_id: str, table: TTableSchema) -> None: "source": self.schema.name, "resource": ( get_inherited_table_hint( - self.schema._schema_tables, table["name"], "resource", allow_none=True + self.schema.tables, table["name"], "resource", allow_none=True ) or "" ), @@ -689,19 +699,20 @@ def _set_query_tags_for_job(self, load_id: str, table: TTableSchema) -> None: ) -class SqlJobClientWithStaging(SqlJobClientBase, WithStagingDataset): - in_staging_mode: bool = False +class SqlJobClientWithStagingDataset(SqlJobClientBase, WithStagingDataset): + in_staging_dataset_mode: bool = False @contextlib.contextmanager def with_staging_dataset(self) -> Iterator["SqlJobClientBase"]: try: with self.sql_client.with_staging_dataset(): - self.in_staging_mode = True + self.in_staging_dataset_mode = True yield self finally: - self.in_staging_mode = False + self.in_staging_dataset_mode = False - def should_load_data_to_staging_dataset(self, table: TTableSchema) -> bool: + def should_load_data_to_staging_dataset(self, table_name: str) -> bool: + table = self.prepare_load_table(table_name) if table["write_disposition"] == "merge": return True elif table["write_disposition"] == "replace" and ( diff --git a/dlt/destinations/job_impl.py b/dlt/destinations/job_impl.py index 1f54913064..3f261bafed 100644 --- a/dlt/destinations/job_impl.py +++ b/dlt/destinations/job_impl.py @@ -8,13 +8,10 @@ HasFollowupJobs, TLoadJobState, RunnableLoadJob, - JobClientBase, FollowupJobRequest, LoadJob, ) -from dlt.common.metrics import LoadJobMetrics from dlt.common.storages.load_package import commit_load_package_state -from dlt.common.schema import Schema, TTableSchema from dlt.common.storages import FileStorage from dlt.common.typing import TDataItems from dlt.common.storages.load_storage import ParsedLoadJobFileName diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index d5f005ee9a..a555fe8a1f 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -2,9 +2,10 @@ import yaml from dlt.common.time import ensure_pendulum_datetime +from dlt.common.destination.reference import PreparedTableSchema +from dlt.common.destination.utils import resolve_merge_strategy from dlt.common.schema.typing import ( - TTableSchema, TSortOrder, TColumnProp, ) @@ -14,7 +15,7 @@ get_dedup_sort_tuple, get_validity_column_names, get_active_record_timestamp, - DEFAULT_MERGE_STRATEGY, + is_nested_table, ) from dlt.common.storages.load_storage import ParsedLoadJobFileName from dlt.common.storages.load_package import load_package as current_load_package @@ -35,7 +36,9 @@ class SqlJobParams(TypedDict, total=False): class SqlJobCreationException(DestinationTransientException): - def __init__(self, original_exception: Exception, table_chain: Sequence[TTableSchema]) -> None: + def __init__( + self, original_exception: Exception, table_chain: Sequence[PreparedTableSchema] + ) -> None: tables_str = yaml.dump( table_chain, allow_unicode=True, default_flow_style=False, sort_keys=False ) @@ -51,18 +54,18 @@ class SqlFollowupJob(FollowupJobRequestImpl): @classmethod def from_table_chain( cls, - table_chain: Sequence[TTableSchema], + table_chain: Sequence[PreparedTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None, ) -> FollowupJobRequestImpl: """Generates a list of sql statements, that will be executed by the sql client when the job is executed in the loader. - The `table_chain` contains a list schemas of a tables with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). + The `table_chain` contains a list of schemas of nested tables, ordered by the ancestry (the root of the tree is first on the list). """ params = cast(SqlJobParams, {**DEFAULTS, **(params or {})}) - top_table = table_chain[0] + root_table = table_chain[0] file_info = ParsedLoadJobFileName( - top_table["name"], ParsedLoadJobFileName.new_file_id(), 0, "sql" + root_table["name"], ParsedLoadJobFileName.new_file_id(), 0, "sql" ) try: @@ -83,7 +86,7 @@ def from_table_chain( @classmethod def generate_sql( cls, - table_chain: Sequence[TTableSchema], + table_chain: Sequence[PreparedTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None, ) -> List[str]: @@ -96,7 +99,7 @@ class SqlStagingCopyFollowupJob(SqlFollowupJob): @classmethod def _generate_clone_sql( cls, - table_chain: Sequence[TTableSchema], + table_chain: Sequence[PreparedTableSchema], sql_client: SqlClientBase[Any], ) -> List[str]: """Drop and clone the table for supported destinations""" @@ -113,7 +116,7 @@ def _generate_clone_sql( @classmethod def _generate_insert_sql( cls, - table_chain: Sequence[TTableSchema], + table_chain: Sequence[PreparedTableSchema], sql_client: SqlClientBase[Any], params: SqlJobParams = None, ) -> List[str]: @@ -138,7 +141,7 @@ def _generate_insert_sql( @classmethod def generate_sql( cls, - table_chain: Sequence[TTableSchema], + table_chain: Sequence[PreparedTableSchema], sql_client: SqlClientBase[Any], params: SqlJobParams = None, ) -> List[str]: @@ -154,13 +157,17 @@ class SqlMergeFollowupJob(SqlFollowupJob): """ @classmethod - def generate_sql( # type: ignore[return] + def generate_sql( cls, - table_chain: Sequence[TTableSchema], + table_chain: Sequence[PreparedTableSchema], sql_client: SqlClientBase[Any], params: Optional[SqlJobParams] = None, ) -> List[str]: - merge_strategy = table_chain[0].get("x-merge-strategy", DEFAULT_MERGE_STRATEGY) + # resolve only root table + root_table = table_chain[0] + merge_strategy = resolve_merge_strategy( + {root_table["name"]: root_table}, root_table, sql_client.capabilities + ) if merge_strategy == "delete-insert": return cls.gen_merge_sql(table_chain, sql_client) elif merge_strategy == "upsert": @@ -332,9 +339,18 @@ def gen_delete_from_sql( ); """ + @classmethod + def _shorten_table_name(cls, ident: str, sql_client: SqlClientBase[Any]) -> str: + """Trims identifier to max length supported by sql_client. Used for dynamically constructed table names""" + from dlt.common.normalizers.naming import NamingConvention + + return NamingConvention.shorten_identifier( + ident, ident, sql_client.capabilities.max_identifier_length + ) + @classmethod def _new_temp_table_name(cls, name_prefix: str, sql_client: SqlClientBase[Any]) -> str: - return f"{name_prefix}_{uniq_id()}" + return cls._shorten_table_name(f"{name_prefix}_{uniq_id()}", sql_client) @classmethod def _to_temp_table(cls, select_sql: str, temp_table_name: str) -> str: @@ -368,7 +384,7 @@ def _escape_list(cls, list_: List[str], escape_id: Callable[[str], str]) -> List @classmethod def _get_hard_delete_col_and_cond( cls, - table: TTableSchema, + table: PreparedTableSchema, escape_id: Callable[[str], str], escape_lit: Callable[[Any], Any], invert: bool = False, @@ -396,9 +412,9 @@ def _get_hard_delete_col_and_cond( @classmethod def _get_unique_col( cls, - table_chain: Sequence[TTableSchema], + table_chain: Sequence[PreparedTableSchema], sql_client: SqlClientBase[Any], - table: TTableSchema, + table: PreparedTableSchema, ) -> str: """Returns name of first column in `table` with `unique` property. @@ -418,9 +434,9 @@ def _get_unique_col( @classmethod def _get_root_key_col( cls, - table_chain: Sequence[TTableSchema], + table_chain: Sequence[PreparedTableSchema], sql_client: SqlClientBase[Any], - table: TTableSchema, + table: PreparedTableSchema, ) -> str: """Returns name of first column in `table` with `root_key` property. @@ -439,7 +455,7 @@ def _get_root_key_col( @classmethod def _get_prop_col_or_raise( - cls, table: TTableSchema, prop: Union[TColumnProp, str], exception: Exception + cls, table: PreparedTableSchema, prop: Union[TColumnProp, str], exception: Exception ) -> str: """Returns name of first column in `table` with `prop` property. @@ -452,15 +468,15 @@ def _get_prop_col_or_raise( @classmethod def gen_merge_sql( - cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any] + cls, table_chain: Sequence[PreparedTableSchema], sql_client: SqlClientBase[Any] ) -> List[str]: """Generates a list of sql statements that merge the data in staging dataset with the data in destination dataset. - The `table_chain` contains a list schemas of a tables with parent-child relationship, ordered by the ancestry (the root of the tree is first on the list). + The `table_chain` contains a list schemas of a tables with row_key - parent_key nested reference, ordered by the ancestry (the root of the tree is first on the list). The root table is merged using primary_key and merge_key hints which can be compound and be both specified. In that case the OR clause is generated. - The child tables are merged based on propagated `root_key` which is a type of foreign key but always leading to a root table. + The nested tables are merged based on propagated `root_key` which is a type of foreign key but always leading to a root table. - First we store the root_keys of root table elements to be deleted in the temp table. Then we use the temp table to delete records from root and all child tables in the destination dataset. + First we store the root_keys of root table elements to be deleted in the temp table. Then we use the temp table to delete records from root and all netsed tables in the destination dataset. At the end we copy the data from the staging dataset into destination dataset. If a hard_delete column is specified, records flagged as deleted will be excluded from the copy into the destination dataset. @@ -503,7 +519,7 @@ def gen_merge_sql( key_table_clauses = cls.gen_key_table_clauses( root_table_name, staging_root_table_name, key_clauses, for_delete=True ) - # if no child tables, just delete data from top table + # if no nested tables, just delete data from root table for clause in key_table_clauses: sql.append(f"DELETE {clause};") else: @@ -511,6 +527,7 @@ def gen_merge_sql( root_table_name, staging_root_table_name, key_clauses, for_delete=False ) # use unique hint to create temp table with all identifiers to delete + # TODO: use row_key hint, not unique when implemented to correctly handle to nested tables unique_column = escape_column_id( cls._get_unique_col(table_chain, sql_client, root_table) ) @@ -521,7 +538,7 @@ def gen_merge_sql( ) sql.extend(create_delete_temp_table_sql) - # delete from child tables first. This is important for databricks which does not support temporary tables, + # delete from nested tables first. This is important for databricks which does not support temporary tables, # but uses temporary views instead for table in table_chain[1:]: table_name = sql_client.make_qualified_table_name(table["name"]) @@ -534,7 +551,7 @@ def gen_merge_sql( ) ) - # delete from top table now that child tables have been processed + # delete from root table now that nested tables have been processed sql.append( cls.gen_delete_from_sql( root_table_name, unique_column, delete_temp_table_name, unique_column @@ -579,17 +596,17 @@ def gen_merge_sql( insert_cond = not_deleted_cond if hard_delete_col is not None else "1 = 1" if (len(primary_keys) > 0 and len(table_chain) > 1) or ( len(primary_keys) == 0 - and table.get("parent") is not None # child table + and is_nested_table(table) # nested table and hard_delete_col is not None ): - uniq_column = unique_column if table.get("parent") is None else root_key_column + uniq_column = root_key_column if is_nested_table(table) else unique_column insert_cond = f"{uniq_column} IN (SELECT * FROM {insert_temp_table_name})" columns = list(map(escape_column_id, get_columns_names_with_prop(table, "name"))) col_str = ", ".join(columns) select_sql = f"SELECT {col_str} FROM {staging_table_name} WHERE {insert_cond}" if len(primary_keys) > 0 and len(table_chain) == 1: - # without child tables we deduplicate inside the query instead of using a temp table + # without nested tables we deduplicate inside the query instead of using a temp table select_sql = cls.gen_select_from_dedup_sql( staging_table_name, primary_keys, columns, dedup_sort, insert_cond ) @@ -599,7 +616,7 @@ def gen_merge_sql( @classmethod def gen_upsert_sql( - cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any] + cls, table_chain: Sequence[PreparedTableSchema], sql_client: SqlClientBase[Any] ) -> List[str]: sql: List[str] = [] root_table = table_chain[0] @@ -641,13 +658,13 @@ def gen_upsert_sql( THEN INSERT ({col_str.format(alias="")}) VALUES ({col_str.format(alias="s.")}); """) - # generate statements for child tables if they exist - child_tables = table_chain[1:] - if child_tables: + # generate statements for nested tables if they exist + nested_tables = table_chain[1:] + if nested_tables: root_unique_column = escape_column_id( cls._get_unique_col(table_chain, sql_client, root_table) ) - for table in child_tables: + for table in nested_tables: unique_column = escape_column_id( cls._get_unique_col(table_chain, sql_client, table) ) @@ -690,14 +707,14 @@ def gen_upsert_sql( @classmethod def gen_scd2_sql( - cls, table_chain: Sequence[TTableSchema], sql_client: SqlClientBase[Any] + cls, table_chain: Sequence[PreparedTableSchema], sql_client: SqlClientBase[Any] ) -> List[str]: """Generates SQL statements for the `scd2` merge strategy. The root table can be inserted into and updated. Updates only take place when a record retires (because there is a new version or it is deleted) and only affect the "valid to" column. - Child tables are insert-only. + Nested tables are insert-only. """ sql: List[str] = [] root_table = table_chain[0] @@ -760,14 +777,14 @@ def gen_scd2_sql( WHERE {hash_} NOT IN (SELECT {hash_} FROM {root_table_name} WHERE {is_active_clause}); """) - # insert list elements for new active records in child tables - child_tables = table_chain[1:] - if child_tables: + # insert list elements for new active records in nested tables + nested_tables = table_chain[1:] + if nested_tables: # TODO: - based on deterministic child hashes (OK) # - if row hash changes all is right # - if it does not we only capture new records, while we should replace existing with those in stage # - this write disposition is way more similar to regular merge (how root tables are handled is different, other tables handled same) - for table in child_tables: + for table in nested_tables: unique_column = escape_column_id( cls._get_unique_col(table_chain, sql_client, table) ) diff --git a/dlt/destinations/type_mapping.py b/dlt/destinations/type_mapping.py index 5ac43e4f1f..d615675fa6 100644 --- a/dlt/destinations/type_mapping.py +++ b/dlt/destinations/type_mapping.py @@ -1,20 +1,18 @@ -from typing import Tuple, ClassVar, Dict, Optional +from typing import Tuple, Dict, Optional from dlt.common import logger +from dlt.common.destination.reference import PreparedTableSchema from dlt.common.schema.typing import ( TColumnSchema, TDataType, TColumnType, - TTableFormat, - TTableSchema, ) -from dlt.common.destination.capabilities import DestinationCapabilitiesContext +from dlt.common.destination.capabilities import DataTypeMapper +from dlt.common.typing import TLoaderFileFormat from dlt.common.utils import without_none -class TypeMapper: - capabilities: DestinationCapabilitiesContext - +class TypeMapperImpl(DataTypeMapper): sct_to_unbound_dbt: Dict[TDataType, str] """Data types without precision or scale specified (e.g. `"text": "varchar"` in postgres)""" sct_to_dbt: Dict[TDataType, str] @@ -24,17 +22,22 @@ class TypeMapper: dbt_to_sct: Dict[str, TDataType] - def __init__(self, capabilities: DestinationCapabilitiesContext) -> None: - self.capabilities = capabilities + def ensure_supported_type( + self, + column: TColumnSchema, + table: PreparedTableSchema, + loader_file_format: TLoaderFileFormat, + ) -> None: + pass - def to_db_integer_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + def to_db_integer_type(self, column: TColumnSchema, table: PreparedTableSchema = None) -> str: # Override in subclass if db supports other integer types (e.g. smallint, integer, tinyint, etc.) return self.sct_to_unbound_dbt["bigint"] def to_db_datetime_type( self, column: TColumnSchema, - table: TTableSchema = None, + table: PreparedTableSchema = None, ) -> str: # Override in subclass if db supports other timestamp types (e.g. with different time resolutions) timezone = column.get("timezone") @@ -54,7 +57,7 @@ def to_db_datetime_type( return None - def to_db_time_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + def to_db_time_type(self, column: TColumnSchema, table: PreparedTableSchema = None) -> str: # Override in subclass if db supports other time types (e.g. with different time resolutions) return None @@ -64,8 +67,8 @@ def to_db_decimal_type(self, column: TColumnSchema) -> str: return self.sct_to_unbound_dbt["decimal"] return self.sct_to_dbt["decimal"] % (precision_tup[0], precision_tup[1]) - # TODO: refactor lancedb and wevavite to make table object required - def to_db_type(self, column: TColumnSchema, table: TTableSchema = None) -> str: + # TODO: refactor lancedb and weaviate to make table object required + def to_destination_type(self, column: TColumnSchema, table: PreparedTableSchema) -> str: sc_t = column["data_type"] if sc_t == "bigint": db_t = self.to_db_integer_type(column, table) @@ -131,7 +134,7 @@ def wei_precision( scale if scale is not None else default_scale, ) - def from_db_type( + def from_destination_type( self, db_type: str, precision: Optional[int], scale: Optional[int] ) -> TColumnType: return without_none( diff --git a/dlt/destinations/utils.py b/dlt/destinations/utils.py index fcc2c4fd16..cd3ee6a54d 100644 --- a/dlt/destinations/utils.py +++ b/dlt/destinations/utils.py @@ -1,16 +1,18 @@ import re -import inspect -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Sequence, Tuple from dlt.common import logger +from dlt.common.destination.capabilities import DestinationCapabilitiesContext +from dlt.common.destination.utils import resolve_merge_strategy from dlt.common.schema import Schema from dlt.common.schema.exceptions import SchemaCorruptedException -from dlt.common.schema.typing import MERGE_STRATEGIES, TTableSchema +from dlt.common.schema.typing import MERGE_STRATEGIES, TColumnType, TTableSchema from dlt.common.schema.utils import ( get_columns_names_with_prop, get_first_column_name_with_prop, has_column_with_prop, + is_nested_table, pipeline_state_table, ) from typing import Any, cast, Tuple, Dict, Type @@ -81,13 +83,22 @@ def get_pipeline_state_query_columns() -> TTableSchema: return state_table -def verify_sql_job_client_schema(schema: Schema, warnings: bool = True) -> List[Exception]: +def verify_schema_merge_disposition( + schema: Schema, + load_tables: Sequence[TTableSchema], + capabilities: DestinationCapabilitiesContext, + warnings: bool = True, +) -> List[Exception]: log = logger.warning if warnings else logger.info # collect all exceptions to show all problems in the schema exception_log: List[Exception] = [] # verifies schema settings specific to sql job client - for table in schema.data_tables(): + for table in load_tables: + # from now on validate only top level tables + if is_nested_table(table): + continue + table_name = table["name"] if table.get("write_disposition") == "merge": if "x-merge-strategy" in table and table["x-merge-strategy"] not in MERGE_STRATEGIES: # type: ignore[typeddict-item] @@ -98,7 +109,9 @@ def verify_sql_job_client_schema(schema: Schema, warnings: bool = True) -> List[ f"""Allowed values: {', '.join(['"' + s + '"' for s in MERGE_STRATEGIES])}.""", ) ) - if table.get("x-merge-strategy") == "delete-insert": + + merge_strategy = resolve_merge_strategy(schema.tables, table, capabilities) + if merge_strategy == "delete-insert": if not has_column_with_prop(table, "primary_key") and not has_column_with_prop( table, "merge_key" ): @@ -108,7 +121,7 @@ def verify_sql_job_client_schema(schema: Schema, warnings: bool = True) -> List[ " merge keys defined." " dlt will fall back to `append` for this table." ) - elif table.get("x-merge-strategy") == "upsert": + elif merge_strategy == "upsert": if not has_column_with_prop(table, "primary_key"): exception_log.append( SchemaCorruptedException( diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index 1eccd86aad..5df165adb7 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -160,7 +160,7 @@ def source( max_table_nesting (int, optional): A schema hint that sets the maximum depth of nested table above which the remaining nodes are loaded as structs or JSON. - root_key (bool): Enables merging on all resources by propagating root foreign key to child tables. This option is most useful if you plan to change write disposition of a resource to disable/enable merge. Defaults to False. + root_key (bool): Enables merging on all resources by propagating row key from root to all nested tables. This option is most useful if you plan to change write disposition of a resource to disable/enable merge. Defaults to False. schema (Schema, optional): An explicit `Schema` instance to be associated with the source. If not present, `dlt` creates a new `Schema` object with provided `name`. If such `Schema` already exists in the same folder as the module containing the decorated function, such schema will be loaded from file. diff --git a/dlt/extract/extract.py b/dlt/extract/extract.py index 485a01eb99..1c42dba329 100644 --- a/dlt/extract/extract.py +++ b/dlt/extract/extract.py @@ -25,6 +25,7 @@ TAnySchemaColumns, TColumnNames, TSchemaContract, + TTableFormat, TWriteDispositionConfig, ) from dlt.common.storages import NormalizeStorageConfiguration, LoadPackageInfo, SchemaStorage @@ -50,12 +51,14 @@ def data_to_sources( data: Any, pipeline: SupportsPipeline, + *, schema: Schema = None, table_name: str = None, parent_table_name: str = None, write_disposition: TWriteDispositionConfig = None, columns: TAnySchemaColumns = None, primary_key: TColumnNames = None, + table_format: TTableFormat = None, schema_contract: TSchemaContract = None, ) -> List[DltSource]: """Creates a list of sources for data items present in `data` and applies specified hints to all resources. @@ -65,12 +68,13 @@ def data_to_sources( def apply_hint_args(resource: DltResource) -> None: resource.apply_hints( - table_name, - parent_table_name, - write_disposition, - columns, - primary_key, + table_name=table_name, + parent_table_name=parent_table_name, + write_disposition=write_disposition, + columns=columns, + primary_key=primary_key, schema_contract=schema_contract, + table_format=table_format, ) def apply_settings(source_: DltSource) -> None: @@ -269,8 +273,8 @@ def _write_empty_files( if resource.name not in tables_by_resources: continue for table in tables_by_resources[resource.name]: - # we only need to write empty files for the top tables - if not table.get("parent", None): + # we only need to write empty files for the root tables + if not utils.is_nested_table(table): json_extractor.write_empty_items_file(table["name"]) # collect resources that received empty materialized lists and had no items @@ -287,8 +291,8 @@ def _write_empty_files( if tables := tables_by_resources.get("resource_name"): # write empty tables for table in tables: - # we only need to write empty files for the top tables - if not table.get("parent", None): + # we only need to write empty files for the root tables + if not utils.is_nested_table(table): json_extractor.write_empty_items_file(table["name"]) else: table_name = json_extractor._get_static_table_name(resource, None) diff --git a/dlt/extract/extractors.py b/dlt/extract/extractors.py index 8a91dd7477..cbee8ed286 100644 --- a/dlt/extract/extractors.py +++ b/dlt/extract/extractors.py @@ -243,7 +243,7 @@ def _compute_and_update_table( # this is a new table so allow evolve once if schema_contract["columns"] != "evolve" and self.schema.is_new_table(table_name): computed_table["x-normalizer"] = {"evolve-columns-once": True} - existing_table = self.schema._schema_tables.get(table_name, None) + existing_table = self.schema.tables.get(table_name, None) if existing_table: # TODO: revise this. computed table should overwrite certain hints (ie. primary and merge keys) completely diff_table = utils.diff_table(self.schema.name, existing_table, computed_table) @@ -257,7 +257,10 @@ def _compute_and_update_table( # merge with schema table if diff_table: - self.schema.update_table(diff_table) + # diff table identifiers already normalized + self.schema.update_table( + diff_table, normalize_identifiers=False, from_diff=bool(existing_table) + ) # process filters if filters: @@ -410,7 +413,7 @@ def _compute_table( arrow_table["columns"] = pyarrow.py_arrow_to_table_schema_columns(item.schema) # Add load_id column if needed - dlt_load_id_col = self.naming.normalize_table_identifier("_dlt_load_id") + dlt_load_id_col = self.naming.normalize_identifier("_dlt_load_id") if ( self._normalize_config.add_dlt_load_id and dlt_load_id_col not in arrow_table["columns"] diff --git a/dlt/extract/hints.py b/dlt/extract/hints.py index c828064288..dd460100aa 100644 --- a/dlt/extract/hints.py +++ b/dlt/extract/hints.py @@ -19,7 +19,6 @@ ) from dlt.common.schema.utils import ( DEFAULT_WRITE_DISPOSITION, - DEFAULT_MERGE_STRATEGY, merge_column, merge_columns, new_column, @@ -166,12 +165,16 @@ def columns(self) -> TTableHintTemplate[TTableSchemaColumns]: @property def schema_contract(self) -> TTableHintTemplate[TSchemaContract]: - return self._hints.get("schema_contract") + return None if self._hints is None else self._hints.get("schema_contract") @property def table_format(self) -> TTableHintTemplate[TTableFormat]: return None if self._hints is None else self._hints.get("table_format") + @property + def parent_table_name(self) -> TTableHintTemplate[str]: + return None if self._hints is None else self._hints.get("parent") + def compute_table_schema(self, item: TDataItem = None, meta: Any = None) -> TTableSchema: """Computes the table schema based on hints and column definitions passed during resource creation. `item` parameter is used to resolve table hints based on data. @@ -424,7 +427,7 @@ def _merge_key(hint: TColumnProp, keys: TColumnNames, partial: TPartialTableSche partial["columns"][key][hint] = True @staticmethod - def _merge_keys(dict_: Dict[str, Any]) -> None: + def _merge_keys(dict_: TResourceHints) -> None: """Merges primary and merge keys into columns in place.""" if "primary_key" in dict_: @@ -436,67 +439,67 @@ def _merge_keys(dict_: Dict[str, Any]) -> None: def _merge_write_disposition_dict(dict_: Dict[str, Any]) -> None: """Merges write disposition dictionary into write disposition shorthand and x-hints in place.""" - if dict_["write_disposition"]["disposition"] == "merge": + write_disposition = dict_["write_disposition"]["disposition"] + if write_disposition == "merge": DltResourceHints._merge_merge_disposition_dict(dict_) # reduce merge disposition from dict to shorthand - dict_["write_disposition"] = dict_["write_disposition"]["disposition"] + dict_["write_disposition"] = write_disposition @staticmethod def _merge_merge_disposition_dict(dict_: Dict[str, Any]) -> None: """Merges merge disposition dict into x-hints in place.""" - mddict: TMergeDispositionDict = deepcopy(dict_["write_disposition"]) - if mddict is not None: - dict_["x-merge-strategy"] = mddict.get("strategy", DEFAULT_MERGE_STRATEGY) - if "boundary_timestamp" in mddict: - dict_["x-boundary-timestamp"] = mddict["boundary_timestamp"] - # add columns for `scd2` merge strategy - if dict_.get("x-merge-strategy") == "scd2": - if mddict.get("validity_column_names") is None: - from_, to = DEFAULT_VALIDITY_COLUMN_NAMES - else: - from_, to = mddict["validity_column_names"] - dict_["columns"][from_] = { - "name": from_, - "data_type": "timestamp", - "nullable": ( - True - ), # validity columns are empty when first loaded into staging table - "x-valid-from": True, - } - dict_["columns"][to] = { - "name": to, - "data_type": "timestamp", - "nullable": True, - "x-valid-to": True, - "x-active-record-timestamp": mddict.get("active_record_timestamp"), - } - # unique constraint is dropped for C_DLT_ID when used to store - # SCD2 row hash (only applies to root table) - hash_ = mddict.get("row_version_column_name", DataItemNormalizer.C_DLT_ID) - dict_["columns"][hash_] = { - "name": hash_, - "nullable": False, - "x-row-version": True, - # duplicate value in row hash column is possible in case - # of insert-delete-reinsert pattern - "unique": False, - } + md_dict: TMergeDispositionDict = dict_.pop("write_disposition") + if merge_strategy := md_dict.get("strategy"): + dict_["x-merge-strategy"] = merge_strategy + if "boundary_timestamp" in md_dict: + dict_["x-boundary-timestamp"] = md_dict["boundary_timestamp"] + # add columns for `scd2` merge strategy + if merge_strategy == "scd2": + if md_dict.get("validity_column_names") is None: + from_, to = DEFAULT_VALIDITY_COLUMN_NAMES + else: + from_, to = md_dict["validity_column_names"] + dict_["columns"][from_] = { + "name": from_, + "data_type": "timestamp", + "nullable": True, # validity columns are empty when first loaded into staging table + "x-valid-from": True, + } + dict_["columns"][to] = { + "name": to, + "data_type": "timestamp", + "nullable": True, + "x-valid-to": True, + "x-active-record-timestamp": md_dict.get("active_record_timestamp"), + } + # unique constraint is dropped for C_DLT_ID when used to store + # SCD2 row hash (only applies to root table) + hash_ = md_dict.get("row_version_column_name", DataItemNormalizer.C_DLT_ID) + dict_["columns"][hash_] = { + "name": hash_, + "nullable": False, + "x-row-version": True, + # duplicate value in row hash column is possible in case + # of insert-delete-reinsert pattern + "unique": False, + } @staticmethod def _create_table_schema(resource_hints: TResourceHints, resource_name: str) -> TTableSchema: - """Creates table schema from resource hints and resource name.""" - - dict_ = cast(Dict[str, Any], resource_hints) - DltResourceHints._merge_keys(dict_) - dict_["resource"] = resource_name - if "write_disposition" in dict_: - if isinstance(dict_["write_disposition"], str): - dict_["write_disposition"] = { - "disposition": dict_["write_disposition"] + """Creates table schema from resource hints and resource name. Resource hints are resolved + (do not contain callables) and will be modified in place + """ + DltResourceHints._merge_keys(resource_hints) + if "write_disposition" in resource_hints: + if isinstance(resource_hints["write_disposition"], str): + resource_hints["write_disposition"] = { + "disposition": resource_hints["write_disposition"] } # wrap in dict - DltResourceHints._merge_write_disposition_dict(dict_) - return cast(TTableSchema, dict_) + DltResourceHints._merge_write_disposition_dict(resource_hints) # type: ignore[arg-type] + dict_ = cast(TTableSchema, resource_hints) + dict_["resource"] = resource_name + return dict_ @staticmethod def validate_dynamic_hints(template: TResourceHints) -> None: diff --git a/dlt/load/load.py b/dlt/load/load.py index f084c9d3d9..3b231f8fa9 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -10,7 +10,7 @@ from dlt.common.configuration import with_config, known_sections from dlt.common.configuration.accessors import config from dlt.common.pipeline import LoadInfo, LoadMetrics, SupportsPipeline, WithStepInfo -from dlt.common.schema.utils import get_top_level_table +from dlt.common.schema.utils import get_root_table from dlt.common.storages.load_storage import ( LoadPackageInfo, ParsedLoadJobFileName, @@ -167,20 +167,30 @@ def submit_job( ) logger.info(f"Will load file {file_path} with table name {job_info.table_name}") - # check write disposition + # determine which dataset to use + if is_staging_destination_job: + use_staging_dataset = isinstance( + job_client, SupportsStagingDestination + ) and job_client.should_load_data_to_staging_dataset_on_staging_destination( + job_info.table_name + ) + else: + use_staging_dataset = isinstance( + job_client, WithStagingDataset + ) and job_client.should_load_data_to_staging_dataset(job_info.table_name) + + # prepare table to be loaded load_table = active_job_client.prepare_load_table(job_info.table_name) if load_table["write_disposition"] not in ["append", "replace", "merge"]: raise LoadClientUnsupportedWriteDisposition( job_info.table_name, load_table["write_disposition"], file_path ) - job = active_job_client.create_load_job( load_table, self.load_storage.normalized_packages.storage.make_full_path(file_path), load_id, restore=restore, ) - if job is None: raise DestinationTerminalException( f"Destination could not create a job for file {file_path}. Typically the file" @@ -204,21 +214,8 @@ def submit_job( # only start a thread if this job is runnable if isinstance(job, RunnableLoadJob): - # determine which dataset to use - if is_staging_destination_job: - use_staging_dataset = isinstance( - job_client, SupportsStagingDestination - ) and job_client.should_load_data_to_staging_dataset_on_staging_destination( - load_table - ) - else: - use_staging_dataset = isinstance( - job_client, WithStagingDataset - ) and job_client.should_load_data_to_staging_dataset(load_table) - # set job vars job.set_run_vars(load_id=load_id, schema=schema, load_table=load_table) - # submit to pool self.pool.submit(Load.w_run_job, *(id(self), job, is_staging_destination_job, use_staging_dataset, schema)) # type: ignore @@ -321,7 +318,7 @@ def create_followup_jobs( starting_job_file_name = starting_job.file_name() if state == "completed" and not self.is_staging_destination_job(starting_job_file_name): client = self.destination.client(schema, self.initial_client_config) - top_job_table = get_top_level_table( + root_job_table = get_root_table( schema.tables, starting_job.job_file_info().table_name ) # if all tables of chain completed, create follow up jobs @@ -329,9 +326,13 @@ def create_followup_jobs( load_id ) if table_chain := get_completed_table_chain( - schema, all_jobs_states, top_job_table, starting_job.job_file_info().job_id() + schema, all_jobs_states, root_job_table, starting_job.job_file_info().job_id() ): table_chain_names = [table["name"] for table in table_chain] + # all tables will be prepared for main dataset + prep_table_chain = [ + client.prepare_load_table(table_name) for table_name in table_chain_names + ] table_chain_jobs = [ # we mark all jobs as completed, as by the time the followup job runs the starting job will be in this # folder too @@ -345,12 +346,12 @@ def create_followup_jobs( ] try: if follow_up_jobs := client.create_table_chain_completed_followup_jobs( - table_chain, table_chain_jobs + prep_table_chain, table_chain_jobs ): jobs = jobs + follow_up_jobs except Exception as e: raise TableChainFollowupJobCreationFailedException( - root_table_name=table_chain[0]["name"] + root_table_name=prep_table_chain[0]["name"] ) from e try: diff --git a/dlt/load/utils.py b/dlt/load/utils.py index e3a2ebcd79..ae67502d13 100644 --- a/dlt/load/utils.py +++ b/dlt/load/utils.py @@ -5,8 +5,8 @@ from dlt.common.storages.load_package import LoadJobInfo, PackageStorage, TPackageJobState from dlt.common.schema.utils import ( fill_hints_from_parent_and_clone_table, - get_child_tables, - get_top_level_table, + get_nested_tables, + get_root_table, has_table_seen_data, ) from dlt.common.storages.load_storage import ParsedLoadJobFileName @@ -38,7 +38,7 @@ def get_completed_table_chain( # make sure all the jobs for the table chain is completed for table in map( lambda t: fill_hints_from_parent_and_clone_table(schema.tables, t), - get_child_tables(schema.tables, top_merged_table["name"]), + get_nested_tables(schema.tables, top_merged_table["name"]), ): table_jobs = PackageStorage.filter_jobs_for_table(all_jobs, table["name"]) # skip tables that never seen data @@ -67,8 +67,8 @@ def init_client( schema: Schema, new_jobs: Iterable[ParsedLoadJobFileName], expected_update: TSchemaTables, - truncate_filter: Callable[[TTableSchema], bool], - load_staging_filter: Callable[[TTableSchema], bool], + truncate_filter: Callable[[str], bool], + load_staging_filter: Callable[[str], bool], drop_tables: Optional[List[TTableSchema]] = None, truncate_tables: Optional[List[TTableSchema]] = None, ) -> TSchemaTables: @@ -81,8 +81,8 @@ def init_client( schema (Schema): The schema as in load package new_jobs (Iterable[LoadJobInfo]): List of new jobs expected_update (TSchemaTables): Schema update as in load package. Always present even if empty - truncate_filter (Callable[[TTableSchema], bool]): A filter that tells which table in destination dataset should be truncated - load_staging_filter (Callable[[TTableSchema], bool]): A filter which tell which table in the staging dataset may be loaded into + truncate_filter (Callable[[str], bool]): A filter that tells which table in destination dataset should be truncated + load_staging_filter (Callable[[str], bool]): A filter which tell which table in the staging dataset may be loaded into drop_tables (Optional[List[TTableSchema]]): List of tables to drop before initializing storage truncate_tables (Optional[List[TTableSchema]]): List of tables to truncate before initializing storage @@ -106,13 +106,14 @@ def init_client( schema, tables_with_jobs, tables_with_jobs, - lambda t: truncate_filter(t) or t["name"] in initial_truncate_names, + lambda table_name: truncate_filter(table_name) + or (table_name in initial_truncate_names), ) ) # get tables to drop drop_table_names = {table["name"] for table in drop_tables} if drop_tables else set() - + job_client.verify_schema(only_tables=tables_with_jobs | dlt_tables, new_jobs=new_jobs) applied_update = _init_dataset_and_update_schema( job_client, expected_update, @@ -175,7 +176,6 @@ def _init_dataset_and_update_schema( f"Client for {job_client.config.destination_type} will update schema to package schema" f" {staging_text}" ) - applied_update = job_client.update_stored_schema( only_tables=update_tables, expected_update=expected_update ) @@ -192,17 +192,17 @@ def _extend_tables_with_table_chain( schema: Schema, tables: Iterable[str], tables_with_jobs: Iterable[str], - include_table_filter: Callable[[TTableSchema], bool] = lambda t: True, + include_table_filter: Callable[[str], bool] = lambda t: True, ) -> Iterable[str]: """Extend 'tables` with all their children and filter out tables that do not have jobs (in `tables_with_jobs`), haven't seen data or are not included by `include_table_filter`. - Note that for top tables with replace and merge, the filter for tables that do not have jobs + Note that for root tables with replace and merge, the filter for tables that do not have jobs Returns an unordered set of table names and their child tables """ result: Set[str] = set() for table_name in tables: - top_job_table = get_top_level_table(schema.tables, table_name) + top_job_table = get_root_table(schema.tables, table_name) # for replace and merge write dispositions we should include tables # without jobs in the table chain, because child tables may need # processing due to changes in the root table @@ -212,14 +212,14 @@ def _extend_tables_with_table_chain( ) for table in map( lambda t: fill_hints_from_parent_and_clone_table(schema.tables, t), - get_child_tables(schema.tables, top_job_table["name"]), + get_nested_tables(schema.tables, top_job_table["name"]), ): chain_table_name = table["name"] table_has_job = chain_table_name in tables_with_jobs # table that never seen data are skipped as they will not be created # also filter out tables # NOTE: this will ie. eliminate all non iceberg tables on ATHENA destination from staging (only iceberg needs that) - if not has_table_seen_data(table) or not include_table_filter(table): + if not has_table_seen_data(table) or not include_table_filter(chain_table_name): continue # if there's no job for the table and we are in append then skip if not table_has_job and skip_jobless_table: diff --git a/dlt/normalize/items_normalizers.py b/dlt/normalize/items_normalizers.py index 650d10c268..fff615f0bf 100644 --- a/dlt/normalize/items_normalizers.py +++ b/dlt/normalize/items_normalizers.py @@ -316,7 +316,8 @@ def _fix_schema_precisions( new_cols: TTableSchemaColumns = {} for key, column in table["columns"].items(): if column.get("data_type") in ("timestamp", "time"): - if prec := column.get("precision"): + prec = column.get("precision") + if prec is not None: # apply the arrow schema precision to dlt column schema data_type = pyarrow.get_column_type_from_py_arrow(arrow_schema.field(key).type) if data_type["data_type"] in ("timestamp", "time"): diff --git a/dlt/normalize/normalize.py b/dlt/normalize/normalize.py index 3df060b141..32db5034b4 100644 --- a/dlt/normalize/normalize.py +++ b/dlt/normalize/normalize.py @@ -34,7 +34,7 @@ from dlt.normalize.configuration import NormalizeConfiguration from dlt.normalize.exceptions import NormalizeJobFailed from dlt.normalize.worker import w_normalize_files, group_worker_files, TWorkerRV -from dlt.normalize.schema import verify_normalized_schema +from dlt.normalize.validate import verify_normalized_table # normalize worker wrapping function signature @@ -185,6 +185,7 @@ def spool_files( # update normalizer specific info for table_name in table_metrics: table = schema.tables[table_name] + verify_normalized_table(schema, table, self.config.destination_capabilities) x_normalizer = table.setdefault("x-normalizer", {}) # drop evolve once for all tables that seen data x_normalizer.pop("evolve-columns-once", None) @@ -196,7 +197,6 @@ def spool_files( x_normalizer["seen-data"] = True # schema is updated, save it to schema volume if schema.is_modified: - verify_normalized_schema(schema) logger.info( f"Saving schema {schema.name} with version {schema.stored_version}:{schema.version}" ) @@ -297,12 +297,18 @@ def run(self, pool: Optional[Executor]) -> TRunMetrics: with self.collector(f"Normalize {schema.name} in {load_id}"): self.collector.update("Files", 0, len(schema_files)) self.collector.update("Items", 0) + # self.verify_package(load_id, schema, schema_files) self._step_info_start_load_id(load_id) self.spool_schema_files(load_id, schema, schema_files) # return info on still pending packages (if extractor saved something in the meantime) return TRunMetrics(False, len(self.normalize_storage.extracted_packages.list_packages())) + # def verify_package(self, load_id, schema: Schema, schema_files: Sequence[str]) -> None: + # """Verifies package schema and jobs against destination capabilities""" + # # get all tables in schema files + # table_names = set(ParsedLoadJobFileName.parse(job).table_name for job in schema_files) + def get_load_package_info(self, load_id: str) -> LoadPackageInfo: """Returns information on extracted/normalized/completed package with given load_id, all jobs and their statuses.""" try: diff --git a/dlt/normalize/schema.py b/dlt/normalize/schema.py deleted file mode 100644 index c01d184c92..0000000000 --- a/dlt/normalize/schema.py +++ /dev/null @@ -1,20 +0,0 @@ -from dlt.common.schema import Schema -from dlt.common.schema.utils import find_incomplete_columns -from dlt.common.schema.exceptions import UnboundColumnException -from dlt.common import logger - - -def verify_normalized_schema(schema: Schema) -> None: - """Verify the schema is valid for next stage after normalization. - - 1. Log warning if any incomplete nullable columns are in any data tables - 2. Raise `UnboundColumnException` on incomplete non-nullable columns (e.g. missing merge/primary key) - """ - for table_name, column, nullable in find_incomplete_columns( - schema.data_tables(seen_data_only=True) - ): - exc = UnboundColumnException(schema.name, table_name, column) - if nullable: - logger.warning(str(exc)) - else: - raise exc diff --git a/dlt/normalize/validate.py b/dlt/normalize/validate.py new file mode 100644 index 0000000000..d680b5bddd --- /dev/null +++ b/dlt/normalize/validate.py @@ -0,0 +1,43 @@ +from dlt.common.destination.capabilities import DestinationCapabilitiesContext +from dlt.common.schema import Schema +from dlt.common.schema.typing import TTableSchema +from dlt.common.schema.utils import find_incomplete_columns +from dlt.common.schema.exceptions import UnboundColumnException +from dlt.common import logger + + +def verify_normalized_table( + schema: Schema, table: TTableSchema, capabilities: DestinationCapabilitiesContext +) -> None: + """Verify `table` schema is valid for next stage after normalization. Only tables that have seen data are verified. + Verification happens before seen-data flag is set so new tables can be detected. + + 1. Log warning if any incomplete nullable columns are in any data tables + 2. Raise `UnboundColumnException` on incomplete non-nullable columns (e.g. missing merge/primary key) + 3. Log warning if table format is not supported by destination capabilities + """ + for column, nullable in find_incomplete_columns(table): + exc = UnboundColumnException(schema.name, table["name"], column) + if nullable: + logger.warning(str(exc)) + else: + raise exc + + # TODO: 3. raise if we detect name conflict for SCD2 columns + # until we track data per column we won't be able to implement this + # if resolve_merge_strategy(schema.tables, table, capabilities) == "scd2": + # for validity_column_name in get_validity_column_names(table): + # if validity_column_name in item.keys(): + # raise ColumnNameConflictException( + # schema_name, + # "Found column in data item with same name as validity column" + # f' "{validity_column_name}".', + # ) + + supported_table_formats = capabilities.supported_table_formats or [] + if "table_format" in table and table["table_format"] not in supported_table_formats: + logger.warning( + "Destination does not support the configured `table_format` value " + f"`{table['table_format']}` for table `{table['name']}`. " + "The setting will probably be ignored." + ) diff --git a/dlt/normalize/worker.py b/dlt/normalize/worker.py index b8969f64a3..53a856f7d0 100644 --- a/dlt/normalize/worker.py +++ b/dlt/normalize/worker.py @@ -73,7 +73,6 @@ def w_normalize_files( ) # TODO: capabilities.supported_*_formats can be None, it should have defaults supported_file_formats = destination_caps.supported_loader_file_formats or [] - supported_table_formats = destination_caps.supported_table_formats or [] # process all files with data items and write to buffered item storage with Container().injectable_context(destination_caps): @@ -90,21 +89,11 @@ def _get_items_normalizer( if table_name in item_normalizers: return item_normalizers[table_name] - if ( - "table_format" in table_schema - and table_schema["table_format"] not in supported_table_formats - ): - logger.warning( - "Destination does not support the configured `table_format` value " - f"`{table_schema['table_format']}` for table `{table_schema['name']}`. " - "The setting will probably be ignored." - ) - items_preferred_file_format = preferred_file_format items_supported_file_formats = supported_file_formats - if destination_caps.loader_file_format_adapter is not None: + if destination_caps.loader_file_format_selector is not None: items_preferred_file_format, items_supported_file_formats = ( - destination_caps.loader_file_format_adapter( + destination_caps.loader_file_format_selector( preferred_file_format, ( supported_file_formats.copy() @@ -233,9 +222,10 @@ def _gather_metrics_and_close( parsed_file_name.table_name ) root_tables.add(root_table_name) + root_table = stored_schema["tables"].get(root_table_name, {"name": root_table_name}) normalizer = _get_items_normalizer( parsed_file_name, - stored_schema["tables"].get(root_table_name, {"name": root_table_name}), + root_table, ) logger.debug( f"Processing extracted items in {extracted_items_file} in load_id" diff --git a/dlt/pipeline/__init__.py b/dlt/pipeline/__init__.py index 8041ca72e0..7af965e989 100644 --- a/dlt/pipeline/__init__.py +++ b/dlt/pipeline/__init__.py @@ -2,7 +2,12 @@ from typing_extensions import TypeVar from dlt.common.schema import Schema -from dlt.common.schema.typing import TColumnSchema, TWriteDispositionConfig, TSchemaContract +from dlt.common.schema.typing import ( + TColumnSchema, + TTableFormat, + TWriteDispositionConfig, + TSchemaContract, +) from dlt.common.typing import TSecretValue, Any from dlt.common.configuration import with_config @@ -219,7 +224,9 @@ def run( columns: Sequence[TColumnSchema] = None, schema: Schema = None, loader_file_format: TLoaderFileFormat = None, + table_format: TTableFormat = None, schema_contract: TSchemaContract = None, + refresh: Optional[TRefreshMode] = None, ) -> LoadInfo: """Loads the data in `data` argument into the destination specified in `destination` and dataset specified in `dataset_name`. @@ -263,6 +270,17 @@ def run( schema (Schema, optional): An explicit `Schema` object in which all table schemas will be grouped. By default `dlt` takes the schema from the source (if passed in `data` argument) or creates a default one itself. + loader_file_format (Literal["jsonl", "insert_values", "parquet"], optional). The file format the loader will use to create the load package. Not all file_formats are compatible with all destinations. Defaults to the preferred file format of the selected destination. + + table_format (Literal["delta", "iceberg"], optional). The table format used by the destination to store tables. Currently you can select table format on filesystem and Athena destinations. + + schema_contract (TSchemaContract, optional): On override for the schema contract settings, this will replace the schema contract settings for all tables in the schema. Defaults to None. + + refresh (str | TRefreshMode): Fully or partially reset sources before loading new data in this run. The following refresh modes are supported: + * `drop_sources`: Drop tables and source and resource state for all sources currently being processed in `run` or `extract` methods of the pipeline. (Note: schema history is erased) + * `drop_resources`: Drop tables and resource state for all resources being processed. Source level state is not modified. (Note: schema history is erased) + * `drop_data`: Wipe all data and resource state for all resources being processed. Schema is not modified. + Raises: PipelineStepFailed when a problem happened during `extract`, `normalize` or `load` steps. Returns: @@ -279,7 +297,9 @@ def run( columns=columns, schema=schema, loader_file_format=loader_file_format, + table_format=table_format, schema_contract=schema_contract, + refresh=refresh, ) diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index 4f29ca4c87..6ad443e3d8 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -42,6 +42,7 @@ from dlt.common.schema.typing import ( TColumnNames, TSchemaTables, + TTableFormat, TWriteDispositionConfig, TAnySchemaColumns, TSchemaContract, @@ -68,7 +69,7 @@ DestinationCapabilitiesContext, merge_caps_file_formats, TDestination, - ALL_SUPPORTED_FILE_FORMATS, + LOADER_FILE_FORMATS, TLoaderFileFormat, ) from dlt.common.destination.reference import ( @@ -401,6 +402,7 @@ def extract( schema: Schema = None, max_parallel_items: int = ConfigValue, workers: int = ConfigValue, + table_format: TTableFormat = None, schema_contract: TSchemaContract = None, refresh: Optional[TRefreshMode] = None, ) -> ExtractInfo: @@ -419,13 +421,14 @@ def extract( for source in data_to_sources( data, self, - schema, - table_name, - parent_table_name, - write_disposition, - columns, - primary_key, - schema_contract, + schema=schema, + table_name=table_name, + parent_table_name=parent_table_name, + write_disposition=write_disposition, + columns=columns, + primary_key=primary_key, + schema_contract=schema_contract, + table_format=table_format, ): if source.exhausted: raise SourceExhausted(source.name) @@ -472,23 +475,6 @@ def _verify_destination_capabilities( set(caps.supported_loader_file_formats), ) - # verify merge strategy - for table in self.default_schema.data_tables(include_incomplete=True): - if ( - "x-merge-strategy" in table - and caps.supported_merge_strategies - and table["x-merge-strategy"] not in caps.supported_merge_strategies # type: ignore[typeddict-item] - ): - if self.destination.destination_name == "filesystem" and table["x-merge-strategy"] == "delete-insert": # type: ignore[typeddict-item] - # `filesystem` does not support `delete-insert`, but no - # error should be raised because it falls back to `append` - pass - else: - raise DestinationCapabilitiesException( - f"`{table.get('x-merge-strategy')}` merge strategy not supported" - f" for `{self.destination.destination_name}` destination." - ) - @with_runtime_trace() @with_schemas_sync @with_config_section((known_sections.NORMALIZE,)) @@ -499,7 +485,7 @@ def normalize( if is_interactive(): workers = 1 - if loader_file_format and loader_file_format not in ALL_SUPPORTED_FILE_FORMATS: + if loader_file_format and loader_file_format not in LOADER_FILE_FORMATS: raise ValueError(f"{loader_file_format} is unknown.") # check if any schema is present, if not then no data was extracted if not self.default_schema_name: @@ -610,6 +596,7 @@ def run( primary_key: TColumnNames = None, schema: Schema = None, loader_file_format: TLoaderFileFormat = None, + table_format: TTableFormat = None, schema_contract: TSchemaContract = None, refresh: Optional[TRefreshMode] = None, ) -> LoadInfo: @@ -662,6 +649,8 @@ def run( loader_file_format (Literal["jsonl", "insert_values", "parquet"], optional). The file format the loader will use to create the load package. Not all file_formats are compatible with all destinations. Defaults to the preferred file format of the selected destination. + table_format (Literal["delta", "iceberg"], optional). The table format used by the destination to store tables. Currently you can select table format on filesystem and Athena destinations. + schema_contract (TSchemaContract, optional): On override for the schema contract settings, this will replace the schema contract settings for all tables in the schema. Defaults to None. refresh (str | TRefreshMode): Fully or partially reset sources before loading new data in this run. The following refresh modes are supported: @@ -714,6 +703,7 @@ def run( columns=columns, primary_key=primary_key, schema=schema, + table_format=table_format, schema_contract=schema_contract, refresh=refresh or self.refresh, ) @@ -1213,7 +1203,9 @@ def _get_destination_client_initial_config( ) if issubclass(client_spec, DestinationClientStagingConfiguration): - spec: DestinationClientDwhConfiguration = client_spec(as_staging=as_staging) + spec: DestinationClientDwhConfiguration = client_spec( + as_staging_destination=as_staging + ) else: spec = client_spec() spec._bind_dataset_name(self.dataset_name, default_schema_name) @@ -1677,7 +1669,7 @@ def _bump_version_and_extract_state( load_package_state_update["pipeline_state"] = doc self._extract_source( extract_, - data_to_sources(data, self, schema)[0], + data_to_sources(data, self, schema=schema)[0], 1, 1, load_package_state_update=load_package_state_update, diff --git a/dlt/pipeline/warnings.py b/dlt/pipeline/warnings.py index ac46a4eef0..a4e917f970 100644 --- a/dlt/pipeline/warnings.py +++ b/dlt/pipeline/warnings.py @@ -2,7 +2,6 @@ import warnings from dlt.common.warnings import Dlt04DeprecationWarning -from dlt.common.destination import Destination, TDestinationReferenceArg def full_refresh_argument_deprecated(caller_name: str, full_refresh: t.Optional[bool]) -> None: diff --git a/docs/website/docs/dlt-ecosystem/destinations/athena.md b/docs/website/docs/dlt-ecosystem/destinations/athena.md index 2a8b8c6b9d..e6f99adc48 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/athena.md +++ b/docs/website/docs/dlt-ecosystem/destinations/athena.md @@ -131,13 +131,6 @@ def data() -> Iterable[TDataItem]: ... ``` -Alternatively, you can set all tables to use the iceberg format with a config variable: - -```toml -[destination.athena] -force_iceberg = "True" -``` - For every table created as an iceberg table, the Athena destination will create a regular Athena table in the staging dataset of both the filesystem and the Athena glue catalog, and then copy all data into the final iceberg table that lives with the non-iceberg tables in the same dataset on both the filesystem and the glue catalog. Switching from iceberg to regular table or vice versa is not supported. #### `merge` support diff --git a/poetry.lock b/poetry.lock index 0bb8ec1fb3..9cbc4b66ea 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1939,6 +1939,31 @@ files = [ [package.extras] testing = ["flake8", "pytest", "pytest-cov", "pytest-virtualenv", "pytest-xdist", "sphinx"] +[[package]] +name = "connectorx" +version = "0.3.2" +description = "" +optional = false +python-versions = "*" +files = [ + {file = "connectorx-0.3.2-cp310-cp310-macosx_10_7_x86_64.whl", hash = "sha256:98274242c64a2831a8b1c86e0fa2c46a557dd8cbcf00c3adcf5a602455fb02d7"}, + {file = "connectorx-0.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e2b11ba49efd330a7348bef3ce09c98218eea21d92a12dd75cd8f0ade5c99ffc"}, + {file = "connectorx-0.3.2-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:3f6431a30304271f9137bd7854d2850231041f95164c6b749d9ede4c0d92d10c"}, + {file = "connectorx-0.3.2-cp310-none-win_amd64.whl", hash = "sha256:b370ebe8f44d2049254dd506f17c62322cc2db1b782a57f22cce01ddcdcc8fed"}, + {file = "connectorx-0.3.2-cp311-cp311-macosx_10_7_x86_64.whl", hash = "sha256:d5277fc936a80da3d1dcf889020e45da3493179070d9be8a47500c7001fab967"}, + {file = "connectorx-0.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8cc6c963237c3d3b02f7dcd47e1be9fc6e8b93ef0aeed8694f65c62b3c4688a1"}, + {file = "connectorx-0.3.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:9403902685b3423cba786db01a36f36efef90ae3d429e45b74dadb4ae9e328dc"}, + {file = "connectorx-0.3.2-cp311-none-win_amd64.whl", hash = "sha256:6b5f518194a2cf12d5ad031d488ded4e4678eff3b63551856f2a6f1a83197bb8"}, + {file = "connectorx-0.3.2-cp38-cp38-macosx_10_7_x86_64.whl", hash = "sha256:a5602ae0531e55c58af8cfca92b8e9454fc1ccd82c801cff8ee0f17c728b4988"}, + {file = "connectorx-0.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7c5959bfb4a049bb8ce1f590b5824cd1105460b6552ffec336c4bd740eebd5bd"}, + {file = "connectorx-0.3.2-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:c4387bb27ba3acde0ab6921fdafa3811e09fce0db3d1f1ede8547d9de3aab685"}, + {file = "connectorx-0.3.2-cp38-none-win_amd64.whl", hash = "sha256:4b1920c191be9a372629c31c92d5f71fc63f49f283e5adfc4111169de40427d9"}, + {file = "connectorx-0.3.2-cp39-cp39-macosx_10_7_x86_64.whl", hash = "sha256:4473fc06ac3618c673cea63a7050e721fe536782d5c1b6e433589c37a63de704"}, + {file = "connectorx-0.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4009b16399457340326137a223921a24e3e166b45db4dbf3ef637b9981914dc2"}, + {file = "connectorx-0.3.2-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:74f5b93535663cf47f9fc3d7964f93e652c07003fa71c38d7a68f42167f54bba"}, + {file = "connectorx-0.3.2-cp39-none-win_amd64.whl", hash = "sha256:0b80acca13326856c14ee726b47699011ab1baa10897180240c8783423ca5e8c"}, +] + [[package]] name = "connectorx" version = "0.3.3" @@ -2659,7 +2684,7 @@ prefixed = ">=0.3.2" name = "et-xmlfile" version = "1.1.0" description = "An implementation of lxml.xmlfile for the standard library" -optional = true +optional = false python-versions = ">=3.6" files = [ {file = "et_xmlfile-1.1.0-py3-none-any.whl", hash = "sha256:a2ba85d1d6a74ef63837eed693bcb89c3f752169b0e3e7ae5b16ca5e1b3deada"}, @@ -5806,7 +5831,7 @@ datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] name = "openpyxl" version = "3.1.2" description = "A Python library to read/write Excel 2010 xlsx/xlsm files" -optional = true +optional = false python-versions = ">=3.6" files = [ {file = "openpyxl-3.1.2-py2.py3-none-any.whl", hash = "sha256:f91456ead12ab3c6c2e9491cf33ba6d08357d802192379bb482f1033ade496f5"}, @@ -9645,4 +9670,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "ae02db22861b419596adea95c7ddff27317ae91579c6e9138f777489fe20c05a" +content-hash = "e5342d5cdc135a27b89747a3665ff68aa76025efcfde6f86318144ce0fd70284" diff --git a/pyproject.toml b/pyproject.toml index 28d6056f60..77cc0b7824 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -161,6 +161,16 @@ types-regex = "^2024.5.15.20240519" flake8-print = "^5.0.0" mimesis = "^7.0.0" +[tool.poetry.group.sources] +optional = true +[tool.poetry.group.sources.dependencies] +connectorx = [ + {version = "0.3.2", python = "3.8"}, + {version = ">=0.3.3", python = ">=3.9"} +] +pymysql = "^1.1.0" +openpyxl = "^3" + [tool.poetry.group.pipeline] optional = true diff --git a/tests/common/cases/schemas/eth/ethereum_schema_v9.yml b/tests/common/cases/schemas/eth/ethereum_schema_v9.yml index c56ff85a9f..a7413575a5 100644 --- a/tests/common/cases/schemas/eth/ethereum_schema_v9.yml +++ b/tests/common/cases/schemas/eth/ethereum_schema_v9.yml @@ -1,5 +1,5 @@ version: 17 -version_hash: PgEHvn5+BHV1jNzNYpx9aDpq6Pq1PSSetufj/h0hKg4= +version_hash: oHfYGTI2GHOxuzwVz6+yvMilXUvHYhxrxkanC2T6MAI= engine_version: 9 name: ethereum tables: @@ -166,7 +166,6 @@ tables: x-normalizer: seen-data: true blocks__transactions: - parent: blocks columns: _dlt_id: nullable: false @@ -178,6 +177,7 @@ tables: primary_key: true foreign_key: true data_type: bigint + merge_key: true name: block_number transaction_index: nullable: false @@ -267,7 +267,6 @@ tables: x-normalizer: seen-data: true blocks__transactions__logs: - parent: blocks__transactions columns: _dlt_id: nullable: false @@ -291,13 +290,13 @@ tables: block_number: nullable: false primary_key: true - foreign_key: true + merge_key: true data_type: bigint name: block_number transaction_index: nullable: false primary_key: true - foreign_key: true + merge_key: true data_type: bigint name: transaction_index log_index: diff --git a/tests/common/destination/__init__.py b/tests/common/destination/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/common/destination/test_destination_capabilities.py b/tests/common/destination/test_destination_capabilities.py new file mode 100644 index 0000000000..938b9836e5 --- /dev/null +++ b/tests/common/destination/test_destination_capabilities.py @@ -0,0 +1,224 @@ +import pytest + +from dlt.common.destination.exceptions import DestinationCapabilitiesException, UnsupportedDataType +from dlt.common.destination.utils import ( + resolve_merge_strategy, + verify_schema_capabilities, + verify_supported_data_types, +) +from dlt.common.exceptions import TerminalValueError +from dlt.common.schema.exceptions import SchemaIdentifierNormalizationCollision +from dlt.common.schema.schema import Schema +from dlt.common.schema.utils import new_table +from dlt.common.storages.load_package import ParsedLoadJobFileName +from dlt.destinations.impl.bigquery.bigquery_adapter import AUTODETECT_SCHEMA_HINT + + +def test_resolve_merge_strategy() -> None: + schema = Schema("schema") + + table = new_table("table", write_disposition="merge") + delta_table = new_table("delta_table", table_format="delta", write_disposition="merge") + iceberg_table = new_table("delta_table", table_format="iceberg", write_disposition="merge") + + schema.update_table(table) + schema.update_table(delta_table) + schema.update_table(iceberg_table) + + assert resolve_merge_strategy(schema.tables, table) is None + assert resolve_merge_strategy(schema.tables, delta_table) is None + assert resolve_merge_strategy(schema.tables, iceberg_table) is None + + # try default merge dispositions + from dlt.destinations import athena, filesystem, duckdb + + assert resolve_merge_strategy(schema.tables, table, filesystem().capabilities()) is None + assert ( + resolve_merge_strategy(schema.tables, delta_table, filesystem().capabilities()) == "upsert" + ) + assert ( + resolve_merge_strategy(schema.tables, iceberg_table, athena().capabilities()) + == "delete-insert" + ) + + # unknown table formats + assert resolve_merge_strategy(schema.tables, iceberg_table, filesystem().capabilities()) is None + assert resolve_merge_strategy(schema.tables, delta_table, athena().capabilities()) is None + + # not supported strategy + schema.tables["delta_table"]["x-merge-strategy"] = "delete-insert" # type: ignore[typeddict-unknown-key] + with pytest.raises(DestinationCapabilitiesException): + resolve_merge_strategy(schema.tables, delta_table, filesystem().capabilities()) + + # non-default strategy + schema.tables["table"]["x-merge-strategy"] = "scd2" # type: ignore[typeddict-unknown-key] + assert resolve_merge_strategy(schema.tables, table, filesystem().capabilities()) is None + assert resolve_merge_strategy(schema.tables, table, duckdb().capabilities()) == "scd2" + + +def test_verify_capabilities_ident_collisions() -> None: + schema = Schema("schema") + table = new_table( + "table", + write_disposition="merge", + columns=[{"name": "col1", "data_type": "bigint"}, {"name": "COL1", "data_type": "bigint"}], + ) + schema.update_table(table, normalize_identifiers=False) + from dlt.destinations import athena, filesystem + + # case sensitive - no name collision + exceptions = verify_schema_capabilities(schema, filesystem().capabilities(), "filesystem") + assert len(exceptions) == 0 + # case insensitive - collision on column name + exceptions = verify_schema_capabilities(schema, athena().capabilities(), "filesystem") + assert len(exceptions) == 1 + assert isinstance(exceptions[0], SchemaIdentifierNormalizationCollision) + assert exceptions[0].identifier_type == "column" + + table = new_table( + "TABLE", write_disposition="merge", columns=[{"name": "col1", "data_type": "bigint"}] + ) + schema.update_table(table, normalize_identifiers=False) + exceptions = verify_schema_capabilities(schema, filesystem().capabilities(), "filesystem") + assert len(exceptions) == 0 + # case insensitive - collision on table name + exceptions = verify_schema_capabilities(schema, athena().capabilities(), "filesystem") + assert len(exceptions) == 2 + assert isinstance(exceptions[1], SchemaIdentifierNormalizationCollision) + assert exceptions[1].identifier_type == "table" + + +def test_verify_capabilities_data_types() -> None: + schema = Schema("schema") + table = new_table( + "table", + write_disposition="merge", + columns=[{"name": "col1", "data_type": "time"}, {"name": "col2", "data_type": "date"}], + ) + schema.update_table(table, normalize_identifiers=False) + + schema.update_table(table, normalize_identifiers=False) + from dlt.destinations import athena, filesystem, databricks, redshift + + new_jobs_parquet = [ParsedLoadJobFileName.parse("table.12345.1.parquet")] + new_jobs_jsonl = [ParsedLoadJobFileName.parse("table.12345.1.jsonl")] + + # all data types supported (no mapper) + exceptions = verify_supported_data_types( + schema.tables.values(), new_jobs_parquet, filesystem().capabilities(), "filesystem" # type: ignore[arg-type] + ) + assert len(exceptions) == 0 + # time not supported via list + exceptions = verify_supported_data_types( + schema.tables.values(), new_jobs_parquet, athena().capabilities(), "athena" # type: ignore[arg-type] + ) + assert len(exceptions) == 1 + assert isinstance(exceptions[0], UnsupportedDataType) + assert exceptions[0].destination_type == "athena" + assert exceptions[0].table_name == "table" + assert exceptions[0].column == "col1" + assert exceptions[0].file_format == "parquet" + assert exceptions[0].available_in_formats == [] + + # all supported on parquet + exceptions = verify_supported_data_types( + schema.tables.values(), new_jobs_parquet, databricks().capabilities(), "databricks" # type: ignore[arg-type] + ) + assert len(exceptions) == 0 + # date not supported on jsonl + exceptions = verify_supported_data_types( + schema.tables.values(), new_jobs_jsonl, databricks().capabilities(), "databricks" # type: ignore[arg-type] + ) + assert len(exceptions) == 1 + assert isinstance(exceptions[0], UnsupportedDataType) + assert exceptions[0].column == "col2" + assert exceptions[0].available_in_formats == ["parquet"] + + # exclude binary type if precision is set on column + schema_bin = Schema("schema_bin") + table = new_table( + "table", + write_disposition="merge", + columns=[ + {"name": "binary_1", "data_type": "binary"}, + {"name": "binary_2", "data_type": "binary", "precision": 128}, + ], + ) + schema_bin.update_table(table, normalize_identifiers=False) + exceptions = verify_supported_data_types( + schema_bin.tables.values(), # type: ignore[arg-type] + new_jobs_jsonl, + redshift().capabilities(), + "redshift", + ) + # binary not supported on jsonl + assert len(exceptions) == 2 + exceptions = verify_supported_data_types( + schema_bin.tables.values(), new_jobs_parquet, redshift().capabilities(), "redshift" # type: ignore[arg-type] + ) + # fixed length not supported on parquet + assert len(exceptions) == 1 + assert isinstance(exceptions[0], UnsupportedDataType) + assert exceptions[0].data_type == "binary(128)" + assert exceptions[0].column == "binary_2" + assert exceptions[0].available_in_formats == ["insert_values"] + + # check complex type on bigquery + from dlt.destinations import bigquery + + schema_complex = Schema("complex") + table = new_table( + "table", + write_disposition="merge", + columns=[ + {"name": "complex_1", "data_type": "complex"}, + ], + ) + schema_complex.update_table(table, normalize_identifiers=False) + exceptions = verify_supported_data_types( + schema_complex.tables.values(), new_jobs_parquet, bigquery().capabilities(), "bigquery" # type: ignore[arg-type] + ) + assert len(exceptions) == 1 + assert isinstance(exceptions[0], UnsupportedDataType) + assert exceptions[0].data_type == "complex" + + # enable schema autodetect + table[AUTODETECT_SCHEMA_HINT] = True # type: ignore[typeddict-unknown-key] + exceptions = verify_supported_data_types( + schema_complex.tables.values(), new_jobs_parquet, bigquery().capabilities(), "bigquery" # type: ignore[arg-type] + ) + assert len(exceptions) == 0 + + # lancedb uses arrow types in type mapper + from dlt.destinations import lancedb + + exceptions = verify_supported_data_types( + schema_bin.tables.values(), # type: ignore[arg-type] + new_jobs_jsonl, + lancedb().capabilities(), + "lancedb", + ) + try: + import pyarrow + + assert len(exceptions) == 0 + except ImportError: + assert len(exceptions) > 0 + + # provoke mapping error, precision not supported on NTZ timestamp + schema_timezone = Schema("tx") + table = new_table( + "table", + write_disposition="merge", + columns=[ + {"name": "ts_1", "data_type": "timestamp", "precision": 12, "timezone": False}, + ], + ) + schema_timezone.update_table(table, normalize_identifiers=False) + from dlt.destinations import motherduck + + exceptions = verify_supported_data_types( + schema_timezone.tables.values(), new_jobs_parquet, motherduck().capabilities(), "motherduck" # type: ignore[arg-type] + ) + assert len(exceptions) == 1 + assert isinstance(exceptions[0], TerminalValueError) diff --git a/tests/common/test_destination.py b/tests/common/destination/test_reference.py similarity index 100% rename from tests/common/test_destination.py rename to tests/common/destination/test_reference.py diff --git a/tests/common/normalizers/test_json_relational.py b/tests/common/normalizers/test_json_relational.py index 159e33da4d..c3791b2a40 100644 --- a/tests/common/normalizers/test_json_relational.py +++ b/tests/common/normalizers/test_json_relational.py @@ -133,13 +133,16 @@ def test_child_table_linking(norm: RelationalNormalizer) -> None: assert [e[1]["value"] for e in list_rows] == ["a", "b", "c"] -def test_child_table_linking_primary_key(norm: RelationalNormalizer) -> None: +def test_skip_nested_link_when_no_parent(norm: RelationalNormalizer) -> None: row = { "id": "level0", "f": [{"id": "level1", "l": ["a", "b", "c"], "v": 120, "o": [{"a": 1}, {"a": 2}]}], } - norm.schema.merge_hints({"primary_key": [TSimpleRegex("id")]}) - norm.schema._compile_settings() + + # create table__f without parent so it is not seen as nested table + # still normalizer will write data to it but not link + table__f = new_table("table__f", parent_table_name=None) + norm.schema.update_table(table__f) rows = list(norm._normalize_row(row, {}, ("table",))) root = next(t for t in rows if t[0][0] == "table")[1] @@ -415,6 +418,10 @@ def test_list_in_list() -> None: "zen__webpath", parent_table_name="zen", columns=[{"name": "list", "data_type": "complex"}] ) schema.update_table(path_table) + assert "zen__webpath" in schema.tables + # clear cache with complex paths + schema.data_item_normalizer._is_complex_type.cache_clear() # type: ignore[attr-defined] + rows = list(schema.normalize_data_item(chats, "1762162.1212", "zen")) # both lists are complex types now assert len(rows) == 3 @@ -734,39 +741,41 @@ def test_table_name_meta_normalized() -> None: assert rows[0][0][0] == "channel_surfing" -def test_parse_with_primary_key() -> None: - schema = create_schema_with_name("discord") - schema._merge_hints({"primary_key": ["id"]}) # type: ignore[list-item] - schema._compile_settings() - add_dlt_root_id_propagation(schema.data_item_normalizer) # type: ignore[arg-type] - - row = {"id": "817949077341208606", "w_id": [{"id": 9128918293891111, "wo_id": [1, 2, 3]}]} - rows = list(schema.normalize_data_item(row, "load_id", "discord")) - # get root - root = next(t[1] for t in rows if t[0][0] == "discord") - assert root["_dlt_id"] != digest128("817949077341208606", DLT_ID_LENGTH_BYTES) - assert "_dlt_parent_id" not in root - assert "_dlt_root_id" not in root - assert root["_dlt_load_id"] == "load_id" - - el_w_id = next(t[1] for t in rows if t[0][0] == "discord__w_id") - # this also has primary key - assert el_w_id["_dlt_id"] != digest128("9128918293891111", DLT_ID_LENGTH_BYTES) - assert "_dlt_parent_id" not in el_w_id - assert "_dlt_list_idx" not in el_w_id - # if enabled, dlt_root is always propagated - assert "_dlt_root_id" in el_w_id - - # this must have deterministic child key - f_wo_id = next( - t[1] for t in rows if t[0][0] == "discord__w_id__wo_id" and t[1]["_dlt_list_idx"] == 2 - ) - assert f_wo_id["value"] == 3 - assert f_wo_id["_dlt_root_id"] != digest128("817949077341208606", DLT_ID_LENGTH_BYTES) - assert f_wo_id["_dlt_parent_id"] != digest128("9128918293891111", DLT_ID_LENGTH_BYTES) - assert f_wo_id["_dlt_id"] == RelationalNormalizer._get_child_row_hash( - f_wo_id["_dlt_parent_id"], "discord__w_id__wo_id", 2 - ) +def test_row_id_is_primary_key() -> None: + # TODO: if there's a column with row_id hint and primary_key, it should get propagated + pass + # schema = create_schema_with_name("discord") + # schema._merge_hints({"primary_key": ["id"]}) # type: ignore[list-item] + # schema._compile_settings() + # add_dlt_root_id_propagation(schema.data_item_normalizer) # type: ignore[arg-type] + + # row = {"id": "817949077341208606", "w_id": [{"id": 9128918293891111, "wo_id": [1, 2, 3]}]} + # rows = list(schema.normalize_data_item(row, "load_id", "discord")) + # # get root + # root = next(t[1] for t in rows if t[0][0] == "discord") + # assert root["_dlt_id"] != digest128("817949077341208606", DLT_ID_LENGTH_BYTES) + # assert "_dlt_parent_id" not in root + # assert "_dlt_root_id" not in root + # assert root["_dlt_load_id"] == "load_id" + + # el_w_id = next(t[1] for t in rows if t[0][0] == "discord__w_id") + # # this also has primary key + # assert el_w_id["_dlt_id"] != digest128("9128918293891111", DLT_ID_LENGTH_BYTES) + # assert "_dlt_parent_id" not in el_w_id + # assert "_dlt_list_idx" not in el_w_id + # # if enabled, dlt_root is always propagated + # assert "_dlt_root_id" in el_w_id + + # # this must have deterministic child key + # f_wo_id = next( + # t[1] for t in rows if t[0][0] == "discord__w_id__wo_id" and t[1]["_dlt_list_idx"] == 2 + # ) + # assert f_wo_id["value"] == 3 + # assert f_wo_id["_dlt_root_id"] != digest128("817949077341208606", DLT_ID_LENGTH_BYTES) + # assert f_wo_id["_dlt_parent_id"] != digest128("9128918293891111", DLT_ID_LENGTH_BYTES) + # assert f_wo_id["_dlt_id"] == RelationalNormalizer._get_nested_row_hash( + # f_wo_id["_dlt_parent_id"], "discord__w_id__wo_id", 2 + # ) def test_keeps_none_values() -> None: @@ -868,6 +877,18 @@ def test_propagation_update_on_table_change(norm: RelationalNormalizer): ] == {"_dlt_id": "_dlt_root_id", "prop1": "prop2"} +def test_caching_perf(norm: RelationalNormalizer) -> None: + from time import time + + table = new_table("test") + table["x-normalizer"] = {} + start = time() + for _ in range(100000): + norm._is_complex_type(norm.schema, "test", "field", 0, 0) + # norm._get_table_nesting_level(norm.schema, "test") + print(f"{time() - start}") + + def set_max_nesting(norm: RelationalNormalizer, max_nesting: int) -> None: RelationalNormalizer.update_normalizer_config(norm.schema, {"max_nesting": max_nesting}) norm._reset() diff --git a/tests/common/normalizers/test_import_normalizers.py b/tests/common/schema/test_import_normalizers.py similarity index 97% rename from tests/common/normalizers/test_import_normalizers.py rename to tests/common/schema/test_import_normalizers.py index fe356de327..a1e3d775f0 100644 --- a/tests/common/normalizers/test_import_normalizers.py +++ b/tests/common/schema/test_import_normalizers.py @@ -4,13 +4,6 @@ from dlt.common.configuration.container import Container from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.normalizers.typing import TNormalizersConfig -from dlt.common.normalizers.utils import ( - DEFAULT_NAMING_NAMESPACE, - explicit_normalizers, - import_normalizers, - naming_from_reference, - serialize_reference, -) from dlt.common.normalizers.json.relational import DataItemNormalizer as RelationalNormalizer from dlt.common.normalizers.naming import snake_case, direct from dlt.common.normalizers.naming.exceptions import ( @@ -18,10 +11,17 @@ NamingTypeNotFound, UnknownNamingModule, ) - from tests.common.normalizers.custom_normalizers import ( DataItemNormalizer as CustomRelationalNormalizer, ) +from dlt.common.schema.normalizers import ( + DEFAULT_NAMING_NAMESPACE, + explicit_normalizers, + import_normalizers, + naming_from_reference, + serialize_reference, +) + from tests.utils import preserve_environ @@ -87,7 +87,9 @@ def test_naming_from_reference() -> None: import sys try: - sys.path.insert(0, os.path.dirname(__file__)) + from tests.common.normalizers import custom_normalizers + + sys.path.insert(0, os.path.dirname(custom_normalizers.__file__)) assert naming_from_reference("custom_normalizers").name() == "custom_normalizers" assert ( naming_from_reference("custom_normalizers.NamingConvention").name() @@ -113,10 +115,8 @@ def test_naming_from_reference() -> None: with pytest.raises(ValueError): naming_from_reference(snake_case.NamingConvention()) # type: ignore[arg-type] - # with capabilities - caps = DestinationCapabilitiesContext.generic_capabilities() - caps.max_identifier_length = 120 - naming = naming_from_reference(snake_case.NamingConvention, caps) + # with max length + naming = naming_from_reference(snake_case.NamingConvention, 120) assert naming.max_length == 120 diff --git a/tests/common/schema/test_merges.py b/tests/common/schema/test_merges.py index 893fd1db5f..7b82cee1eb 100644 --- a/tests/common/schema/test_merges.py +++ b/tests/common/schema/test_merges.py @@ -304,12 +304,15 @@ def test_diff_tables() -> None: changed = deepcopy(table) changed["description"] = "new description" changed["name"] = "new name" - partial = utils.diff_table("schema", deepcopy(table), changed) + # names must be identical + renamed_table = deepcopy(table) + renamed_table["name"] = "new name" + partial = utils.diff_table("schema", renamed_table, changed) print(partial) assert partial == {"name": "new name", "description": "new description", "columns": {}} # ignore identical table props - existing = deepcopy(table) + existing = deepcopy(renamed_table) changed["write_disposition"] = "append" changed["schema_contract"] = "freeze" partial = utils.diff_table("schema", deepcopy(existing), changed) @@ -360,12 +363,19 @@ def test_diff_tables_conflicts() -> None: "columns": {"test": COL_1_HINTS, "test_2": COL_2_HINTS}, } - other = utils.new_table("table_2") + other = utils.new_table("table") with pytest.raises(TablePropertiesConflictException) as cf_ex: utils.diff_table("schema", table, other) assert cf_ex.value.table_name == "table" assert cf_ex.value.prop_name == "parent" + # conflict on name + other = utils.new_table("other_name") + with pytest.raises(TablePropertiesConflictException) as cf_ex: + utils.diff_table("schema", table, other) + assert cf_ex.value.table_name == "table" + assert cf_ex.value.prop_name == "name" + # conflict on data types in columns changed = deepcopy(table) changed["columns"]["test"]["data_type"] = "bigint" diff --git a/tests/common/schema/test_schema.py b/tests/common/schema/test_schema.py index 93be165358..8fef0184b1 100644 --- a/tests/common/schema/test_schema.py +++ b/tests/common/schema/test_schema.py @@ -726,7 +726,7 @@ def test_compare_columns() -> None: ) # any of the hints may differ for hint in COLUMN_HINTS: - table["columns"]["col3"][hint] = True # type: ignore[typeddict-unknown-key] + table["columns"]["col3"][hint] = True # name may not differ assert ( utils.compare_complete_columns(table["columns"]["col3"], table["columns"]["col4"]) is False diff --git a/tests/common/storages/test_schema_storage.py b/tests/common/storages/test_schema_storage.py index ffbd2ecf1b..1f82a85abb 100644 --- a/tests/common/storages/test_schema_storage.py +++ b/tests/common/storages/test_schema_storage.py @@ -3,7 +3,7 @@ import yaml from dlt.common import json -from dlt.common.normalizers.utils import explicit_normalizers +from dlt.common.schema.normalizers import explicit_normalizers from dlt.common.schema.schema import Schema from dlt.common.storages.exceptions import ( InStorageSchemaModified, diff --git a/tests/common/utils.py b/tests/common/utils.py index 32741128b8..85f79d7560 100644 --- a/tests/common/utils.py +++ b/tests/common/utils.py @@ -20,7 +20,7 @@ def IMPORTED_VERSION_HASH_ETH_V9() -> str: # for import schema tests, change when upgrading the schema version eth_V9 = load_yml_case("schemas/eth/ethereum_schema_v9") - assert eth_V9["version_hash"] == "PgEHvn5+BHV1jNzNYpx9aDpq6Pq1PSSetufj/h0hKg4=" + assert eth_V9["version_hash"] == "oHfYGTI2GHOxuzwVz6+yvMilXUvHYhxrxkanC2T6MAI=" # remove processing hints before installing as import schema # ethereum schema is a "dirty" schema with processing hints eth = Schema.from_dict(eth_V9, remove_processing_hints=True) diff --git a/tests/extract/cases/eth_source/ethereum.schema.yaml b/tests/extract/cases/eth_source/ethereum.schema.yaml index 5a8db47163..a7413575a5 100644 --- a/tests/extract/cases/eth_source/ethereum.schema.yaml +++ b/tests/extract/cases/eth_source/ethereum.schema.yaml @@ -1,6 +1,6 @@ -version: 14 -version_hash: VuzNqiLOk7XuPxYLndMFMPHTDVItKU5ayiy70nQLdus= -engine_version: 7 +version: 17 +version_hash: oHfYGTI2GHOxuzwVz6+yvMilXUvHYhxrxkanC2T6MAI= +engine_version: 9 name: ethereum tables: _dlt_loads: @@ -163,8 +163,9 @@ tables: schema_contract: {} name: blocks resource: blocks + x-normalizer: + seen-data: true blocks__transactions: - parent: blocks columns: _dlt_id: nullable: false @@ -176,6 +177,7 @@ tables: primary_key: true foreign_key: true data_type: bigint + merge_key: true name: block_number transaction_index: nullable: false @@ -262,8 +264,9 @@ tables: data_type: decimal name: eth_value name: blocks__transactions + x-normalizer: + seen-data: true blocks__transactions__logs: - parent: blocks__transactions columns: _dlt_id: nullable: false @@ -287,13 +290,13 @@ tables: block_number: nullable: false primary_key: true - foreign_key: true + merge_key: true data_type: bigint name: block_number transaction_index: nullable: false primary_key: true - foreign_key: true + merge_key: true data_type: bigint name: transaction_index log_index: @@ -314,6 +317,8 @@ tables: data_type: text name: transaction_hash name: blocks__transactions__logs + x-normalizer: + seen-data: true blocks__transactions__logs__topics: parent: blocks__transactions__logs columns: @@ -341,6 +346,8 @@ tables: data_type: text name: value name: blocks__transactions__logs__topics + x-normalizer: + seen-data: true blocks__transactions__access_list: parent: blocks__transactions columns: @@ -368,6 +375,8 @@ tables: data_type: text name: address name: blocks__transactions__access_list + x-normalizer: + seen-data: true blocks__transactions__access_list__storage_keys: parent: blocks__transactions__access_list columns: @@ -395,6 +404,8 @@ tables: data_type: text name: value name: blocks__transactions__access_list__storage_keys + x-normalizer: + seen-data: true blocks__uncles: parent: blocks columns: @@ -422,6 +433,8 @@ tables: data_type: text name: value name: blocks__uncles + x-normalizer: + seen-data: true settings: default_hints: foreign_key: @@ -456,4 +469,7 @@ normalizers: blocks: timestamp: block_timestamp hash: block_hash +previous_hashes: +- C5An8WClbavalXDdNSqXbdI7Swqh/mTWMcwWKCF//EE= +- yjMtV4Zv0IJlfR5DPMwuXxGg8BRhy7E79L26XAHWEGE= diff --git a/tests/libs/pyarrow/test_pyarrow_normalizer.py b/tests/libs/pyarrow/test_pyarrow_normalizer.py index d975702ad8..32ee5fdafc 100644 --- a/tests/libs/pyarrow/test_pyarrow_normalizer.py +++ b/tests/libs/pyarrow/test_pyarrow_normalizer.py @@ -4,8 +4,8 @@ import pytest from dlt.common.libs.pyarrow import normalize_py_arrow_item, NameNormalizationCollision -from dlt.common.normalizers.utils import explicit_normalizers, import_normalizers from dlt.common.schema.utils import new_column, TColumnSchema +from dlt.common.schema.normalizers import explicit_normalizers, import_normalizers from dlt.common.destination import DestinationCapabilitiesContext diff --git a/tests/libs/test_parquet_writer.py b/tests/libs/test_parquet_writer.py index 5fdc6d6cc2..80172bf5d8 100644 --- a/tests/libs/test_parquet_writer.py +++ b/tests/libs/test_parquet_writer.py @@ -296,7 +296,7 @@ def _assert_arrow_field(field: int, prec: str) -> None: else: assert column_type.tz is None - _assert_arrow_field(0, "us") + _assert_arrow_field(0, "s") _assert_arrow_field(1, "ms") _assert_arrow_field(2, "us") _assert_arrow_field(3, "ns") @@ -306,10 +306,12 @@ def _assert_arrow_field(field: int, prec: str) -> None: def _assert_pq_column(col: int, prec: str) -> None: info = json.loads(reader.metadata.schema.column(col).logical_type.to_json()) + print(info) assert info["isAdjustedToUTC"] is adjusted assert info["timeUnit"] == prec - _assert_pq_column(0, "microseconds") + # apparently storting seconds is not supported + _assert_pq_column(0, "milliseconds") _assert_pq_column(1, "milliseconds") _assert_pq_column(2, "microseconds") _assert_pq_column(3, "nanoseconds") diff --git a/tests/load/athena_iceberg/test_athena_iceberg.py b/tests/load/athena_iceberg/test_athena_iceberg.py index 0ef935a8bc..6190f8793a 100644 --- a/tests/load/athena_iceberg/test_athena_iceberg.py +++ b/tests/load/athena_iceberg/test_athena_iceberg.py @@ -1,8 +1,8 @@ import pytest -import os from typing import Iterator, Any import dlt +from tests.load.utils import DestinationTestConfiguration, destinations_configs from tests.pipeline.utils import load_table_counts from dlt.destinations.exceptions import DatabaseTerminalException @@ -11,19 +11,22 @@ pytestmark = pytest.mark.essential -def test_iceberg() -> None: +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + default_sql_configs=True, + with_table_format="iceberg", + subset=["athena"], + ), + ids=lambda x: x.name, +) +def test_iceberg(destination_config: DestinationTestConfiguration) -> None: """ We write two tables, one with the iceberg flag, one without. We expect the iceberg table and its subtables to accept update commands and the other table to reject them. """ - os.environ["DESTINATION__FILESYSTEM__BUCKET_URL"] = "s3://dlt-ci-test-bucket" - pipeline = dlt.pipeline( - pipeline_name="athena-iceberg", - destination="athena", - staging="filesystem", - dev_mode=True, - ) + pipeline = destination_config.setup_pipeline("test_iceberg", dev_mode=True) def items() -> Iterator[Any]: yield { @@ -67,3 +70,57 @@ def items_iceberg(): # modifying iceberg table will succeed client.execute_sql("UPDATE items_iceberg SET name='new name'") client.execute_sql("UPDATE items_iceberg__sub_items SET name='super new name'") + + +@pytest.mark.parametrize( + "destination_config", + destinations_configs( + default_sql_configs=True, + with_table_format="iceberg", + subset=["athena"], + ), + ids=lambda x: x.name, +) +def test_force_iceberg_deprecation(destination_config: DestinationTestConfiguration) -> None: + """Fails on deprecated force_iceberg option""" + destination_config.force_iceberg = True + pipeline = destination_config.setup_pipeline("test_force_iceberg_deprecation", dev_mode=True) + + def items() -> Iterator[Any]: + yield { + "id": 1, + "name": "item", + "sub_items": [{"id": 101, "name": "sub item 101"}, {"id": 101, "name": "sub item 102"}], + } + + @dlt.resource(name="items_normal", write_disposition="append") + def items_normal(): + yield from items() + + @dlt.resource(name="items_hive", write_disposition="append", table_format="hive") + def items_hive(): + yield from items() + + print(pipeline.run([items_normal, items_hive])) + + # items_normal should load as iceberg + # _dlt_pipeline_state should load as iceberg (IMPORTANT for backward comp) + + with pipeline.sql_client() as client: + client.execute_sql("SELECT * FROM items_normal") + client.execute_sql("SELECT * FROM items_hive") + + with pytest.raises(DatabaseTerminalException) as dbex: + client.execute_sql("UPDATE items_hive SET name='new name'") + assert "Modifying Hive table rows is only supported for transactional tables" in str(dbex) + + # modifying iceberg table will succeed + client.execute_sql("UPDATE items_normal SET name='new name'") + client.execute_sql("UPDATE items_normal__sub_items SET name='super new name'") + client.execute_sql("UPDATE _dlt_pipeline_state SET pipeline_name='new name'") + + # trigger deprecation warning + from dlt.destinations import athena + + athena_c = athena(force_iceberg=True).configuration(athena().spec()._bind_dataset_name("ds")) + assert athena_c.force_iceberg is True diff --git a/tests/load/bigquery/test_bigquery_client.py b/tests/load/bigquery/test_bigquery_client.py index c92f18e159..10ee55cc6c 100644 --- a/tests/load/bigquery/test_bigquery_client.py +++ b/tests/load/bigquery/test_bigquery_client.py @@ -23,7 +23,10 @@ from dlt.destinations.impl.bigquery.bigquery import BigQueryClient, BigQueryClientConfiguration from dlt.destinations.exceptions import LoadJobNotExistsException, LoadJobTerminalException -from dlt.destinations.impl.bigquery.bigquery_adapter import AUTODETECT_SCHEMA_HINT +from dlt.destinations.impl.bigquery.bigquery_adapter import ( + AUTODETECT_SCHEMA_HINT, + should_autodetect_schema, +) from tests.utils import TEST_STORAGE_ROOT, delete_test_storage from tests.common.utils import json_case_path as common_json_case_path from tests.common.configuration.utils import environment @@ -277,23 +280,35 @@ def test_bigquery_different_project_id(bigquery_project_id) -> None: def test_bigquery_autodetect_configuration(client: BigQueryClient) -> None: # no schema autodetect - assert client._should_autodetect_schema("event_slot") is False - assert client._should_autodetect_schema("_dlt_loads") is False + event_slot = client.prepare_load_table("event_slot") + _dlt_loads = client.prepare_load_table("_dlt_loads") + assert should_autodetect_schema(event_slot) is False + assert should_autodetect_schema(_dlt_loads) is False # add parent table child = new_table("event_slot__values", "event_slot") - client.schema.update_table(child) - assert client._should_autodetect_schema("event_slot__values") is False + client.schema.update_table(child, normalize_identifiers=False) + event_slot__values = client.prepare_load_table("event_slot__values") + assert should_autodetect_schema(event_slot__values) is False + # enable global config client.config.autodetect_schema = True - assert client._should_autodetect_schema("event_slot") is True - assert client._should_autodetect_schema("_dlt_loads") is False - assert client._should_autodetect_schema("event_slot__values") is True + # prepare again + event_slot = client.prepare_load_table("event_slot") + _dlt_loads = client.prepare_load_table("_dlt_loads") + event_slot__values = client.prepare_load_table("event_slot__values") + assert should_autodetect_schema(event_slot) is True + assert should_autodetect_schema(_dlt_loads) is False + assert should_autodetect_schema(event_slot__values) is True + # enable hint per table client.config.autodetect_schema = False client.schema.get_table("event_slot")[AUTODETECT_SCHEMA_HINT] = True # type: ignore[typeddict-unknown-key] - assert client._should_autodetect_schema("event_slot") is True - assert client._should_autodetect_schema("_dlt_loads") is False - assert client._should_autodetect_schema("event_slot__values") is True + event_slot = client.prepare_load_table("event_slot") + _dlt_loads = client.prepare_load_table("_dlt_loads") + event_slot__values = client.prepare_load_table("event_slot__values") + assert should_autodetect_schema(event_slot) is True + assert should_autodetect_schema(_dlt_loads) is False + assert should_autodetect_schema(event_slot__values) is True def test_bigquery_job_resuming(client: BigQueryClient, file_storage: FileStorage) -> None: @@ -311,14 +326,14 @@ def test_bigquery_job_resuming(client: BigQueryClient, file_storage: FileStorage r_job = cast( RunnableLoadJob, client.create_load_job( - client.schema.get_table(user_table_name), + client.prepare_load_table(user_table_name), file_storage.make_full_path(job.file_name()), uniq_id(), ), ) # job will be automatically found and resumed - r_job.set_run_vars(uniq_id(), client.schema, client.schema.tables[user_table_name]) + r_job.set_run_vars(uniq_id(), client.schema, client.prepare_load_table(user_table_name)) r_job.run_managed(client) assert r_job.state() == "completed" assert r_job._resumed_job # type: ignore diff --git a/tests/load/bigquery/test_bigquery_table_builder.py b/tests/load/bigquery/test_bigquery_table_builder.py index 63ac645113..31199bd8e4 100644 --- a/tests/load/bigquery/test_bigquery_table_builder.py +++ b/tests/load/bigquery/test_bigquery_table_builder.py @@ -197,7 +197,7 @@ def test_create_table_case_insensitive(ci_gcp_client: BigQueryClient) -> None: ) assert "Event_TEST_tablE" in ci_gcp_client.schema.tables with pytest.raises(SchemaIdentifierNormalizationCollision) as coll_ex: - ci_gcp_client.update_stored_schema([]) + ci_gcp_client.verify_schema() assert coll_ex.value.conflict_identifier_name == "Event_test_tablE" assert coll_ex.value.table_name == "Event_TEST_tablE" @@ -205,6 +205,7 @@ def test_create_table_case_insensitive(ci_gcp_client: BigQueryClient) -> None: ci_gcp_client.capabilities.has_case_sensitive_identifiers = True # now the check passes, we are stopped because it is not allowed to change schema in the loader with pytest.raises(DestinationSchemaTampered): + ci_gcp_client.verify_schema() ci_gcp_client.update_stored_schema([]) diff --git a/tests/load/duckdb/test_duckdb_table_builder.py b/tests/load/duckdb/test_duckdb_table_builder.py index 85f86ce84d..eee3c1dc58 100644 --- a/tests/load/duckdb/test_duckdb_table_builder.py +++ b/tests/load/duckdb/test_duckdb_table_builder.py @@ -1,7 +1,9 @@ +from typing import List import pytest from copy import deepcopy import sqlfluff +from dlt.common.schema.typing import TColumnSchema from dlt.common.utils import uniq_id from dlt.common.schema import Schema @@ -31,7 +33,9 @@ def client(empty_schema: Schema) -> DuckDbClient: def test_create_table(client: DuckDbClient) -> None: # non existing table - sql = client._get_table_update_sql("event_test_table", TABLE_UPDATE, False)[0] + sql = client._get_table_update_sql( + "event_test_table", add_timezone_false_on_precision(TABLE_UPDATE), False + )[0] sqlfluff.parse(sql, dialect="duckdb") assert "event_test_table" in sql assert '"col1" BIGINT NOT NULL' in sql @@ -57,13 +61,15 @@ def test_create_table_all_precisions(client: DuckDbClient) -> None: # non existing table sql = client._get_table_update_sql( "event_test_table", - TABLE_UPDATE_ALL_TIMESTAMP_PRECISIONS + TABLE_UPDATE_ALL_INT_PRECISIONS, + add_timezone_false_on_precision( + TABLE_UPDATE_ALL_TIMESTAMP_PRECISIONS + TABLE_UPDATE_ALL_INT_PRECISIONS + ), False, )[0] sqlfluff.parse(sql, dialect="duckdb") assert '"col1_ts" TIMESTAMP_S ' in sql assert '"col2_ts" TIMESTAMP_MS ' in sql - assert '"col3_ts" TIMESTAMP WITH TIME ZONE ' in sql + assert '"col3_ts" TIMESTAMP ' in sql assert '"col4_ts" TIMESTAMP_NS ' in sql assert '"col1_int" TINYINT ' in sql assert '"col2_int" SMALLINT ' in sql @@ -74,7 +80,9 @@ def test_create_table_all_precisions(client: DuckDbClient) -> None: def test_alter_table(client: DuckDbClient) -> None: # existing table has no columns - sqls = client._get_table_update_sql("event_test_table", TABLE_UPDATE, True) + sqls = client._get_table_update_sql( + "event_test_table", add_timezone_false_on_precision(TABLE_UPDATE), True + ) for sql in sqls: sqlfluff.parse(sql, dialect="duckdb") canonical_name = client.sql_client.make_qualified_table_name("event_test_table") @@ -127,3 +135,11 @@ def test_create_table_with_hints(client: DuckDbClient) -> None: sql = client._get_table_update_sql("event_test_table", mod_update, False)[0] sqlfluff.parse(sql) assert '"col2" DOUBLE UNIQUE NOT NULL' in sql + + +def add_timezone_false_on_precision(table_update: List[TColumnSchema]) -> List[TColumnSchema]: + table_update = deepcopy(table_update) + for column in table_update: + if column["data_type"] == "timestamp" and column.get("precision") is not None: + column["timezone"] = False + return table_update diff --git a/tests/load/lancedb/__init__.py b/tests/load/lancedb/__init__.py index fb4bf0b35d..69eb3fb011 100644 --- a/tests/load/lancedb/__init__.py +++ b/tests/load/lancedb/__init__.py @@ -1,3 +1,5 @@ +import pytest from tests.utils import skip_if_not_active skip_if_not_active("lancedb") +pytest.importorskip("lancedb") diff --git a/tests/load/lancedb/test_pipeline.py b/tests/load/lancedb/test_pipeline.py index 728127f833..3dc2a999d4 100644 --- a/tests/load/lancedb/test_pipeline.py +++ b/tests/load/lancedb/test_pipeline.py @@ -1,8 +1,8 @@ import multiprocessing from typing import Iterator, Generator, Any, List, Mapping -import lancedb # type: ignore import pytest +import lancedb # type: ignore from lancedb import DBConnection from lancedb.embeddings import EmbeddingFunctionRegistry # type: ignore @@ -19,7 +19,6 @@ from tests.load.utils import sequence_generator, drop_active_pipeline_data from tests.pipeline.utils import assert_load_info - # Mark all tests as essential, do not remove. pytestmark = pytest.mark.essential diff --git a/tests/load/pipeline/test_arrow_loading.py b/tests/load/pipeline/test_arrow_loading.py index 6d78968996..a86deea799 100644 --- a/tests/load/pipeline/test_arrow_loading.py +++ b/tests/load/pipeline/test_arrow_loading.py @@ -76,11 +76,11 @@ def some_data(): yield item # use csv for postgres to get native arrow processing - file_format = ( + destination_config.file_format = ( destination_config.file_format if destination_config.destination != "postgres" else "csv" ) - load_info = pipeline.run(some_data(), loader_file_format=file_format) + load_info = pipeline.run(some_data(), **destination_config.run_kwargs) assert_load_info(load_info) # assert the table types some_table_columns = pipeline.default_schema.get_table("some_data")["columns"] @@ -234,6 +234,14 @@ def test_load_arrow_with_not_null_columns( item_type: TestDataItemFormat, destination_config: DestinationTestConfiguration ) -> None: """Resource schema contains non-nullable columns. Arrow schema should be written accordingly""" + if ( + destination_config.destination in ("databricks", "redshift") + and destination_config.file_format == "jsonl" + ): + pytest.skip( + "databricks + redshift / json cannot load most of the types so we skip this test" + ) + item, records, _ = arrow_table_all_data_types(item_type, include_json=False, include_time=False) @dlt.resource(primary_key="string", columns=[{"name": "int", "nullable": False}]) @@ -242,7 +250,7 @@ def some_data(): pipeline = destination_config.setup_pipeline("arrow_" + uniq_id()) - pipeline.extract(some_data()) + pipeline.extract(some_data(), table_format=destination_config.table_format) norm_storage = pipeline._get_normalize_storage() extract_files = [ @@ -258,7 +266,7 @@ def some_data(): assert result_tbl.schema.field("int").nullable is False assert result_tbl.schema.field("int").type == pa.int64() - pipeline.normalize() - # Load is succesful + pipeline.normalize(loader_file_format=destination_config.file_format) + # Load is successful info = pipeline.load() assert_load_info(info) diff --git a/tests/load/pipeline/test_athena.py b/tests/load/pipeline/test_athena.py index 3197a19d14..21a0e69794 100644 --- a/tests/load/pipeline/test_athena.py +++ b/tests/load/pipeline/test_athena.py @@ -4,7 +4,9 @@ import dlt, os from dlt.common import pendulum +from dlt.common.destination.exceptions import UnsupportedDataType from dlt.common.utils import uniq_id +from dlt.pipeline.exceptions import PipelineStepFailed from tests.cases import table_update_and_row, assert_all_data_types_row from tests.pipeline.utils import assert_load_info, load_table_counts from tests.pipeline.utils import load_table_counts @@ -40,7 +42,7 @@ def items(): "sub_items": [{"id": 101, "name": "sub item 101"}, {"id": 101, "name": "sub item 102"}], } - pipeline.run(items, loader_file_format=destination_config.file_format) + pipeline.run(items, **destination_config.run_kwargs) # see if we have athena tables with items table_counts = load_table_counts( @@ -71,7 +73,7 @@ def items2(): ], } - pipeline.run(items2) + pipeline.run(items2, **destination_config.run_kwargs) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values()] ) @@ -103,7 +105,7 @@ def my_resource() -> Iterator[Any]: def my_source() -> Any: return my_resource - info = pipeline.run(my_source()) + info = pipeline.run(my_source(), **destination_config.run_kwargs) assert_load_info(info) with pipeline.sql_client() as sql_client: @@ -190,14 +192,10 @@ def my_resource() -> Iterator[Any]: def my_source() -> Any: return my_resource - info = pipeline.run(my_source()) - - assert info.has_failed_jobs - - assert ( - "Athena cannot load TIME columns from parquet tables" - in info.load_packages[0].jobs["failed_jobs"][0].failed_message - ) + with pytest.raises(PipelineStepFailed) as pip_ex: + pipeline.run(my_source(), **destination_config.run_kwargs) + assert isinstance(pip_ex.value.__cause__, UnsupportedDataType) + assert pip_ex.value.__cause__.data_type == "time" @pytest.mark.parametrize( @@ -223,10 +221,10 @@ def test_athena_file_layouts(destination_config: DestinationTestConfiguration, l FILE_LAYOUT_TABLE_NOT_FIRST, # table not the first variable ]: with pytest.raises(CantExtractTablePrefix): - pipeline.run(resources) + pipeline.run(resources, **destination_config.run_kwargs) return - info = pipeline.run(resources) + info = pipeline.run(resources, **destination_config.run_kwargs) assert_load_info(info) table_counts = load_table_counts( @@ -237,11 +235,11 @@ def test_athena_file_layouts(destination_config: DestinationTestConfiguration, l @pytest.mark.parametrize( "destination_config", - destinations_configs(default_sql_configs=True, subset=["athena"], force_iceberg=True), + destinations_configs(default_sql_configs=True, subset=["athena"], with_table_format="iceberg"), ids=lambda x: x.name, ) def test_athena_partitioned_iceberg_table(destination_config: DestinationTestConfiguration): - """Load an iceberg table with partition hints and verifiy partitions are created correctly.""" + """Load an iceberg table with partition hints and verify partitions are created correctly.""" pipeline = destination_config.setup_pipeline("athena_" + uniq_id(), dev_mode=True) data_items = [ @@ -269,7 +267,7 @@ def partitioned_table(): ], ) - info = pipeline.run(partitioned_table) + info = pipeline.run(partitioned_table, **destination_config.run_kwargs) assert_load_info(info) # Get partitions from metadata diff --git a/tests/load/pipeline/test_bigquery.py b/tests/load/pipeline/test_bigquery.py index fd0a55e273..809bd11bc0 100644 --- a/tests/load/pipeline/test_bigquery.py +++ b/tests/load/pipeline/test_bigquery.py @@ -33,7 +33,7 @@ def test_bigquery_numeric_types(destination_config: DestinationTestConfiguration }, ] - info = pipeline.run(iter(data), table_name="big_numeric", columns=columns) # type: ignore[arg-type] + info = pipeline.run(iter(data), table_name="big_numeric", columns=columns, **destination_config.run_kwargs) # type: ignore[arg-type] assert_load_info(info) with pipeline.sql_client() as client: diff --git a/tests/load/pipeline/test_clickhouse.py b/tests/load/pipeline/test_clickhouse.py index 8ad3a7f1a7..9e9c156144 100644 --- a/tests/load/pipeline/test_clickhouse.py +++ b/tests/load/pipeline/test_clickhouse.py @@ -32,7 +32,7 @@ def items() -> Iterator[TDataItem]: pipeline.run( items, - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, staging=destination_config.staging, ) @@ -64,7 +64,7 @@ def items2() -> Iterator[TDataItem]: ], } - pipeline.run(items2) + pipeline.run(items2, **destination_config.run_kwargs) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values()] ) diff --git a/tests/load/pipeline/test_databricks_pipeline.py b/tests/load/pipeline/test_databricks_pipeline.py index 5f8641f9fa..73868f4e97 100644 --- a/tests/load/pipeline/test_databricks_pipeline.py +++ b/tests/load/pipeline/test_databricks_pipeline.py @@ -52,7 +52,7 @@ def test_databricks_external_location(destination_config: DestinationTestConfigu destination=bricks, staging=stage, ) - info = pipeline.run([1, 2, 3], table_name="digits") + info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) assert info.has_failed_jobs is True assert ( "Invalid configuration value detected" @@ -67,7 +67,7 @@ def test_databricks_external_location(destination_config: DestinationTestConfigu destination=bricks, staging=stage, ) - info = pipeline.run([1, 2, 3], table_name="digits") + info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) assert info.has_failed_jobs is True assert ( "credential_x" in pipeline.list_failed_jobs_in_package(info.loads_ids[0])[0].failed_message @@ -78,7 +78,7 @@ def test_databricks_external_location(destination_config: DestinationTestConfigu pipeline = destination_config.setup_pipeline( "test_databricks_external_location", dataset_name=dataset_name, destination=bricks ) - info = pipeline.run([1, 2, 3], table_name="digits") + info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) assert info.has_failed_jobs is True assert ( "credential_x" in pipeline.list_failed_jobs_in_package(info.loads_ids[0])[0].failed_message diff --git a/tests/load/pipeline/test_dbt_helper.py b/tests/load/pipeline/test_dbt_helper.py index 86ee1a646e..0ee02ba4b5 100644 --- a/tests/load/pipeline/test_dbt_helper.py +++ b/tests/load/pipeline/test_dbt_helper.py @@ -93,7 +93,7 @@ def test_run_chess_dbt(destination_config: DestinationTestConfiguration, dbt_ven with pytest.raises(PrerequisitesException): transforms.run_all(source_tests_selector="source:*") # load data - info = pipeline.run(chess(max_players=5, month=9)) + info = pipeline.run(chess(max_players=5, month=9), **destination_config.run_kwargs) print(info) assert pipeline.schema_names == ["chess"] # run all the steps (deps -> seed -> source tests -> run) @@ -150,7 +150,7 @@ def test_run_chess_dbt_to_other_dataset( transforms = dlt.dbt.package(pipeline, "docs/examples/chess/dbt_transform", venv=dbt_venv) # assert pipeline.default_schema_name is None # load data - info = pipeline.run(chess(max_players=5, month=9)) + info = pipeline.run(chess(max_players=5, month=9), **destination_config.run_kwargs) print(info) assert pipeline.schema_names == ["chess"] # store transformations in alternative dataset diff --git a/tests/load/pipeline/test_dremio.py b/tests/load/pipeline/test_dremio.py index 66d1b0be4f..f19f9f44d9 100644 --- a/tests/load/pipeline/test_dremio.py +++ b/tests/load/pipeline/test_dremio.py @@ -22,7 +22,7 @@ def items() -> Iterator[Any]: "sub_items": [{"id": 101, "name": "sub item 101"}, {"id": 101, "name": "sub item 102"}], } - print(pipeline.run([items])) + print(pipeline.run([items], **destination_config.run_kwargs)) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema._schema_tables.values()] diff --git a/tests/load/pipeline/test_drop.py b/tests/load/pipeline/test_drop.py index e1c6ec9d79..f6ddd79b99 100644 --- a/tests/load/pipeline/test_drop.py +++ b/tests/load/pipeline/test_drop.py @@ -6,6 +6,7 @@ import pytest import dlt +from dlt.common.destination.reference import JobClientBase from dlt.extract import DltResource from dlt.common.utils import uniq_id from dlt.pipeline import helpers, state_sync, Pipeline @@ -18,6 +19,7 @@ from dlt.destinations.job_client_impl import SqlJobClientBase from tests.load.utils import destinations_configs, DestinationTestConfiguration +from tests.pipeline.utils import assert_load_info, load_table_counts def _attach(pipeline: Pipeline) -> Pipeline: @@ -124,24 +126,45 @@ def assert_destination_state_loaded(pipeline: Pipeline) -> None: assert pipeline_state == destination_state +@pytest.mark.essential @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs( + default_sql_configs=True, local_filesystem_configs=True, all_buckets_filesystem_configs=True + ), + ids=lambda x: x.name, ) def test_drop_command_resources_and_state(destination_config: DestinationTestConfiguration) -> None: """Test the drop command with resource and state path options and verify correct data is deleted from destination and locally""" source = droppable_source() pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), dev_mode=True) - pipeline.run(source) + info = pipeline.run(source, **destination_config.run_kwargs) + assert_load_info(info) + assert load_table_counts(pipeline, *pipeline.default_schema.tables.keys()) == { + "_dlt_version": 1, + "_dlt_loads": 1, + "droppable_a": 2, + "droppable_b": 1, + "droppable_c": 1, + "droppable_d": 2, + "droppable_no_state": 3, + "_dlt_pipeline_state": 1, + "droppable_b__items": 2, + "droppable_c__items": 1, + "droppable_c__items__labels": 2, + } attached = _attach(pipeline) helpers.drop( - attached, resources=["droppable_c", "droppable_d"], state_paths="data_from_d.*.bar" + attached, + resources=["droppable_c", "droppable_d", "droppable_no_state"], + state_paths="data_from_d.*.bar", ) attached = _attach(pipeline) - assert_dropped_resources(attached, ["droppable_c", "droppable_d"]) + assert_dropped_resources(attached, ["droppable_c", "droppable_d", "droppable_no_state"]) # Verify extra json paths are removed from state sources_state = pipeline.state["sources"] @@ -149,15 +172,38 @@ def test_drop_command_resources_and_state(destination_config: DestinationTestCon assert_destination_state_loaded(pipeline) + # now run the same droppable_source to see if tables are recreated and they contain right number of items + info = pipeline.run(source, **destination_config.run_kwargs) + assert_load_info(info) + # 2 versions (one dropped and replaced with schema with dropped tables, then we added missing tables) + # 3 loads (one for drop) + # droppable_no_state correctly replaced + # all other resources stay at the same count (they are incremental so they got loaded again or not loaded at all ie droppable_a) + assert load_table_counts(pipeline, *pipeline.default_schema.tables.keys()) == { + "_dlt_version": 2, + "_dlt_loads": 3, + "droppable_a": 2, + "droppable_b": 1, + "_dlt_pipeline_state": 3, + "droppable_b__items": 2, + "droppable_c": 1, + "droppable_d": 2, + "droppable_no_state": 3, + "droppable_c__items": 1, + "droppable_c__items__labels": 2, + } + @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, local_filesystem_configs=True), + ids=lambda x: x.name, ) def test_drop_command_only_state(destination_config: DestinationTestConfiguration) -> None: """Test drop command that deletes part of the state and syncs with destination""" source = droppable_source() pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), dev_mode=True) - pipeline.run(source) + pipeline.run(source, **destination_config.run_kwargs) attached = _attach(pipeline) helpers.drop(attached, state_paths="data_from_d.*.bar") @@ -174,13 +220,15 @@ def test_drop_command_only_state(destination_config: DestinationTestConfiguratio @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, local_filesystem_configs=True), + ids=lambda x: x.name, ) def test_drop_command_only_tables(destination_config: DestinationTestConfiguration) -> None: """Test drop only tables and makes sure that schema and state are synced""" source = droppable_source() pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), dev_mode=True) - pipeline.run(source) + pipeline.run(source, **destination_config.run_kwargs) sources_state = pipeline.state["sources"] attached = _attach(pipeline) @@ -196,13 +244,15 @@ def test_drop_command_only_tables(destination_config: DestinationTestConfigurati @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, local_filesystem_configs=True), + ids=lambda x: x.name, ) def test_drop_destination_tables_fails(destination_config: DestinationTestConfiguration) -> None: """Fail on DROP TABLES in destination init. Command runs again.""" source = droppable_source() pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), dev_mode=True) - pipeline.run(source) + pipeline.run(source, **destination_config.run_kwargs) attached = _attach(pipeline) @@ -224,13 +274,15 @@ def test_drop_destination_tables_fails(destination_config: DestinationTestConfig @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, local_filesystem_configs=True), + ids=lambda x: x.name, ) def test_fail_after_drop_tables(destination_config: DestinationTestConfiguration) -> None: """Fail directly after drop tables. Command runs again ignoring destination tables missing.""" source = droppable_source() pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), dev_mode=True) - pipeline.run(source) + pipeline.run(source, **destination_config.run_kwargs) attached = _attach(pipeline) @@ -255,13 +307,15 @@ def test_fail_after_drop_tables(destination_config: DestinationTestConfiguration @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, local_filesystem_configs=True), + ids=lambda x: x.name, ) def test_load_step_fails(destination_config: DestinationTestConfiguration) -> None: """Test idempotence. pipeline.load() fails. Command can be run again successfully""" source = droppable_source() pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), dev_mode=True) - pipeline.run(source) + pipeline.run(source, **destination_config.run_kwargs) attached = _attach(pipeline) @@ -278,12 +332,14 @@ def test_load_step_fails(destination_config: DestinationTestConfiguration) -> No @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, local_filesystem_configs=True), + ids=lambda x: x.name, ) def test_resource_regex(destination_config: DestinationTestConfiguration) -> None: source = droppable_source() pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), dev_mode=True) - pipeline.run(source) + pipeline.run(source, **destination_config.run_kwargs) attached = _attach(pipeline) @@ -296,13 +352,15 @@ def test_resource_regex(destination_config: DestinationTestConfiguration) -> Non @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, local_filesystem_configs=True), + ids=lambda x: x.name, ) def test_drop_nothing(destination_config: DestinationTestConfiguration) -> None: """No resources, no state keys. Nothing is changed.""" source = droppable_source() pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), dev_mode=True) - pipeline.run(source) + pipeline.run(source, **destination_config.run_kwargs) attached = _attach(pipeline) previous_state = dict(attached.state) @@ -320,7 +378,7 @@ def test_drop_all_flag(destination_config: DestinationTestConfiguration) -> None """Using drop_all flag. Destination dataset and all local state is deleted""" source = droppable_source() pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), dev_mode=True) - pipeline.run(source) + pipeline.run(source, **destination_config.run_kwargs) dlt_tables = [ t["name"] for t in pipeline.default_schema.dlt_tables() ] # Original _dlt tables to check for @@ -340,12 +398,14 @@ def test_drop_all_flag(destination_config: DestinationTestConfiguration) -> None @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, local_filesystem_configs=True), + ids=lambda x: x.name, ) def test_run_pipeline_after_partial_drop(destination_config: DestinationTestConfiguration) -> None: """Pipeline can be run again after dropping some resources""" pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), dev_mode=True) - pipeline.run(droppable_source()) + pipeline.run(droppable_source(), **destination_config.run_kwargs) attached = _attach(pipeline) @@ -359,12 +419,14 @@ def test_run_pipeline_after_partial_drop(destination_config: DestinationTestConf @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, local_filesystem_configs=True), + ids=lambda x: x.name, ) def test_drop_state_only(destination_config: DestinationTestConfiguration) -> None: """Pipeline can be run again after dropping some resources""" pipeline = destination_config.setup_pipeline("drop_test_" + uniq_id(), dev_mode=True) - pipeline.run(droppable_source()) + pipeline.run(droppable_source(), **destination_config.run_kwargs) attached = _attach(pipeline) diff --git a/tests/load/pipeline/test_duckdb.py b/tests/load/pipeline/test_duckdb.py index 2fa44d77c5..1129523318 100644 --- a/tests/load/pipeline/test_duckdb.py +++ b/tests/load/pipeline/test_duckdb.py @@ -13,6 +13,7 @@ from dlt.pipeline.exceptions import PipelineStepFailed from tests.cases import TABLE_UPDATE_ALL_INT_PRECISIONS, TABLE_UPDATE_ALL_TIMESTAMP_PRECISIONS +from tests.load.duckdb.test_duckdb_table_builder import add_timezone_false_on_precision from tests.load.utils import destinations_configs, DestinationTestConfiguration from tests.pipeline.utils import airtable_emojis, assert_data_table_counts, load_table_counts @@ -32,13 +33,13 @@ def test_duck_case_names(destination_config: DestinationTestConfiguration) -> No # create tables and columns with emojis and other special characters info = pipeline.run( airtable_emojis().with_resources("📆 Schedule", "🦚Peacock", "🦚WidePeacock"), - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) info.raise_on_failed_jobs() info = pipeline.run( [{"🐾Feet": 2, "1+1": "two", "\nhey": "value"}], table_name="🦚Peacocks🦚", - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) info.raise_on_failed_jobs() table_counts = load_table_counts( @@ -58,7 +59,7 @@ def test_duck_case_names(destination_config: DestinationTestConfiguration) -> No pipeline.run( [{"🐾Feet": 2, "1+1": "two", "🐾feet": "value"}], table_name="🦚peacocks🦚", - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) assert isinstance(pip_ex.value.__context__, SchemaIdentifierNormalizationCollision) assert pip_ex.value.__context__.conflict_identifier_name == "🦚Peacocks🦚" @@ -106,8 +107,10 @@ def test_duck_precision_types(destination_config: DestinationTestConfiguration) info = pipeline.run( row, table_name="row", - loader_file_format=destination_config.file_format, - columns=TABLE_UPDATE_ALL_TIMESTAMP_PRECISIONS + TABLE_UPDATE_ALL_INT_PRECISIONS, + **destination_config.run_kwargs, + columns=add_timezone_false_on_precision( + TABLE_UPDATE_ALL_TIMESTAMP_PRECISIONS + TABLE_UPDATE_ALL_INT_PRECISIONS + ), ) info.raise_on_failed_jobs() @@ -117,7 +120,7 @@ def test_duck_precision_types(destination_config: DestinationTestConfiguration) # only us has TZ aware timestamp in duckdb, also we have UTC here assert table.schema.field(0).type == pa.timestamp("s") assert table.schema.field(1).type == pa.timestamp("ms") - assert table.schema.field(2).type == pa.timestamp("us", tz="UTC") + assert table.schema.field(2).type == pa.timestamp("us") assert table.schema.field(3).type == pa.timestamp("ns") assert table.schema.field(4).type == pa.int8() @@ -129,6 +132,7 @@ def test_duck_precision_types(destination_config: DestinationTestConfiguration) table_row = table.to_pylist()[0] table_row["col1_ts"] = ensure_pendulum_datetime(table_row["col1_ts"]) table_row["col2_ts"] = ensure_pendulum_datetime(table_row["col2_ts"]) + table_row["col3_ts"] = ensure_pendulum_datetime(table_row["col3_ts"]) table_row["col4_ts"] = ensure_pendulum_datetime(table_row["col4_ts"]) table_row.pop("_dlt_id") table_row.pop("_dlt_load_id") @@ -242,7 +246,7 @@ def _get_shuffled_events(repeat: int = 1): pipeline = destination_config.setup_pipeline("test_provoke_parallel_parquet_same_table") - info = pipeline.run(_get_shuffled_events(50)) + info = pipeline.run(_get_shuffled_events(50), **destination_config.run_kwargs) info.raise_on_failed_jobs() assert_data_table_counts( pipeline, diff --git a/tests/load/pipeline/test_filesystem_pipeline.py b/tests/load/pipeline/test_filesystem_pipeline.py index bc6cbd9848..98d6cce294 100644 --- a/tests/load/pipeline/test_filesystem_pipeline.py +++ b/tests/load/pipeline/test_filesystem_pipeline.py @@ -257,7 +257,7 @@ def foo(): "destination_config", destinations_configs( table_format_filesystem_configs=True, - table_format="delta", + with_table_format="delta", bucket_exclude=(MEMORY_BUCKET), ), ids=lambda x: x.name, @@ -330,7 +330,7 @@ def data_types(): "destination_config", destinations_configs( table_format_filesystem_configs=True, - table_format="delta", + with_table_format="delta", bucket_subset=(FILE_BUCKET), ), ids=lambda x: x.name, @@ -375,7 +375,7 @@ def delta_table(): "destination_config", destinations_configs( table_format_filesystem_configs=True, - table_format="delta", + with_table_format="delta", bucket_subset=(FILE_BUCKET), ), ids=lambda x: x.name, @@ -421,7 +421,7 @@ def delta_table(): "destination_config", destinations_configs( table_format_filesystem_configs=True, - table_format="delta", + with_table_format="delta", bucket_subset=(FILE_BUCKET), ), ids=lambda x: x.name, @@ -498,7 +498,7 @@ def complex_table(): "destination_config", destinations_configs( table_format_filesystem_configs=True, - table_format="delta", + with_table_format="delta", bucket_subset=(FILE_BUCKET), ), ids=lambda x: x.name, @@ -584,7 +584,7 @@ def two_part(): "destination_config", destinations_configs( table_format_filesystem_configs=True, - table_format="delta", + with_table_format="delta", bucket_subset=(FILE_BUCKET), ), ids=lambda x: x.name, @@ -681,7 +681,7 @@ def delta_table(data): "destination_config", destinations_configs( table_format_filesystem_configs=True, - table_format="delta", + with_table_format="delta", bucket_subset=(FILE_BUCKET, AZ_BUCKET), ), ids=lambda x: x.name, @@ -755,7 +755,7 @@ def delta_table(data): "destination_config", destinations_configs( table_format_filesystem_configs=True, - table_format="delta", + with_table_format="delta", bucket_subset=(FILE_BUCKET), ), ids=lambda x: x.name, @@ -804,7 +804,7 @@ def s(): "destination_config", destinations_configs( table_format_filesystem_configs=True, - table_format="delta", + with_table_format="delta", bucket_subset=(FILE_BUCKET), ), ids=lambda x: x.name, @@ -832,7 +832,7 @@ def github_events(): "destination_config", destinations_configs( table_format_filesystem_configs=True, - table_format="delta", + with_table_format="delta", bucket_subset=(FILE_BUCKET, AZ_BUCKET), ), ids=lambda x: x.name, @@ -912,7 +912,7 @@ def parent_delta(): "destination_config", destinations_configs( table_format_filesystem_configs=True, - table_format="delta", + with_table_format="delta", bucket_subset=(FILE_BUCKET,), ), ids=lambda x: x.name, diff --git a/tests/load/pipeline/test_merge_disposition.py b/tests/load/pipeline/test_merge_disposition.py index b2197dd273..0706511cc8 100644 --- a/tests/load/pipeline/test_merge_disposition.py +++ b/tests/load/pipeline/test_merge_disposition.py @@ -41,10 +41,6 @@ AZ_BUCKET, ) -# uncomment add motherduck tests -# NOTE: the tests are passing but we disable them due to frequent ATTACH DATABASE timeouts -# ACTIVE_DESTINATIONS += ["motherduck"] - def skip_if_not_supported( merge_strategy: TLoaderMergeStrategy, @@ -78,32 +74,46 @@ def test_merge_on_keys_in_schema( skip_if_not_supported(merge_strategy, p.destination) - with open("tests/common/cases/schemas/eth/ethereum_schema_v5.yml", "r", encoding="utf-8") as f: + with open("tests/common/cases/schemas/eth/ethereum_schema_v9.yml", "r", encoding="utf-8") as f: schema = dlt.Schema.from_dict(yaml.safe_load(f)) - # make block uncles unseen to trigger filtering loader in loader for child tables + # make block uncles unseen to trigger filtering loader in loader for nested tables if has_table_seen_data(schema.tables["blocks__uncles"]): del schema.tables["blocks__uncles"]["x-normalizer"] assert not has_table_seen_data(schema.tables["blocks__uncles"]) - @dlt.resource( - table_name="blocks", - write_disposition={"disposition": "merge", "strategy": merge_strategy}, - table_format=destination_config.table_format, - ) - def data(slice_: slice = None): - with open( - "tests/normalize/cases/ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2.json", - "r", - encoding="utf-8", - ) as f: - yield json.load(f) if slice_ is None else json.load(f)[slice_] + @dlt.source(schema=schema) + def ethereum(slice_: slice = None): + @dlt.resource( + table_name="blocks", + write_disposition={"disposition": "merge", "strategy": merge_strategy}, + ) + def data(): + with open( + "tests/normalize/cases/ethereum.blocks.9c1d9b504ea240a482b007788d5cd61c_2.json", + "r", + encoding="utf-8", + ) as f: + yield json.load(f) if slice_ is None else json.load(f)[slice_] + + # also modify the child tables (not nested) + schema_ = dlt.current.source_schema() + blocks__transactions = schema_.tables["blocks__transactions"] + blocks__transactions["write_disposition"] = "merge" + blocks__transactions["x-merge-strategy"] = merge_strategy # type: ignore[typeddict-unknown-key] + blocks__transactions["table_format"] = destination_config.table_format + + blocks__transactions__logs = schema_.tables["blocks__transactions__logs"] + blocks__transactions__logs["write_disposition"] = "merge" + blocks__transactions__logs["x-merge-strategy"] = merge_strategy # type: ignore[typeddict-unknown-key] + blocks__transactions__logs["table_format"] = destination_config.table_format + + return data # take only the first block. the first block does not have uncles so this table should not be created and merged info = p.run( - data(slice(1)), - schema=schema, - loader_file_format=destination_config.file_format, + ethereum(slice(1)), + **destination_config.run_kwargs, ) assert_load_info(info) eth_1_counts = load_table_counts(p, "blocks") @@ -117,18 +127,17 @@ def data(slice_: slice = None): # now we load the whole dataset. blocks should be created which adds columns to blocks # if the table would be created before the whole load would fail because new columns have hints info = p.run( - data, - schema=schema, - loader_file_format=destination_config.file_format, + ethereum(), + **destination_config.run_kwargs, ) + assert_load_info(info) eth_2_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) # we have 2 blocks in dataset assert eth_2_counts["blocks"] == 2 if destination_config.supports_merge else 3 # make sure we have same record after merging full dataset again info = p.run( - data, - schema=schema, - loader_file_format=destination_config.file_format, + ethereum(), + **destination_config.run_kwargs, ) assert_load_info(info) # for non merge destinations we just check that the run passes @@ -163,7 +172,6 @@ def test_merge_record_updates( table_name="parent", write_disposition={"disposition": "merge", "strategy": merge_strategy}, primary_key="id", - table_format=destination_config.table_format, ) def r(data): yield data @@ -173,7 +181,7 @@ def r(data): {"id": 1, "foo": 1, "child": [{"bar": 1, "grandchild": [{"baz": 1}]}]}, {"id": 2, "foo": 1, "child": [{"bar": 1, "grandchild": [{"baz": 1}]}]}, ] - info = p.run(r(run_1)) + info = p.run(r(run_1), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, "parent", "parent__child", "parent__child__grandchild") == { "parent": 2, @@ -194,7 +202,7 @@ def r(data): {"id": 1, "foo": 2, "child": [{"bar": 1, "grandchild": [{"baz": 1}]}]}, {"id": 2, "foo": 1, "child": [{"bar": 1, "grandchild": [{"baz": 1}]}]}, ] - info = p.run(r(run_2)) + info = p.run(r(run_2), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, "parent", "parent__child", "parent__child__grandchild") == { "parent": 2, @@ -215,7 +223,7 @@ def r(data): {"id": 1, "foo": 2, "child": [{"bar": 2, "grandchild": [{"baz": 1}]}]}, {"id": 2, "foo": 1, "child": [{"bar": 1, "grandchild": [{"baz": 1}]}]}, ] - info = p.run(r(run_3)) + info = p.run(r(run_3), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, "parent", "parent__child", "parent__child__grandchild") == { "parent": 2, @@ -236,7 +244,7 @@ def r(data): {"id": 1, "foo": 2, "child": [{"bar": 2, "grandchild": [{"baz": 2}]}]}, {"id": 2, "foo": 1, "child": [{"bar": 1, "grandchild": [{"baz": 1}]}]}, ] - info = p.run(r(run_3)) + info = p.run(r(run_3), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, "parent", "parent__child", "parent__child__grandchild") == { "parent": 2, @@ -285,10 +293,7 @@ def data(slice_: slice = None): yield json.load(f) if slice_ is None else json.load(f)[slice_] # note: NodeId will be normalized to "node_id" which exists in the schema - info = p.run( - data(slice(0, 17)), - loader_file_format=destination_config.file_format, - ) + info = p.run(data(slice(0, 17)), **destination_config.run_kwargs) assert_load_info(info) github_1_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) # 17 issues @@ -298,10 +303,7 @@ def data(slice_: slice = None): assert p.default_schema.tables["issues"]["columns"]["node_id"]["data_type"] == "text" assert p.default_schema.tables["issues"]["columns"]["node_id"]["nullable"] is False - info = p.run( - data(slice(5, None)), - loader_file_format=destination_config.file_format, - ) + info = p.run(data(slice(5, None)), **destination_config.run_kwargs) assert_load_info(info) # for non merge destinations we just check that the run passes if not destination_config.supports_merge: @@ -339,7 +341,7 @@ def test_merge_source_compound_keys_and_changes( ) -> None: p = destination_config.setup_pipeline("github_3", dev_mode=True) - info = p.run(github(), loader_file_format=destination_config.file_format) + info = p.run(github(), **destination_config.run_kwargs) assert_load_info(info) github_1_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) # 100 issues total @@ -359,11 +361,7 @@ def test_merge_source_compound_keys_and_changes( ) # append load_issues resource - info = p.run( - github().load_issues, - write_disposition="append", - loader_file_format=destination_config.file_format, - ) + info = p.run(github().load_issues, write_disposition="append", **destination_config.run_kwargs) assert_load_info(info) assert p.default_schema.tables["issues"]["write_disposition"] == "append" # the counts of all tables must be double @@ -371,9 +369,7 @@ def test_merge_source_compound_keys_and_changes( assert {k: v * 2 for k, v in github_1_counts.items()} == github_2_counts # now replace all resources - info = p.run( - github(), write_disposition="replace", loader_file_format=destination_config.file_format - ) + info = p.run(github(), write_disposition="replace", **destination_config.run_kwargs) assert_load_info(info) assert p.default_schema.tables["issues"]["write_disposition"] == "replace" # assert p.default_schema.tables["issues__labels"]["write_disposition"] == "replace" @@ -383,7 +379,9 @@ def test_merge_source_compound_keys_and_changes( @pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name + "destination_config", + destinations_configs(default_sql_configs=True, supports_merge=True), + ids=lambda x: x.name, ) def test_merge_no_child_tables(destination_config: DestinationTestConfiguration) -> None: p = destination_config.setup_pipeline("github_3", dev_mode=True) @@ -398,7 +396,7 @@ def test_merge_no_child_tables(destination_config: DestinationTestConfiguration) # take only first 15 elements github_data.load_issues.add_filter(take_first(15)) - info = p.run(github_data, loader_file_format=destination_config.file_format) + info = p.run(github_data, **destination_config.run_kwargs) assert len(p.default_schema.data_tables()) == 1 assert "issues" in p.default_schema.tables assert_load_info(info) @@ -408,7 +406,7 @@ def test_merge_no_child_tables(destination_config: DestinationTestConfiguration) # load all github_data = github() github_data.max_table_nesting = 0 - info = p.run(github_data, loader_file_format=destination_config.file_format) + info = p.run(github_data, **destination_config.run_kwargs) assert_load_info(info) github_2_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) # 100 issues total, or 115 if merge is not supported @@ -432,7 +430,7 @@ def test_merge_no_merge_keys(destination_config: DestinationTestConfiguration) - github_data.load_issues.apply_hints(merge_key=(), primary_key=()) # skip first 45 rows github_data.load_issues.add_filter(skip_first(45)) - info = p.run(github_data, loader_file_format=destination_config.file_format) + info = p.run(github_data, **destination_config.run_kwargs) assert_load_info(info) github_1_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) assert github_1_counts["issues"] == 100 - 45 @@ -443,7 +441,7 @@ def test_merge_no_merge_keys(destination_config: DestinationTestConfiguration) - github_data.load_issues.apply_hints(merge_key=(), primary_key=()) # skip first 45 rows github_data.load_issues.add_filter(take_first(10)) - info = p.run(github_data, loader_file_format=destination_config.file_format) + info = p.run(github_data, **destination_config.run_kwargs) assert_load_info(info) github_1_counts = load_table_counts(p, *[t["name"] for t in p.default_schema.data_tables()]) # we have 10 rows more, merge falls back to append if no keys present @@ -452,7 +450,9 @@ def test_merge_no_merge_keys(destination_config: DestinationTestConfiguration) - @pytest.mark.parametrize( "destination_config", - destinations_configs(default_sql_configs=True, file_format="parquet"), + destinations_configs( + default_sql_configs=True, with_file_format="parquet", local_filesystem_configs=True + ), ids=lambda x: x.name, ) def test_pipeline_load_parquet(destination_config: DestinationTestConfiguration) -> None: @@ -465,7 +465,9 @@ def test_pipeline_load_parquet(destination_config: DestinationTestConfiguration) github_data_copy = github() github_data_copy.max_table_nesting = 2 info = p.run( - [github_data, github_data_copy], loader_file_format="parquet", write_disposition="merge" + [github_data, github_data_copy], + write_disposition="merge", + **destination_config.run_kwargs, ) assert_load_info(info) # make sure it was parquet or sql transforms @@ -486,11 +488,15 @@ def test_pipeline_load_parquet(destination_config: DestinationTestConfiguration) github_data = github() # generate some complex types github_data.max_table_nesting = 2 - info = p.run(github_data, loader_file_format="parquet", write_disposition="replace") + info = p.run( + github_data, + write_disposition="replace", + **destination_config.run_kwargs, + ) assert_load_info(info) # make sure it was parquet or sql inserts files = p.get_load_package_info(p.list_completed_load_packages()[1]).jobs["completed_jobs"] - if destination_config.force_iceberg: + if destination_config.destination == "athena" and destination_config.table_format == "iceberg": # iceberg uses sql to copy tables expected_formats.append("sql") assert all(f.job_file_info.file_format in expected_formats for f in files) @@ -590,7 +596,7 @@ def _updated_event(node_id): p = destination_config.setup_pipeline("github_3", dev_mode=True) info = p.run( _get_shuffled_events(True) | github_resource, - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) assert_load_info(info) # get top tables @@ -602,15 +608,13 @@ def _updated_event(node_id): # this should skip all events due to incremental load info = p.run( _get_shuffled_events(True) | github_resource, - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) # no packages were loaded assert len(info.loads_ids) == 0 # load one more event with a new id - info = p.run( - _new_event("new_node") | github_resource, loader_file_format=destination_config.file_format - ) + info = p.run(_new_event("new_node") | github_resource, **destination_config.run_kwargs) assert_load_info(info) counts = load_table_counts( p, *[t["name"] for t in p.default_schema.data_tables() if t.get("parent") is None] @@ -625,7 +629,7 @@ def _updated_event(node_id): # load updated event info = p.run( _updated_event("new_node_X") | github_resource, - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) assert_load_info(info) # still 101 @@ -656,7 +660,7 @@ def duplicates(): {"id": 1, "name": "row2", "child": [4, 5, 6]}, ] - info = p.run(duplicates(), loader_file_format=destination_config.file_format) + info = p.run(duplicates(), **destination_config.run_kwargs) assert_load_info(info) counts = load_table_counts(p, "duplicates", "duplicates__child") assert counts["duplicates"] == 1 if destination_config.supports_merge else 2 @@ -668,7 +672,7 @@ def duplicates(): def duplicates_no_child(): yield [{"id": 1, "subkey": "AX", "name": "row1"}, {"id": 1, "subkey": "AX", "name": "row2"}] - info = p.run(duplicates_no_child(), loader_file_format=destination_config.file_format) + info = p.run(duplicates_no_child(), **destination_config.run_kwargs) assert_load_info(info) counts = load_table_counts(p, "duplicates_no_child") assert counts["duplicates_no_child"] == 1 if destination_config.supports_merge else 2 @@ -687,7 +691,7 @@ def duplicates(): {"id": 1, "name": "row2", "child": [4, 5, 6]}, ] - info = p.run(duplicates(), loader_file_format=destination_config.file_format) + info = p.run(duplicates(), **destination_config.run_kwargs) assert_load_info(info) counts = load_table_counts(p, "duplicates", "duplicates__child") assert counts["duplicates"] == 2 @@ -697,7 +701,7 @@ def duplicates(): def duplicates_no_child(): yield [{"id": 1, "subkey": "AX", "name": "row1"}, {"id": 1, "subkey": "AX", "name": "row2"}] - info = p.run(duplicates_no_child(), loader_file_format=destination_config.file_format) + info = p.run(duplicates_no_child(), **destination_config.run_kwargs) assert_load_info(info) counts = load_table_counts(p, "duplicates_no_child") assert counts["duplicates_no_child"] == 2 @@ -743,7 +747,7 @@ def r(data): {"id": 1, "simple": "foo", "complex": [1, 2, 3]}, {"id": 2, "simple": "foo", "complex": [1, 2]}, ] - info = p.run(r(data), loader_file_format=destination_config.file_format) + info = p.run(r(data), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 2 assert load_table_counts(p, table_name + "__complex")[table_name + "__complex"] == 5 @@ -752,7 +756,7 @@ def r(data): data = [ {"id": 1, "simple": "bar"}, ] - info = p.run(r(data), loader_file_format=destination_config.file_format) + info = p.run(r(data), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 2 assert load_table_counts(p, table_name + "__complex")[table_name + "__complex"] == 2 @@ -800,7 +804,7 @@ def data_resource(data): {"id": 1, "val": "foo", "deleted": False}, {"id": 2, "val": "bar", "deleted": False}, ] - info = p.run(data_resource(data), loader_file_format=destination_config.file_format) + info = p.run(data_resource(data), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 2 @@ -808,7 +812,7 @@ def data_resource(data): data = [ {"id": 1, "deleted": True}, ] - info = p.run(data_resource(data), loader_file_format=destination_config.file_format) + info = p.run(data_resource(data), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == (1 if key_type != "no_key" else 2) @@ -816,7 +820,7 @@ def data_resource(data): data = [ {"id": 2, "val": "baz", "deleted": None}, ] - info = p.run(data_resource(data), loader_file_format=destination_config.file_format) + info = p.run(data_resource(data), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == (1 if key_type != "no_key" else 3) @@ -837,7 +841,7 @@ def data_resource(data): ] if merge_strategy == "upsert": del data[0] # `upsert` requires unique `primary_key` - info = p.run(data_resource(data), loader_file_format=destination_config.file_format) + info = p.run(data_resource(data), **destination_config.run_kwargs) assert_load_info(info) counts = load_table_counts(p, table_name)[table_name] if key_type == "primary_key": @@ -855,7 +859,7 @@ def data_resource(data): data = [ {"id": 3, "val": "foo", "deleted": True}, ] - info = p.run(data_resource(data), loader_file_format=destination_config.file_format) + info = p.run(data_resource(data), **destination_config.run_kwargs) assert_load_info(info) counts = load_table_counts(p, table_name)[table_name] assert load_table_counts(p, table_name)[table_name] == 1 @@ -881,7 +885,7 @@ def data_resource(data): "deleted": False, }, ] - info = p.run(data_resource(data), loader_file_format=destination_config.file_format) + info = p.run(data_resource(data), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 2 assert load_table_counts(p, table_name + "__child_1")[table_name + "__child_1"] == 3 @@ -897,7 +901,7 @@ def data_resource(data): data = [ {"id": 1, "deleted": True}, ] - info = p.run(data_resource(data), loader_file_format=destination_config.file_format) + info = p.run(data_resource(data), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 1 assert load_table_counts(p, table_name + "__child_1")[table_name + "__child_1"] == 1 @@ -912,7 +916,7 @@ def data_resource(data): data = [ {"id": 2, "deleted": True}, ] - info = p.run(data_resource(data), loader_file_format=destination_config.file_format) + info = p.run(data_resource(data), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 0 assert load_table_counts(p, table_name + "__child_1")[table_name + "__child_1"] == 0 @@ -955,7 +959,7 @@ def data_resource(data): {"id": 1, "val": "foo", "deleted_timestamp": None}, {"id": 2, "val": "bar", "deleted_timestamp": None}, ] - info = p.run(data_resource(data), loader_file_format=destination_config.file_format) + info = p.run(data_resource(data), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 2 @@ -963,7 +967,7 @@ def data_resource(data): data = [ {"id": 1, "deleted_timestamp": "2024-02-15T17:16:53Z"}, ] - info = p.run(data_resource(data), loader_file_format=destination_config.file_format) + info = p.run(data_resource(data), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 1 @@ -986,7 +990,7 @@ def r(): yield {"id": 1, "val": "foo", "deleted_1": True, "deleted_2": False} with pytest.raises(PipelineStepFailed): - info = p.run(r(), loader_file_format=destination_config.file_format) + info = p.run(r(), **destination_config.run_kwargs) @pytest.mark.essential @@ -1018,7 +1022,7 @@ def data_resource(data): {"id": 1, "val": "baz", "sequence": 3}, {"id": 1, "val": "bar", "sequence": 2}, ] - info = p.run(data_resource(data), loader_file_format=destination_config.file_format) + info = p.run(data_resource(data), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 1 @@ -1035,7 +1039,7 @@ def data_resource(data): # now test "asc" sorting data_resource.apply_hints(columns={"sequence": {"dedup_sort": "asc", "nullable": False}}) - info = p.run(data_resource(data), loader_file_format=destination_config.file_format) + info = p.run(data_resource(data), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 1 @@ -1062,7 +1066,7 @@ def data_resource(data): {"id": 1, "val": [7, 8, 9], "sequence": 3}, {"id": 1, "val": [4, 5, 6], "sequence": 2}, ] - info = p.run(data_resource(data), loader_file_format=destination_config.file_format) + info = p.run(data_resource(data), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 1 assert load_table_counts(p, table_name + "__val")[table_name + "__val"] == 3 @@ -1089,7 +1093,7 @@ def data_resource(data): {"id": 1, "val": "baz", "sequence": 3, "deleted": True}, {"id": 1, "val": "bar", "sequence": 2, "deleted": False}, ] - info = p.run(data_resource(data), loader_file_format=destination_config.file_format) + info = p.run(data_resource(data), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 0 @@ -1100,7 +1104,7 @@ def data_resource(data): {"id": 1, "val": "bar", "sequence": 2, "deleted": True}, {"id": 1, "val": "baz", "sequence": 3, "deleted": False}, ] - info = p.run(data_resource(data), loader_file_format=destination_config.file_format) + info = p.run(data_resource(data), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 1 @@ -1122,7 +1126,7 @@ def data_resource(data): {"id": 1, "val": "foo", "sequence": 1}, {"id": 1, "val": "bar", "sequence": 2, "deleted": True}, ] - info = p.run(data_resource(data), loader_file_format=destination_config.file_format) + info = p.run(data_resource(data), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 0 @@ -1132,7 +1136,7 @@ def data_resource(data): {"id": 1, "val": "foo", "sequence": 2}, {"id": 1, "val": "bar", "sequence": 1, "deleted": True}, ] - info = p.run(data_resource(data), loader_file_format=destination_config.file_format) + info = p.run(data_resource(data), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, table_name)[table_name] == 1 @@ -1147,14 +1151,14 @@ def r(): # invalid value for "dedup_sort" hint with pytest.raises(PipelineStepFailed): - info = p.run(r(), loader_file_format=destination_config.file_format) + info = p.run(r(), **destination_config.run_kwargs) # more than one "dedup_sort" column hints are provided r.apply_hints( columns={"dedup_sort_1": {"dedup_sort": "desc"}, "dedup_sort_2": {"dedup_sort": "desc"}} ) with pytest.raises(PipelineStepFailed): - info = p.run(r(), loader_file_format=destination_config.file_format) + info = p.run(r(), **destination_config.run_kwargs) def test_merge_strategy_config() -> None: @@ -1177,8 +1181,11 @@ def r(): yield {"foo": "bar"} assert "scd2" not in p.destination.capabilities().supported_merge_strategies - with pytest.raises(DestinationCapabilitiesException): + with pytest.raises(PipelineStepFailed) as pip_ex: p.run(r()) + assert pip_ex.value.step == "normalize" # failed already in normalize when generating row ids + # PipelineStepFailed -> NormalizeJobFailed -> DestinationCapabilitiesException + assert isinstance(pip_ex.value.__cause__.__cause__, DestinationCapabilitiesException) @pytest.mark.parametrize( @@ -1207,7 +1214,7 @@ def r(): p = destination_config.setup_pipeline("upsert_pipeline", dev_mode=True) assert "primary_key" not in r._hints with pytest.raises(PipelineStepFailed) as pip_ex: - p.run(r()) + p.run(r(), **destination_config.run_kwargs) assert isinstance(pip_ex.value.__context__, SchemaCorruptedException) @@ -1225,7 +1232,7 @@ def merging_test_table(): p = destination_config.setup_pipeline("abstract", full_refresh=True) with pytest.raises(PipelineStepFailed) as pip_ex: - p.run(merging_test_table()) + p.run(merging_test_table(), **destination_config.run_kwargs) ex = pip_ex.value assert ex.step == "normalize" @@ -1250,7 +1257,7 @@ def r(): p = destination_config.setup_pipeline("abstract", full_refresh=True) with pytest.raises(PipelineStepFailed) as pip_ex: - p.run(r()) + p.run(r(), **destination_config.run_kwargs) ex = pip_ex.value assert ex.step == "normalize" diff --git a/tests/load/pipeline/test_pipelines.py b/tests/load/pipeline/test_pipelines.py index 2792cec085..edc210800c 100644 --- a/tests/load/pipeline/test_pipelines.py +++ b/tests/load/pipeline/test_pipelines.py @@ -85,12 +85,12 @@ def data_fun() -> Iterator[Any]: yield data # this will create default schema - p.extract(data_fun) + p.extract(data_fun, table_format=destination_config.table_format) # _pipeline suffix removed when creating default schema name assert p.default_schema_name in ["dlt_pytest", "dlt", "dlt_jb_pytest_runner"] # this will create additional schema - p.extract(data_fun(), schema=dlt.Schema("names")) + p.extract(data_fun(), schema=dlt.Schema("names"), table_format=destination_config.table_format) assert p.default_schema_name in ["dlt_pytest", "dlt", "dlt_jb_pytest_runner"] assert "names" in p.schemas.keys() @@ -119,7 +119,7 @@ def data_fun() -> Iterator[Any]: state_package = p.get_load_package_info(last_load_id) assert len(state_package.jobs["new_jobs"]) == 1 assert state_package.schema_name == p.default_schema_name - p.normalize() + p.normalize(loader_file_format=destination_config.file_format) info = p.load(dataset_name="d" + uniq_id()) print(p.dataset_name) assert info.pipeline is p @@ -170,8 +170,13 @@ def test_default_schema_name( dataset_name=dataset_name, ) p.config.use_single_dataset = use_single_dataset - p.extract(data, table_name="test", schema=Schema("default")) - p.normalize() + p.extract( + data, + table_name="test", + schema=Schema("default"), + table_format=destination_config.table_format, + ) + p.normalize(loader_file_format=destination_config.file_format) info = p.load() print(info) @@ -210,7 +215,7 @@ def _data(): destination=destination_config.destination, staging=destination_config.staging, dataset_name="specific" + uniq_id(), - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) with pytest.raises(CannotRestorePipelineException): @@ -247,7 +252,7 @@ def _data(): yield d p = destination_config.setup_pipeline("test_skip_sync_schema_for_tables", dev_mode=True) - p.extract(_data) + p.extract(_data, table_format=destination_config.table_format) schema = p.default_schema assert "data_table" in schema.tables assert schema.tables["data_table"]["columns"] == {} @@ -286,7 +291,7 @@ def _data(): destination=destination_config.destination, staging=destination_config.staging, dataset_name="iteration" + uniq_id(), - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) assert info.dataset_name == p.dataset_name assert info.dataset_name.endswith(p._pipeline_instance_id) @@ -357,9 +362,11 @@ def extended_rows(): "my_pipeline", import_schema_path=import_schema_path, export_schema_path=export_schema_path ) - p.extract(source(10).with_resources("simple_rows")) + p.extract( + source(10).with_resources("simple_rows"), table_format=destination_config.table_format + ) # print(p.default_schema.to_pretty_yaml()) - p.normalize() + p.normalize(loader_file_format=destination_config.file_format) info = p.load(dataset_name=dataset_name) # test __str__ print(info) @@ -377,7 +384,7 @@ def extended_rows(): assert p.dataset_name == dataset_name err_info = p.run( source(1).with_resources("simple_rows"), - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) version_history.append(p.default_schema.stored_version_hash) # print(err_info) @@ -388,8 +395,7 @@ def extended_rows(): # - new column in "simple_rows" table # - new "simple" table info_ext = dlt.run( - source(10).with_resources("extended_rows", "simple"), - loader_file_format=destination_config.file_format, + source(10).with_resources("extended_rows", "simple"), **destination_config.run_kwargs ) print(info_ext) # print(p.default_schema.to_pretty_yaml()) @@ -432,14 +438,14 @@ def test_pipeline_data_writer_compression( "disable_compression": disable_compression } # not sure how else to set this p = destination_config.setup_pipeline("compression_test", dataset_name=dataset_name) - p.extract(dlt.resource(data, name="data")) + p.extract(dlt.resource(data, name="data"), table_format=destination_config.table_format) s = p._get_normalize_storage() # check that files are not compressed if compression is disabled if disable_compression: for f in s.list_files_to_normalize_sorted(): with pytest.raises(gzip.BadGzipFile): gzip.open(s.extracted_packages.storage.make_full_path(f), "rb").read() - p.normalize() + p.normalize(loader_file_format=destination_config.file_format) info = p.load() assert_table(p, "data", data, info=info) @@ -461,7 +467,7 @@ def complex_data(): destination=destination_config.destination, staging=destination_config.staging, dataset_name="ds_" + uniq_id(), - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) print(info) with dlt.pipeline().sql_client() as client: @@ -474,359 +480,11 @@ def complex_data(): assert cn_val == complex_part -@pytest.mark.parametrize( - "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name -) -def test_dataset_name_change(destination_config: DestinationTestConfiguration) -> None: - destination_config.setup() - # standard name - ds_1_name = "iteration" + uniq_id() - # will go to snake case - ds_2_name = "IteRation" + uniq_id() - # illegal name that will be later normalized - ds_3_name = "1it/era 👍 tion__" + uniq_id() - p, s = simple_nested_pipeline(destination_config, dataset_name=ds_1_name, dev_mode=False) - try: - info = p.run(s(), loader_file_format=destination_config.file_format) - assert_load_info(info) - assert info.dataset_name == ds_1_name - ds_1_counts = load_table_counts(p, "lists", "lists__value") - # run to another dataset - info = p.run(s(), dataset_name=ds_2_name, loader_file_format=destination_config.file_format) - assert_load_info(info) - assert info.dataset_name.startswith("ite_ration") - # save normalized dataset name to delete correctly later - ds_2_name = info.dataset_name - ds_2_counts = load_table_counts(p, "lists", "lists__value") - assert ds_1_counts == ds_2_counts - # set name and run to another dataset - p.dataset_name = ds_3_name - info = p.run(s(), loader_file_format=destination_config.file_format) - assert_load_info(info) - assert info.dataset_name.startswith("_1it_era_tion_") - ds_3_counts = load_table_counts(p, "lists", "lists__value") - assert ds_1_counts == ds_3_counts - - finally: - # we have to clean dataset ourselves - with p.sql_client() as client: - delete_dataset(client, ds_1_name) - delete_dataset(client, ds_2_name) - # delete_dataset(client, ds_3_name) # will be deleted by the fixture - - -# do not remove - it allows us to filter tests by destination -@pytest.mark.parametrize( - "destination_config", - destinations_configs(default_sql_configs=True, subset=["postgres"]), - ids=lambda x: x.name, -) -def test_pipeline_explicit_destination_credentials( - destination_config: DestinationTestConfiguration, -) -> None: - from dlt.destinations import postgres - from dlt.destinations.impl.postgres.configuration import PostgresCredentials - - # explicit credentials resolved - p = dlt.pipeline( - destination=Destination.from_reference( - "postgres", - destination_name="mydest", - credentials="postgresql://loader:loader@localhost:7777/dlt_data", - ), - ) - c = p._get_destination_clients(Schema("s"), p._get_destination_client_initial_config())[0] - assert c.config.credentials.port == 7777 # type: ignore[attr-defined] - - # TODO: may want to clear the env completely and ignore/mock config files somehow to avoid side effects - # explicit credentials resolved ignoring the config providers - os.environ["DESTINATION__MYDEST__CREDENTIALS__HOST"] = "HOST" - p = dlt.pipeline( - destination=Destination.from_reference( - "postgres", - destination_name="mydest", - credentials="postgresql://loader:loader@localhost:5432/dlt_data", - ), - ) - c = p._get_destination_clients(Schema("s"), p._get_destination_client_initial_config())[0] - assert c.config.credentials.host == "localhost" # type: ignore[attr-defined] - - # explicit partial credentials will use config providers - os.environ["DESTINATION__MYDEST__CREDENTIALS__USERNAME"] = "UN" - os.environ["DESTINATION__MYDEST__CREDENTIALS__PASSWORD"] = "PW" - p = dlt.pipeline( - destination=Destination.from_reference( - "postgres", - destination_name="mydest", - credentials="postgresql://localhost:5432/dlt_data", - ), - ) - c = p._get_destination_clients(Schema("s"), p._get_destination_client_initial_config())[0] - assert c.config.credentials.username == "UN" # type: ignore[attr-defined] - # host is taken form explicit credentials - assert c.config.credentials.host == "localhost" # type: ignore[attr-defined] - - # instance of credentials will be simply passed - cred = PostgresCredentials("postgresql://user:pass@localhost/dlt_data") - p = dlt.pipeline(destination=postgres(credentials=cred)) - inner_c = p.destination_client() - assert inner_c.config.credentials is cred - - # with staging - p = dlt.pipeline( - pipeline_name="postgres_pipeline", - staging=filesystem("_storage"), - destination=redshift(credentials="redshift://loader:password@localhost:5432/dlt_data"), - ) - config = p.destination_client().config - assert config.credentials.is_resolved() - assert ( - config.credentials.to_native_representation() - == "redshift://loader:password@localhost:5432/dlt_data?connect_timeout=15" - ) - - -# do not remove - it allows us to filter tests by destination -@pytest.mark.parametrize( - "destination_config", - destinations_configs(default_sql_configs=True, subset=["postgres"]), - ids=lambda x: x.name, -) -def test_pipeline_with_sources_sharing_schema( - destination_config: DestinationTestConfiguration, -) -> None: - schema = Schema("shared") - - @dlt.source(schema=schema, max_table_nesting=1) - def source_1(): - @dlt.resource(primary_key="user_id") - def gen1(): - dlt.current.source_state()["source_1"] = True - dlt.current.resource_state()["source_1"] = True - yield {"id": "Y", "user_id": "user_y"} - - @dlt.resource(columns={"col": {"data_type": "bigint"}}) - def conflict(): - yield "conflict" - - return gen1, conflict - - @dlt.source(schema=schema, max_table_nesting=2) - def source_2(): - @dlt.resource(primary_key="id") - def gen1(): - dlt.current.source_state()["source_2"] = True - dlt.current.resource_state()["source_2"] = True - yield {"id": "X", "user_id": "user_X"} - - def gen2(): - yield from "CDE" - - @dlt.resource(columns={"col": {"data_type": "bool"}}, selected=False) - def conflict(): - yield "conflict" - - return gen2, gen1, conflict - - # all selected tables with hints should be there - discover_1 = source_1().discover_schema() - assert "gen1" in discover_1.tables - assert discover_1.tables["gen1"]["columns"]["user_id"]["primary_key"] is True - assert "data_type" not in discover_1.tables["gen1"]["columns"]["user_id"] - assert "conflict" in discover_1.tables - assert discover_1.tables["conflict"]["columns"]["col"]["data_type"] == "bigint" - - discover_2 = source_2().discover_schema() - assert "gen1" in discover_2.tables - assert "gen2" in discover_2.tables - # conflict deselected - assert "conflict" not in discover_2.tables - - p = dlt.pipeline(pipeline_name="multi", destination="duckdb", dev_mode=True) - p.extract([source_1(), source_2()]) - default_schema = p.default_schema - gen1_table = default_schema.tables["gen1"] - assert "user_id" in gen1_table["columns"] - assert "id" in gen1_table["columns"] - assert "conflict" in default_schema.tables - assert "gen2" in default_schema.tables - p.normalize() - assert "gen2" in default_schema.tables - p.load() - table_names = [t["name"] for t in default_schema.data_tables()] - counts = load_table_counts(p, *table_names) - assert counts == {"gen1": 2, "gen2": 3, "conflict": 1} - # both sources share the same state - assert p.state["sources"] == { - "shared": { - "source_1": True, - "resources": {"gen1": {"source_1": True, "source_2": True}}, - "source_2": True, - } - } - drop_active_pipeline_data() - - # same pipeline but enable conflict - p = dlt.pipeline(pipeline_name="multi", destination="duckdb", dev_mode=True) - with pytest.raises(PipelineStepFailed) as py_ex: - p.extract([source_1(), source_2().with_resources("conflict")]) - assert isinstance(py_ex.value.__context__, CannotCoerceColumnException) - - -# do not remove - it allows us to filter tests by destination -@pytest.mark.parametrize( - "destination_config", - destinations_configs(default_sql_configs=True, subset=["postgres"]), - ids=lambda x: x.name, -) -def test_many_pipelines_single_dataset(destination_config: DestinationTestConfiguration) -> None: - schema = Schema("shared") - - @dlt.source(schema=schema, max_table_nesting=1) - def source_1(): - @dlt.resource(primary_key="user_id") - def gen1(): - dlt.current.source_state()["source_1"] = True - dlt.current.resource_state()["source_1"] = True - yield {"id": "Y", "user_id": "user_y"} - - return gen1 - - @dlt.source(schema=schema, max_table_nesting=2) - def source_2(): - @dlt.resource(primary_key="id") - def gen1(): - dlt.current.source_state()["source_2"] = True - dlt.current.resource_state()["source_2"] = True - yield {"id": "X", "user_id": "user_X"} - - def gen2(): - yield from "CDE" - - return gen2, gen1 - - # load source_1 to common dataset - p = dlt.pipeline( - pipeline_name="source_1_pipeline", destination="duckdb", dataset_name="shared_dataset" - ) - p.run(source_1(), credentials="duckdb:///_storage/test_quack.duckdb") - counts = load_table_counts(p, *p.default_schema.tables.keys()) - assert counts.items() >= {"gen1": 1, "_dlt_pipeline_state": 1, "_dlt_loads": 1}.items() - p._wipe_working_folder() - p.deactivate() - - p = dlt.pipeline( - pipeline_name="source_2_pipeline", destination="duckdb", dataset_name="shared_dataset" - ) - p.run(source_2(), credentials="duckdb:///_storage/test_quack.duckdb") - # table_names = [t["name"] for t in p.default_schema.data_tables()] - counts = load_table_counts(p, *p.default_schema.tables.keys()) - # gen1: one record comes from source_1, 1 record from source_2 - assert counts.items() >= {"gen1": 2, "_dlt_pipeline_state": 2, "_dlt_loads": 2}.items() - # assert counts == {'gen1': 2, 'gen2': 3} - p._wipe_working_folder() - p.deactivate() - - # restore from destination, check state - p = dlt.pipeline( - pipeline_name="source_1_pipeline", - destination=dlt.destinations.duckdb(credentials="duckdb:///_storage/test_quack.duckdb"), - dataset_name="shared_dataset", - ) - p.sync_destination() - # we have our separate state - assert p.state["sources"]["shared"] == { - "source_1": True, - "resources": {"gen1": {"source_1": True}}, - } - # but the schema was common so we have the earliest one - assert "gen2" in p.default_schema.tables - p._wipe_working_folder() - p.deactivate() - - p = dlt.pipeline( - pipeline_name="source_2_pipeline", - destination=dlt.destinations.duckdb(credentials="duckdb:///_storage/test_quack.duckdb"), - dataset_name="shared_dataset", - ) - p.sync_destination() - # we have our separate state - assert p.state["sources"]["shared"] == { - "source_2": True, - "resources": {"gen1": {"source_2": True}}, - } - - -# do not remove - it allows us to filter tests by destination -@pytest.mark.parametrize( - "destination_config", - destinations_configs(default_sql_configs=True, subset=["snowflake"]), - ids=lambda x: x.name, -) -def test_snowflake_custom_stage(destination_config: DestinationTestConfiguration) -> None: - """Using custom stage name instead of the table stage""" - os.environ["DESTINATION__SNOWFLAKE__STAGE_NAME"] = "my_non_existing_stage" - pipeline, data = simple_nested_pipeline(destination_config, f"custom_stage_{uniq_id()}", False) - info = pipeline.run(data(), loader_file_format=destination_config.file_format) - with pytest.raises(DestinationHasFailedJobs) as f_jobs: - info.raise_on_failed_jobs() - assert "MY_NON_EXISTING_STAGE" in f_jobs.value.failed_jobs[0].failed_message - - drop_active_pipeline_data() - - # NOTE: this stage must be created in DLT_DATA database for this test to pass! - # CREATE STAGE MY_CUSTOM_LOCAL_STAGE; - # GRANT READ, WRITE ON STAGE DLT_DATA.PUBLIC.MY_CUSTOM_LOCAL_STAGE TO ROLE DLT_LOADER_ROLE; - stage_name = "PUBLIC.MY_CUSTOM_LOCAL_STAGE" - os.environ["DESTINATION__SNOWFLAKE__STAGE_NAME"] = stage_name - pipeline, data = simple_nested_pipeline(destination_config, f"custom_stage_{uniq_id()}", False) - info = pipeline.run(data(), loader_file_format=destination_config.file_format) - assert_load_info(info) - - load_id = info.loads_ids[0] - - # Get a list of the staged files and verify correct number of files in the "load_id" dir - with pipeline.sql_client() as client: - staged_files = client.execute_sql(f'LIST @{stage_name}/"{load_id}"') - assert len(staged_files) == 3 - # check data of one table to ensure copy was done successfully - tbl_name = client.make_qualified_table_name("lists") - assert_query_data(pipeline, f"SELECT value FROM {tbl_name}", ["a", None, None]) - - -# do not remove - it allows us to filter tests by destination -@pytest.mark.parametrize( - "destination_config", - destinations_configs(default_sql_configs=True, subset=["snowflake"]), - ids=lambda x: x.name, -) -def test_snowflake_delete_file_after_copy(destination_config: DestinationTestConfiguration) -> None: - """Using keep_staged_files = false option to remove staged files after copy""" - os.environ["DESTINATION__SNOWFLAKE__KEEP_STAGED_FILES"] = "FALSE" - - pipeline, data = simple_nested_pipeline( - destination_config, f"delete_staged_files_{uniq_id()}", False - ) - - info = pipeline.run(data(), loader_file_format=destination_config.file_format) - assert_load_info(info) - - load_id = info.loads_ids[0] - - with pipeline.sql_client() as client: - # no files are left in table stage - stage_name = client.make_qualified_table_name("%lists") - staged_files = client.execute_sql(f'LIST @{stage_name}/"{load_id}"') - assert len(staged_files) == 0 - - # ensure copy was done - tbl_name = client.make_qualified_table_name("lists") - assert_query_data(pipeline, f"SELECT value FROM {tbl_name}", ["a", None, None]) - - @pytest.mark.parametrize( "destination_config", - destinations_configs(default_sql_configs=True, all_staging_configs=True, file_format="parquet"), + destinations_configs( + default_sql_configs=True, all_staging_configs=True, with_file_format="parquet" + ), ids=lambda x: x.name, ) def test_parquet_loading(destination_config: DestinationTestConfiguration) -> None: @@ -854,7 +512,7 @@ def other_data(): # duckdb 0.9.1 does not support TIME other than 6 if destination_config.destination in ["duckdb", "motherduck"]: - column_schemas["col11_precision"]["precision"] = 0 + column_schemas["col11_precision"]["precision"] = None # also we do not want to test col4_precision (datetime) because # those timestamps are not TZ aware in duckdb and we'd need to # disable TZ when generating parquet @@ -890,7 +548,7 @@ def my_resource(): def some_source(): return [some_data(), other_data(), my_resource()] - info = pipeline.run(some_source(), loader_file_format="parquet") + info = pipeline.run(some_source(), **destination_config.run_kwargs) package_info = pipeline.get_load_package_info(info.loads_ids[0]) # print(package_info.asstr(verbosity=2)) assert package_info.state == "loaded" @@ -901,9 +559,9 @@ def some_source(): # add sql merge job if destination_config.supports_merge: expected_completed_jobs += 1 - # add iceberg copy jobs - if destination_config.force_iceberg: - expected_completed_jobs += 3 if destination_config.supports_merge else 4 + # add iceberg copy jobs + if destination_config.destination == "athena": + expected_completed_jobs += 2 # if destination_config.supports_merge else 4 assert len(package_info.jobs["completed_jobs"]) == expected_completed_jobs with pipeline.sql_client() as sql_client: @@ -930,6 +588,47 @@ def some_source(): ) +@pytest.mark.parametrize( + "destination_config", destinations_configs(default_sql_configs=True), ids=lambda x: x.name +) +def test_dataset_name_change(destination_config: DestinationTestConfiguration) -> None: + destination_config.setup() + # standard name + ds_1_name = "iteration" + uniq_id() + # will go to snake case + ds_2_name = "IteRation" + uniq_id() + # illegal name that will be later normalized + ds_3_name = "1it/era 👍 tion__" + uniq_id() + p, s = simple_nested_pipeline(destination_config, dataset_name=ds_1_name, dev_mode=False) + try: + info = p.run(s(), **destination_config.run_kwargs) + assert_load_info(info) + assert info.dataset_name == ds_1_name + ds_1_counts = load_table_counts(p, "lists", "lists__value") + # run to another dataset + info = p.run(s(), dataset_name=ds_2_name, **destination_config.run_kwargs) + assert_load_info(info) + assert info.dataset_name.startswith("ite_ration") + # save normalized dataset name to delete correctly later + ds_2_name = info.dataset_name + ds_2_counts = load_table_counts(p, "lists", "lists__value") + assert ds_1_counts == ds_2_counts + # set name and run to another dataset + p.dataset_name = ds_3_name + info = p.run(s(), **destination_config.run_kwargs) + assert_load_info(info) + assert info.dataset_name.startswith("_1it_era_tion_") + ds_3_counts = load_table_counts(p, "lists", "lists__value") + assert ds_1_counts == ds_3_counts + + finally: + # we have to clean dataset ourselves + with p.sql_client() as client: + delete_dataset(client, ds_1_name) + delete_dataset(client, ds_2_name) + # delete_dataset(client, ds_3_name) # will be deleted by the fixture + + @pytest.mark.parametrize( "destination_config", destinations_configs(default_staging_configs=True, default_sql_configs=True), @@ -989,9 +688,7 @@ def table_3(make_data=False): # now we use this schema but load just one resource source = two_tables() # push state, table 3 not created - load_info_1 = pipeline.run( - source.table_3, schema=schema, loader_file_format=destination_config.file_format - ) + load_info_1 = pipeline.run(source.table_3, schema=schema, **destination_config.run_kwargs) assert_load_info(load_info_1) with pytest.raises(DatabaseUndefinedRelation): load_table_counts(pipeline, "table_3") @@ -1001,15 +698,13 @@ def table_3(make_data=False): ) # load with one empty job, table 3 not created - load_info = pipeline.run(source.table_3, loader_file_format=destination_config.file_format) + load_info = pipeline.run(source.table_3, **destination_config.run_kwargs) assert_load_info(load_info, expected_load_packages=0) with pytest.raises(DatabaseUndefinedRelation): load_table_counts(pipeline, "table_3") # print(pipeline.default_schema.to_pretty_yaml()) - load_info_2 = pipeline.run( - [source.table_1, source.table_3], loader_file_format=destination_config.file_format - ) + load_info_2 = pipeline.run([source.table_1, source.table_3], **destination_config.run_kwargs) assert_load_info(load_info_2) # 1 record in table 1 assert pipeline.last_trace.last_normalize_info.row_counts["table_1"] == 1 @@ -1031,7 +726,7 @@ def table_3(make_data=False): # also we make the replace resource to load its 1 record load_info_3 = pipeline.run( [source.table_3(make_data=True), source.table_2], - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) assert_load_info(load_info_3) assert_data_table_counts(pipeline, {"table_1": 1, "table_2": 1, "table_3": 1}) @@ -1050,9 +745,7 @@ def table_3(make_data=False): with pipeline.sql_client() as client: table_name = f"table_{i}" - if job_client.should_load_data_to_staging_dataset( - job_client.schema.tables[table_name] - ): + if job_client.should_load_data_to_staging_dataset(table_name): with client.with_staging_dataset(): tab_name = client.make_qualified_table_name(table_name) with client.execute_query(f"SELECT * FROM {tab_name}") as cur: @@ -1067,7 +760,7 @@ def test_query_all_info_tables_fallback(destination_config: DestinationTestConfi "parquet_test_" + uniq_id(), dataset_name="parquet_test_" + uniq_id() ) with mock.patch.object(SqlJobClientBase, "INFO_TABLES_QUERY_THRESHOLD", 0): - info = pipeline.run([1, 2, 3], table_name="digits_1") + info = pipeline.run([1, 2, 3], table_name="digits_1", **destination_config.run_kwargs) assert_load_info(info) # create empty table client: SqlJobClientBase @@ -1081,7 +774,7 @@ def test_query_all_info_tables_fallback(destination_config: DestinationTestConfi # remove it from schema del pipeline.default_schema._schema_tables["existing_table"] # store another table - info = pipeline.run([1, 2, 3], table_name="digits_2") + info = pipeline.run([1, 2, 3], table_name="digits_2", **destination_config.run_kwargs) assert_data_table_counts(pipeline, {"digits_1": 3, "digits_2": 3}) @@ -1160,7 +853,13 @@ def test_dest_column_invalid_timestamp_precision( invalid_precision = 10 @dlt.resource( - columns={"event_tstamp": {"data_type": "timestamp", "precision": invalid_precision}}, + columns={ + "event_tstamp": { + "data_type": "timestamp", + "precision": invalid_precision, + "timezone": False, + } + }, primary_key="event_id", ) def events(): @@ -1169,7 +868,7 @@ def events(): pipeline = destination_config.setup_pipeline(uniq_id()) with pytest.raises((TerminalValueError, PipelineStepFailed)): - pipeline.run(events()) + pipeline.run(events(), **destination_config.run_kwargs) @pytest.mark.parametrize( @@ -1282,7 +981,10 @@ def events_timezone_unset(): f"{destination}_" + uniq_id(), dataset_name="experiments" ) - pipeline.run([events_timezone_off(), events_timezone_on(), events_timezone_unset()]) + pipeline.run( + [events_timezone_off(), events_timezone_on(), events_timezone_unset()], + **destination_config.run_kwargs, + ) with pipeline.sql_client() as client: for t in output_map[destination]["tables"].keys(): # type: ignore diff --git a/tests/load/pipeline/test_postgres.py b/tests/load/pipeline/test_postgres.py index 5cadf701a2..c8dc0e10cc 100644 --- a/tests/load/pipeline/test_postgres.py +++ b/tests/load/pipeline/test_postgres.py @@ -4,10 +4,22 @@ from string import ascii_lowercase import pytest +import dlt +from dlt.common.destination.reference import Destination +from dlt.common.schema.exceptions import CannotCoerceColumnException +from dlt.common.schema.schema import Schema from dlt.common.utils import uniq_id -from tests.load.utils import destinations_configs, DestinationTestConfiguration -from tests.pipeline.utils import assert_load_info, load_tables_to_dicts +from dlt.destinations import filesystem, redshift + +from dlt.pipeline.exceptions import PipelineStepFailed + +from tests.load.utils import ( + destinations_configs, + DestinationTestConfiguration, + drop_active_pipeline_data, +) +from tests.pipeline.utils import assert_load_info, load_table_counts, load_tables_to_dicts from tests.utils import TestDataItemFormat @@ -44,6 +56,248 @@ def test_postgres_encoded_binary( assert data["table"][0]["hash"].tobytes() == blob +# do not remove - it allows us to filter tests by destination +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["postgres"]), + ids=lambda x: x.name, +) +def test_pipeline_explicit_destination_credentials( + destination_config: DestinationTestConfiguration, +) -> None: + from dlt.destinations import postgres + from dlt.destinations.impl.postgres.configuration import PostgresCredentials + + # explicit credentials resolved + p = dlt.pipeline( + destination=Destination.from_reference( + "postgres", + destination_name="mydest", + credentials="postgresql://loader:loader@localhost:7777/dlt_data", + ), + ) + c = p._get_destination_clients(Schema("s"), p._get_destination_client_initial_config())[0] + assert c.config.credentials.port == 7777 # type: ignore[attr-defined] + + # TODO: may want to clear the env completely and ignore/mock config files somehow to avoid side effects + # explicit credentials resolved ignoring the config providers + os.environ["DESTINATION__MYDEST__CREDENTIALS__HOST"] = "HOST" + p = dlt.pipeline( + destination=Destination.from_reference( + "postgres", + destination_name="mydest", + credentials="postgresql://loader:loader@localhost:5432/dlt_data", + ), + ) + c = p._get_destination_clients(Schema("s"), p._get_destination_client_initial_config())[0] + assert c.config.credentials.host == "localhost" # type: ignore[attr-defined] + + # explicit partial credentials will use config providers + os.environ["DESTINATION__MYDEST__CREDENTIALS__USERNAME"] = "UN" + os.environ["DESTINATION__MYDEST__CREDENTIALS__PASSWORD"] = "PW" + p = dlt.pipeline( + destination=Destination.from_reference( + "postgres", + destination_name="mydest", + credentials="postgresql://localhost:5432/dlt_data", + ), + ) + c = p._get_destination_clients(Schema("s"), p._get_destination_client_initial_config())[0] + assert c.config.credentials.username == "UN" # type: ignore[attr-defined] + # host is taken form explicit credentials + assert c.config.credentials.host == "localhost" # type: ignore[attr-defined] + + # instance of credentials will be simply passed + cred = PostgresCredentials("postgresql://user:pass@localhost/dlt_data") + p = dlt.pipeline(destination=postgres(credentials=cred)) + inner_c = p.destination_client() + assert inner_c.config.credentials is cred + + # with staging + p = dlt.pipeline( + pipeline_name="postgres_pipeline", + staging=filesystem("_storage"), + destination=redshift(credentials="redshift://loader:password@localhost:5432/dlt_data"), + ) + config = p.destination_client().config + assert config.credentials.is_resolved() + assert ( + config.credentials.to_native_representation() + == "redshift://loader:password@localhost:5432/dlt_data?connect_timeout=15" + ) + + +# do not remove - it allows us to filter tests by destination +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["postgres"]), + ids=lambda x: x.name, +) +def test_pipeline_with_sources_sharing_schema( + destination_config: DestinationTestConfiguration, +) -> None: + schema = Schema("shared") + + @dlt.source(schema=schema, max_table_nesting=1) + def source_1(): + @dlt.resource(primary_key="user_id") + def gen1(): + dlt.current.source_state()["source_1"] = True + dlt.current.resource_state()["source_1"] = True + yield {"id": "Y", "user_id": "user_y"} + + @dlt.resource(columns={"col": {"data_type": "bigint"}}) + def conflict(): + yield "conflict" + + return gen1, conflict + + @dlt.source(schema=schema, max_table_nesting=2) + def source_2(): + @dlt.resource(primary_key="id") + def gen1(): + dlt.current.source_state()["source_2"] = True + dlt.current.resource_state()["source_2"] = True + yield {"id": "X", "user_id": "user_X"} + + def gen2(): + yield from "CDE" + + @dlt.resource(columns={"col": {"data_type": "bool"}}, selected=False) + def conflict(): + yield "conflict" + + return gen2, gen1, conflict + + # all selected tables with hints should be there + discover_1 = source_1().discover_schema() + assert "gen1" in discover_1.tables + assert discover_1.tables["gen1"]["columns"]["user_id"]["primary_key"] is True + assert "data_type" not in discover_1.tables["gen1"]["columns"]["user_id"] + assert "conflict" in discover_1.tables + assert discover_1.tables["conflict"]["columns"]["col"]["data_type"] == "bigint" + + discover_2 = source_2().discover_schema() + assert "gen1" in discover_2.tables + assert "gen2" in discover_2.tables + # conflict deselected + assert "conflict" not in discover_2.tables + + p = dlt.pipeline(pipeline_name="multi", destination="duckdb", dev_mode=True) + p.extract([source_1(), source_2()], table_format=destination_config.table_format) + default_schema = p.default_schema + gen1_table = default_schema.tables["gen1"] + assert "user_id" in gen1_table["columns"] + assert "id" in gen1_table["columns"] + assert "conflict" in default_schema.tables + assert "gen2" in default_schema.tables + p.normalize(loader_file_format=destination_config.file_format) + assert "gen2" in default_schema.tables + p.load() + table_names = [t["name"] for t in default_schema.data_tables()] + counts = load_table_counts(p, *table_names) + assert counts == {"gen1": 2, "gen2": 3, "conflict": 1} + # both sources share the same state + assert p.state["sources"] == { + "shared": { + "source_1": True, + "resources": {"gen1": {"source_1": True, "source_2": True}}, + "source_2": True, + } + } + drop_active_pipeline_data() + + # same pipeline but enable conflict + p = dlt.pipeline(pipeline_name="multi", destination="duckdb", dev_mode=True) + with pytest.raises(PipelineStepFailed) as py_ex: + p.extract([source_1(), source_2().with_resources("conflict")]) + assert isinstance(py_ex.value.__context__, CannotCoerceColumnException) + + +# do not remove - it allows us to filter tests by destination +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["postgres"]), + ids=lambda x: x.name, +) +def test_many_pipelines_single_dataset(destination_config: DestinationTestConfiguration) -> None: + schema = Schema("shared") + + @dlt.source(schema=schema, max_table_nesting=1) + def source_1(): + @dlt.resource(primary_key="user_id") + def gen1(): + dlt.current.source_state()["source_1"] = True + dlt.current.resource_state()["source_1"] = True + yield {"id": "Y", "user_id": "user_y"} + + return gen1 + + @dlt.source(schema=schema, max_table_nesting=2) + def source_2(): + @dlt.resource(primary_key="id") + def gen1(): + dlt.current.source_state()["source_2"] = True + dlt.current.resource_state()["source_2"] = True + yield {"id": "X", "user_id": "user_X"} + + def gen2(): + yield from "CDE" + + return gen2, gen1 + + # load source_1 to common dataset + p = dlt.pipeline( + pipeline_name="source_1_pipeline", destination="duckdb", dataset_name="shared_dataset" + ) + p.run(source_1(), credentials="duckdb:///_storage/test_quack.duckdb") + counts = load_table_counts(p, *p.default_schema.tables.keys()) + assert counts.items() >= {"gen1": 1, "_dlt_pipeline_state": 1, "_dlt_loads": 1}.items() + p._wipe_working_folder() + p.deactivate() + + p = dlt.pipeline( + pipeline_name="source_2_pipeline", destination="duckdb", dataset_name="shared_dataset" + ) + p.run(source_2(), credentials="duckdb:///_storage/test_quack.duckdb") + # table_names = [t["name"] for t in p.default_schema.data_tables()] + counts = load_table_counts(p, *p.default_schema.tables.keys()) + # gen1: one record comes from source_1, 1 record from source_2 + assert counts.items() >= {"gen1": 2, "_dlt_pipeline_state": 2, "_dlt_loads": 2}.items() + # assert counts == {'gen1': 2, 'gen2': 3} + p._wipe_working_folder() + p.deactivate() + + # restore from destination, check state + p = dlt.pipeline( + pipeline_name="source_1_pipeline", + destination=dlt.destinations.duckdb(credentials="duckdb:///_storage/test_quack.duckdb"), + dataset_name="shared_dataset", + ) + p.sync_destination() + # we have our separate state + assert p.state["sources"]["shared"] == { + "source_1": True, + "resources": {"gen1": {"source_1": True}}, + } + # but the schema was common so we have the earliest one + assert "gen2" in p.default_schema.tables + p._wipe_working_folder() + p.deactivate() + + p = dlt.pipeline( + pipeline_name="source_2_pipeline", + destination=dlt.destinations.duckdb(credentials="duckdb:///_storage/test_quack.duckdb"), + dataset_name="shared_dataset", + ) + p.sync_destination() + # we have our separate state + assert p.state["sources"]["shared"] == { + "source_2": True, + "resources": {"gen1": {"source_2": True}}, + } + + # TODO: uncomment and finalize when we implement encoding for psycopg2 # @pytest.mark.parametrize( # "destination_config", diff --git a/tests/load/pipeline/test_redshift.py b/tests/load/pipeline/test_redshift.py index bfdc15459c..7c26ac17a9 100644 --- a/tests/load/pipeline/test_redshift.py +++ b/tests/load/pipeline/test_redshift.py @@ -32,7 +32,7 @@ def my_resource() -> Iterator[Any]: def my_source() -> Any: return my_resource - info = pipeline.run(my_source(), loader_file_format=destination_config.file_format) + info = pipeline.run(my_source(), **destination_config.run_kwargs) assert info.has_failed_jobs diff --git a/tests/load/pipeline/test_refresh_modes.py b/tests/load/pipeline/test_refresh_modes.py index f4bf3b0311..59063aacea 100644 --- a/tests/load/pipeline/test_refresh_modes.py +++ b/tests/load/pipeline/test_refresh_modes.py @@ -110,13 +110,16 @@ def test_refresh_drop_sources(destination_config: DestinationTestConfiguration): pipeline = destination_config.setup_pipeline("refresh_full_test", refresh="drop_sources") # First run pipeline so destination so tables are created - info = pipeline.run(refresh_source(first_run=True, drop_sources=True)) + info = pipeline.run( + refresh_source(first_run=True, drop_sources=True), **destination_config.run_kwargs + ) assert_load_info(info) # Second run of pipeline with only selected resources info = pipeline.run( refresh_source(first_run=False, drop_sources=True).with_resources( "some_data_1", "some_data_2" - ) + ), + **destination_config.run_kwargs, ) assert set(t["name"] for t in pipeline.default_schema.data_tables(include_incomplete=True)) == { @@ -154,7 +157,9 @@ def test_existing_schema_hash(destination_config: DestinationTestConfiguration): """ pipeline = destination_config.setup_pipeline("refresh_full_test", refresh="drop_sources") - info = pipeline.run(refresh_source(first_run=True, drop_sources=True)) + info = pipeline.run( + refresh_source(first_run=True, drop_sources=True), **destination_config.run_kwargs + ) assert_load_info(info) first_schema_hash = pipeline.default_schema.version_hash @@ -162,7 +167,8 @@ def test_existing_schema_hash(destination_config: DestinationTestConfiguration): info = pipeline.run( refresh_source(first_run=False, drop_sources=True).with_resources( "some_data_1", "some_data_2" - ) + ), + **destination_config.run_kwargs, ) # Just check the local schema @@ -173,7 +179,9 @@ def test_existing_schema_hash(destination_config: DestinationTestConfiguration): # Run again with all tables to ensure they are re-created # The new schema in this case should match the schema of the first run exactly - info = pipeline.run(refresh_source(first_run=True, drop_sources=True)) + info = pipeline.run( + refresh_source(first_run=True, drop_sources=True), **destination_config.run_kwargs + ) # Check table 3 was re-created data = load_tables_to_dicts(pipeline, "some_data_3")["some_data_3"] result = sorted([(row["id"], row["name"]) for row in data]) @@ -195,12 +203,13 @@ def test_refresh_drop_resources(destination_config: DestinationTestConfiguration # First run pipeline with load to destination so tables are created pipeline = destination_config.setup_pipeline("refresh_full_test", refresh="drop_tables") - info = pipeline.run(refresh_source(first_run=True)) + info = pipeline.run(refresh_source(first_run=True), **destination_config.run_kwargs) assert_load_info(info) # Second run of pipeline with only selected resources info = pipeline.run( - refresh_source(first_run=False).with_resources("some_data_1", "some_data_2") + refresh_source(first_run=False).with_resources("some_data_1", "some_data_2"), + **destination_config.run_kwargs, ) # Confirm resource tables not selected on second run are untouched @@ -244,7 +253,9 @@ def test_refresh_drop_data_only(destination_config: DestinationTestConfiguration # First run pipeline with load to destination so tables are created pipeline = destination_config.setup_pipeline("refresh_full_test", refresh="drop_data") - info = pipeline.run(refresh_source(first_run=True), write_disposition="append") + info = pipeline.run( + refresh_source(first_run=True), write_disposition="append", **destination_config.run_kwargs + ) assert_load_info(info) first_schema_hash = pipeline.default_schema.version_hash @@ -253,6 +264,7 @@ def test_refresh_drop_data_only(destination_config: DestinationTestConfiguration info = pipeline.run( refresh_source(first_run=False).with_resources("some_data_1", "some_data_2"), write_disposition="append", + **destination_config.run_kwargs, ) assert_load_info(info) @@ -348,11 +360,15 @@ def source_2_data_2(): # Run both sources info = pipeline.run( - [refresh_source(first_run=True, drop_sources=True), refresh_source_2(first_run=True)] + [refresh_source(first_run=True, drop_sources=True), refresh_source_2(first_run=True)], + **destination_config.run_kwargs, ) assert_load_info(info, 2) # breakpoint() - info = pipeline.run(refresh_source_2(first_run=False).with_resources("source_2_data_1")) + info = pipeline.run( + refresh_source_2(first_run=False).with_resources("source_2_data_1"), + **destination_config.run_kwargs, + ) assert_load_info(info, 2) # Check source 1 schema still has all tables @@ -394,11 +410,12 @@ def source_2_data_2(): def test_refresh_argument_to_run(destination_config: DestinationTestConfiguration): pipeline = destination_config.setup_pipeline("refresh_full_test") - info = pipeline.run(refresh_source(first_run=True)) + info = pipeline.run(refresh_source(first_run=True), **destination_config.run_kwargs) assert_load_info(info) info = pipeline.run( refresh_source(first_run=False).with_resources("some_data_3"), + **destination_config.run_kwargs, refresh="drop_sources", ) assert_load_info(info) @@ -408,7 +425,10 @@ def test_refresh_argument_to_run(destination_config: DestinationTestConfiguratio assert tables == {"some_data_3"} # Run again without refresh to confirm refresh option doesn't persist on pipeline - info = pipeline.run(refresh_source(first_run=False).with_resources("some_data_2")) + info = pipeline.run( + refresh_source(first_run=False).with_resources("some_data_2"), + **destination_config.run_kwargs, + ) assert_load_info(info) # Nothing is dropped @@ -426,11 +446,12 @@ def test_refresh_argument_to_run(destination_config: DestinationTestConfiguratio def test_refresh_argument_to_extract(destination_config: DestinationTestConfiguration): pipeline = destination_config.setup_pipeline("refresh_full_test") - info = pipeline.run(refresh_source(first_run=True)) + info = pipeline.run(refresh_source(first_run=True), **destination_config.run_kwargs) assert_load_info(info) pipeline.extract( refresh_source(first_run=False).with_resources("some_data_3"), + table_format=destination_config.table_format, refresh="drop_sources", ) @@ -439,7 +460,10 @@ def test_refresh_argument_to_extract(destination_config: DestinationTestConfigur assert tables == {"some_data_3"} # Run again without refresh to confirm refresh option doesn't persist on pipeline - pipeline.extract(refresh_source(first_run=False).with_resources("some_data_2")) + pipeline.extract( + refresh_source(first_run=False).with_resources("some_data_2"), + table_format=destination_config.table_format, + ) tables = set(t["name"] for t in pipeline.default_schema.data_tables(include_incomplete=True)) assert tables == {"some_data_2", "some_data_3"} @@ -470,7 +494,7 @@ def test_refresh_staging_dataset(destination_config: DestinationTestConfiguratio ], ) # create two tables so two tables need to be dropped - info = pipeline.run(source) + info = pipeline.run(source, **destination_config.run_kwargs) assert_load_info(info) # make data so inserting on mangled tables is not possible @@ -487,7 +511,7 @@ def test_refresh_staging_dataset(destination_config: DestinationTestConfiguratio dlt.resource(data_i, name="data_2", primary_key="id", write_disposition="append"), ], ) - info = pipeline.run(source_i, refresh="drop_resources") + info = pipeline.run(source_i, refresh="drop_resources", **destination_config.run_kwargs) assert_load_info(info) # now replace the whole source and load different tables @@ -499,7 +523,7 @@ def test_refresh_staging_dataset(destination_config: DestinationTestConfiguratio dlt.resource(data_i, name="data_2_v2", primary_key="id", write_disposition="append"), ], ) - info = pipeline.run(source_i, refresh="drop_sources") + info = pipeline.run(source_i, refresh="drop_sources", **destination_config.run_kwargs) assert_load_info(info) # tables got dropped diff --git a/tests/load/pipeline/test_replace_disposition.py b/tests/load/pipeline/test_replace_disposition.py index d49ce2904f..82cef83019 100644 --- a/tests/load/pipeline/test_replace_disposition.py +++ b/tests/load/pipeline/test_replace_disposition.py @@ -94,9 +94,7 @@ def append_items(): } # first run with offset 0 - info = pipeline.run( - [load_items, append_items], loader_file_format=destination_config.file_format - ) + info = pipeline.run([load_items, append_items], **destination_config.run_kwargs) assert_load_info(info) # count state records that got extracted state_records = increase_state_loads(info) @@ -105,9 +103,7 @@ def append_items(): # second run with higher offset so we can check the results offset = 1000 - info = pipeline.run( - [load_items, append_items], loader_file_format=destination_config.file_format - ) + info = pipeline.run([load_items, append_items], **destination_config.run_kwargs) assert_load_info(info) state_records += increase_state_loads(info) dlt_loads += 1 @@ -153,9 +149,7 @@ def load_items_none(): if False: yield - info = pipeline.run( - [load_items_none, append_items], loader_file_format=destination_config.file_format - ) + info = pipeline.run([load_items_none, append_items], **destination_config.run_kwargs) assert_load_info(info) state_records += increase_state_loads(info) dlt_loads += 1 @@ -186,9 +180,7 @@ def load_items_none(): pipeline_2 = destination_config.setup_pipeline( "test_replace_strategies_2", dataset_name=dataset_name ) - info = pipeline_2.run( - load_items, table_name="items_copy", loader_file_format=destination_config.file_format - ) + info = pipeline_2.run(load_items, table_name="items_copy", **destination_config.run_kwargs) assert_load_info(info) new_state_records = increase_state_loads(info) assert new_state_records == 1 @@ -202,7 +194,7 @@ def load_items_none(): "_dlt_pipeline_state": 1, } - info = pipeline_2.run(append_items, loader_file_format=destination_config.file_format) + info = pipeline_2.run(append_items, **destination_config.run_kwargs) assert_load_info(info) new_state_records = increase_state_loads(info) assert new_state_records == 0 @@ -321,9 +313,7 @@ def yield_empty_list(): yield [] # regular call - pipeline.run( - [items_with_subitems, static_items], loader_file_format=destination_config.file_format - ) + pipeline.run([items_with_subitems, static_items], **destination_config.run_kwargs) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) @@ -345,7 +335,7 @@ def yield_empty_list(): } # see if child table gets cleared - pipeline.run(items_without_subitems, loader_file_format=destination_config.file_format) + pipeline.run(items_without_subitems, **destination_config.run_kwargs) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) @@ -360,8 +350,8 @@ def yield_empty_list(): # see if yield none clears everything for empty_resource in [yield_none, no_yield, yield_empty_list]: - pipeline.run(items_with_subitems, loader_file_format=destination_config.file_format) - pipeline.run(empty_resource, loader_file_format=destination_config.file_format) + pipeline.run(items_with_subitems, **destination_config.run_kwargs) + pipeline.run(empty_resource, **destination_config.run_kwargs) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) @@ -375,7 +365,7 @@ def yield_empty_list(): assert pipeline.last_trace.last_normalize_info.row_counts == {"items": 0, "other_items": 0} # see if yielding something next to other none entries still goes into db - pipeline.run(items_with_subitems_yield_none, loader_file_format=destination_config.file_format) + pipeline.run(items_with_subitems_yield_none, **destination_config.run_kwargs) table_counts = load_table_counts( pipeline, *[t["name"] for t in pipeline.default_schema.data_tables()] ) diff --git a/tests/load/pipeline/test_restore_state.py b/tests/load/pipeline/test_restore_state.py index c3968e2e74..22d4fd7404 100644 --- a/tests/load/pipeline/test_restore_state.py +++ b/tests/load/pipeline/test_restore_state.py @@ -357,15 +357,15 @@ def some_data(): p.extract([data1, some_data("state2")], schema=Schema("default")) data_two = source_two("state3") - p.extract(data_two) + p.extract(data_two, table_format=destination_config.table_format) data_three = source_three("state4") - p.extract(data_three) + p.extract(data_three, table_format=destination_config.table_format) data_four = source_four() - p.extract(data_four) + p.extract(data_four, table_format=destination_config.table_format) - p.normalize() + p.normalize(loader_file_format=destination_config.file_format) p.load() # keep the orig state orig_state = p.state @@ -374,14 +374,14 @@ def some_data(): p._wipe_working_folder() os.environ["RESTORE_FROM_DESTINATION"] = "False" p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) - p.run(loader_file_format=destination_config.file_format) + p.run(**destination_config.run_kwargs) # restore was not requested so schema is empty assert p.default_schema_name is None p._wipe_working_folder() # request restore os.environ["RESTORE_FROM_DESTINATION"] = "True" p = destination_config.setup_pipeline(pipeline_name=pipeline_name, dataset_name=dataset_name) - p.run(loader_file_format=destination_config.file_format) + p.run(**destination_config.run_kwargs) assert p.default_schema_name == "default" assert set(p.schema_names) == set(["default", "two", "three", "four"]) assert p.state["sources"] == { @@ -402,7 +402,7 @@ def some_data(): p = destination_config.setup_pipeline( pipeline_name=pipeline_name, dataset_name=dataset_name, dev_mode=True ) - p.run(loader_file_format=destination_config.file_format) + p.run(**destination_config.run_kwargs) assert p.default_schema_name is None drop_active_pipeline_data() @@ -415,7 +415,7 @@ def some_data(): assert p.dataset_name == dataset_name assert p.default_schema_name is None # restore - p.run(loader_file_format=destination_config.file_format) + p.run(**destination_config.run_kwargs) assert p.default_schema_name is not None restored_state = p.state assert restored_state["_state_version"] == orig_state["_state_version"] @@ -426,7 +426,7 @@ def some_data(): ) # this will modify state, run does not sync if states are identical assert p.state["_state_version"] > orig_state["_state_version"] # print(p.state) - p.run(loader_file_format=destination_config.file_format) + p.run(**destination_config.run_kwargs) assert set(p.schema_names) == set( ["default", "two", "three", "second", "four"] ) # we keep our local copy @@ -435,7 +435,7 @@ def some_data(): state["_state_version"] -= 1 p._save_state(state) p._state_restored = False - p.run(loader_file_format=destination_config.file_format) + p.run(**destination_config.run_kwargs) assert set(p.schema_names) == set(["default", "two", "three", "four"]) @@ -458,9 +458,9 @@ def some_data(param: str) -> Any: job_client: WithStateSync # Load some complete load packages with state to the destination - p.run(some_data("state1"), loader_file_format=destination_config.file_format) - p.run(some_data("state2"), loader_file_format=destination_config.file_format) - p.run(some_data("state3"), loader_file_format=destination_config.file_format) + p.run(some_data("state1"), **destination_config.run_kwargs) + p.run(some_data("state2"), **destination_config.run_kwargs) + p.run(some_data("state3"), **destination_config.run_kwargs) with p._get_destination_clients(p.default_schema)[0] as job_client: # type: ignore[assignment] state = load_pipeline_state_from_destination(pipeline_name, job_client) @@ -472,7 +472,7 @@ def complete_package_mock(self, load_id: str, schema: Schema, aborted: bool = Fa self.load_storage.complete_load_package(load_id, aborted) with patch.object(Load, "complete_package", complete_package_mock): - p.run(some_data("fix_1"), loader_file_format=destination_config.file_format) + p.run(some_data("fix_1"), **destination_config.run_kwargs) # assert complete_package.called with p._get_destination_clients(p.default_schema)[0] as job_client: # type: ignore[assignment] @@ -523,7 +523,7 @@ def test_restore_schemas_while_import_schemas_exist( ["A", "B", "C"], table_name="labels", schema=schema, - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) # schema should be up to date normalized_labels = schema.naming.normalize_table_identifier("labels") @@ -534,9 +534,7 @@ def test_restore_schemas_while_import_schemas_exist( # re-attach the pipeline p = destination_config.attach_pipeline(pipeline_name=pipeline_name) - p.run( - ["C", "D", "E"], table_name="annotations", loader_file_format=destination_config.file_format - ) + p.run(["C", "D", "E"], table_name="annotations", **destination_config.run_kwargs) schema = p.schemas["ethereum"] assert normalized_labels in schema.tables assert normalized_annotations in schema.tables @@ -555,7 +553,7 @@ def test_restore_schemas_while_import_schemas_exist( destination=destination_config.destination, staging=destination_config.staging, dataset_name=dataset_name, - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) schema = p.schemas["ethereum"] assert normalized_labels in schema.tables @@ -564,9 +562,7 @@ def test_restore_schemas_while_import_schemas_exist( # check if attached to import schema assert schema._imported_version_hash == IMPORTED_VERSION_HASH_ETH_V9() # extract some data with restored pipeline - p.run( - ["C", "D", "E"], table_name="blacklist", loader_file_format=destination_config.file_format - ) + p.run(["C", "D", "E"], table_name="blacklist", **destination_config.run_kwargs) assert normalized_labels in schema.tables assert normalized_annotations in schema.tables assert normalized_blacklist in schema.tables @@ -608,7 +604,7 @@ def some_data(param: str) -> Any: destination=destination_config.destination, staging=destination_config.staging, dataset_name=dataset_name, - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) orig_state = p.state @@ -618,7 +614,7 @@ def some_data(param: str) -> Any: destination=destination_config.destination, staging=destination_config.staging, dataset_name=dataset_name, - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) assert production_p.default_schema_name == "default" @@ -630,7 +626,7 @@ def some_data(param: str) -> Any: # rename extract table/ data2.apply_hints(table_name="state1_data2") print("---> run production") - production_p.run(data2, loader_file_format=destination_config.file_format) + production_p.run(data2, **destination_config.run_kwargs) assert production_p.state["_state_version"] == prod_state["_state_version"] normalize = production_p.default_schema.naming.normalize_table_identifier @@ -645,7 +641,7 @@ def some_data(param: str) -> Any: data3 = some_data("state3") data3.apply_hints(table_name="state1_data2") print("---> run production") - production_p.run(data3, loader_file_format=destination_config.file_format) + production_p.run(data3, **destination_config.run_kwargs) assert production_p.state["_state_version"] > prod_state["_state_version"] # and will be detected locally # print(p.default_schema) @@ -658,14 +654,14 @@ def some_data(param: str) -> Any: # change state locally data4 = some_data("state4") data4.apply_hints(table_name="state1_data4") - p.run(data4, loader_file_format=destination_config.file_format) + p.run(data4, **destination_config.run_kwargs) # and on production in parallel data5 = some_data("state5") data5.apply_hints(table_name="state1_data5") - production_p.run(data5, loader_file_format=destination_config.file_format) + production_p.run(data5, **destination_config.run_kwargs) data6 = some_data("state6") data6.apply_hints(table_name="state1_data6") - production_p.run(data6, loader_file_format=destination_config.file_format) + production_p.run(data6, **destination_config.run_kwargs) # production state version ahead of local state version prod_state = production_p.state assert p.state["_state_version"] == prod_state["_state_version"] - 1 @@ -726,11 +722,11 @@ def some_data(param: str) -> Any: destination=destination_config.destination, staging=destination_config.staging, dataset_name=dataset_name, - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) data5 = some_data("state4") data5.apply_hints(table_name="state1_data5") - p.run(data5, schema=Schema("sch2"), loader_file_format=destination_config.file_format) + p.run(data5, schema=Schema("sch2"), **destination_config.run_kwargs) assert p.state["_state_version"] == 3 assert p.first_run is False with p.destination_client() as job_client: @@ -756,7 +752,7 @@ def some_data(param: str) -> Any: destination=destination_config.destination, staging=destination_config.staging, dataset_name=dataset_name, - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) assert p.first_run is False assert p.state["_local"]["first_run"] is False @@ -766,7 +762,7 @@ def some_data(param: str) -> Any: p.config.restore_from_destination = True data5 = some_data("state4") data5.apply_hints(table_name="state1_data5") - p.run(data5, schema=Schema("sch2"), loader_file_format=destination_config.file_format) + p.run(data5, schema=Schema("sch2"), **destination_config.run_kwargs) # the pipeline was not wiped out, the actual presence if the dataset was checked assert set(p.schema_names) == set(["sch2", "sch1"]) diff --git a/tests/load/pipeline/test_scd2.py b/tests/load/pipeline/test_scd2.py index 065da5ce94..9eac505a7f 100644 --- a/tests/load/pipeline/test_scd2.py +++ b/tests/load/pipeline/test_scd2.py @@ -134,7 +134,7 @@ def r(data): {"nk": 1, "c1": "foo", "c2": "foo" if simple else {"nc1": "foo"}}, {"nk": 2, "c1": "bar", "c2": "bar" if simple else {"nc1": "bar"}}, ] - info = p.run(r(dim_snap), loader_file_format=destination_config.file_format) + info = p.run(r(dim_snap), **destination_config.run_kwargs) assert_load_info(info) # assert x-hints table = p.default_schema.get_table("dim_test") @@ -171,7 +171,7 @@ def r(data): {"nk": 1, "c1": "foo", "c2": "foo_updated" if simple else {"nc1": "foo_updated"}}, {"nk": 2, "c1": "bar", "c2": "bar" if simple else {"nc1": "bar"}}, ] - info = p.run(r(dim_snap), loader_file_format=destination_config.file_format) + info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_2 = get_load_package_created_at(p, info) assert_load_info(info) assert get_table(p, "dim_test", cname) == [ @@ -196,7 +196,7 @@ def r(data): dim_snap = [ {"nk": 1, "c1": "foo", "c2": "foo_updated" if simple else {"nc1": "foo_updated"}}, ] - info = p.run(r(dim_snap), loader_file_format=destination_config.file_format) + info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_3 = get_load_package_created_at(p, info) assert_load_info(info) assert get_table(p, "dim_test", cname) == [ @@ -216,7 +216,7 @@ def r(data): {"nk": 1, "c1": "foo", "c2": "foo_updated" if simple else {"nc1": "foo_updated"}}, {"nk": 3, "c1": "baz", "c2": "baz" if simple else {"nc1": "baz"}}, ] - info = p.run(r(dim_snap), loader_file_format=destination_config.file_format) + info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_4 = get_load_package_created_at(p, info) assert_load_info(info) assert get_table(p, "dim_test", cname) == [ @@ -263,7 +263,7 @@ def r(data): l1_1 := {"nk": 1, "c1": "foo", "c2": [1] if simple else [{"cc1": 1}]}, l1_2 := {"nk": 2, "c1": "bar", "c2": [2, 3] if simple else [{"cc1": 2}, {"cc1": 3}]}, ] - info = p.run(r(dim_snap), loader_file_format=destination_config.file_format) + info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_1 = get_load_package_created_at(p, info) assert_load_info(info) assert get_table(p, "dim_test", "c1") == [ @@ -282,7 +282,7 @@ def r(data): l2_1 := {"nk": 1, "c1": "foo_updated", "c2": [1] if simple else [{"cc1": 1}]}, {"nk": 2, "c1": "bar", "c2": [2, 3] if simple else [{"cc1": 2}, {"cc1": 3}]}, ] - info = p.run(r(dim_snap), loader_file_format=destination_config.file_format) + info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_2 = get_load_package_created_at(p, info) assert_load_info(info) assert get_table(p, "dim_test", "c1") == [ @@ -309,7 +309,7 @@ def r(data): }, {"nk": 2, "c1": "bar", "c2": [2, 3] if simple else [{"cc1": 2}, {"cc1": 3}]}, ] - info = p.run(r(dim_snap), loader_file_format=destination_config.file_format) + info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_3 = get_load_package_created_at(p, info) assert_load_info(info) assert_records_as_set( @@ -335,7 +335,7 @@ def r(data): dim_snap = [ {"nk": 1, "c1": "foo_updated", "c2": [1, 2] if simple else [{"cc1": 1}, {"cc1": 2}]}, ] - info = p.run(r(dim_snap), loader_file_format=destination_config.file_format) + info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_4 = get_load_package_created_at(p, info) assert_load_info(info) assert_records_as_set( @@ -356,7 +356,7 @@ def r(data): {"nk": 1, "c1": "foo_updated", "c2": [1, 2] if simple else [{"cc1": 1}, {"cc1": 2}]}, l5_3 := {"nk": 3, "c1": "baz", "c2": [1, 2] if simple else [{"cc1": 1}, {"cc1": 2}]}, ] - info = p.run(r(dim_snap), loader_file_format=destination_config.file_format) + info = p.run(r(dim_snap), **destination_config.run_kwargs) ts_5 = get_load_package_created_at(p, info) assert_load_info(info) assert_records_as_set( @@ -403,7 +403,7 @@ def r(data): l1_1 := {"nk": 1, "c1": "foo", "c2": [{"cc1": [1]}]}, l1_2 := {"nk": 2, "c1": "bar", "c2": [{"cc1": [1, 2]}]}, ] - info = p.run(r(dim_snap), loader_file_format=destination_config.file_format) + info = p.run(r(dim_snap), **destination_config.run_kwargs) assert_load_info(info) assert_records_as_set( get_table(p, "dim_test__c2__cc1"), @@ -419,7 +419,7 @@ def r(data): l2_1 := {"nk": 1, "c1": "foo_updated", "c2": [{"cc1": [1]}]}, l1_2 := {"nk": 2, "c1": "bar", "c2": [{"cc1": [1, 2]}]}, ] - info = p.run(r(dim_snap), loader_file_format=destination_config.file_format) + info = p.run(r(dim_snap), **destination_config.run_kwargs) assert_load_info(info) assert_records_as_set( (get_table(p, "dim_test__c2__cc1")), @@ -436,7 +436,7 @@ def r(data): l3_1 := {"nk": 1, "c1": "foo_updated", "c2": [{"cc1": [1, 2]}]}, {"nk": 2, "c1": "bar", "c2": [{"cc1": [1, 2]}]}, ] - info = p.run(r(dim_snap), loader_file_format=destination_config.file_format) + info = p.run(r(dim_snap), **destination_config.run_kwargs) assert_load_info(info) exp_3 = [ {"_dlt_root_id": get_row_hash(l1_1), "value": 1}, @@ -452,7 +452,7 @@ def r(data): dim_snap = [ {"nk": 1, "c1": "foo_updated", "c2": [{"cc1": [1, 2]}]}, ] - info = p.run(r(dim_snap), loader_file_format=destination_config.file_format) + info = p.run(r(dim_snap), **destination_config.run_kwargs) assert_load_info(info) assert_records_as_set(get_table(p, "dim_test__c2__cc1"), exp_3) @@ -461,7 +461,7 @@ def r(data): {"nk": 1, "c1": "foo_updated", "c2": [{"cc1": [1, 2]}]}, l5_3 := {"nk": 3, "c1": "baz", "c2": [{"cc1": [1]}]}, ] - info = p.run(r(dim_snap), loader_file_format=destination_config.file_format) + info = p.run(r(dim_snap), **destination_config.run_kwargs) assert_load_info(info) assert_records_as_set( get_table(p, "dim_test__c2__cc1"), @@ -496,7 +496,7 @@ def r(data): r1 := {"nk": 1, "c1": "foo", "c2": "foo", "child": [1]}, r2 := {"nk": 2, "c1": "bar", "c2": "bar", "child": [2, 3]}, ] - info = p.run(r(dim_snap)) + info = p.run(r(dim_snap), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, "dim_test")["dim_test"] == 2 assert load_table_counts(p, "dim_test__child")["dim_test__child"] == 3 @@ -504,7 +504,7 @@ def r(data): # load 2 — delete natural key 1 dim_snap = [r2] - info = p.run(r(dim_snap)) + info = p.run(r(dim_snap), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, "dim_test")["dim_test"] == 2 assert load_table_counts(p, "dim_test__child")["dim_test__child"] == 3 @@ -512,7 +512,7 @@ def r(data): # load 3 — reinsert natural key 1 dim_snap = [r1, r2] - info = p.run(r(dim_snap)) + info = p.run(r(dim_snap), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, "dim_test")["dim_test"] == 3 assert load_table_counts(p, "dim_test__child")["dim_test__child"] == 3 # no new record @@ -557,15 +557,17 @@ def test_validity_column_name_conflict(destination_config: DestinationTestConfig def r(data): yield data - # configuring a validity column name that appears in the data should cause an exception - dim_snap = {"nk": 1, "foo": 1, "from": 1} # conflict on "from" column - with pytest.raises(PipelineStepFailed) as pip_ex: - p.run(r(dim_snap), loader_file_format=destination_config.file_format) - assert isinstance(pip_ex.value.__context__.__context__, ColumnNameConflictException) + # a schema check against an items got dropped because it was very costly and done on each row + dim_snap = {"nk": 1, "foo": 1, "from": "X"} # conflict on "from" column + p.run(r(dim_snap), **destination_config.run_kwargs) dim_snap = {"nk": 1, "foo": 1, "to": 1} # conflict on "to" column - with pytest.raises(PipelineStepFailed): - p.run(r(dim_snap), loader_file_format=destination_config.file_format) - assert isinstance(pip_ex.value.__context__.__context__, ColumnNameConflictException) + p.run(r(dim_snap), **destination_config.run_kwargs) + + # instead the variant columns got generated + dim_test_table = p.default_schema.tables["dim_test"] + assert "from__v_text" in dim_test_table["columns"] + + # but `to` column was coerced and then overwritten, this is the cost of dropping the check @pytest.mark.parametrize( @@ -610,7 +612,7 @@ def test_active_record_timestamp( def r(): yield {"foo": "bar"} - p.run(r()) + p.run(r(), **destination_config.run_kwargs) actual_active_record_timestamp = ensure_pendulum_datetime( load_tables_to_dicts(p, "dim_test")["dim_test"][0]["_dlt_valid_to"] ) @@ -648,7 +650,7 @@ def r(data): l1_1 := {"nk": 1, "foo": "foo"}, l1_2 := {"nk": 2, "foo": "foo"}, ] - info = p.run(r(dim_snap)) + info = p.run(r(dim_snap), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, "dim_test")["dim_test"] == 2 from_, to = DEFAULT_VALIDITY_COLUMN_NAMES @@ -671,7 +673,7 @@ def r(data): # l1_2, # natural key 2 no longer present l2_3 := {"nk": 3, "foo": "foo"}, # new natural key ] - info = p.run(r(dim_snap)) + info = p.run(r(dim_snap), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, "dim_test")["dim_test"] == 4 expected = [ @@ -693,7 +695,7 @@ def r(data): } ) dim_snap = [l2_1] # natural key 3 no longer present - info = p.run(r(dim_snap)) + info = p.run(r(dim_snap), **destination_config.run_kwargs) assert_load_info(info) assert load_table_counts(p, "dim_test")["dim_test"] == 4 expected = [ @@ -743,7 +745,7 @@ def _make_scd2_r(table_: Any) -> DltResource: ).add_map(add_row_hash_to_table("row_hash")) p = destination_config.setup_pipeline("abstract", dev_mode=True) - info = p.run(_make_scd2_r(table), loader_file_format=destination_config.file_format) + info = p.run(_make_scd2_r(table), **destination_config.run_kwargs) assert_load_info(info) # make sure we have scd2 columns in schema table_schema = p.default_schema.get_table("tabular") @@ -759,14 +761,14 @@ def _make_scd2_r(table_: Any) -> DltResource: if item_type == "pandas": table = orig_table orig_table = table.copy(deep=True) - info = p.run(_make_scd2_r(table), loader_file_format=destination_config.file_format) + info = p.run(_make_scd2_r(table), **destination_config.run_kwargs) assert_load_info(info) # no changes (hopefully hash is deterministic) assert load_table_counts(p, "tabular")["tabular"] == 100 # change single row orig_table.iloc[0, 0] = "Duck 🦆!" - info = p.run(_make_scd2_r(orig_table), loader_file_format=destination_config.file_format) + info = p.run(_make_scd2_r(orig_table), **destination_config.run_kwargs) assert_load_info(info) # on row changed assert load_table_counts(p, "tabular")["tabular"] == 101 @@ -796,7 +798,7 @@ def r(data): {"nk": 1, "c1": "foo", "c2": [1], "row_hash": "mocked_hash_1"}, {"nk": 2, "c1": "bar", "c2": [2, 3], "row_hash": "mocked_hash_2"}, ] - info = p.run(r(dim_snap), loader_file_format=destination_config.file_format) + info = p.run(r(dim_snap), **destination_config.run_kwargs) assert_load_info(info) ts_1 = get_load_package_created_at(p, info) table = p.default_schema.get_table("dim_test") @@ -809,7 +811,7 @@ def r(data): dim_snap = [ {"nk": 1, "c1": "foo_upd", "c2": [1], "row_hash": "mocked_hash_1_upd"}, ] - info = p.run(r(dim_snap), loader_file_format=destination_config.file_format) + info = p.run(r(dim_snap), **destination_config.run_kwargs) assert_load_info(info) ts_2 = get_load_package_created_at(p, info) diff --git a/tests/load/pipeline/test_snowflake_pipeline.py b/tests/load/pipeline/test_snowflake_pipeline.py index 0203a39147..87c6f337a1 100644 --- a/tests/load/pipeline/test_snowflake_pipeline.py +++ b/tests/load/pipeline/test_snowflake_pipeline.py @@ -1,15 +1,23 @@ +from copy import deepcopy import os import pytest from pytest_mock import MockerFixture import dlt +from dlt.common.destination.exceptions import DestinationHasFailedJobs from dlt.common.utils import uniq_id from dlt.destinations.exceptions import DatabaseUndefinedRelation +from tests.cases import assert_all_data_types_row +from tests.load.pipeline.test_pipelines import simple_nested_pipeline from tests.load.snowflake.test_snowflake_client import QUERY_TAG -from tests.pipeline.utils import assert_load_info -from tests.load.utils import destinations_configs, DestinationTestConfiguration +from tests.pipeline.utils import assert_load_info, assert_query_data +from tests.load.utils import ( + destinations_configs, + DestinationTestConfiguration, + drop_active_pipeline_data, +) # mark all tests as essential, do not remove pytestmark = pytest.mark.essential @@ -41,7 +49,9 @@ def test_snowflake_case_sensitive_identifiers( assert destination_client.capabilities.casefold_identifier is str # load some case sensitive data - info = pipeline.run([{"Id": 1, "Capital": 0.0}], table_name="Expenses") + info = pipeline.run( + [{"Id": 1, "Capital": 0.0}], table_name="Expenses", **destination_config.run_kwargs + ) assert_load_info(info) tag_query_spy.assert_not_called() with pipeline.sql_client() as client: @@ -76,6 +86,73 @@ def test_snowflake_query_tagging( os.environ["DESTINATION__SNOWFLAKE__QUERY_TAG"] = QUERY_TAG tag_query_spy = mocker.spy(SnowflakeSqlClient, "_tag_session") pipeline = destination_config.setup_pipeline("test_snowflake_case_sensitive_identifiers") - info = pipeline.run([1, 2, 3], table_name="digits") + info = pipeline.run([1, 2, 3], table_name="digits", **destination_config.run_kwargs) assert_load_info(info) assert tag_query_spy.call_count == 2 + + +# do not remove - it allows us to filter tests by destination +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["snowflake"]), + ids=lambda x: x.name, +) +def test_snowflake_custom_stage(destination_config: DestinationTestConfiguration) -> None: + """Using custom stage name instead of the table stage""" + os.environ["DESTINATION__SNOWFLAKE__STAGE_NAME"] = "my_non_existing_stage" + pipeline, data = simple_nested_pipeline(destination_config, f"custom_stage_{uniq_id()}", False) + info = pipeline.run(data(), **destination_config.run_kwargs) + with pytest.raises(DestinationHasFailedJobs) as f_jobs: + info.raise_on_failed_jobs() + assert "MY_NON_EXISTING_STAGE" in f_jobs.value.failed_jobs[0].failed_message + + drop_active_pipeline_data() + + # NOTE: this stage must be created in DLT_DATA database for this test to pass! + # CREATE STAGE MY_CUSTOM_LOCAL_STAGE; + # GRANT READ, WRITE ON STAGE DLT_DATA.PUBLIC.MY_CUSTOM_LOCAL_STAGE TO ROLE DLT_LOADER_ROLE; + stage_name = "PUBLIC.MY_CUSTOM_LOCAL_STAGE" + os.environ["DESTINATION__SNOWFLAKE__STAGE_NAME"] = stage_name + pipeline, data = simple_nested_pipeline(destination_config, f"custom_stage_{uniq_id()}", False) + info = pipeline.run(data(), **destination_config.run_kwargs) + assert_load_info(info) + + load_id = info.loads_ids[0] + + # Get a list of the staged files and verify correct number of files in the "load_id" dir + with pipeline.sql_client() as client: + staged_files = client.execute_sql(f'LIST @{stage_name}/"{load_id}"') + assert len(staged_files) == 3 + # check data of one table to ensure copy was done successfully + tbl_name = client.make_qualified_table_name("lists") + assert_query_data(pipeline, f"SELECT value FROM {tbl_name}", ["a", None, None]) + + +# do not remove - it allows us to filter tests by destination +@pytest.mark.parametrize( + "destination_config", + destinations_configs(default_sql_configs=True, subset=["snowflake"]), + ids=lambda x: x.name, +) +def test_snowflake_delete_file_after_copy(destination_config: DestinationTestConfiguration) -> None: + """Using keep_staged_files = false option to remove staged files after copy""" + os.environ["DESTINATION__SNOWFLAKE__KEEP_STAGED_FILES"] = "FALSE" + + pipeline, data = simple_nested_pipeline( + destination_config, f"delete_staged_files_{uniq_id()}", False + ) + + info = pipeline.run(data(), **destination_config.run_kwargs) + assert_load_info(info) + + load_id = info.loads_ids[0] + + with pipeline.sql_client() as client: + # no files are left in table stage + stage_name = client.make_qualified_table_name("%lists") + staged_files = client.execute_sql(f'LIST @{stage_name}/"{load_id}"') + assert len(staged_files) == 0 + + # ensure copy was done + tbl_name = client.make_qualified_table_name("lists") + assert_query_data(pipeline, f"SELECT value FROM {tbl_name}", ["a", None, None]) diff --git a/tests/load/pipeline/test_stage_loading.py b/tests/load/pipeline/test_stage_loading.py index 6c4f6dfec8..07f31fecea 100644 --- a/tests/load/pipeline/test_stage_loading.py +++ b/tests/load/pipeline/test_stage_loading.py @@ -6,7 +6,6 @@ from dlt.common.storages.configuration import FilesystemConfiguration from dlt.common.utils import uniq_id from dlt.common.schema.typing import TDataType -from dlt.destinations.impl.filesystem.filesystem import FilesystemClient from tests.load.pipeline.test_merge_disposition import github from tests.pipeline.utils import load_table_counts, assert_load_info @@ -55,7 +54,7 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: pipeline_name="test_stage_loading_5", dataset_name="test_staging_load" + uniq_id() ) - info = pipeline.run(github(), loader_file_format=destination_config.file_format) + info = pipeline.run(github(), **destination_config.run_kwargs) assert_load_info(info) # checks if remote_url is set correctly on copy jobs metrics = info.metrics[info.loads_ids[0]][0] @@ -78,9 +77,9 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: num_sql_jobs = 0 if destination_config.supports_merge: num_sql_jobs += 1 - # sql job is used to copy parquet to Athena Iceberg table (_dlt_pipeline_state) - if destination_config.destination == "athena" and destination_config.table_format == "iceberg": - num_sql_jobs += 1 + # sql job is used to copy parquet to Athena Iceberg table (_dlt_pipeline_state) + # if destination_config.destination == "athena": + # num_sql_jobs += 1 assert len(package_info.jobs["completed_jobs"]) == num_jobs + num_sql_jobs assert ( len( @@ -139,7 +138,7 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: if destination_config.supports_merge: # test merging in some changed values - info = pipeline.run(load_modified_issues, loader_file_format=destination_config.file_format) + info = pipeline.run(load_modified_issues, **destination_config.run_kwargs) assert_load_info(info) assert pipeline.default_schema.tables["issues"]["write_disposition"] == "merge" merge_counts = load_table_counts( @@ -171,7 +170,7 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: info = pipeline.run( github().load_issues, write_disposition="append", - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) assert_load_info(info) assert pipeline.default_schema.tables["issues"]["write_disposition"] == "append" @@ -185,7 +184,7 @@ def test_staging_load(destination_config: DestinationTestConfiguration) -> None: info = pipeline.run( github().load_issues, write_disposition="replace", - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) assert_load_info(info) assert pipeline.default_schema.tables["issues"]["write_disposition"] == "replace" @@ -214,12 +213,18 @@ def test_truncate_staging_dataset(destination_config: DestinationTestConfigurati table_name: str = resource.table_name # type: ignore[assignment] # load the data, files stay on the stage after the load - info = pipeline.run(resource) + info = pipeline.run( + resource, + **destination_config.run_kwargs, + ) assert_load_info(info) # load the data without truncating of the staging, should see two files on staging pipeline.destination.config_params["truncate_tables_on_staging_destination_before_load"] = False - info = pipeline.run(resource) + info = pipeline.run( + resource, + **destination_config.run_kwargs, + ) assert_load_info(info) # check there are two staging files _, staging_client = pipeline._get_destination_clients(pipeline.default_schema) @@ -239,7 +244,10 @@ def test_truncate_staging_dataset(destination_config: DestinationTestConfigurati # load the data with truncating, so only new file is on the staging pipeline.destination.config_params["truncate_tables_on_staging_destination_before_load"] = True - info = pipeline.run(resource) + info = pipeline.run( + resource, + **destination_config.run_kwargs, + ) assert_load_info(info) # check that table exists in the destination with pipeline.sql_client() as sql_client: @@ -322,7 +330,10 @@ def my_resource(): def my_source(): return my_resource - info = pipeline.run(my_source(), loader_file_format=destination_config.file_format) + info = pipeline.run( + my_source(), + **destination_config.run_kwargs, + ) assert_load_info(info) with pipeline.sql_client() as sql_client: diff --git a/tests/load/pipeline/test_write_disposition_changes.py b/tests/load/pipeline/test_write_disposition_changes.py index ba2f6bf172..1a799059d0 100644 --- a/tests/load/pipeline/test_write_disposition_changes.py +++ b/tests/load/pipeline/test_write_disposition_changes.py @@ -34,7 +34,7 @@ def test_switch_from_merge(destination_config: DestinationTestConfiguration): data_with_subtables(10), table_name="items", write_disposition="merge", - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) assert_data_table_counts(pipeline, {"items": 100, "items__sub_items": 100}) assert pipeline.default_schema._normalizers_config["json"]["config"]["propagation"]["tables"][ @@ -45,7 +45,7 @@ def test_switch_from_merge(destination_config: DestinationTestConfiguration): data_with_subtables(10), table_name="items", write_disposition="merge", - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) assert_load_info(info) assert_data_table_counts( @@ -63,7 +63,7 @@ def test_switch_from_merge(destination_config: DestinationTestConfiguration): data_with_subtables(10), table_name="items", write_disposition="append", - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) assert_load_info(info) assert_data_table_counts( @@ -81,7 +81,7 @@ def test_switch_from_merge(destination_config: DestinationTestConfiguration): data_with_subtables(10), table_name="items", write_disposition="replace", - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) assert_load_info(info) assert_data_table_counts(pipeline, {"items": 100, "items__sub_items": 100}) @@ -110,7 +110,7 @@ def source(): s, table_name="items", write_disposition="append", - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) assert_data_table_counts(pipeline, {"items": 100, "items__sub_items": 100}) @@ -137,7 +137,7 @@ def source(): s, table_name="items", write_disposition="merge", - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) return @@ -148,7 +148,7 @@ def source(): s, table_name="items", write_disposition="merge", - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) return @@ -156,7 +156,7 @@ def source(): s, table_name="items", write_disposition="merge", - loader_file_format=destination_config.file_format, + **destination_config.run_kwargs, ) assert_load_info(info) assert_data_table_counts( diff --git a/tests/load/sources/sql_database/test_sql_database_source.py b/tests/load/sources/sql_database/test_sql_database_source.py index 58382877ee..b0f7fdf1de 100644 --- a/tests/load/sources/sql_database/test_sql_database_source.py +++ b/tests/load/sources/sql_database/test_sql_database_source.py @@ -23,7 +23,7 @@ load_tables_to_dicts, ) from tests.load.sources.sql_database.test_helpers import mock_json_column -from tests.utils import data_item_length +from tests.utils import data_item_length, load_table_counts try: @@ -959,7 +959,7 @@ def dummy_source(): channel_rows = load_tables_to_dicts(pipeline, "chat_channel")["chat_channel"] assert channel_rows and all(row["active"] for row in channel_rows) - # unfiltred table loads all rows + # unfiltered table loads all rows assert_row_counts(pipeline, sql_source_db, ["chat_message"]) @@ -969,18 +969,18 @@ def assert_row_counts( tables: Optional[List[str]] = None, include_views: bool = False, ) -> None: - with pipeline.sql_client() as c: - if not tables: - tables = [ - tbl_name - for tbl_name, info in sql_source_db.table_infos.items() - if include_views or not info["is_view"] - ] - for table in tables: - info = sql_source_db.table_infos[table] - with c.execute_query(f"SELECT count(*) FROM {table}") as cur: - row = cur.fetchone() - assert row[0] == info["row_count"] + if not tables: + tables = [ + tbl_name + for tbl_name, info in sql_source_db.table_infos.items() + if include_views or not info["is_view"] + ] + dest_counts = load_table_counts(pipeline, *tables) + for table in tables: + info = sql_source_db.table_infos[table] + assert ( + dest_counts[table] == info["row_count"] + ), f"Table {table} counts do not match with the source" def assert_precision_columns( diff --git a/tests/load/synapse/test_synapse_table_builder.py b/tests/load/synapse/test_synapse_table_builder.py index 1a92a20f1e..5c80f8d7fa 100644 --- a/tests/load/synapse/test_synapse_table_builder.py +++ b/tests/load/synapse/test_synapse_table_builder.py @@ -119,7 +119,7 @@ def test_create_table_with_column_hint( # Case: table with hint, but client does not have indexes enabled. mod_update = deepcopy(TABLE_UPDATE) - mod_update[0][hint] = True # type: ignore[typeddict-unknown-key] + mod_update[0][hint] = True sql = client._get_table_update_sql("event_test_table", mod_update, False)[0] sqlfluff.parse(sql, dialect="tsql") assert f" {attr} " not in sql diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index 72c5772668..f0d0f8cdde 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -14,8 +14,8 @@ from dlt.common.destination.reference import RunnableLoadJob, TDestination from dlt.common.schema.utils import ( fill_hints_from_parent_and_clone_table, - get_child_tables, - get_top_level_table, + get_nested_tables, + get_root_table, ) from dlt.destinations.impl.filesystem.configuration import FilesystemDestinationClientConfiguration @@ -156,7 +156,7 @@ def test_get_completed_table_chain_single_job_per_table() -> None: for table_name, table in schema.tables.items(): schema.tables[table_name] = fill_hints_from_parent_and_clone_table(schema.tables, table) - top_job_table = get_top_level_table(schema.tables, "event_user") + top_job_table = get_root_table(schema.tables, "event_user") all_jobs = load.load_storage.normalized_packages.list_all_jobs_with_states(load_id) assert get_completed_table_chain(schema, all_jobs, top_job_table) is None # fake being completed @@ -172,7 +172,7 @@ def test_get_completed_table_chain_single_job_per_table() -> None: == 1 ) # actually complete - loop_top_job_table = get_top_level_table(schema.tables, "event_loop_interrupted") + loop_top_job_table = get_root_table(schema.tables, "event_loop_interrupted") load.load_storage.normalized_packages.start_job( load_id, "event_loop_interrupted.839c6e6b514e427687586ccc65bf133f.0.jsonl" ) @@ -549,7 +549,7 @@ def test_completed_loop_with_delete_completed() -> None: @pytest.mark.parametrize("to_truncate", [True, False]) -def test_truncate_table_before_load_on_stanging(to_truncate) -> None: +def test_truncate_table_before_load_on_staging(to_truncate) -> None: load = setup_loader( client_config=DummyClientConfiguration( truncate_tables_on_staging_destination_before_load=to_truncate @@ -559,7 +559,7 @@ def test_truncate_table_before_load_on_stanging(to_truncate) -> None: destination_client = load.get_destination_client(schema) assert ( destination_client.should_truncate_table_before_load_on_staging_destination( # type: ignore - schema.tables["_dlt_version"] + schema.tables["_dlt_version"]["name"] ) == to_truncate ) @@ -679,7 +679,7 @@ def test_extend_table_chain() -> None: assert tables == user_chain - {"event_user__parse_data__entities"} # exclude the whole chain tables = _extend_tables_with_table_chain( - schema, ["event_user"], ["event_user"], lambda table: table["name"] not in entities_chain + schema, ["event_user"], ["event_user"], lambda table_name: table_name not in entities_chain ) assert tables == user_chain - entities_chain # ask for tables that are not top @@ -753,7 +753,7 @@ def test_get_completed_table_chain_cases() -> None: assert chain == [event_user, event_user_entities] # merge and replace do not require whole chain to be in jobs - user_chain = get_child_tables(schema.tables, "event_user") + user_chain = get_nested_tables(schema.tables, "event_user") for w_d in ["merge", "replace"]: event_user["write_disposition"] = w_d # type:ignore[typeddict-item] @@ -848,11 +848,17 @@ def test_init_client_truncate_tables() -> None: "event_bot", } - replace_ = lambda table: table["write_disposition"] == "replace" - merge_ = lambda table: table["write_disposition"] == "merge" + replace_ = ( + lambda table_name: client.prepare_load_table(table_name)["write_disposition"] + == "replace" + ) + merge_ = ( + lambda table_name: client.prepare_load_table(table_name)["write_disposition"] + == "merge" + ) # set event_bot chain to merge - bot_chain = get_child_tables(schema.tables, "event_bot") + bot_chain = get_nested_tables(schema.tables, "event_bot") for w_d in ["merge", "replace"]: initialize_storage.reset_mock() update_stored_schema.reset_mock() @@ -1062,7 +1068,7 @@ def setup_loader( staging_system_config = FilesystemDestinationClientConfiguration()._bind_dataset_name( dataset_name="dummy" ) - staging_system_config.as_staging = True + staging_system_config.as_staging_destination = True os.makedirs(REMOTE_FILESYSTEM) staging = filesystem(bucket_url=REMOTE_FILESYSTEM) # patch destination to provide client_config diff --git a/tests/load/test_job_client.py b/tests/load/test_job_client.py index 06b70a49da..7af7588789 100644 --- a/tests/load/test_job_client.py +++ b/tests/load/test_job_client.py @@ -576,7 +576,7 @@ def test_load_with_all_types( client.schema._bump_version() client.update_stored_schema() - should_load_to_staging = client.should_load_data_to_staging_dataset(client.schema.tables[table_name]) # type: ignore[attr-defined] + should_load_to_staging = client.should_load_data_to_staging_dataset(table_name) # type: ignore[attr-defined] if should_load_to_staging: with client.with_staging_dataset(): # type: ignore[attr-defined] # create staging for merge dataset @@ -665,7 +665,7 @@ def test_write_dispositions( with io.BytesIO() as f: write_dataset(client, f, [data_row], column_schemas) query = f.getvalue() - if client.should_load_data_to_staging_dataset(client.schema.tables[table_name]): # type: ignore[attr-defined] + if client.should_load_data_to_staging_dataset(table_name): # type: ignore[attr-defined] # load to staging dataset on merge with client.with_staging_dataset(): # type: ignore[attr-defined] expect_load_file(client, file_storage, query, t) @@ -722,7 +722,7 @@ def test_get_resumed_job(client: SqlJobClientBase, file_storage: FileStorage) -> # now try to retrieve the job # TODO: we should re-create client instance as this call is intended to be run after some disruption ie. stopped loader process r_job = client.create_load_job( - client.schema.get_table(user_table_name), + client.prepare_load_table(user_table_name), file_storage.make_full_path(job.file_name()), uniq_id(), restore=True, diff --git a/tests/load/utils.py b/tests/load/utils.py index 5427904d52..0eaf68d8f8 100644 --- a/tests/load/utils.py +++ b/tests/load/utils.py @@ -53,6 +53,7 @@ from dlt.destinations.sql_client import SqlClientBase from dlt.destinations.job_client_impl import SqlJobClientBase +from dlt.pipeline.exceptions import SqlClientNotAvailable from tests.utils import ( ACTIVE_DESTINATIONS, IMPLEMENTED_DESTINATIONS, @@ -153,7 +154,7 @@ class DestinationTestConfiguration: staging_use_msi: bool = False extra_info: Optional[str] = None supports_merge: bool = True # TODO: take it from client base class - force_iceberg: bool = False + force_iceberg: bool = None # used only to test deprecation supports_dbt: bool = True disable_compression: bool = False dev_mode: bool = False @@ -228,6 +229,19 @@ def attach_pipeline(self, pipeline_name: str, **kwargs) -> dlt.Pipeline: pipeline = dlt.attach(pipeline_name, **kwargs) return pipeline + def supports_sql_client(self, pipeline: dlt.Pipeline) -> bool: + """Checks if destination supports SQL queries""" + try: + pipeline.sql_client() + return True + except SqlClientNotAvailable: + return False + + @property + def run_kwargs(self): + """Returns a dict of kwargs to be passed to pipeline.run method: currently file and table format""" + return dict(loader_file_format=self.file_format, table_format=self.table_format) + def destinations_configs( default_sql_configs: bool = False, @@ -241,11 +255,10 @@ def destinations_configs( bucket_subset: Sequence[str] = (), exclude: Sequence[str] = (), bucket_exclude: Sequence[str] = (), - file_format: Union[TLoaderFileFormat, Sequence[TLoaderFileFormat]] = None, - table_format: Union[TTableFormat, Sequence[TTableFormat]] = None, + with_file_format: Union[TLoaderFileFormat, Sequence[TLoaderFileFormat]] = None, + with_table_format: Union[TTableFormat, Sequence[TTableFormat]] = None, supports_merge: Optional[bool] = None, supports_dbt: Optional[bool] = None, - force_iceberg: Optional[bool] = None, ) -> List[DestinationTestConfiguration]: # sanity check for item in subset: @@ -270,7 +283,6 @@ def destinations_configs( destination="athena", file_format="parquet", bucket_url=AWS_BUCKET, - force_iceberg=True, supports_merge=True, supports_dbt=False, table_format="iceberg", @@ -587,21 +599,21 @@ def destinations_configs( for conf in destination_configs if conf.destination != "filesystem" or conf.bucket_url not in bucket_exclude ] - if file_format: - if not isinstance(file_format, Sequence): - file_format = [file_format] + if with_file_format: + if not isinstance(with_file_format, Sequence): + with_file_format = [with_file_format] destination_configs = [ conf for conf in destination_configs - if conf.file_format and conf.file_format in file_format + if conf.file_format and conf.file_format in with_file_format ] - if table_format: - if not isinstance(table_format, Sequence): - table_format = [table_format] + if with_table_format: + if not isinstance(with_table_format, Sequence): + with_table_format = [with_table_format] destination_configs = [ conf for conf in destination_configs - if conf.table_format and conf.table_format in table_format + if conf.table_format and conf.table_format in with_table_format ] if supports_merge is not None: destination_configs = [ @@ -617,11 +629,6 @@ def destinations_configs( conf for conf in destination_configs if conf.name not in EXCLUDED_DESTINATION_CONFIGURATIONS ] - if force_iceberg is not None: - destination_configs = [ - conf for conf in destination_configs if conf.force_iceberg is force_iceberg - ] - # add marks destination_configs = [ cast( @@ -759,6 +766,8 @@ def prepare_table( else: user_table_name = table_name client.schema.update_table(new_table(user_table_name, columns=list(user_table.values()))) + print(client.schema.to_pretty_yaml()) + client.verify_schema([user_table_name]) client.schema._bump_version() client.update_stored_schema() return user_table_name diff --git a/tests/load/weaviate/test_weaviate_client.py b/tests/load/weaviate/test_weaviate_client.py index 0a249db0fd..b7851f271d 100644 --- a/tests/load/weaviate/test_weaviate_client.py +++ b/tests/load/weaviate/test_weaviate_client.py @@ -121,6 +121,7 @@ def test_case_sensitive_properties_create(client: WeaviateClient) -> None: ) client.schema._bump_version() with pytest.raises(SchemaIdentifierNormalizationCollision) as clash_ex: + client.verify_schema() client.update_stored_schema() assert clash_ex.value.identifier_type == "column" assert clash_ex.value.identifier_name == "coL1" @@ -170,6 +171,7 @@ def test_case_sensitive_properties_add(client: WeaviateClient) -> None: ) client.schema._bump_version() with pytest.raises(SchemaIdentifierNormalizationCollision): + client.verify_schema() client.update_stored_schema() # _, table_columns = client.get_storage_table("ColClass") diff --git a/tests/normalize/test_max_nesting.py b/tests/normalize/test_max_nesting.py index 5def1617dc..ec44e1c4db 100644 --- a/tests/normalize/test_max_nesting.py +++ b/tests/normalize/test_max_nesting.py @@ -8,7 +8,7 @@ from tests.common.utils import json_case_path -TOP_LEVEL_TABLES = ["bot_events"] +ROOT_TABLES = ["bot_events"] ALL_TABLES_FOR_RASA_EVENT = [ "bot_events", @@ -37,8 +37,8 @@ def rasa_event_bot_metadata(): @pytest.mark.parametrize( "nesting_level,expected_num_tables,expected_table_names", ( - (0, 1, TOP_LEVEL_TABLES), - (1, 1, TOP_LEVEL_TABLES), + (0, 1, ROOT_TABLES), + (1, 1, ROOT_TABLES), (2, 3, ALL_TABLES_FOR_RASA_EVENT_NESTING_LEVEL_2), (5, 8, ALL_TABLES_FOR_RASA_EVENT), (15, 8, ALL_TABLES_FOR_RASA_EVENT), diff --git a/tests/pipeline/test_dlt_versions.py b/tests/pipeline/test_dlt_versions.py index 319055184a..ea8ed4550c 100644 --- a/tests/pipeline/test_dlt_versions.py +++ b/tests/pipeline/test_dlt_versions.py @@ -45,7 +45,7 @@ def test_simulate_default_naming_convention_change() -> None: # mock the mod # from dlt.common.normalizers import utils - with patch("dlt.common.normalizers.utils.DEFAULT_NAMING_MODULE", "duck_case"): + with patch("dlt.common.schema.normalizers.DEFAULT_NAMING_MODULE", "duck_case"): duck_pipeline = dlt.pipeline("simulated_duck_case", destination="duckdb") assert duck_pipeline.naming.name() == "duck_case" print(airtable_emojis().schema.naming.name()) diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 535d5d28e4..be8d274eb0 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -398,7 +398,7 @@ def test_destination_staging_config(environment: Any) -> None: staging_config = fs_dest.configuration(initial_config) # type: ignore[arg-type] # Ensure that as_staging flag is set in the final resolved conifg - assert staging_config.as_staging is True + assert staging_config.as_staging_destination is True def test_destination_factory_defaults_resolve_from_config(environment: Any) -> None: diff --git a/tests/pipeline/utils.py b/tests/pipeline/utils.py index d605fa9893..4fb9e2215b 100644 --- a/tests/pipeline/utils.py +++ b/tests/pipeline/utils.py @@ -6,6 +6,7 @@ import dlt from dlt.common import json, sleep +from dlt.common.configuration.utils import auto_cast from dlt.common.data_types import py_type_to_sc_type from dlt.common.pipeline import LoadInfo from dlt.common.schema.utils import get_table_format @@ -147,7 +148,7 @@ def _load_file(client: FSClientBase, filepath) -> List[Dict[str, Any]]: cols = lines[0][15:-2].split(",") for line in lines[2:]: if line: - values = line[1:-3].split(",") + values = map(auto_cast, line[1:-3].split(",")) result.append(dict(zip(cols, values))) # load parquet diff --git a/tests/sources/rest_api/test_rest_api_pipeline_template.py b/tests/sources/rest_api/test_rest_api_pipeline_template.py index ef30b63a7f..cd5cca0b10 100644 --- a/tests/sources/rest_api/test_rest_api_pipeline_template.py +++ b/tests/sources/rest_api/test_rest_api_pipeline_template.py @@ -1,9 +1,9 @@ +import os import dlt import pytest from dlt.common.typing import TSecretStrValue -# NOTE: needs github secrets to work @pytest.mark.parametrize( "example_name", ( @@ -16,5 +16,8 @@ def test_all_examples(example_name: str) -> None: # reroute token location from secrets github_token: TSecretStrValue = dlt.secrets.get("sources.github.access_token") + if not github_token: + # try to get GITHUB TOKEN which is available on github actions, fallback to None if not available + github_token = os.environ.get("GITHUB_TOKEN", None) # type: ignore dlt.secrets["sources.rest_api_pipeline.github.access_token"] = github_token getattr(rest_api_pipeline, example_name)()