diff --git a/dlt/common/__init__.py b/dlt/common/__init__.py index 466fd7c546..0c8a09ec3e 100644 --- a/dlt/common/__init__.py +++ b/dlt/common/__init__.py @@ -1,8 +1,8 @@ +from dlt.common import logger from dlt.common.arithmetics import Decimal from dlt.common.wei import Wei from dlt.common.pendulum import pendulum from dlt.common.json import json from dlt.common.runtime.signals import sleep -from dlt.common.runtime import logger __all__ = ["Decimal", "Wei", "pendulum", "json", "sleep", "logger"] diff --git a/dlt/common/configuration/resolve.py b/dlt/common/configuration/resolve.py index b398f0463a..ebfa7b6b89 100644 --- a/dlt/common/configuration/resolve.py +++ b/dlt/common/configuration/resolve.py @@ -76,7 +76,7 @@ def initialize_credentials(hint: Any, initial_value: Any) -> CredentialsConfigur first_credentials: CredentialsConfiguration = None for idx, spec in enumerate(specs_in_union): try: - credentials = spec(initial_value) + credentials = spec.from_init_value(initial_value) if credentials.is_resolved(): return credentials # keep first credentials in the union to return in case all of the match but not resolve @@ -88,7 +88,7 @@ def initialize_credentials(hint: Any, initial_value: Any) -> CredentialsConfigur return first_credentials else: assert issubclass(hint, CredentialsConfiguration) - return hint(initial_value) # type: ignore + return hint.from_init_value(initial_value) # type: ignore def inject_section( diff --git a/dlt/common/configuration/specs/api_credentials.py b/dlt/common/configuration/specs/api_credentials.py index fd7ae8cb09..918cd4ee45 100644 --- a/dlt/common/configuration/specs/api_credentials.py +++ b/dlt/common/configuration/specs/api_credentials.py @@ -6,9 +6,9 @@ @configspec class OAuth2Credentials(CredentialsConfiguration): - client_id: str - client_secret: TSecretValue - refresh_token: Optional[TSecretValue] + client_id: str = None + client_secret: TSecretValue = None + refresh_token: Optional[TSecretValue] = None scopes: Optional[List[str]] = None token: Optional[TSecretValue] = None diff --git a/dlt/common/configuration/specs/aws_credentials.py b/dlt/common/configuration/specs/aws_credentials.py index ee7360e2cb..ee49e79e40 100644 --- a/dlt/common/configuration/specs/aws_credentials.py +++ b/dlt/common/configuration/specs/aws_credentials.py @@ -121,3 +121,9 @@ def parse_native_representation(self, native_value: Any) -> None: self.__is_resolved__ = True except Exception: raise InvalidBoto3Session(self.__class__, native_value) + + @classmethod + def from_session(cls, botocore_session: Any) -> "AwsCredentials": + self = cls() + self.parse_native_representation(botocore_session) + return self diff --git a/dlt/common/configuration/specs/base_configuration.py b/dlt/common/configuration/specs/base_configuration.py index 62abf42f27..06fb97fcdd 100644 --- a/dlt/common/configuration/specs/base_configuration.py +++ b/dlt/common/configuration/specs/base_configuration.py @@ -2,6 +2,7 @@ import inspect import contextlib import dataclasses +import warnings from collections.abc import Mapping as C_Mapping from typing import ( @@ -19,7 +20,7 @@ ClassVar, TypeVar, ) -from typing_extensions import get_args, get_origin +from typing_extensions import get_args, get_origin, dataclass_transform from functools import wraps if TYPE_CHECKING: @@ -44,6 +45,7 @@ _F_BaseConfiguration: Any = type(object) _F_ContainerInjectableContext: Any = type(object) _T = TypeVar("_T", bound="BaseConfiguration") +_C = TypeVar("_C", bound="CredentialsConfiguration") def is_base_configuration_inner_hint(inner_hint: Type[Any]) -> bool: @@ -106,18 +108,26 @@ def is_secret_hint(hint: Type[Any]) -> bool: @overload -def configspec(cls: Type[TAnyClass]) -> Type[TAnyClass]: ... +def configspec(cls: Type[TAnyClass], init: bool = True) -> Type[TAnyClass]: ... @overload -def configspec(cls: None = ...) -> Callable[[Type[TAnyClass]], Type[TAnyClass]]: ... +def configspec( + cls: None = ..., init: bool = True +) -> Callable[[Type[TAnyClass]], Type[TAnyClass]]: ... +@dataclass_transform(eq_default=False, field_specifiers=(dataclasses.Field, dataclasses.field)) def configspec( - cls: Optional[Type[Any]] = None, + cls: Optional[Type[Any]] = None, init: bool = True ) -> Union[Type[TAnyClass], Callable[[Type[TAnyClass]], Type[TAnyClass]]]: """Converts (via derivation) any decorated class to a Python dataclass that may be used as a spec to resolve configurations + __init__ method is synthesized by default. `init` flag is ignored if the decorated class implements custom __init__ as well as + when any of base classes has no synthesized __init__ + + All fields must have default values. This decorator will add `None` default values that miss one. + In comparison the Python dataclass, a spec implements full dictionary interface for its attributes, allows instance creation from ie. strings or other types (parsing, deserialization) and control over configuration resolution process. See `BaseConfiguration` and CredentialsConfiguration` for more information. @@ -142,6 +152,10 @@ def wrap(cls: Type[TAnyClass]) -> Type[TAnyClass]: # get all annotations without corresponding attributes and set them to None for ann in cls.__annotations__: if not hasattr(cls, ann) and not ann.startswith(("__", "_abc_")): + warnings.warn( + f"Missing default value for field {ann} on {cls.__name__}. None assumed. All" + " fields in configspec must have default." + ) setattr(cls, ann, None) # get all attributes without corresponding annotations for att_name, att_value in list(cls.__dict__.items()): @@ -177,17 +191,18 @@ def default_factory(att_value=att_value): # type: ignore[no-untyped-def] # We don't want to overwrite user's __init__ method # Create dataclass init only when not defined in the class - # (never put init on BaseConfiguration itself) - try: - is_base = cls is BaseConfiguration - except NameError: - is_base = True - init = False - base_params = getattr(cls, "__dataclass_params__", None) - if not is_base and (base_params and base_params.init or cls.__init__ is object.__init__): - init = True + # NOTE: any class without synthesized __init__ breaks the creation chain + has_default_init = super(cls, cls).__init__ == cls.__init__ # type: ignore[misc] + base_params = getattr(cls, "__dataclass_params__", None) # cls.__init__ is object.__init__ + synth_init = init and ((not base_params or base_params.init) and has_default_init) + if synth_init != init and has_default_init: + warnings.warn( + f"__init__ method will not be generated on {cls.__name__} because bas class didn't" + " synthesize __init__. Please correct `init` flag in confispec decorator. You are" + " probably receiving incorrect __init__ signature for type checking" + ) # do not generate repr as it may contain secret values - return dataclasses.dataclass(cls, init=init, eq=False, repr=False) # type: ignore + return dataclasses.dataclass(cls, init=synth_init, eq=False, repr=False) # type: ignore # called with parenthesis if cls is None: @@ -198,12 +213,14 @@ def default_factory(att_value=att_value): # type: ignore[no-untyped-def] @configspec class BaseConfiguration(MutableMapping[str, Any]): - __is_resolved__: bool = dataclasses.field(default=False, init=False, repr=False) + __is_resolved__: bool = dataclasses.field(default=False, init=False, repr=False, compare=False) """True when all config fields were resolved and have a specified value type""" - __section__: str = dataclasses.field(default=None, init=False, repr=False) - """Obligatory section used by config providers when searching for keys, always present in the search path""" - __exception__: Exception = dataclasses.field(default=None, init=False, repr=False) + __exception__: Exception = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) """Holds the exception that prevented the full resolution""" + __section__: ClassVar[str] = None + """Obligatory section used by config providers when searching for keys, always present in the search path""" __config_gen_annotations__: ClassVar[List[str]] = [] """Additional annotations for config generator, currently holds a list of fields of interest that have defaults""" __dataclass_fields__: ClassVar[Dict[str, TDtcField]] @@ -342,9 +359,10 @@ def call_method_in_mro(config, method_name: str) -> None: class CredentialsConfiguration(BaseConfiguration): """Base class for all credentials. Credentials are configurations that may be stored only by providers supporting secrets.""" - __section__: str = "credentials" + __section__: ClassVar[str] = "credentials" - def __init__(self, init_value: Any = None) -> None: + @classmethod + def from_init_value(cls: Type[_C], init_value: Any = None) -> _C: """Initializes credentials from `init_value` Init value may be a native representation of the credentials or a dict. In case of native representation (for example a connection string or JSON with service account credentials) @@ -353,14 +371,10 @@ def __init__(self, init_value: Any = None) -> None: Credentials will be marked as resolved if all required fields are set. """ - if init_value is None: - return - elif isinstance(init_value, C_Mapping): - self.update(init_value) - else: - self.parse_native_representation(init_value) - if not self.is_partial(): - self.resolve() + # create an instance + self = cls() + self._apply_init_value(init_value) + return self def to_native_credentials(self) -> Any: """Returns native credentials object. @@ -369,6 +383,16 @@ def to_native_credentials(self) -> Any: """ return self.to_native_representation() + def _apply_init_value(self, init_value: Any = None) -> None: + if isinstance(init_value, C_Mapping): + self.update(init_value) + elif init_value is not None: + self.parse_native_representation(init_value) + else: + return + if not self.is_partial(): + self.resolve() + def __str__(self) -> str: """Get string representation of credentials to be displayed, with all secret parts removed""" return super().__str__() diff --git a/dlt/common/configuration/specs/config_providers_context.py b/dlt/common/configuration/specs/config_providers_context.py index 860e7414de..642634fb0a 100644 --- a/dlt/common/configuration/specs/config_providers_context.py +++ b/dlt/common/configuration/specs/config_providers_context.py @@ -1,4 +1,5 @@ import contextlib +import dataclasses import io from typing import ClassVar, List @@ -28,7 +29,7 @@ class ConfigProvidersConfiguration(BaseConfiguration): only_toml_fragments: bool = True # always look in providers - __section__ = known_sections.PROVIDERS + __section__: ClassVar[str] = known_sections.PROVIDERS @configspec @@ -37,8 +38,12 @@ class ConfigProvidersContext(ContainerInjectableContext): global_affinity: ClassVar[bool] = True - providers: List[ConfigProvider] - context_provider: ConfigProvider + providers: List[ConfigProvider] = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) + context_provider: ConfigProvider = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) def __init__(self) -> None: super().__init__() diff --git a/dlt/common/configuration/specs/config_section_context.py b/dlt/common/configuration/specs/config_section_context.py index a656a2b0fe..1e6cd56155 100644 --- a/dlt/common/configuration/specs/config_section_context.py +++ b/dlt/common/configuration/specs/config_section_context.py @@ -8,7 +8,7 @@ class ConfigSectionContext(ContainerInjectableContext): TMergeFunc = Callable[["ConfigSectionContext", "ConfigSectionContext"], None] - pipeline_name: Optional[str] + pipeline_name: Optional[str] = None sections: Tuple[str, ...] = () merge_style: TMergeFunc = None source_state_key: str = None @@ -70,13 +70,3 @@ def __str__(self) -> str: super().__str__() + f": {self.pipeline_name} {self.sections}@{self.merge_style} state['{self.source_state_key}']" ) - - if TYPE_CHECKING: - # provide __init__ signature when type checking - def __init__( - self, - pipeline_name: str = None, - sections: Tuple[str, ...] = (), - merge_style: TMergeFunc = None, - source_state_key: str = None, - ) -> None: ... diff --git a/dlt/common/configuration/specs/connection_string_credentials.py b/dlt/common/configuration/specs/connection_string_credentials.py index 9dd6f00942..2691c5d886 100644 --- a/dlt/common/configuration/specs/connection_string_credentials.py +++ b/dlt/common/configuration/specs/connection_string_credentials.py @@ -1,14 +1,15 @@ -from typing import Any, ClassVar, Dict, List, Optional +import dataclasses +from typing import Any, ClassVar, Dict, List, Optional, Union + from dlt.common.libs.sql_alchemy import URL, make_url from dlt.common.configuration.specs.exceptions import InvalidConnectionString - from dlt.common.typing import TSecretValue from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration, configspec @configspec class ConnectionStringCredentials(CredentialsConfiguration): - drivername: str = None + drivername: str = dataclasses.field(default=None, init=False, repr=False, compare=False) database: str = None password: Optional[TSecretValue] = None username: str = None @@ -18,6 +19,11 @@ class ConnectionStringCredentials(CredentialsConfiguration): __config_gen_annotations__: ClassVar[List[str]] = ["port", "password", "host"] + def __init__(self, connection_string: Union[str, Dict[str, Any]] = None) -> None: + """Initializes the credentials from SQLAlchemy like connection string or from dict holding connection string elements""" + super().__init__() + self._apply_init_value(connection_string) + def parse_native_representation(self, native_value: Any) -> None: if not isinstance(native_value, str): raise InvalidConnectionString(self.__class__, native_value, self.drivername) diff --git a/dlt/common/configuration/specs/gcp_credentials.py b/dlt/common/configuration/specs/gcp_credentials.py index 431f35c8d0..4d81a493a3 100644 --- a/dlt/common/configuration/specs/gcp_credentials.py +++ b/dlt/common/configuration/specs/gcp_credentials.py @@ -1,5 +1,6 @@ +import dataclasses import sys -from typing import Any, Final, List, Tuple, Union, Dict +from typing import Any, ClassVar, Final, List, Tuple, Union, Dict from dlt.common import json, pendulum from dlt.common.configuration.specs.api_credentials import OAuth2Credentials @@ -22,8 +23,12 @@ @configspec class GcpCredentials(CredentialsConfiguration): - token_uri: Final[str] = "https://oauth2.googleapis.com/token" - auth_uri: Final[str] = "https://accounts.google.com/o/oauth2/auth" + token_uri: Final[str] = dataclasses.field( + default="https://oauth2.googleapis.com/token", init=False, repr=False, compare=False + ) + auth_uri: Final[str] = dataclasses.field( + default="https://accounts.google.com/o/oauth2/auth", init=False, repr=False, compare=False + ) project_id: str = None @@ -69,7 +74,9 @@ def to_gcs_credentials(self) -> Dict[str, Any]: class GcpServiceAccountCredentialsWithoutDefaults(GcpCredentials): private_key: TSecretValue = None client_email: str = None - type: Final[str] = "service_account" # noqa: A003 + type: Final[str] = dataclasses.field( # noqa: A003 + default="service_account", init=False, repr=False, compare=False + ) def parse_native_representation(self, native_value: Any) -> None: """Accepts ServiceAccountCredentials as native value. In other case reverts to serialized services.json""" @@ -121,8 +128,10 @@ def __str__(self) -> str: @configspec class GcpOAuthCredentialsWithoutDefaults(GcpCredentials, OAuth2Credentials): # only desktop app supported - refresh_token: TSecretValue - client_type: Final[str] = "installed" + refresh_token: TSecretValue = None + client_type: Final[str] = dataclasses.field( + default="installed", init=False, repr=False, compare=False + ) def parse_native_representation(self, native_value: Any) -> None: """Accepts Google OAuth2 credentials as native value. In other case reverts to serialized oauth client secret json""" @@ -237,7 +246,7 @@ def __str__(self) -> str: @configspec class GcpDefaultCredentials(CredentialsWithDefault, GcpCredentials): - _LAST_FAILED_DEFAULT: float = 0.0 + _LAST_FAILED_DEFAULT: ClassVar[float] = 0.0 def parse_native_representation(self, native_value: Any) -> None: """Accepts google credentials as native value""" diff --git a/dlt/common/configuration/specs/known_sections.py b/dlt/common/configuration/specs/known_sections.py index 97ba85ffd6..8bd754ddd5 100644 --- a/dlt/common/configuration/specs/known_sections.py +++ b/dlt/common/configuration/specs/known_sections.py @@ -13,6 +13,9 @@ EXTRACT = "extract" """extract stage of the pipeline""" +SCHEMA = "schema" +"""schema configuration, ie. normalizers""" + PROVIDERS = "providers" """secrets and config providers""" diff --git a/dlt/common/configuration/specs/run_configuration.py b/dlt/common/configuration/specs/run_configuration.py index 54ce46ceba..b57b4abbdd 100644 --- a/dlt/common/configuration/specs/run_configuration.py +++ b/dlt/common/configuration/specs/run_configuration.py @@ -1,7 +1,7 @@ import binascii from os.path import isfile, join from pathlib import Path -from typing import Any, Optional, Tuple, IO +from typing import Any, ClassVar, Optional, IO from dlt.common.typing import TSecretStrValue from dlt.common.utils import encoding_for_mode, main_module_file_path, reveal_pseudo_secret @@ -30,7 +30,7 @@ class RunConfiguration(BaseConfiguration): """Platform connection""" dlthub_dsn: Optional[TSecretStrValue] = None - __section__ = "runtime" + __section__: ClassVar[str] = "runtime" def on_resolved(self) -> None: # generate pipeline name from the entry point script name diff --git a/dlt/common/data_writers/buffered.py b/dlt/common/data_writers/buffered.py index 24935d73ac..b10b1d14b9 100644 --- a/dlt/common/data_writers/buffered.py +++ b/dlt/common/data_writers/buffered.py @@ -1,6 +1,6 @@ import gzip import time -from typing import List, IO, Any, Optional, Type, TypeVar, Generic +from typing import ClassVar, List, IO, Any, Optional, Type, TypeVar, Generic from dlt.common.typing import TDataItem, TDataItems from dlt.common.data_writers import TLoaderFileFormat @@ -33,7 +33,7 @@ class BufferedDataWriterConfiguration(BaseConfiguration): disable_compression: bool = False _caps: Optional[DestinationCapabilitiesContext] = None - __section__ = known_sections.DATA_WRITER + __section__: ClassVar[str] = known_sections.DATA_WRITER @with_config(spec=BufferedDataWriterConfiguration) def __init__( diff --git a/dlt/common/data_writers/writers.py b/dlt/common/data_writers/writers.py index 0f3640da1e..2aadb010e0 100644 --- a/dlt/common/data_writers/writers.py +++ b/dlt/common/data_writers/writers.py @@ -4,6 +4,7 @@ IO, TYPE_CHECKING, Any, + ClassVar, Dict, List, Optional, @@ -236,7 +237,7 @@ class ParquetDataWriterConfiguration(BaseConfiguration): timestamp_timezone: str = "UTC" row_group_size: Optional[int] = None - __section__: str = known_sections.DATA_WRITER + __section__: ClassVar[str] = known_sections.DATA_WRITER class ParquetDataWriter(DataWriter): diff --git a/dlt/common/destination/capabilities.py b/dlt/common/destination/capabilities.py index 36a9cc3b6e..7a64f32ea3 100644 --- a/dlt/common/destination/capabilities.py +++ b/dlt/common/destination/capabilities.py @@ -30,22 +30,22 @@ class DestinationCapabilitiesContext(ContainerInjectableContext): """Injectable destination capabilities required for many Pipeline stages ie. normalize""" - preferred_loader_file_format: TLoaderFileFormat - supported_loader_file_formats: List[TLoaderFileFormat] - preferred_staging_file_format: Optional[TLoaderFileFormat] - supported_staging_file_formats: List[TLoaderFileFormat] - escape_identifier: Callable[[str], str] - escape_literal: Callable[[Any], Any] - decimal_precision: Tuple[int, int] - wei_precision: Tuple[int, int] - max_identifier_length: int - max_column_identifier_length: int - max_query_length: int - is_max_query_length_in_bytes: bool - max_text_data_type_length: int - is_max_text_data_type_length_in_bytes: bool - supports_transactions: bool - supports_ddl_transactions: bool + preferred_loader_file_format: TLoaderFileFormat = None + supported_loader_file_formats: List[TLoaderFileFormat] = None + preferred_staging_file_format: Optional[TLoaderFileFormat] = None + supported_staging_file_formats: List[TLoaderFileFormat] = None + escape_identifier: Callable[[str], str] = None + escape_literal: Callable[[Any], Any] = None + decimal_precision: Tuple[int, int] = None + wei_precision: Tuple[int, int] = None + max_identifier_length: int = None + max_column_identifier_length: int = None + max_query_length: int = None + is_max_query_length_in_bytes: bool = None + max_text_data_type_length: int = None + is_max_text_data_type_length_in_bytes: bool = None + supports_transactions: bool = None + supports_ddl_transactions: bool = None naming_convention: str = "snake_case" alter_add_multi_column: bool = True supports_truncate_command: bool = True diff --git a/dlt/common/destination/reference.py b/dlt/common/destination/reference.py index 738c07bdc7..ddcc5d1146 100644 --- a/dlt/common/destination/reference.py +++ b/dlt/common/destination/reference.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +import dataclasses from importlib import import_module from types import TracebackType from typing import ( @@ -11,7 +12,6 @@ Iterable, Type, Union, - TYPE_CHECKING, List, ContextManager, Dict, @@ -48,11 +48,11 @@ from dlt.common.schema.exceptions import UnknownTableException from dlt.common.storages import FileStorage from dlt.common.storages.load_storage import ParsedLoadJobFileName -from dlt.common.configuration.specs import GcpCredentials, AwsCredentialsWithoutDefaults TLoaderReplaceStrategy = Literal["truncate-and-insert", "insert-from-staging", "staging-optimized"] TDestinationConfig = TypeVar("TDestinationConfig", bound="DestinationClientConfiguration") TDestinationClient = TypeVar("TDestinationClient", bound="JobClientBase") +TDestinationDwhClient = TypeVar("TDestinationDwhClient", bound="DestinationClientDwhConfiguration") class StorageSchemaInfo(NamedTuple): @@ -75,8 +75,10 @@ class StateInfo(NamedTuple): @configspec class DestinationClientConfiguration(BaseConfiguration): - destination_type: Final[str] = None # which destination to load data to - credentials: Optional[CredentialsConfiguration] + destination_type: Final[str] = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) # which destination to load data to + credentials: Optional[CredentialsConfiguration] = None destination_name: Optional[str] = ( None # name of the destination, if not set, destination_type is used ) @@ -93,28 +95,33 @@ def __str__(self) -> str: def on_resolved(self) -> None: self.destination_name = self.destination_name or self.destination_type - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Optional[CredentialsConfiguration] = None, - destination_name: str = None, - environment: str = None, - ) -> None: ... - @configspec class DestinationClientDwhConfiguration(DestinationClientConfiguration): """Configuration of a destination that supports datasets/schemas""" - dataset_name: Final[str] = None # dataset must be final so it is not configurable + dataset_name: Final[str] = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) # dataset must be final so it is not configurable """dataset name in the destination to load data to, for schemas that are not default schema, it is used as dataset prefix""" - default_schema_name: Optional[str] = None + default_schema_name: Final[Optional[str]] = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) """name of default schema to be used to name effective dataset to load data to""" replace_strategy: TLoaderReplaceStrategy = "truncate-and-insert" """How to handle replace disposition for this destination, can be classic or staging""" + def _bind_dataset_name( + self: TDestinationDwhClient, dataset_name: str, default_schema_name: str = None + ) -> TDestinationDwhClient: + """Binds the dataset and default schema name to the configuration + + This method is intended to be used internally. + """ + self.dataset_name = dataset_name # type: ignore[misc] + self.default_schema_name = default_schema_name # type: ignore[misc] + return self + def normalize_dataset_name(self, schema: Schema) -> str: """Builds full db dataset (schema) name out of configured dataset name and schema name: {dataset_name}_{schema.name}. The resulting name is normalized. @@ -136,18 +143,6 @@ def normalize_dataset_name(self, schema: Schema) -> str: else schema.naming.normalize_table_identifier(self.dataset_name) ) - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Optional[CredentialsConfiguration] = None, - dataset_name: str = None, - default_schema_name: Optional[str] = None, - destination_name: str = None, - environment: str = None, - ) -> None: ... - @configspec class DestinationClientStagingConfiguration(DestinationClientDwhConfiguration): @@ -161,21 +156,6 @@ class DestinationClientStagingConfiguration(DestinationClientDwhConfiguration): # layout of the destination files layout: str = "{table_name}/{load_id}.{file_id}.{ext}" - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Union[AwsCredentialsWithoutDefaults, GcpCredentials] = None, - dataset_name: str = None, - default_schema_name: Optional[str] = None, - as_staging: bool = False, - bucket_url: str = None, - layout: str = None, - destination_name: str = None, - environment: str = None, - ) -> None: ... - @configspec class DestinationClientDwhWithStagingConfiguration(DestinationClientDwhConfiguration): @@ -183,18 +163,6 @@ class DestinationClientDwhWithStagingConfiguration(DestinationClientDwhConfigura staging_config: Optional[DestinationClientStagingConfiguration] = None """configuration of the staging, if present, injected at runtime""" - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Optional[CredentialsConfiguration] = None, - dataset_name: str = None, - default_schema_name: Optional[str] = None, - staging_config: Optional[DestinationClientStagingConfiguration] = None, - destination_name: str = None, - environment: str = None, - ) -> None: ... TLoadJobState = Literal["running", "failed", "retry", "completed"] diff --git a/dlt/common/runtime/logger.py b/dlt/common/logger.py similarity index 84% rename from dlt/common/runtime/logger.py rename to dlt/common/logger.py index 9dd8ce4e3a..02412248c3 100644 --- a/dlt/common/runtime/logger.py +++ b/dlt/common/logger.py @@ -2,12 +2,7 @@ import logging import traceback from logging import LogRecord, Logger -from typing import Any, Iterator, Protocol - -from dlt.common.json import json -from dlt.common.runtime.exec_info import dlt_version_info -from dlt.common.typing import StrAny, StrStr -from dlt.common.configuration.specs import RunConfiguration +from typing import Any, Mapping, Iterator, Protocol DLT_LOGGER_NAME = "dlt" LOGGER: Logger = None @@ -32,7 +27,7 @@ def wrapper(msg: str, *args: Any, **kwargs: Any) -> None: return wrapper -def metrics(name: str, extra: StrAny, stacklevel: int = 1) -> None: +def metrics(name: str, extra: Mapping[str, Any], stacklevel: int = 1) -> None: """Forwards metrics call to LOGGER""" if LOGGER: LOGGER.info(name, extra=extra, stacklevel=stacklevel) @@ -46,15 +41,6 @@ def suppress_and_warn() -> Iterator[None]: LOGGER.warning("Suppressed exception", exc_info=True) -def init_logging(config: RunConfiguration) -> None: - global LOGGER - - version = dlt_version_info(config.pipeline_name) - LOGGER = _init_logging( - DLT_LOGGER_NAME, config.log_level, config.log_format, config.pipeline_name, version - ) - - def is_logging() -> bool: return LOGGER is not None @@ -75,6 +61,8 @@ def pretty_format_exception() -> str: class _MetricsFormatter(logging.Formatter): def format(self, record: LogRecord) -> str: # noqa: A003 + from dlt.common.json import json + s = super(_MetricsFormatter, self).format(record) # dump metrics dictionary nicely if "metrics" in record.__dict__: @@ -83,7 +71,7 @@ def format(self, record: LogRecord) -> str: # noqa: A003 def _init_logging( - logger_name: str, level: str, fmt: str, component: str, version: StrStr + logger_name: str, level: str, fmt: str, component: str, version: Mapping[str, str] ) -> Logger: if logger_name == "root": logging.basicConfig(level=level) @@ -102,7 +90,7 @@ def _init_logging( from dlt.common.runtime import json_logging class _CustomJsonFormatter(json_logging.JSONLogFormatter): - version: StrStr = None + version: Mapping[str, str] = None def _format_log_object(self, record: LogRecord) -> Any: json_log_object = super(_CustomJsonFormatter, self)._format_log_object(record) diff --git a/dlt/common/normalizers/configuration.py b/dlt/common/normalizers/configuration.py index adeefe2237..54b725db1f 100644 --- a/dlt/common/normalizers/configuration.py +++ b/dlt/common/normalizers/configuration.py @@ -1,8 +1,7 @@ -import dataclasses -from typing import Optional, TYPE_CHECKING +from typing import ClassVar, Optional, TYPE_CHECKING from dlt.common.configuration import configspec -from dlt.common.configuration.specs import BaseConfiguration +from dlt.common.configuration.specs import BaseConfiguration, known_sections from dlt.common.destination import DestinationCapabilitiesContext from dlt.common.normalizers.typing import TJSONNormalizer from dlt.common.typing import DictStrAny @@ -11,7 +10,7 @@ @configspec class NormalizersConfiguration(BaseConfiguration): # always in section - __section__: str = "schema" + __section__: ClassVar[str] = known_sections.SCHEMA naming: Optional[str] = None json_normalizer: Optional[DictStrAny] = None @@ -32,7 +31,3 @@ def on_resolved(self) -> None: self.json_normalizer["config"][ "max_nesting" ] = self.destination_capabilities.max_table_nesting - - if TYPE_CHECKING: - - def __init__(self, naming: str = None, json_normalizer: TJSONNormalizer = None) -> None: ... diff --git a/dlt/common/pipeline.py b/dlt/common/pipeline.py index 57dda11c39..7c117d4612 100644 --- a/dlt/common/pipeline.py +++ b/dlt/common/pipeline.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +import dataclasses import os import datetime # noqa: 251 import humanize @@ -553,8 +554,12 @@ def __call__( @configspec class PipelineContext(ContainerInjectableContext): - _deferred_pipeline: Callable[[], SupportsPipeline] - _pipeline: SupportsPipeline + _deferred_pipeline: Callable[[], SupportsPipeline] = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) + _pipeline: SupportsPipeline = dataclasses.field( + default=None, init=False, repr=False, compare=False + ) can_create_default: ClassVar[bool] = True @@ -592,14 +597,10 @@ def __init__(self, deferred_pipeline: Callable[..., SupportsPipeline] = None) -> @configspec class StateInjectableContext(ContainerInjectableContext): - state: TPipelineState + state: TPipelineState = None can_create_default: ClassVar[bool] = False - if TYPE_CHECKING: - - def __init__(self, state: TPipelineState = None) -> None: ... - def pipeline_state( container: Container, initial_default: TPipelineState = None diff --git a/dlt/common/runners/configuration.py b/dlt/common/runners/configuration.py index c5de2353f4..5857e1799f 100644 --- a/dlt/common/runners/configuration.py +++ b/dlt/common/runners/configuration.py @@ -16,13 +16,3 @@ class PoolRunnerConfiguration(BaseConfiguration): """# how many threads/processes in the pool""" run_sleep: float = 0.1 """how long to sleep between runs with workload, seconds""" - - if TYPE_CHECKING: - - def __init__( - self, - pool_type: TPoolType = None, - start_method: str = None, - workers: int = None, - run_sleep: float = 0.1, - ) -> None: ... diff --git a/dlt/common/runtime/init.py b/dlt/common/runtime/init.py index dc1430a527..5354dee4ff 100644 --- a/dlt/common/runtime/init.py +++ b/dlt/common/runtime/init.py @@ -5,8 +5,17 @@ _RUN_CONFIGURATION: RunConfiguration = None +def init_logging(config: RunConfiguration) -> None: + from dlt.common import logger + from dlt.common.runtime.exec_info import dlt_version_info + + version = dlt_version_info(config.pipeline_name) + logger.LOGGER = logger._init_logging( + logger.DLT_LOGGER_NAME, config.log_level, config.log_format, config.pipeline_name, version + ) + + def initialize_runtime(config: RunConfiguration) -> None: - from dlt.common.runtime.logger import init_logging from dlt.common.runtime.telemetry import start_telemetry from dlt.sources.helpers import requests diff --git a/dlt/common/runtime/prometheus.py b/dlt/common/runtime/prometheus.py index 1b233ffa9b..07c960efe7 100644 --- a/dlt/common/runtime/prometheus.py +++ b/dlt/common/runtime/prometheus.py @@ -3,7 +3,7 @@ from prometheus_client.metrics import MetricWrapperBase from dlt.common.configuration.specs import RunConfiguration -from dlt.common.runtime import logger +from dlt.common import logger from dlt.common.runtime.exec_info import dlt_version_info from dlt.common.typing import DictStrAny, StrAny diff --git a/dlt/common/runtime/segment.py b/dlt/common/runtime/segment.py index e302767fcc..70b81fb4f4 100644 --- a/dlt/common/runtime/segment.py +++ b/dlt/common/runtime/segment.py @@ -10,7 +10,7 @@ from typing import Literal, Optional from dlt.common.configuration.paths import get_dlt_data_dir -from dlt.common.runtime import logger +from dlt.common import logger from dlt.common.managed_thread_pool import ManagedThreadPool from dlt.common.configuration.specs import RunConfiguration diff --git a/dlt/common/runtime/signals.py b/dlt/common/runtime/signals.py index 8e64c8ba64..8d1cb3803e 100644 --- a/dlt/common/runtime/signals.py +++ b/dlt/common/runtime/signals.py @@ -2,8 +2,9 @@ import signal from contextlib import contextmanager from threading import Event -from typing import Any, TYPE_CHECKING, Iterator +from typing import Any, Iterator +from dlt.common import logger from dlt.common.exceptions import SignalReceivedException _received_signal: int = 0 @@ -11,11 +12,6 @@ def signal_receiver(sig: int, frame: Any) -> None: - if not TYPE_CHECKING: - from dlt.common.runtime import logger - else: - logger: Any = None - global _received_signal logger.info(f"Signal {sig} received") @@ -64,9 +60,5 @@ def delayed_signals() -> Iterator[None]: signal.signal(signal.SIGINT, original_sigint_handler) signal.signal(signal.SIGTERM, original_sigterm_handler) else: - if not TYPE_CHECKING: - from dlt.common.runtime import logger - else: - logger: Any = None logger.info("Running in daemon thread, signals not enabled") yield diff --git a/dlt/common/runtime/slack.py b/dlt/common/runtime/slack.py index 15da89f333..b1e090098d 100644 --- a/dlt/common/runtime/slack.py +++ b/dlt/common/runtime/slack.py @@ -1,8 +1,9 @@ import requests -from dlt.common import json, logger def send_slack_message(incoming_hook: str, message: str, is_markdown: bool = True) -> None: + from dlt.common import json, logger + """Sends a `message` to Slack `incoming_hook`, by default formatted as markdown.""" r = requests.post( incoming_hook, diff --git a/dlt/common/storages/configuration.py b/dlt/common/storages/configuration.py index 2cbe7c78d5..d0100c335d 100644 --- a/dlt/common/storages/configuration.py +++ b/dlt/common/storages/configuration.py @@ -31,24 +31,11 @@ class SchemaStorageConfiguration(BaseConfiguration): True # remove default values when exporting schema ) - if TYPE_CHECKING: - - def __init__( - self, - schema_volume_path: str = None, - import_schema_path: str = None, - export_schema_path: str = None, - ) -> None: ... - @configspec class NormalizeStorageConfiguration(BaseConfiguration): normalize_volume_path: str = None # path to volume where normalized loader files will be stored - if TYPE_CHECKING: - - def __init__(self, normalize_volume_path: str = None) -> None: ... - @configspec class LoadStorageConfiguration(BaseConfiguration): @@ -59,12 +46,6 @@ class LoadStorageConfiguration(BaseConfiguration): False # if set to true the folder with completed jobs will be deleted ) - if TYPE_CHECKING: - - def __init__( - self, load_volume_path: str = None, delete_completed_jobs: bool = None - ) -> None: ... - FileSystemCredentials = Union[ AwsCredentials, GcpServiceAccountCredentials, AzureCredentials, GcpOAuthCredentials @@ -96,7 +77,7 @@ class FilesystemConfiguration(BaseConfiguration): bucket_url: str = None # should be a union of all possible credentials as found in PROTOCOL_CREDENTIALS - credentials: FileSystemCredentials + credentials: FileSystemCredentials = None read_only: bool = False """Indicates read only filesystem access. Will enable caching""" @@ -144,14 +125,3 @@ def __str__(self) -> str: new_netloc += f":{url.port}" return url._replace(netloc=new_netloc).geturl() return self.bucket_url - - if TYPE_CHECKING: - - def __init__( - self, - bucket_url: str, - credentials: FileSystemCredentials = None, - read_only: bool = False, - kwargs: Optional[DictStrAny] = None, - client_kwargs: Optional[DictStrAny] = None, - ) -> None: ... diff --git a/dlt/common/storages/load_package.py b/dlt/common/storages/load_package.py index b9c36143ac..3b8af424ee 100644 --- a/dlt/common/storages/load_package.py +++ b/dlt/common/storages/load_package.py @@ -627,8 +627,8 @@ def filter_jobs_for_table( @configspec class LoadPackageStateInjectableContext(ContainerInjectableContext): - storage: PackageStorage - load_id: str + storage: PackageStorage = None + load_id: str = None can_create_default: ClassVar[bool] = False global_affinity: ClassVar[bool] = False @@ -640,10 +640,6 @@ def on_resolved(self) -> None: self.state_save_lock = threading.Lock() self.state = self.storage.get_load_package_state(self.load_id) - if TYPE_CHECKING: - - def __init__(self, load_id: str, storage: PackageStorage) -> None: ... - def load_package() -> TLoadPackage: """Get full load package state present in current context. Across all threads this will be the same in memory dict.""" diff --git a/dlt/destinations/decorators.py b/dlt/destinations/decorators.py index 62d059c4a6..a920d336a2 100644 --- a/dlt/destinations/decorators.py +++ b/dlt/destinations/decorators.py @@ -1,6 +1,6 @@ import functools -from typing import Any, Type, Optional, Callable, Union, cast +from typing import Any, Type, Optional, Callable, Union from typing_extensions import Concatenate from dlt.common.typing import AnyFun @@ -13,7 +13,6 @@ CustomDestinationClientConfiguration, ) from dlt.common.destination import TLoaderFileFormat -from dlt.common.destination.reference import Destination from dlt.common.typing import TDataItems from dlt.common.schema import TTableSchema diff --git a/dlt/destinations/impl/athena/configuration.py b/dlt/destinations/impl/athena/configuration.py index 6b985f284a..59dfeee4ec 100644 --- a/dlt/destinations/impl/athena/configuration.py +++ b/dlt/destinations/impl/athena/configuration.py @@ -1,4 +1,5 @@ -from typing import ClassVar, Final, List, Optional, TYPE_CHECKING +import dataclasses +from typing import ClassVar, Final, List, Optional from dlt.common.configuration import configspec from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration @@ -7,7 +8,7 @@ @configspec class AthenaClientConfiguration(DestinationClientDwhWithStagingConfiguration): - destination_type: Final[str] = "athena" # type: ignore[misc] + destination_type: Final[str] = dataclasses.field(default="athena", init=False, repr=False, compare=False) # type: ignore[misc] query_result_bucket: str = None credentials: AwsCredentials = None athena_work_group: Optional[str] = None @@ -23,19 +24,3 @@ def __str__(self) -> str: return str(self.staging_config.credentials) else: return "[no staging set]" - - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Optional[AwsCredentials] = None, - dataset_name: str = None, - default_schema_name: Optional[str] = None, - athena_work_group: Optional[str] = None, - aws_data_catalog: Optional[str] = None, - supports_truncate_command: bool = False, - force_iceberg: Optional[bool] = False, - destination_name: str = None, - environment: str = None, - ) -> None: ... diff --git a/dlt/destinations/impl/bigquery/configuration.py b/dlt/destinations/impl/bigquery/configuration.py index 3c4a71c0df..a6686c3f2d 100644 --- a/dlt/destinations/impl/bigquery/configuration.py +++ b/dlt/destinations/impl/bigquery/configuration.py @@ -1,5 +1,6 @@ +import dataclasses import warnings -from typing import TYPE_CHECKING, ClassVar, List, Optional, Final +from typing import ClassVar, List, Final from dlt.common.configuration import configspec from dlt.common.configuration.specs import GcpServiceAccountCredentials @@ -10,7 +11,7 @@ @configspec class BigQueryClientConfiguration(DestinationClientDwhWithStagingConfiguration): - destination_type: Final[str] = "bigquery" # type: ignore + destination_type: Final[str] = dataclasses.field(default="bigquery", init=False, repr=False, compare=False) # type: ignore credentials: GcpServiceAccountCredentials = None location: str = "US" @@ -38,31 +39,3 @@ def fingerprint(self) -> str: if self.credentials and self.credentials.project_id: return digest128(self.credentials.project_id) return "" - - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Optional[GcpServiceAccountCredentials] = None, - dataset_name: str = None, - default_schema_name: Optional[str] = None, - location: str = "US", - http_timeout: float = 15.0, - file_upload_timeout: float = 30 * 60.0, - retry_deadline: float = 60.0, - destination_name: str = None, - environment: str = None - ) -> None: - super().__init__( - credentials=credentials, - dataset_name=dataset_name, - default_schema_name=default_schema_name, - destination_name=destination_name, - environment=environment, - ) - self.retry_deadline = retry_deadline - self.file_upload_timeout = file_upload_timeout - self.http_timeout = http_timeout - self.location = location - ... diff --git a/dlt/destinations/impl/databricks/configuration.py b/dlt/destinations/impl/databricks/configuration.py index 924047e30f..3bd2d12a5a 100644 --- a/dlt/destinations/impl/databricks/configuration.py +++ b/dlt/destinations/impl/databricks/configuration.py @@ -1,3 +1,4 @@ +import dataclasses from typing import ClassVar, Final, Optional, Any, Dict, List from dlt.common.typing import TSecretStrValue @@ -40,8 +41,8 @@ def to_connector_params(self) -> Dict[str, Any]: @configspec class DatabricksClientConfiguration(DestinationClientDwhWithStagingConfiguration): - destination_type: Final[str] = "databricks" # type: ignore[misc] - credentials: DatabricksCredentials + destination_type: Final[str] = dataclasses.field(default="databricks", init=False, repr=False, compare=False) # type: ignore[misc] + credentials: DatabricksCredentials = None def __str__(self) -> str: """Return displayable destination location""" diff --git a/dlt/destinations/impl/destination/configuration.py b/dlt/destinations/impl/destination/configuration.py index f123ba69b3..30e54a8313 100644 --- a/dlt/destinations/impl/destination/configuration.py +++ b/dlt/destinations/impl/destination/configuration.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Optional, Final, Callable, Union, Any +import dataclasses +from typing import Optional, Final, Callable, Union from typing_extensions import ParamSpec from dlt.common.configuration import configspec @@ -16,19 +17,9 @@ @configspec class CustomDestinationClientConfiguration(DestinationClientConfiguration): - destination_type: Final[str] = "destination" # type: ignore + destination_type: Final[str] = dataclasses.field(default="destination", init=False, repr=False, compare=False) # type: ignore destination_callable: Optional[Union[str, TDestinationCallable]] = None # noqa: A003 loader_file_format: TLoaderFileFormat = "puae-jsonl" batch_size: int = 10 skip_dlt_columns_and_tables: bool = True max_table_nesting: int = 0 - - if TYPE_CHECKING: - - def __init__( - self, - *, - loader_file_format: TLoaderFileFormat = "puae-jsonl", - batch_size: int = 10, - destination_callable: Union[TDestinationCallable, str] = None, - ) -> None: ... diff --git a/dlt/destinations/impl/duckdb/configuration.py b/dlt/destinations/impl/duckdb/configuration.py index 8cb88c43b5..70d91dcb56 100644 --- a/dlt/destinations/impl/duckdb/configuration.py +++ b/dlt/destinations/impl/duckdb/configuration.py @@ -1,7 +1,8 @@ import os +import dataclasses import threading from pathvalidate import is_valid_filepath -from typing import Any, ClassVar, Final, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, ClassVar, Final, List, Optional, Tuple, TYPE_CHECKING, Type, Union from dlt.common import logger from dlt.common.configuration import configspec @@ -13,12 +14,17 @@ ) from dlt.common.typing import TSecretValue +try: + from duckdb import DuckDBPyConnection +except ModuleNotFoundError: + DuckDBPyConnection = Type[Any] # type: ignore[assignment,misc] + DUCK_DB_NAME = "%s.duckdb" DEFAULT_DUCK_DB_NAME = DUCK_DB_NAME % "quack" LOCAL_STATE_KEY = "duckdb_database" -@configspec +@configspec(init=False) class DuckDbBaseCredentials(ConnectionStringCredentials): password: Optional[TSecretValue] = None host: Optional[str] = None @@ -95,7 +101,7 @@ def __del__(self) -> None: @configspec class DuckDbCredentials(DuckDbBaseCredentials): - drivername: Final[str] = "duckdb" # type: ignore + drivername: Final[str] = dataclasses.field(default="duckdb", init=False, repr=False, compare=False) # type: ignore username: Optional[str] = None __config_gen_annotations__: ClassVar[List[str]] = [] @@ -193,30 +199,31 @@ def _path_from_pipeline(self, default_path: str) -> Tuple[str, bool]: def _conn_str(self) -> str: return self.database + def __init__(self, conn_or_path: Union[str, DuckDBPyConnection] = None) -> None: + """Access to duckdb database at a given path or from duckdb connection""" + self._apply_init_value(conn_or_path) + @configspec class DuckDbClientConfiguration(DestinationClientDwhWithStagingConfiguration): - destination_type: Final[str] = "duckdb" # type: ignore - credentials: DuckDbCredentials + destination_type: Final[str] = dataclasses.field(default="duckdb", init=False, repr=False, compare=False) # type: ignore + credentials: DuckDbCredentials = None create_indexes: bool = ( False # should unique indexes be created, this slows loading down massively ) - if TYPE_CHECKING: - try: - from duckdb import DuckDBPyConnection - except ModuleNotFoundError: - DuckDBPyConnection = Any # type: ignore[assignment,misc] - - def __init__( - self, - *, - credentials: Union[DuckDbCredentials, str, DuckDBPyConnection] = None, - dataset_name: str = None, - default_schema_name: Optional[str] = None, - create_indexes: bool = False, - staging_config: Optional[DestinationClientStagingConfiguration] = None, - destination_name: str = None, - environment: str = None, - ) -> None: ... + def __init__( + self, + *, + credentials: Union[DuckDbCredentials, str, DuckDBPyConnection] = None, + create_indexes: bool = False, + destination_name: str = None, + environment: str = None, + ) -> None: + super().__init__( + credentials=credentials, # type: ignore[arg-type] + destination_name=destination_name, + environment=environment, + ) + self.create_indexes = create_indexes diff --git a/dlt/destinations/impl/dummy/configuration.py b/dlt/destinations/impl/dummy/configuration.py index cce0dfa8ed..a9fdb1f47d 100644 --- a/dlt/destinations/impl/dummy/configuration.py +++ b/dlt/destinations/impl/dummy/configuration.py @@ -1,4 +1,5 @@ -from typing import TYPE_CHECKING, Optional, Final +import dataclasses +from typing import Final from dlt.common.configuration import configspec from dlt.common.destination import TLoaderFileFormat @@ -16,7 +17,7 @@ def __str__(self) -> str: @configspec class DummyClientConfiguration(DestinationClientConfiguration): - destination_type: Final[str] = "dummy" # type: ignore + destination_type: Final[str] = dataclasses.field(default="dummy", init=False, repr=False, compare=False) # type: ignore loader_file_format: TLoaderFileFormat = "jsonl" fail_schema_update: bool = False fail_prob: float = 0.0 @@ -30,22 +31,3 @@ class DummyClientConfiguration(DestinationClientConfiguration): create_followup_jobs: bool = False credentials: DummyClientCredentials = None - - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Optional[CredentialsConfiguration] = None, - loader_file_format: TLoaderFileFormat = None, - fail_schema_update: bool = None, - fail_prob: float = None, - retry_prob: float = None, - completed_prob: float = None, - exception_prob: float = None, - timeout: float = None, - fail_in_init: bool = None, - create_followup_jobs: bool = None, - destination_name: str = None, - environment: str = None, - ) -> None: ... diff --git a/dlt/destinations/impl/filesystem/configuration.py b/dlt/destinations/impl/filesystem/configuration.py index 93e5537aab..1521222180 100644 --- a/dlt/destinations/impl/filesystem/configuration.py +++ b/dlt/destinations/impl/filesystem/configuration.py @@ -1,6 +1,5 @@ -from urllib.parse import urlparse - -from typing import Final, Type, Optional, Any, TYPE_CHECKING +import dataclasses +from typing import Final, Type, Optional from dlt.common.configuration import configspec, resolve_type from dlt.common.destination.reference import ( @@ -12,22 +11,9 @@ @configspec class FilesystemDestinationClientConfiguration(FilesystemConfiguration, DestinationClientStagingConfiguration): # type: ignore[misc] - destination_type: Final[str] = "filesystem" # type: ignore + destination_type: Final[str] = dataclasses.field(default="filesystem", init=False, repr=False, compare=False) # type: ignore @resolve_type("credentials") def resolve_credentials_type(self) -> Type[CredentialsConfiguration]: # use known credentials or empty credentials for unknown protocol return self.PROTOCOL_CREDENTIALS.get(self.protocol) or Optional[CredentialsConfiguration] # type: ignore[return-value] - - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Optional[Any] = None, - dataset_name: str = None, - default_schema_name: Optional[str] = None, - bucket_url: str = None, - destination_name: str = None, - environment: str = None, - ) -> None: ... diff --git a/dlt/destinations/impl/motherduck/configuration.py b/dlt/destinations/impl/motherduck/configuration.py index 35f02f709a..3179295c54 100644 --- a/dlt/destinations/impl/motherduck/configuration.py +++ b/dlt/destinations/impl/motherduck/configuration.py @@ -1,4 +1,5 @@ -from typing import Any, ClassVar, Final, List, TYPE_CHECKING, Optional +import dataclasses +from typing import Any, ClassVar, Final, List from dlt.common.configuration import configspec from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration @@ -12,9 +13,9 @@ MOTHERDUCK_DRIVERNAME = "md" -@configspec +@configspec(init=False) class MotherDuckCredentials(DuckDbBaseCredentials): - drivername: Final[str] = "md" # type: ignore + drivername: Final[str] = dataclasses.field(default="md", init=False, repr=False, compare=False) # type: ignore username: str = "motherduck" read_only: bool = False # open database read/write @@ -57,8 +58,8 @@ def on_resolved(self) -> None: @configspec class MotherDuckClientConfiguration(DestinationClientDwhWithStagingConfiguration): - destination_type: Final[str] = "motherduck" # type: ignore - credentials: MotherDuckCredentials + destination_type: Final[str] = dataclasses.field(default="motherduck", init=False, repr=False, compare=False) # type: ignore + credentials: MotherDuckCredentials = None create_indexes: bool = ( False # should unique indexes be created, this slows loading down massively @@ -70,19 +71,6 @@ def fingerprint(self) -> str: return digest128(self.credentials.password) return "" - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Optional[MotherDuckCredentials] = None, - dataset_name: str = None, - default_schema_name: Optional[str] = None, - create_indexes: Optional[bool] = None, - destination_name: str = None, - environment: str = None, - ) -> None: ... - class MotherduckLocalVersionNotSupported(DestinationTerminalException): def __init__(self, duckdb_version: str) -> None: diff --git a/dlt/destinations/impl/mssql/configuration.py b/dlt/destinations/impl/mssql/configuration.py index 45c448fab7..1d085f40c1 100644 --- a/dlt/destinations/impl/mssql/configuration.py +++ b/dlt/destinations/impl/mssql/configuration.py @@ -1,4 +1,5 @@ -from typing import Final, ClassVar, Any, List, Dict, Optional, TYPE_CHECKING +import dataclasses +from typing import Final, ClassVar, Any, List, Dict from dlt.common.libs.sql_alchemy import URL from dlt.common.configuration import configspec @@ -10,11 +11,11 @@ from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration -@configspec +@configspec(init=False) class MsSqlCredentials(ConnectionStringCredentials): - drivername: Final[str] = "mssql" # type: ignore - password: TSecretValue - host: str + drivername: Final[str] = dataclasses.field(default="mssql", init=False, repr=False, compare=False) # type: ignore + password: TSecretValue = None + host: str = None port: int = 1433 connect_timeout: int = 15 driver: str = None @@ -90,8 +91,8 @@ def to_odbc_dsn(self) -> str: @configspec class MsSqlClientConfiguration(DestinationClientDwhWithStagingConfiguration): - destination_type: Final[str] = "mssql" # type: ignore - credentials: MsSqlCredentials + destination_type: Final[str] = dataclasses.field(default="mssql", init=False, repr=False, compare=False) # type: ignore + credentials: MsSqlCredentials = None create_indexes: bool = False @@ -100,16 +101,3 @@ def fingerprint(self) -> str: if self.credentials and self.credentials.host: return digest128(self.credentials.host) return "" - - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: Optional[MsSqlCredentials] = None, - dataset_name: str = None, - default_schema_name: Optional[str] = None, - create_indexes: Optional[bool] = None, - destination_name: str = None, - environment: str = None, - ) -> None: ... diff --git a/dlt/destinations/impl/postgres/configuration.py b/dlt/destinations/impl/postgres/configuration.py index 109d422650..0d12abbac7 100644 --- a/dlt/destinations/impl/postgres/configuration.py +++ b/dlt/destinations/impl/postgres/configuration.py @@ -1,6 +1,7 @@ -from typing import Final, ClassVar, Any, List, TYPE_CHECKING -from dlt.common.libs.sql_alchemy import URL +import dataclasses +from typing import Final, ClassVar, Any, List, TYPE_CHECKING, Union +from dlt.common.libs.sql_alchemy import URL from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials from dlt.common.utils import digest128 @@ -9,11 +10,11 @@ from dlt.common.destination.reference import DestinationClientDwhWithStagingConfiguration -@configspec +@configspec(init=False) class PostgresCredentials(ConnectionStringCredentials): - drivername: Final[str] = "postgresql" # type: ignore - password: TSecretValue - host: str + drivername: Final[str] = dataclasses.field(default="postgresql", init=False, repr=False, compare=False) # type: ignore + password: TSecretValue = None + host: str = None port: int = 5432 connect_timeout: int = 15 @@ -33,8 +34,8 @@ def to_url(self) -> URL: @configspec class PostgresClientConfiguration(DestinationClientDwhWithStagingConfiguration): - destination_type: Final[str] = "postgres" # type: ignore - credentials: PostgresCredentials + destination_type: Final[str] = dataclasses.field(default="postgres", init=False, repr=False, compare=False) # type: ignore + credentials: PostgresCredentials = None create_indexes: bool = True @@ -43,16 +44,3 @@ def fingerprint(self) -> str: if self.credentials and self.credentials.host: return digest128(self.credentials.host) return "" - - if TYPE_CHECKING: - - def __init__( - self, - *, - credentials: PostgresCredentials = None, - dataset_name: str = None, - default_schema_name: str = None, - create_indexes: bool = True, - destination_name: str = None, - environment: str = None, - ) -> None: ... diff --git a/dlt/destinations/impl/qdrant/configuration.py b/dlt/destinations/impl/qdrant/configuration.py index 23637dee33..d589537742 100644 --- a/dlt/destinations/impl/qdrant/configuration.py +++ b/dlt/destinations/impl/qdrant/configuration.py @@ -1,3 +1,4 @@ +import dataclasses from typing import Optional, Final from dlt.common.configuration import configspec @@ -15,7 +16,7 @@ class QdrantCredentials(CredentialsConfiguration): # If `None` - use default values for `host` and `port` location: Optional[str] = None # API key for authentication in Qdrant Cloud. Default: `None` - api_key: Optional[str] + api_key: Optional[str] = None def __str__(self) -> str: return self.location or "localhost" @@ -47,12 +48,14 @@ class QdrantClientOptions(BaseConfiguration): @configspec class QdrantClientConfiguration(DestinationClientDwhConfiguration): - destination_type: Final[str] = "qdrant" # type: ignore + destination_type: Final[str] = dataclasses.field(default="qdrant", init=False, repr=False, compare=False) # type: ignore + # Qdrant connection credentials + credentials: QdrantCredentials = None # character for the dataset separator dataset_separator: str = "_" # make it optional so empty dataset is allowed - dataset_name: Final[Optional[str]] = None # type: ignore[misc] + dataset_name: Final[Optional[str]] = dataclasses.field(default=None, init=False, repr=False, compare=False) # type: ignore[misc] # Batch size for generating embeddings embedding_batch_size: int = 32 @@ -67,10 +70,7 @@ class QdrantClientConfiguration(DestinationClientDwhConfiguration): upload_max_retries: int = 3 # Qdrant client options - options: QdrantClientOptions - - # Qdrant connection credentials - credentials: QdrantCredentials + options: QdrantClientOptions = None # FlagEmbedding model to use # Find the list here. https://qdrant.github.io/fastembed/examples/Supported_Models/. diff --git a/dlt/destinations/impl/redshift/configuration.py b/dlt/destinations/impl/redshift/configuration.py index 2a6ade4a4f..72d7f70a9f 100644 --- a/dlt/destinations/impl/redshift/configuration.py +++ b/dlt/destinations/impl/redshift/configuration.py @@ -1,4 +1,5 @@ -from typing import Final, Optional, TYPE_CHECKING +import dataclasses +from typing import Final, Optional from dlt.common.typing import TSecretValue from dlt.common.configuration import configspec @@ -10,7 +11,7 @@ ) -@configspec +@configspec(init=False) class RedshiftCredentials(PostgresCredentials): port: int = 5439 password: TSecretValue = None @@ -20,8 +21,8 @@ class RedshiftCredentials(PostgresCredentials): @configspec class RedshiftClientConfiguration(PostgresClientConfiguration): - destination_type: Final[str] = "redshift" # type: ignore - credentials: RedshiftCredentials + destination_type: Final[str] = dataclasses.field(default="redshift", init=False, repr=False, compare=False) # type: ignore + credentials: RedshiftCredentials = None staging_iam_role: Optional[str] = None def fingerprint(self) -> str: @@ -29,17 +30,3 @@ def fingerprint(self) -> str: if self.credentials and self.credentials.host: return digest128(self.credentials.host) return "" - - if TYPE_CHECKING: - - def __init__( - self, - *, - destination_type: str = None, - credentials: PostgresCredentials = None, - dataset_name: str = None, - default_schema_name: str = None, - staging_iam_role: str = None, - destination_name: str = None, - environment: str = None, - ) -> None: ... diff --git a/dlt/destinations/impl/snowflake/configuration.py b/dlt/destinations/impl/snowflake/configuration.py index 4f97f08700..5a1f7a65a9 100644 --- a/dlt/destinations/impl/snowflake/configuration.py +++ b/dlt/destinations/impl/snowflake/configuration.py @@ -1,11 +1,9 @@ +import dataclasses import base64 -import binascii - -from typing import Final, Optional, Any, Dict, ClassVar, List, TYPE_CHECKING - -from dlt.common.libs.sql_alchemy import URL +from typing import Final, Optional, Any, Dict, ClassVar, List, TYPE_CHECKING, Union from dlt import version +from dlt.common.libs.sql_alchemy import URL from dlt.common.exceptions import MissingDependencyException from dlt.common.typing import TSecretStrValue from dlt.common.configuration.specs import ConnectionStringCredentials @@ -51,9 +49,9 @@ def _read_private_key(private_key: str, password: Optional[str] = None) -> bytes ) -@configspec +@configspec(init=False) class SnowflakeCredentials(ConnectionStringCredentials): - drivername: Final[str] = "snowflake" # type: ignore[misc] + drivername: Final[str] = dataclasses.field(default="snowflake", init=False, repr=False, compare=False) # type: ignore[misc] password: Optional[TSecretStrValue] = None host: str = None database: str = None @@ -118,8 +116,8 @@ def to_connector_params(self) -> Dict[str, Any]: @configspec class SnowflakeClientConfiguration(DestinationClientDwhWithStagingConfiguration): - destination_type: Final[str] = "snowflake" # type: ignore[misc] - credentials: SnowflakeCredentials + destination_type: Final[str] = dataclasses.field(default="snowflake", init=False, repr=False, compare=False) # type: ignore[misc] + credentials: SnowflakeCredentials = None stage_name: Optional[str] = None """Use an existing named stage instead of the default. Default uses the implicit table stage per table""" @@ -131,18 +129,3 @@ def fingerprint(self) -> str: if self.credentials and self.credentials.host: return digest128(self.credentials.host) return "" - - if TYPE_CHECKING: - - def __init__( - self, - *, - destination_type: str = None, - credentials: SnowflakeCredentials = None, - dataset_name: str = None, - default_schema_name: str = None, - stage_name: str = None, - keep_staged_files: bool = True, - destination_name: str = None, - environment: str = None, - ) -> None: ... diff --git a/dlt/destinations/impl/synapse/configuration.py b/dlt/destinations/impl/synapse/configuration.py index bb1ba632dc..37b932cd67 100644 --- a/dlt/destinations/impl/synapse/configuration.py +++ b/dlt/destinations/impl/synapse/configuration.py @@ -1,9 +1,8 @@ +import dataclasses +from dlt import version from typing import Final, Any, List, Dict, Optional, ClassVar -from dlt.common import logger from dlt.common.configuration import configspec -from dlt.common.schema.typing import TSchemaTables -from dlt.common.schema.utils import get_write_disposition from dlt.destinations.impl.mssql.configuration import ( MsSqlCredentials, @@ -14,9 +13,9 @@ from dlt.destinations.impl.synapse.synapse_adapter import TTableIndexType -@configspec +@configspec(init=False) class SynapseCredentials(MsSqlCredentials): - drivername: Final[str] = "synapse" # type: ignore + drivername: Final[str] = dataclasses.field(default="synapse", init=False, repr=False, compare=False) # type: ignore # LongAsMax keyword got introduced in ODBC Driver 18 for SQL Server. SUPPORTED_DRIVERS: ClassVar[List[str]] = ["ODBC Driver 18 for SQL Server"] @@ -32,8 +31,8 @@ def _get_odbc_dsn_dict(self) -> Dict[str, Any]: @configspec class SynapseClientConfiguration(MsSqlClientConfiguration): - destination_type: Final[str] = "synapse" # type: ignore - credentials: SynapseCredentials + destination_type: Final[str] = dataclasses.field(default="synapse", init=False, repr=False, compare=False) # type: ignore + credentials: SynapseCredentials = None # While Synapse uses CLUSTERED COLUMNSTORE INDEX tables by default, we use # HEAP tables (no indexing) by default. HEAP is a more robust choice, because diff --git a/dlt/destinations/impl/synapse/factory.py b/dlt/destinations/impl/synapse/factory.py index b7eddd6ef7..100878ae05 100644 --- a/dlt/destinations/impl/synapse/factory.py +++ b/dlt/destinations/impl/synapse/factory.py @@ -16,6 +16,11 @@ class synapse(Destination[SynapseClientConfiguration, "SynapseClient"]): spec = SynapseClientConfiguration + # TODO: implement as property everywhere and makes sure not accessed as class property + # @property + # def spec(self) -> t.Type[SynapseClientConfiguration]: + # return SynapseClientConfiguration + def capabilities(self) -> DestinationCapabilitiesContext: return capabilities() diff --git a/dlt/destinations/impl/weaviate/configuration.py b/dlt/destinations/impl/weaviate/configuration.py index 5014e69163..90fb7ce5ce 100644 --- a/dlt/destinations/impl/weaviate/configuration.py +++ b/dlt/destinations/impl/weaviate/configuration.py @@ -1,5 +1,5 @@ -from typing import Dict, Literal, Optional, Final, TYPE_CHECKING -from dataclasses import field +import dataclasses +from typing import Dict, Literal, Optional, Final from urllib.parse import urlparse from dlt.common.configuration import configspec @@ -13,7 +13,7 @@ @configspec class WeaviateCredentials(CredentialsConfiguration): url: str = "http://localhost:8080" - api_key: Optional[str] + api_key: Optional[str] = None additional_headers: Optional[Dict[str, str]] = None def __str__(self) -> str: @@ -24,7 +24,7 @@ def __str__(self) -> str: @configspec class WeaviateClientConfiguration(DestinationClientDwhConfiguration): - destination_type: Final[str] = "weaviate" # type: ignore + destination_type: Final[str] = dataclasses.field(default="weaviate", init=False, repr=False, compare=False) # type: ignore # make it optional so empty dataset is allowed dataset_name: Optional[str] = None # type: ignore[misc] @@ -39,9 +39,9 @@ class WeaviateClientConfiguration(DestinationClientDwhConfiguration): dataset_separator: str = "_" - credentials: WeaviateCredentials + credentials: WeaviateCredentials = None vectorizer: str = "text2vec-openai" - module_config: Dict[str, Dict[str, str]] = field( + module_config: Dict[str, Dict[str, str]] = dataclasses.field( default_factory=lambda: { "text2vec-openai": { "model": "ada", @@ -58,26 +58,3 @@ def fingerprint(self) -> str: hostname = urlparse(self.credentials.url).hostname return digest128(hostname) return "" - - if TYPE_CHECKING: - - def __init__( - self, - *, - destination_type: str = None, - credentials: WeaviateCredentials = None, - name: str = None, - environment: str = None, - dataset_name: str = None, - default_schema_name: str = None, - batch_size: int = None, - batch_workers: int = None, - batch_consistency: TWeaviateBatchConsistency = None, - batch_retries: int = None, - conn_timeout: float = None, - read_timeout: float = None, - startup_period: int = None, - dataset_separator: str = None, - vectorizer: str = None, - module_config: Dict[str, Dict[str, str]] = None, - ) -> None: ... diff --git a/dlt/destinations/sql_jobs.py b/dlt/destinations/sql_jobs.py index 215bcf9fe5..91be3a60c9 100644 --- a/dlt/destinations/sql_jobs.py +++ b/dlt/destinations/sql_jobs.py @@ -1,7 +1,7 @@ from typing import Any, Callable, List, Sequence, Tuple, cast, TypedDict, Optional import yaml -from dlt.common.runtime.logger import pretty_format_exception +from dlt.common.logger import pretty_format_exception from dlt.common.schema.typing import TTableSchema, TSortOrder from dlt.common.schema.utils import ( diff --git a/dlt/extract/decorators.py b/dlt/extract/decorators.py index e5525519ec..28a2aca633 100644 --- a/dlt/extract/decorators.py +++ b/dlt/extract/decorators.py @@ -72,27 +72,19 @@ class SourceSchemaInjectableContext(ContainerInjectableContext): """A context containing the source schema, present when dlt.source/resource decorated function is executed""" - schema: Schema + schema: Schema = None can_create_default: ClassVar[bool] = False - if TYPE_CHECKING: - - def __init__(self, schema: Schema = None) -> None: ... - @configspec class SourceInjectableContext(ContainerInjectableContext): """A context containing the source schema, present when dlt.resource decorated function is executed""" - source: DltSource + source: DltSource = None can_create_default: ClassVar[bool] = False - if TYPE_CHECKING: - - def __init__(self, source: DltSource = None) -> None: ... - TSourceFunParams = ParamSpec("TSourceFunParams") TResourceFunParams = ParamSpec("TResourceFunParams") diff --git a/dlt/extract/pipe_iterator.py b/dlt/extract/pipe_iterator.py index 145b517802..1edd9bd039 100644 --- a/dlt/extract/pipe_iterator.py +++ b/dlt/extract/pipe_iterator.py @@ -2,6 +2,7 @@ import types from typing import ( AsyncIterator, + ClassVar, Dict, Sequence, Union, @@ -16,7 +17,11 @@ from dlt.common.configuration import configspec from dlt.common.configuration.inject import with_config -from dlt.common.configuration.specs import BaseConfiguration, ContainerInjectableContext +from dlt.common.configuration.specs import ( + BaseConfiguration, + ContainerInjectableContext, + known_sections, +) from dlt.common.configuration.container import Container from dlt.common.exceptions import PipelineException from dlt.common.source import unset_current_pipe_name, set_current_pipe_name @@ -48,7 +53,7 @@ class PipeIteratorConfiguration(BaseConfiguration): copy_on_fork: bool = False next_item_mode: str = "fifo" - __section__ = "extract" + __section__: ClassVar[str] = known_sections.EXTRACT def __init__( self, diff --git a/dlt/helpers/dbt/configuration.py b/dlt/helpers/dbt/configuration.py index 4cd3f3a0f4..70fa4d1ac5 100644 --- a/dlt/helpers/dbt/configuration.py +++ b/dlt/helpers/dbt/configuration.py @@ -19,7 +19,7 @@ class DBTRunnerConfiguration(BaseConfiguration): package_additional_vars: Optional[StrAny] = None - runtime: RunConfiguration + runtime: RunConfiguration = None def on_resolved(self) -> None: if not self.package_profiles_dir: diff --git a/dlt/helpers/dbt/runner.py b/dlt/helpers/dbt/runner.py index 388b81b2ee..7b1f79dc77 100644 --- a/dlt/helpers/dbt/runner.py +++ b/dlt/helpers/dbt/runner.py @@ -11,7 +11,7 @@ from dlt.common.runners import Venv from dlt.common.runners.stdout import iter_stdout_with_result from dlt.common.typing import StrAny, TSecretValue -from dlt.common.runtime.logger import is_json_logging +from dlt.common.logger import is_json_logging from dlt.common.storages import FileStorage from dlt.common.git import git_custom_key_command, ensure_remote_head, force_clone_repo from dlt.common.utils import with_custom_environ diff --git a/dlt/helpers/dbt_cloud/configuration.py b/dlt/helpers/dbt_cloud/configuration.py index aac94b2f4a..3c95d53431 100644 --- a/dlt/helpers/dbt_cloud/configuration.py +++ b/dlt/helpers/dbt_cloud/configuration.py @@ -9,13 +9,13 @@ class DBTCloudConfiguration(BaseConfiguration): api_token: TSecretValue = TSecretValue("") - account_id: Optional[str] - job_id: Optional[str] - project_id: Optional[str] - environment_id: Optional[str] - run_id: Optional[str] + account_id: Optional[str] = None + job_id: Optional[str] = None + project_id: Optional[str] = None + environment_id: Optional[str] = None + run_id: Optional[str] = None cause: str = "Triggered via API" - git_sha: Optional[str] - git_branch: Optional[str] - schema_override: Optional[str] + git_sha: Optional[str] = None + git_branch: Optional[str] = None + schema_override: Optional[str] = None diff --git a/dlt/load/configuration.py b/dlt/load/configuration.py index 0a84e3c331..97cf23fdfc 100644 --- a/dlt/load/configuration.py +++ b/dlt/load/configuration.py @@ -18,13 +18,3 @@ class LoaderConfiguration(PoolRunnerConfiguration): def on_resolved(self) -> None: self.pool_type = "none" if self.workers == 1 else "thread" - - if TYPE_CHECKING: - - def __init__( - self, - pool_type: TPoolType = "thread", - workers: int = None, - raise_on_failed_jobs: bool = False, - _load_storage_config: LoadStorageConfiguration = None, - ) -> None: ... diff --git a/dlt/load/load.py b/dlt/load/load.py index a0909fa2d0..f02a21f98e 100644 --- a/dlt/load/load.py +++ b/dlt/load/load.py @@ -15,16 +15,15 @@ SupportsPipeline, WithStepInfo, ) -from dlt.common.schema.utils import get_child_tables, get_top_level_table +from dlt.common.schema.utils import get_top_level_table from dlt.common.storages.load_storage import LoadPackageInfo, ParsedLoadJobFileName, TJobState from dlt.common.storages.load_package import LoadPackageStateInjectableContext from dlt.common.runners import TRunMetrics, Runnable, workermethod, NullExecutor from dlt.common.runtime.collector import Collector, NULL_COLLECTOR -from dlt.common.runtime.logger import pretty_format_exception +from dlt.common.logger import pretty_format_exception from dlt.common.exceptions import TerminalValueError from dlt.common.configuration.container import Container -from dlt.common.configuration.specs.config_section_context import ConfigSectionContext -from dlt.common.schema import Schema, TSchemaTables +from dlt.common.schema import Schema from dlt.common.storages import LoadStorage from dlt.common.destination.reference import ( DestinationClientDwhConfiguration, diff --git a/dlt/normalize/configuration.py b/dlt/normalize/configuration.py index 3949a07fa8..5676d23569 100644 --- a/dlt/normalize/configuration.py +++ b/dlt/normalize/configuration.py @@ -18,18 +18,14 @@ class ItemsNormalizerConfiguration(BaseConfiguration): add_dlt_load_id: bool = False """When true, items to be normalized will have `_dlt_load_id` column added with the current load ID.""" - if TYPE_CHECKING: - - def __init__(self, add_dlt_id: bool = None, add_dlt_load_id: bool = None) -> None: ... - @configspec class NormalizeConfiguration(PoolRunnerConfiguration): pool_type: TPoolType = "process" destination_capabilities: DestinationCapabilitiesContext = None # injectable - _schema_storage_config: SchemaStorageConfiguration - _normalize_storage_config: NormalizeStorageConfiguration - _load_storage_config: LoadStorageConfiguration + _schema_storage_config: SchemaStorageConfiguration = None + _normalize_storage_config: NormalizeStorageConfiguration = None + _load_storage_config: LoadStorageConfiguration = None json_normalizer: ItemsNormalizerConfiguration = ItemsNormalizerConfiguration( add_dlt_id=True, add_dlt_load_id=True @@ -41,14 +37,3 @@ class NormalizeConfiguration(PoolRunnerConfiguration): def on_resolved(self) -> None: self.pool_type = "none" if self.workers == 1 else "process" - - if TYPE_CHECKING: - - def __init__( - self, - pool_type: TPoolType = "process", - workers: int = None, - _schema_storage_config: SchemaStorageConfiguration = None, - _normalize_storage_config: NormalizeStorageConfiguration = None, - _load_storage_config: LoadStorageConfiguration = None, - ) -> None: ... diff --git a/dlt/pipeline/configuration.py b/dlt/pipeline/configuration.py index 7aa54541c0..d7ffca6e89 100644 --- a/dlt/pipeline/configuration.py +++ b/dlt/pipeline/configuration.py @@ -27,7 +27,7 @@ class PipelineConfiguration(BaseConfiguration): full_refresh: bool = False """When set to True, each instance of the pipeline with the `pipeline_name` starts from scratch when run and loads the data to a separate dataset.""" progress: Optional[str] = None - runtime: RunConfiguration + runtime: RunConfiguration = None def on_resolved(self) -> None: if not self.pipeline_name: diff --git a/dlt/pipeline/pipeline.py b/dlt/pipeline/pipeline.py index efb6ae078b..de1f7afced 100644 --- a/dlt/pipeline/pipeline.py +++ b/dlt/pipeline/pipeline.py @@ -12,6 +12,7 @@ Optional, Sequence, Tuple, + Type, cast, get_type_hints, ContextManager, @@ -1122,17 +1123,16 @@ def _get_destination_client_initial_config( ) if issubclass(client_spec, DestinationClientStagingConfiguration): - return client_spec( - dataset_name=self.dataset_name, - default_schema_name=default_schema_name, + spec: DestinationClientDwhConfiguration = client_spec( credentials=credentials, as_staging=as_staging, ) - return client_spec( - dataset_name=self.dataset_name, - default_schema_name=default_schema_name, - credentials=credentials, - ) + else: + spec = client_spec( + credentials=credentials, + ) + spec._bind_dataset_name(self.dataset_name, default_schema_name) + return spec return client_spec(credentials=credentials) diff --git a/dlt/pipeline/platform.py b/dlt/pipeline/platform.py index c8014d5ae7..0955e91b51 100644 --- a/dlt/pipeline/platform.py +++ b/dlt/pipeline/platform.py @@ -6,7 +6,7 @@ from dlt.pipeline.trace import PipelineTrace, PipelineStepTrace, TPipelineStep, SupportsPipeline from dlt.common import json -from dlt.common.runtime import logger +from dlt.common import logger from dlt.common.pipeline import LoadInfo from dlt.common.schema.typing import TStoredSchema diff --git a/dlt/pipeline/trace.py b/dlt/pipeline/trace.py index 5679884b0b..b610d1751f 100644 --- a/dlt/pipeline/trace.py +++ b/dlt/pipeline/trace.py @@ -14,7 +14,7 @@ from dlt.common.configuration.utils import _RESOLVED_TRACES from dlt.common.configuration.container import Container from dlt.common.exceptions import ExceptionTrace, ResourceNameNotAvailable -from dlt.common.runtime.logger import suppress_and_warn +from dlt.common.logger import suppress_and_warn from dlt.common.runtime.exec_info import TExecutionContext, get_execution_context from dlt.common.pipeline import ( ExtractInfo, diff --git a/docs/examples/chess_production/chess.py b/docs/examples/chess_production/chess.py index d9e138187f..e2d0b9c10d 100644 --- a/docs/examples/chess_production/chess.py +++ b/docs/examples/chess_production/chess.py @@ -6,6 +6,7 @@ from dlt.common.typing import StrAny, TDataItems from dlt.sources.helpers.requests import client + @dlt.source def chess( chess_url: str = dlt.config.value, @@ -56,6 +57,7 @@ def players_games(username: Any) -> Iterator[TDataItems]: MAX_PLAYERS = 5 + def load_data_with_retry(pipeline, data): try: for attempt in Retrying( @@ -65,9 +67,7 @@ def load_data_with_retry(pipeline, data): reraise=True, ): with attempt: - logger.info( - f"Running the pipeline, attempt={attempt.retry_state.attempt_number}" - ) + logger.info(f"Running the pipeline, attempt={attempt.retry_state.attempt_number}") load_info = pipeline.run(data) logger.info(str(load_info)) @@ -89,9 +89,7 @@ def load_data_with_retry(pipeline, data): # print the information on the first load package and all jobs inside logger.info(f"First load package info: {load_info.load_packages[0]}") # print the information on the first completed job in first load package - logger.info( - f"First completed job info: {load_info.load_packages[0].jobs['completed_jobs'][0]}" - ) + logger.info(f"First completed job info: {load_info.load_packages[0].jobs['completed_jobs'][0]}") # check for schema updates: schema_updates = [p.schema_update for p in load_info.load_packages] @@ -149,4 +147,4 @@ def load_data_with_retry(pipeline, data): ) # get data for a few famous players data = chess(chess_url="https://api.chess.com/pub/", max_players=MAX_PLAYERS) - load_data_with_retry(pipeline, data) \ No newline at end of file + load_data_with_retry(pipeline, data) diff --git a/docs/examples/connector_x_arrow/load_arrow.py b/docs/examples/connector_x_arrow/load_arrow.py index 06ca4e17b3..b3c654cef9 100644 --- a/docs/examples/connector_x_arrow/load_arrow.py +++ b/docs/examples/connector_x_arrow/load_arrow.py @@ -3,6 +3,7 @@ import dlt from dlt.sources.credentials import ConnectionStringCredentials + def read_sql_x( conn_str: ConnectionStringCredentials = dlt.secrets.value, query: str = dlt.config.value, @@ -14,6 +15,7 @@ def read_sql_x( protocol="binary", ) + def genome_resource(): # create genome resource with merge on `upid` primary key genome = dlt.resource( diff --git a/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py b/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py index fcd964a980..624888f70a 100644 --- a/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py +++ b/docs/examples/custom_destination_bigquery/custom_destination_bigquery.py @@ -15,6 +15,7 @@ # format: "your-project.your_dataset.your_table" BIGQUERY_TABLE_ID = "chat-analytics-rasa-ci.ci_streaming_insert.natural-disasters" + # dlt sources @dlt.resource(name="natural_disasters") def resource(url: str): @@ -38,6 +39,7 @@ def resource(url: str): ) yield table + # dlt biquery custom destination # we can use the dlt provided credentials class # to retrieve the gcp credentials from the secrets @@ -58,6 +60,7 @@ def bigquery_insert( load_job = client.load_table_from_file(f, BIGQUERY_TABLE_ID, job_config=job_config) load_job.result() # Waits for the job to complete. + if __name__ == "__main__": # run the pipeline and print load results pipeline = dlt.pipeline( @@ -68,4 +71,4 @@ def bigquery_insert( ) load_info = pipeline.run(resource(url=OWID_DISASTERS_URL)) - print(load_info) \ No newline at end of file + print(load_info) diff --git a/docs/examples/google_sheets/google_sheets.py b/docs/examples/google_sheets/google_sheets.py index 8a93df9970..1ba330e4ca 100644 --- a/docs/examples/google_sheets/google_sheets.py +++ b/docs/examples/google_sheets/google_sheets.py @@ -9,6 +9,7 @@ ) from dlt.common.typing import DictStrAny, StrAny + def _initialize_sheets( credentials: Union[GcpOAuthCredentials, GcpServiceAccountCredentials] ) -> Any: @@ -16,6 +17,7 @@ def _initialize_sheets( service = build("sheets", "v4", credentials=credentials.to_native_credentials()) return service + @dlt.source def google_spreadsheet( spreadsheet_id: str, @@ -55,6 +57,7 @@ def get_sheet(sheet_name: str) -> Iterator[DictStrAny]: for name in sheet_names ] + if __name__ == "__main__": pipeline = dlt.pipeline(destination="duckdb") # see example.secrets.toml to where to put credentials @@ -67,4 +70,4 @@ def get_sheet(sheet_name: str) -> Iterator[DictStrAny]: sheet_names=range_names, ) ) - print(info) \ No newline at end of file + print(info) diff --git a/docs/examples/incremental_loading/zendesk.py b/docs/examples/incremental_loading/zendesk.py index 4b8597886a..6113f98793 100644 --- a/docs/examples/incremental_loading/zendesk.py +++ b/docs/examples/incremental_loading/zendesk.py @@ -6,12 +6,11 @@ from dlt.common.typing import TAnyDateTime from dlt.sources.helpers.requests import client + @dlt.source(max_table_nesting=2) def zendesk_support( credentials: Dict[str, str] = dlt.secrets.value, - start_date: Optional[TAnyDateTime] = pendulum.datetime( # noqa: B008 - year=2000, month=1, day=1 - ), + start_date: Optional[TAnyDateTime] = pendulum.datetime(year=2000, month=1, day=1), # noqa: B008 end_date: Optional[TAnyDateTime] = None, ): """ @@ -113,6 +112,7 @@ def get_pages( if not response_json["end_of_stream"]: get_url = response_json["next_page"] + if __name__ == "__main__": # create dlt pipeline pipeline = dlt.pipeline( @@ -120,4 +120,4 @@ def get_pages( ) load_info = pipeline.run(zendesk_support()) - print(load_info) \ No newline at end of file + print(load_info) diff --git a/docs/examples/nested_data/nested_data.py b/docs/examples/nested_data/nested_data.py index 3464448de6..7f85f0522e 100644 --- a/docs/examples/nested_data/nested_data.py +++ b/docs/examples/nested_data/nested_data.py @@ -13,6 +13,7 @@ CHUNK_SIZE = 10000 + # You can limit how deep dlt goes when generating child tables. # By default, the library will descend and generate child tables # for all nested lists, without a limit. @@ -81,6 +82,7 @@ def load_documents(self) -> Iterator[TDataItem]: while docs_slice := list(islice(cursor, CHUNK_SIZE)): yield map_nested_in_place(convert_mongo_objs, docs_slice) + def convert_mongo_objs(value: Any) -> Any: if isinstance(value, (ObjectId, Decimal128)): return str(value) diff --git a/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py b/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py index 8f7833e7d7..e7f57853ed 100644 --- a/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py +++ b/docs/examples/pdf_to_weaviate/pdf_to_weaviate.py @@ -4,6 +4,7 @@ from dlt.destinations.impl.weaviate import weaviate_adapter from PyPDF2 import PdfReader + @dlt.resource(selected=False) def list_files(folder_path: str): folder_path = os.path.abspath(folder_path) @@ -15,6 +16,7 @@ def list_files(folder_path: str): "mtime": os.path.getmtime(file_path), } + @dlt.transformer(primary_key="page_id", write_disposition="merge") def pdf_to_text(file_item, separate_pages: bool = False): if not separate_pages: @@ -28,6 +30,7 @@ def pdf_to_text(file_item, separate_pages: bool = False): page_item["page_id"] = file_item["file_name"] + "_" + str(page_no) yield page_item + pipeline = dlt.pipeline(pipeline_name="pdf_to_text", destination="weaviate") # this constructs a simple pipeline that: (1) reads files from "invoices" folder (2) filters only those ending with ".pdf" @@ -51,4 +54,4 @@ def pdf_to_text(file_item, separate_pages: bool = False): client = weaviate.Client("http://localhost:8080") # get text of all the invoices in InvoiceText class we just created above -print(client.query.get("InvoiceText", ["text", "file_name", "mtime", "page_id"]).do()) \ No newline at end of file +print(client.query.get("InvoiceText", ["text", "file_name", "mtime", "page_id"]).do()) diff --git a/docs/examples/qdrant_zendesk/qdrant.py b/docs/examples/qdrant_zendesk/qdrant.py index 300d8dc6ad..bd0cbafc99 100644 --- a/docs/examples/qdrant_zendesk/qdrant.py +++ b/docs/examples/qdrant_zendesk/qdrant.py @@ -10,13 +10,12 @@ from dlt.common.configuration.inject import with_config + # function from: https://github.com/dlt-hub/verified-sources/tree/master/sources/zendesk @dlt.source(max_table_nesting=2) def zendesk_support( credentials: Dict[str, str] = dlt.secrets.value, - start_date: Optional[TAnyDateTime] = pendulum.datetime( # noqa: B008 - year=2000, month=1, day=1 - ), + start_date: Optional[TAnyDateTime] = pendulum.datetime(year=2000, month=1, day=1), # noqa: B008 end_date: Optional[TAnyDateTime] = None, ): """ @@ -80,6 +79,7 @@ def _parse_date_or_none(value: Optional[str]) -> Optional[pendulum.DateTime]: return None return ensure_pendulum_datetime(value) + # modify dates to return datetime objects instead def _fix_date(ticket): ticket["updated_at"] = _parse_date_or_none(ticket["updated_at"]) @@ -87,6 +87,7 @@ def _fix_date(ticket): ticket["due_at"] = _parse_date_or_none(ticket["due_at"]) return ticket + # function from: https://github.com/dlt-hub/verified-sources/tree/master/sources/zendesk def get_pages( url: str, @@ -127,6 +128,7 @@ def get_pages( if not response_json["end_of_stream"]: get_url = response_json["next_page"] + if __name__ == "__main__": # create a pipeline with an appropriate name pipeline = dlt.pipeline( @@ -146,7 +148,6 @@ def get_pages( print(load_info) - # running the Qdrant client to connect to your Qdrant database @with_config(sections=("destination", "qdrant", "credentials")) diff --git a/docs/examples/transformers/pokemon.py b/docs/examples/transformers/pokemon.py index 2181c33259..ca32c570ef 100644 --- a/docs/examples/transformers/pokemon.py +++ b/docs/examples/transformers/pokemon.py @@ -1,6 +1,7 @@ import dlt from dlt.sources.helpers import requests + @dlt.source(max_table_nesting=2) def source(pokemon_api_url: str): """""" @@ -44,6 +45,7 @@ def species(pokemon_details): return (pokemon_list | pokemon, pokemon_list | pokemon | species) + if __name__ == "__main__": # build duck db pipeline pipeline = dlt.pipeline( @@ -52,4 +54,4 @@ def species(pokemon_details): # the pokemon_list resource does not need to be loaded load_info = pipeline.run(source("https://pokeapi.co/api/v2/pokemon")) - print(load_info) \ No newline at end of file + print(load_info) diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index a883f76ddb..5fbcd86d92 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -126,18 +126,14 @@ class MockProdConfiguration(RunConfiguration): @configspec class FieldWithNoDefaultConfiguration(RunConfiguration): - no_default: str - - if TYPE_CHECKING: - - def __init__(self, no_default: str = None, sentry_dsn: str = None) -> None: ... + no_default: str = None @configspec class InstrumentedConfiguration(BaseConfiguration): - head: str - tube: List[str] - heels: str + head: str = None + tube: List[str] = None + heels: str = None def to_native_representation(self) -> Any: return self.head + ">" + ">".join(self.tube) + ">" + self.heels @@ -156,63 +152,50 @@ def on_resolved(self) -> None: if self.head > self.heels: raise RuntimeError("Head over heels") - if TYPE_CHECKING: - - def __init__(self, head: str = None, tube: List[str] = None, heels: str = None) -> None: ... - @configspec class EmbeddedConfiguration(BaseConfiguration): - default: str - instrumented: InstrumentedConfiguration - sectioned: SectionedConfiguration - - if TYPE_CHECKING: - - def __init__( - self, - default: str = None, - instrumented: InstrumentedConfiguration = None, - sectioned: SectionedConfiguration = None, - ) -> None: ... + default: str = None + instrumented: InstrumentedConfiguration = None + sectioned: SectionedConfiguration = None @configspec class EmbeddedOptionalConfiguration(BaseConfiguration): - instrumented: Optional[InstrumentedConfiguration] + instrumented: Optional[InstrumentedConfiguration] = None @configspec class EmbeddedSecretConfiguration(BaseConfiguration): - secret: SecretConfiguration + secret: SecretConfiguration = None @configspec class NonTemplatedComplexTypesConfiguration(BaseConfiguration): - list_val: list # type: ignore[type-arg] - tuple_val: tuple # type: ignore[type-arg] - dict_val: dict # type: ignore[type-arg] + list_val: list = None # type: ignore[type-arg] + tuple_val: tuple = None # type: ignore[type-arg] + dict_val: dict = None # type: ignore[type-arg] @configspec class DynamicConfigA(BaseConfiguration): - field_for_a: str + field_for_a: str = None @configspec class DynamicConfigB(BaseConfiguration): - field_for_b: str + field_for_b: str = None @configspec class DynamicConfigC(BaseConfiguration): - field_for_c: str + field_for_c: str = None @configspec class ConfigWithDynamicType(BaseConfiguration): - discriminator: str - embedded_config: BaseConfiguration + discriminator: str = None + embedded_config: BaseConfiguration = None @resolve_type("embedded_config") def resolve_embedded_type(self) -> Type[BaseConfiguration]: @@ -240,8 +223,8 @@ def resolve_c_type(self) -> Type[BaseConfiguration]: @configspec class SubclassConfigWithDynamicType(ConfigWithDynamicType): - is_number: bool - dynamic_type_field: Any + is_number: bool = None + dynamic_type_field: Any = None @resolve_type("embedded_config") def resolve_embedded_type(self) -> Type[BaseConfiguration]: @@ -937,11 +920,7 @@ def test_is_valid_hint() -> None: def test_configspec_auto_base_config_derivation() -> None: @configspec class AutoBaseDerivationConfiguration: - auto: str - - if TYPE_CHECKING: - - def __init__(self, auto: str = None) -> None: ... + auto: str = None assert issubclass(AutoBaseDerivationConfiguration, BaseConfiguration) assert hasattr(AutoBaseDerivationConfiguration, "auto") diff --git a/tests/common/configuration/test_container.py b/tests/common/configuration/test_container.py index 9521f5960d..eddd0b21dc 100644 --- a/tests/common/configuration/test_container.py +++ b/tests/common/configuration/test_container.py @@ -20,19 +20,15 @@ @configspec class InjectableTestContext(ContainerInjectableContext): - current_value: str + current_value: str = None def parse_native_representation(self, native_value: Any) -> None: raise ValueError(native_value) - if TYPE_CHECKING: - - def __init__(self, current_value: str = None) -> None: ... - @configspec class EmbeddedWithInjectableContext(BaseConfiguration): - injected: InjectableTestContext + injected: InjectableTestContext = None @configspec @@ -47,12 +43,12 @@ class GlobalTestContext(InjectableTestContext): @configspec class EmbeddedWithNoDefaultInjectableContext(BaseConfiguration): - injected: NoDefaultInjectableContext + injected: NoDefaultInjectableContext = None @configspec class EmbeddedWithNoDefaultInjectableOptionalContext(BaseConfiguration): - injected: Optional[NoDefaultInjectableContext] + injected: Optional[NoDefaultInjectableContext] = None @pytest.fixture() diff --git a/tests/common/configuration/test_credentials.py b/tests/common/configuration/test_credentials.py index ae9b96e903..7c184c16e5 100644 --- a/tests/common/configuration/test_credentials.py +++ b/tests/common/configuration/test_credentials.py @@ -158,6 +158,34 @@ def test_connection_string_resolved_from_native_representation_env(environment: assert c.host == "aws.12.1" +def test_connection_string_from_init() -> None: + c = ConnectionStringCredentials("postgres://loader:pass@localhost:5432/dlt_data?a=b&c=d") + assert c.drivername == "postgres" + assert c.is_resolved() + assert not c.is_partial() + + c = ConnectionStringCredentials( + { + "drivername": "postgres", + "username": "loader", + "password": "pass", + "host": "localhost", + "port": 5432, + "database": "dlt_data", + "query": {"a": "b", "c": "d"}, + } + ) + assert c.drivername == "postgres" + assert c.username == "loader" + assert c.password == "pass" + assert c.host == "localhost" + assert c.port == 5432 + assert c.database == "dlt_data" + assert c.query == {"a": "b", "c": "d"} + assert c.is_resolved() + assert not c.is_partial() + + def test_gcp_service_credentials_native_representation(environment) -> None: with pytest.raises(InvalidGoogleNativeCredentialsType): GcpServiceAccountCredentials().parse_native_representation(1) diff --git a/tests/common/configuration/test_inject.py b/tests/common/configuration/test_inject.py index c6ab8aa756..1aa52c1919 100644 --- a/tests/common/configuration/test_inject.py +++ b/tests/common/configuration/test_inject.py @@ -167,7 +167,7 @@ def test_inject_with_sections() -> None: def test_inject_spec_in_func_params() -> None: @configspec class TestConfig(BaseConfiguration): - base_value: str + base_value: str = None # if any of args (ie. `init` below) is an instance of SPEC, we use it as initial value @@ -179,7 +179,7 @@ def test_spec_arg(base_value=dlt.config.value, init: TestConfig = None): spec = get_fun_spec(test_spec_arg) assert spec == TestConfig # call function with init, should resolve even if we do not provide the base_value in config - assert test_spec_arg(init=TestConfig(base_value="A")) == "A" # type: ignore[call-arg] + assert test_spec_arg(init=TestConfig(base_value="A")) == "A" def test_inject_with_sections_and_sections_context() -> None: @@ -272,7 +272,7 @@ def test_sections(value=dlt.config.value): def test_base_spec() -> None: @configspec class BaseParams(BaseConfiguration): - str_str: str + str_str: str = None @with_config(base=BaseParams) def f_explicit_base(str_str=dlt.config.value, opt: bool = True): diff --git a/tests/common/configuration/test_sections.py b/tests/common/configuration/test_sections.py index 9e0bc7e26d..bf6780e087 100644 --- a/tests/common/configuration/test_sections.py +++ b/tests/common/configuration/test_sections.py @@ -25,33 +25,33 @@ @configspec class SingleValConfiguration(BaseConfiguration): - sv: str + sv: str = None @configspec class EmbeddedConfiguration(BaseConfiguration): - sv_config: Optional[SingleValConfiguration] + sv_config: Optional[SingleValConfiguration] = None @configspec class EmbeddedWithSectionedConfiguration(BaseConfiguration): - embedded: SectionedConfiguration + embedded: SectionedConfiguration = None @configspec class EmbeddedIgnoredConfiguration(BaseConfiguration): # underscore prevents the field name to be added to embedded sections - _sv_config: Optional[SingleValConfiguration] + _sv_config: Optional[SingleValConfiguration] = None @configspec class EmbeddedIgnoredWithSectionedConfiguration(BaseConfiguration): - _embedded: SectionedConfiguration + _embedded: SectionedConfiguration = None @configspec class EmbeddedWithIgnoredEmbeddedConfiguration(BaseConfiguration): - ignored_embedded: EmbeddedIgnoredWithSectionedConfiguration + ignored_embedded: EmbeddedIgnoredWithSectionedConfiguration = None def test_sectioned_configuration(environment: Any, env_provider: ConfigProvider) -> None: diff --git a/tests/common/configuration/test_spec_union.py b/tests/common/configuration/test_spec_union.py index 4892967ab7..b1e316734d 100644 --- a/tests/common/configuration/test_spec_union.py +++ b/tests/common/configuration/test_spec_union.py @@ -26,8 +26,8 @@ def auth(self): @configspec class ZenEmailCredentials(ZenCredentials): - email: str - password: TSecretValue + email: str = None + password: TSecretValue = None def parse_native_representation(self, native_value: Any) -> None: assert isinstance(native_value, str) @@ -44,8 +44,8 @@ def auth(self): @configspec class ZenApiKeyCredentials(ZenCredentials): - api_key: str - api_secret: TSecretValue + api_key: str = None + api_secret: TSecretValue = None def parse_native_representation(self, native_value: Any) -> None: assert isinstance(native_value, str) @@ -62,14 +62,14 @@ def auth(self): @configspec class ZenConfig(BaseConfiguration): - credentials: Union[ZenApiKeyCredentials, ZenEmailCredentials] + credentials: Union[ZenApiKeyCredentials, ZenEmailCredentials] = None some_option: bool = False @configspec class ZenConfigOptCredentials: # add none to union to make it optional - credentials: Union[ZenApiKeyCredentials, ZenEmailCredentials, None] + credentials: Union[ZenApiKeyCredentials, ZenEmailCredentials, None] = None some_option: bool = False @@ -200,10 +200,10 @@ class GoogleAnalyticsCredentialsOAuth(GoogleAnalyticsCredentialsBase): This class is used to store credentials Google Analytics """ - client_id: str - client_secret: TSecretValue - project_id: TSecretValue - refresh_token: TSecretValue + client_id: str = None + client_secret: TSecretValue = None + project_id: TSecretValue = None + refresh_token: TSecretValue = None access_token: Optional[TSecretValue] = None diff --git a/tests/common/configuration/test_toml_provider.py b/tests/common/configuration/test_toml_provider.py index fcec881521..4f2219716a 100644 --- a/tests/common/configuration/test_toml_provider.py +++ b/tests/common/configuration/test_toml_provider.py @@ -42,12 +42,12 @@ @configspec class EmbeddedWithGcpStorage(BaseConfiguration): - gcp_storage: GcpServiceAccountCredentialsWithoutDefaults + gcp_storage: GcpServiceAccountCredentialsWithoutDefaults = None @configspec class EmbeddedWithGcpCredentials(BaseConfiguration): - credentials: GcpServiceAccountCredentialsWithoutDefaults + credentials: GcpServiceAccountCredentialsWithoutDefaults = None def test_secrets_from_toml_secrets(toml_providers: ConfigProvidersContext) -> None: @@ -378,7 +378,7 @@ def test_write_value(toml_providers: ConfigProvidersContext) -> None: # dict creates only shallow dict so embedded credentials will fail creds = WithCredentialsConfiguration() - creds.credentials = SecretCredentials({"secret_value": "***** ***"}) + creds.credentials = SecretCredentials(secret_value=TSecretValue("***** ***")) with pytest.raises(ValueError): provider.set_value("written_creds", dict(creds), None) diff --git a/tests/common/configuration/utils.py b/tests/common/configuration/utils.py index 73643561dc..670dcac87a 100644 --- a/tests/common/configuration/utils.py +++ b/tests/common/configuration/utils.py @@ -3,6 +3,7 @@ import datetime # noqa: I251 from typing import ( Any, + ClassVar, Iterator, List, Optional, @@ -71,19 +72,15 @@ class SecretCredentials(CredentialsConfiguration): @configspec class WithCredentialsConfiguration(BaseConfiguration): - credentials: SecretCredentials + credentials: SecretCredentials = None @configspec class SectionedConfiguration(BaseConfiguration): - __section__ = "DLT_TEST" + __section__: ClassVar[str] = "DLT_TEST" password: str = None - if TYPE_CHECKING: - - def __init__(self, password: str = None) -> None: ... - @pytest.fixture(scope="function") def environment() -> Any: diff --git a/tests/common/reflection/test_reflect_spec.py b/tests/common/reflection/test_reflect_spec.py index 092d25b717..952d0fc596 100644 --- a/tests/common/reflection/test_reflect_spec.py +++ b/tests/common/reflection/test_reflect_spec.py @@ -314,7 +314,7 @@ def f_kw_defaults_args( def test_reflect_custom_base() -> None: @configspec class BaseParams(BaseConfiguration): - str_str: str + str_str: str = None def _f_1(str_str=dlt.config.value, p_def: bool = True): pass diff --git a/tests/common/runtime/test_logging.py b/tests/common/runtime/test_logging.py index 19f67fe899..5ff92f7d94 100644 --- a/tests/common/runtime/test_logging.py +++ b/tests/common/runtime/test_logging.py @@ -3,7 +3,7 @@ from dlt.common import logger from dlt.common.runtime import exec_info -from dlt.common.runtime.logger import is_logging +from dlt.common.logger import is_logging from dlt.common.typing import StrStr, DictStrStr from dlt.common.configuration import configspec from dlt.common.configuration.specs import RunConfiguration diff --git a/tests/common/runtime/test_telemetry.py b/tests/common/runtime/test_telemetry.py index eece36aae7..e67f7e8360 100644 --- a/tests/common/runtime/test_telemetry.py +++ b/tests/common/runtime/test_telemetry.py @@ -35,16 +35,6 @@ class SentryLoggerConfiguration(RunConfiguration): class SentryLoggerCriticalConfiguration(SentryLoggerConfiguration): log_level: str = "CRITICAL" - if TYPE_CHECKING: - - def __init__( - self, - pipeline_name: str = "logger", - sentry_dsn: str = "https://sentry.io", - dlthub_telemetry_segment_write_key: str = "TLJiyRkGVZGCi2TtjClamXpFcxAA1rSB", - log_level: str = "CRITICAL", - ) -> None: ... - def test_sentry_log_level() -> None: from dlt.common.runtime.sentry import _get_sentry_log_level diff --git a/tests/common/test_destination.py b/tests/common/test_destination.py index 5240b889f3..24b0928463 100644 --- a/tests/common/test_destination.py +++ b/tests/common/test_destination.py @@ -78,7 +78,7 @@ def test_import_destination_config() -> None: dest = Destination.from_reference(ref="dlt.destinations.duckdb", environment="stage") assert dest.destination_type == "dlt.destinations.duckdb" assert dest.config_params["environment"] == "stage" - config = dest.configuration(dest.spec(dataset_name="dataset")) # type: ignore + config = dest.configuration(dest.spec()._bind_dataset_name(dataset_name="dataset")) # type: ignore assert config.destination_type == "duckdb" assert config.destination_name == "duckdb" assert config.environment == "stage" @@ -87,7 +87,7 @@ def test_import_destination_config() -> None: dest = Destination.from_reference(ref=None, destination_name="duckdb", environment="production") assert dest.destination_type == "dlt.destinations.duckdb" assert dest.config_params["environment"] == "production" - config = dest.configuration(dest.spec(dataset_name="dataset")) # type: ignore + config = dest.configuration(dest.spec()._bind_dataset_name(dataset_name="dataset")) # type: ignore assert config.destination_type == "duckdb" assert config.destination_name == "duckdb" assert config.environment == "production" @@ -98,7 +98,7 @@ def test_import_destination_config() -> None: ) assert dest.destination_type == "dlt.destinations.duckdb" assert dest.config_params["environment"] == "devel" - config = dest.configuration(dest.spec(dataset_name="dataset")) # type: ignore + config = dest.configuration(dest.spec()._bind_dataset_name(dataset_name="dataset")) # type: ignore assert config.destination_type == "duckdb" assert config.destination_name == "my_destination" assert config.environment == "devel" @@ -112,63 +112,63 @@ def test_normalize_dataset_name() -> None: # with schema name appended assert ( - DestinationClientDwhConfiguration( - dataset_name="ban_ana_dataset", default_schema_name="default" - ).normalize_dataset_name(Schema("banana")) + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name="ban_ana_dataset", default_schema_name="default") + .normalize_dataset_name(Schema("banana")) == "ban_ana_dataset_banana" ) # without schema name appended assert ( - DestinationClientDwhConfiguration( - dataset_name="ban_ana_dataset", default_schema_name="default" - ).normalize_dataset_name(Schema("default")) + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name="ban_ana_dataset", default_schema_name="default") + .normalize_dataset_name(Schema("default")) == "ban_ana_dataset" ) # dataset name will be normalized (now it is up to destination to normalize this) assert ( - DestinationClientDwhConfiguration( - dataset_name="BaNaNa", default_schema_name="default" - ).normalize_dataset_name(Schema("banana")) + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name="BaNaNa", default_schema_name="default") + .normalize_dataset_name(Schema("banana")) == "ba_na_na_banana" ) # empty schemas are invalid with pytest.raises(ValueError): - DestinationClientDwhConfiguration( - dataset_name="banana_dataset", default_schema_name=None + DestinationClientDwhConfiguration()._bind_dataset_name( + dataset_name="banana_dataset" ).normalize_dataset_name(Schema(None)) with pytest.raises(ValueError): - DestinationClientDwhConfiguration( + DestinationClientDwhConfiguration()._bind_dataset_name( dataset_name="banana_dataset", default_schema_name="" ).normalize_dataset_name(Schema("")) # empty dataset name is valid! assert ( - DestinationClientDwhConfiguration( - dataset_name="", default_schema_name="ban_schema" - ).normalize_dataset_name(Schema("schema_ana")) + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name="", default_schema_name="ban_schema") + .normalize_dataset_name(Schema("schema_ana")) == "_schema_ana" ) # empty dataset name is valid! assert ( - DestinationClientDwhConfiguration( - dataset_name="", default_schema_name="schema_ana" - ).normalize_dataset_name(Schema("schema_ana")) + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name="", default_schema_name="schema_ana") + .normalize_dataset_name(Schema("schema_ana")) == "" ) # None dataset name is valid! assert ( - DestinationClientDwhConfiguration( - dataset_name=None, default_schema_name="ban_schema" - ).normalize_dataset_name(Schema("schema_ana")) + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name=None, default_schema_name="ban_schema") + .normalize_dataset_name(Schema("schema_ana")) == "_schema_ana" ) # None dataset name is valid! assert ( - DestinationClientDwhConfiguration( - dataset_name=None, default_schema_name="schema_ana" - ).normalize_dataset_name(Schema("schema_ana")) + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name=None, default_schema_name="schema_ana") + .normalize_dataset_name(Schema("schema_ana")) is None ) @@ -176,9 +176,9 @@ def test_normalize_dataset_name() -> None: schema = Schema("barbapapa") schema._schema_name = "BarbaPapa" assert ( - DestinationClientDwhConfiguration( - dataset_name="set", default_schema_name="default" - ).normalize_dataset_name(schema) + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name="set", default_schema_name="default") + .normalize_dataset_name(schema) == "set_barba_papa" ) @@ -186,8 +186,8 @@ def test_normalize_dataset_name() -> None: def test_normalize_dataset_name_none_default_schema() -> None: # if default schema is None, suffix is not added assert ( - DestinationClientDwhConfiguration( - dataset_name="ban_ana_dataset", default_schema_name=None - ).normalize_dataset_name(Schema("default")) + DestinationClientDwhConfiguration() + ._bind_dataset_name(dataset_name="ban_ana_dataset", default_schema_name=None) + .normalize_dataset_name(Schema("default")) == "ban_ana_dataset" ) diff --git a/tests/destinations/test_custom_destination.py b/tests/destinations/test_custom_destination.py index 7280ec419b..cfefceac88 100644 --- a/tests/destinations/test_custom_destination.py +++ b/tests/destinations/test_custom_destination.py @@ -455,7 +455,7 @@ def my_gcp_sink( def test_destination_with_spec() -> None: @configspec class MyDestinationSpec(CustomDestinationClientConfiguration): - my_predefined_val: str + my_predefined_val: str = None # check destination without additional config params @dlt.destination(spec=MyDestinationSpec) diff --git a/tests/helpers/dbt_tests/local/utils.py b/tests/helpers/dbt_tests/local/utils.py index 7097140a83..8fd3dba44f 100644 --- a/tests/helpers/dbt_tests/local/utils.py +++ b/tests/helpers/dbt_tests/local/utils.py @@ -40,7 +40,9 @@ def setup_rasa_runner( runner = create_runner( Venv.restore_current(), # credentials are exported to env in setup_rasa_runner_client - DestinationClientDwhConfiguration(dataset_name=dataset_name or FIXTURES_DATASET_NAME), + DestinationClientDwhConfiguration()._bind_dataset_name( + dataset_name=dataset_name or FIXTURES_DATASET_NAME + ), TEST_STORAGE_ROOT, package_profile_name=profile_name, config=C, diff --git a/tests/helpers/providers/test_google_secrets_provider.py b/tests/helpers/providers/test_google_secrets_provider.py index d6d94774b9..00c54b5705 100644 --- a/tests/helpers/providers/test_google_secrets_provider.py +++ b/tests/helpers/providers/test_google_secrets_provider.py @@ -1,6 +1,5 @@ -import dlt from dlt import TSecretValue -from dlt.common import logger +from dlt.common.runtime.init import init_logging from dlt.common.configuration.specs import GcpServiceAccountCredentials from dlt.common.configuration.providers import GoogleSecretsProvider from dlt.common.configuration.accessors import secrets @@ -24,7 +23,7 @@ def test_regular_keys() -> None: - logger.init_logging(RunConfiguration()) + init_logging(RunConfiguration()) # copy bigquery credentials into providers credentials c = resolve_configuration( GcpServiceAccountCredentials(), sections=(known_sections.DESTINATION, "bigquery") diff --git a/tests/load/bigquery/test_bigquery_client.py b/tests/load/bigquery/test_bigquery_client.py index ac17bb8316..a97b612ad0 100644 --- a/tests/load/bigquery/test_bigquery_client.py +++ b/tests/load/bigquery/test_bigquery_client.py @@ -203,7 +203,8 @@ def test_get_oauth_access_token() -> None: def test_bigquery_configuration() -> None: config = resolve_configuration( - BigQueryClientConfiguration(dataset_name="dataset"), sections=("destination", "bigquery") + BigQueryClientConfiguration()._bind_dataset_name(dataset_name="dataset"), + sections=("destination", "bigquery"), ) assert config.location == "US" assert config.get_location() == "US" @@ -215,7 +216,8 @@ def test_bigquery_configuration() -> None: # credential location is deprecated os.environ["CREDENTIALS__LOCATION"] = "EU" config = resolve_configuration( - BigQueryClientConfiguration(dataset_name="dataset"), sections=("destination", "bigquery") + BigQueryClientConfiguration()._bind_dataset_name(dataset_name="dataset"), + sections=("destination", "bigquery"), ) assert config.location == "US" assert config.credentials.location == "EU" @@ -223,17 +225,21 @@ def test_bigquery_configuration() -> None: assert config.get_location() == "EU" os.environ["LOCATION"] = "ATLANTIS" config = resolve_configuration( - BigQueryClientConfiguration(dataset_name="dataset"), sections=("destination", "bigquery") + BigQueryClientConfiguration()._bind_dataset_name(dataset_name="dataset"), + sections=("destination", "bigquery"), ) assert config.get_location() == "ATLANTIS" os.environ["DESTINATION__FILE_UPLOAD_TIMEOUT"] = "20000" config = resolve_configuration( - BigQueryClientConfiguration(dataset_name="dataset"), sections=("destination", "bigquery") + BigQueryClientConfiguration()._bind_dataset_name(dataset_name="dataset"), + sections=("destination", "bigquery"), ) assert config.file_upload_timeout == 20000.0 # default fingerprint is empty - assert BigQueryClientConfiguration(dataset_name="dataset").fingerprint() == "" + assert ( + BigQueryClientConfiguration()._bind_dataset_name(dataset_name="dataset").fingerprint() == "" + ) def test_bigquery_job_errors(client: BigQueryClient, file_storage: FileStorage) -> None: diff --git a/tests/load/bigquery/test_bigquery_table_builder.py b/tests/load/bigquery/test_bigquery_table_builder.py index a223de9b26..fd58a6e033 100644 --- a/tests/load/bigquery/test_bigquery_table_builder.py +++ b/tests/load/bigquery/test_bigquery_table_builder.py @@ -4,10 +4,6 @@ from dlt.destinations.impl.bigquery.bigquery_adapter import ( PARTITION_HINT, CLUSTER_HINT, - TABLE_DESCRIPTION_HINT, - ROUND_HALF_EVEN_HINT, - ROUND_HALF_AWAY_FROM_ZERO_HINT, - TABLE_EXPIRATION_HINT, ) import google @@ -17,9 +13,12 @@ import dlt from dlt.common.configuration import resolve_configuration -from dlt.common.configuration.specs import GcpServiceAccountCredentialsWithoutDefaults +from dlt.common.configuration.specs import ( + GcpServiceAccountCredentialsWithoutDefaults, + GcpServiceAccountCredentials, +) from dlt.common.pendulum import pendulum -from dlt.common.schema import Schema, TColumnHint +from dlt.common.schema import Schema from dlt.common.utils import custom_environ from dlt.common.utils import uniq_id from dlt.destinations.exceptions import DestinationSchemaWillNotUpdate @@ -53,13 +52,13 @@ def test_configuration() -> None: @pytest.fixture def gcp_client(empty_schema: Schema) -> BigQueryClient: # return a client without opening connection - creds = GcpServiceAccountCredentialsWithoutDefaults() + creds = GcpServiceAccountCredentials() creds.project_id = "test_project_id" # noinspection PydanticTypeChecker return BigQueryClient( empty_schema, - BigQueryClientConfiguration( - dataset_name=f"test_{uniq_id()}", credentials=creds # type: ignore[arg-type] + BigQueryClientConfiguration(credentials=creds)._bind_dataset_name( + dataset_name=f"test_{uniq_id()}" ), ) diff --git a/tests/load/databricks/test_databricks_configuration.py b/tests/load/databricks/test_databricks_configuration.py index 9127e39be4..8d30d05e42 100644 --- a/tests/load/databricks/test_databricks_configuration.py +++ b/tests/load/databricks/test_databricks_configuration.py @@ -17,7 +17,9 @@ def test_databricks_credentials_to_connector_params(): # JSON encoded dict of extra args os.environ["CREDENTIALS__CONNECTION_PARAMETERS"] = '{"extra_a": "a", "extra_b": "b"}' - config = resolve_configuration(DatabricksClientConfiguration(dataset_name="my-dataset")) + config = resolve_configuration( + DatabricksClientConfiguration()._bind_dataset_name(dataset_name="my-dataset") + ) credentials = config.credentials diff --git a/tests/load/duckdb/test_duckdb_client.py b/tests/load/duckdb/test_duckdb_client.py index ef151833e4..3deed7a77d 100644 --- a/tests/load/duckdb/test_duckdb_client.py +++ b/tests/load/duckdb/test_duckdb_client.py @@ -31,7 +31,9 @@ def test_duckdb_open_conn_default() -> None: delete_quack_db() try: get_resolved_traces().clear() - c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset")) + c = resolve_configuration( + DuckDbClientConfiguration()._bind_dataset_name(dataset_name="test_dataset") + ) # print(str(c.credentials)) # print(str(os.getcwd())) # print(get_resolved_traces()) @@ -52,11 +54,15 @@ def test_duckdb_open_conn_default() -> None: def test_duckdb_database_path() -> None: # resolve without any path provided - c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset")) + c = resolve_configuration( + DuckDbClientConfiguration()._bind_dataset_name(dataset_name="test_dataset") + ) assert c.credentials._conn_str().lower() == os.path.abspath("quack.duckdb").lower() # resolve without any path but with pipeline context p = dlt.pipeline(pipeline_name="quack_pipeline") - c = resolve_configuration(DuckDbClientConfiguration(dataset_name="test_dataset")) + c = resolve_configuration( + DuckDbClientConfiguration()._bind_dataset_name(dataset_name="test_dataset") + ) # still cwd db_path = os.path.abspath(os.path.join(".", "quack_pipeline.duckdb")) assert c.credentials._conn_str().lower() == db_path.lower() @@ -75,7 +81,9 @@ def test_duckdb_database_path() -> None: # test special :pipeline: path to create in pipeline folder c = resolve_configuration( - DuckDbClientConfiguration(dataset_name="test_dataset", credentials=":pipeline:") + DuckDbClientConfiguration(credentials=":pipeline:")._bind_dataset_name( + dataset_name="test_dataset" + ) ) db_path = os.path.abspath(os.path.join(p.working_dir, DEFAULT_DUCK_DB_NAME)) assert c.credentials._conn_str().lower() == db_path.lower() @@ -90,8 +98,8 @@ def test_duckdb_database_path() -> None: db_path = "_storage/test_quack.duckdb" c = resolve_configuration( DuckDbClientConfiguration( - dataset_name="test_dataset", credentials="duckdb:///_storage/test_quack.duckdb" - ) + credentials="duckdb:///_storage/test_quack.duckdb" + )._bind_dataset_name(dataset_name="test_dataset") ) assert c.credentials._conn_str().lower() == os.path.abspath(db_path).lower() conn = c.credentials.borrow_conn(read_only=False) @@ -102,7 +110,9 @@ def test_duckdb_database_path() -> None: # provide absolute path db_path = os.path.abspath("_storage/abs_test_quack.duckdb") c = resolve_configuration( - DuckDbClientConfiguration(dataset_name="test_dataset", credentials=f"duckdb:///{db_path}") + DuckDbClientConfiguration(credentials=f"duckdb:///{db_path}")._bind_dataset_name( + dataset_name="test_dataset", + ) ) assert os.path.isabs(c.credentials.database) assert c.credentials._conn_str().lower() == db_path.lower() @@ -114,7 +124,9 @@ def test_duckdb_database_path() -> None: # set just path as credentials db_path = "_storage/path_test_quack.duckdb" c = resolve_configuration( - DuckDbClientConfiguration(dataset_name="test_dataset", credentials=db_path) + DuckDbClientConfiguration(credentials=db_path)._bind_dataset_name( + dataset_name="test_dataset" + ) ) assert c.credentials._conn_str().lower() == os.path.abspath(db_path).lower() conn = c.credentials.borrow_conn(read_only=False) @@ -124,7 +136,9 @@ def test_duckdb_database_path() -> None: db_path = os.path.abspath("_storage/abs_path_test_quack.duckdb") c = resolve_configuration( - DuckDbClientConfiguration(dataset_name="test_dataset", credentials=db_path) + DuckDbClientConfiguration(credentials=db_path)._bind_dataset_name( + dataset_name="test_dataset" + ) ) assert os.path.isabs(c.credentials.database) assert c.credentials._conn_str().lower() == db_path.lower() @@ -138,7 +152,9 @@ def test_duckdb_database_path() -> None: with pytest.raises(duckdb.IOException): c = resolve_configuration( - DuckDbClientConfiguration(dataset_name="test_dataset", credentials=TEST_STORAGE_ROOT) + DuckDbClientConfiguration(credentials=TEST_STORAGE_ROOT)._bind_dataset_name( + dataset_name="test_dataset" + ) ) conn = c.credentials.borrow_conn(read_only=False) @@ -225,7 +241,7 @@ def test_external_duckdb_database() -> None: # pass explicit in memory database conn = duckdb.connect(":memory:") c = resolve_configuration( - DuckDbClientConfiguration(dataset_name="test_dataset", credentials=conn) + DuckDbClientConfiguration(credentials=conn)._bind_dataset_name(dataset_name="test_dataset") ) assert c.credentials._conn_borrows == 0 assert c.credentials._conn is conn diff --git a/tests/load/duckdb/test_duckdb_table_builder.py b/tests/load/duckdb/test_duckdb_table_builder.py index 9b12e04f77..542b18993c 100644 --- a/tests/load/duckdb/test_duckdb_table_builder.py +++ b/tests/load/duckdb/test_duckdb_table_builder.py @@ -14,7 +14,10 @@ @pytest.fixture def client(empty_schema: Schema) -> DuckDbClient: # return client without opening connection - return DuckDbClient(empty_schema, DuckDbClientConfiguration(dataset_name="test_" + uniq_id())) + return DuckDbClient( + empty_schema, + DuckDbClientConfiguration()._bind_dataset_name(dataset_name="test_" + uniq_id()), + ) def test_create_table(client: DuckDbClient) -> None: @@ -89,7 +92,9 @@ def test_create_table_with_hints(client: DuckDbClient) -> None: # same thing with indexes client = DuckDbClient( client.schema, - DuckDbClientConfiguration(dataset_name="test_" + uniq_id(), create_indexes=True), + DuckDbClientConfiguration(create_indexes=True)._bind_dataset_name( + dataset_name="test_" + uniq_id() + ), ) sql = client._get_table_update_sql("event_test_table", mod_update, False)[0] sqlfluff.parse(sql) diff --git a/tests/load/duckdb/test_motherduck_client.py b/tests/load/duckdb/test_motherduck_client.py index d57cf58f53..ba60e0de6d 100644 --- a/tests/load/duckdb/test_motherduck_client.py +++ b/tests/load/duckdb/test_motherduck_client.py @@ -19,13 +19,15 @@ def test_motherduck_database() -> None: # os.environ.pop("HOME", None) cred = MotherDuckCredentials("md:///?token=TOKEN") + print(dict(cred)) assert cred.password == "TOKEN" cred = MotherDuckCredentials() cred.parse_native_representation("md:///?token=TOKEN") assert cred.password == "TOKEN" config = resolve_configuration( - MotherDuckClientConfiguration(dataset_name="test"), sections=("destination", "motherduck") + MotherDuckClientConfiguration()._bind_dataset_name(dataset_name="test"), + sections=("destination", "motherduck"), ) # connect con = config.credentials.borrow_conn(read_only=False) diff --git a/tests/load/filesystem/test_aws_credentials.py b/tests/load/filesystem/test_aws_credentials.py index 7a0d42eb6d..62c2e3cd85 100644 --- a/tests/load/filesystem/test_aws_credentials.py +++ b/tests/load/filesystem/test_aws_credentials.py @@ -45,7 +45,7 @@ def test_aws_credentials_from_botocore(environment: Dict[str, str]) -> None: session = botocore.session.get_session() region_name = "eu-central-1" # session.get_config_variable('region') - c = AwsCredentials(session) + c = AwsCredentials.from_session(session) assert c.profile_name is None assert c.aws_access_key_id == "fake_access_key" assert c.region_name == region_name @@ -83,7 +83,7 @@ def test_aws_credentials_from_boto3(environment: Dict[str, str]) -> None: session = boto3.Session() - c = AwsCredentials(session) + c = AwsCredentials.from_session(session) assert c.profile_name is None assert c.aws_access_key_id == "fake_access_key" assert c.region_name == session.region_name diff --git a/tests/load/filesystem/utils.py b/tests/load/filesystem/utils.py index 1232be5c43..6e697fdef9 100644 --- a/tests/load/filesystem/utils.py +++ b/tests/load/filesystem/utils.py @@ -14,7 +14,7 @@ from dlt.common.configuration.container import Container from dlt.common.configuration.specs.config_section_context import ConfigSectionContext -from dlt.common.destination.reference import LoadJob, TDestination +from dlt.common.destination.reference import LoadJob from dlt.common.pendulum import timedelta, __utcnow from dlt.destinations import filesystem from dlt.destinations.impl.filesystem.filesystem import FilesystemClient @@ -24,11 +24,11 @@ def setup_loader(dataset_name: str) -> Load: - destination: TDestination = filesystem() # type: ignore[assignment] - config = filesystem.spec(dataset_name=dataset_name) + destination = filesystem() + config = destination.spec()._bind_dataset_name(dataset_name=dataset_name) # setup loader with Container().injectable_context(ConfigSectionContext(sections=("filesystem",))): - return Load(destination, initial_client_config=config) + return Load(destination, initial_client_config=config) # type: ignore[arg-type] @contextmanager diff --git a/tests/load/mssql/test_mssql_table_builder.py b/tests/load/mssql/test_mssql_table_builder.py index 75f46e8905..1b4a77a2ab 100644 --- a/tests/load/mssql/test_mssql_table_builder.py +++ b/tests/load/mssql/test_mssql_table_builder.py @@ -17,7 +17,9 @@ def client(empty_schema: Schema) -> MsSqlClient: # return client without opening connection return MsSqlClient( empty_schema, - MsSqlClientConfiguration(dataset_name="test_" + uniq_id(), credentials=MsSqlCredentials()), + MsSqlClientConfiguration(credentials=MsSqlCredentials())._bind_dataset_name( + dataset_name="test_" + uniq_id() + ), ) diff --git a/tests/load/postgres/test_postgres_client.py b/tests/load/postgres/test_postgres_client.py index daabf6fc51..896e449b28 100644 --- a/tests/load/postgres/test_postgres_client.py +++ b/tests/load/postgres/test_postgres_client.py @@ -62,6 +62,10 @@ def test_postgres_credentials_native_value(environment) -> None: assert c.is_resolved() assert c.password == "loader" + c = PostgresCredentials("postgres://loader:loader@localhost/dlt_data") + assert c.password == "loader" + assert c.database == "dlt_data" + def test_postgres_credentials_timeout() -> None: # test postgres timeout diff --git a/tests/load/postgres/test_postgres_table_builder.py b/tests/load/postgres/test_postgres_table_builder.py index fde9d82cf7..0ab1343a3b 100644 --- a/tests/load/postgres/test_postgres_table_builder.py +++ b/tests/load/postgres/test_postgres_table_builder.py @@ -2,7 +2,6 @@ from copy import deepcopy import sqlfluff -from dlt.common.schema.utils import new_table from dlt.common.utils import uniq_id from dlt.common.schema import Schema @@ -20,8 +19,8 @@ def client(empty_schema: Schema) -> PostgresClient: # return client without opening connection return PostgresClient( empty_schema, - PostgresClientConfiguration( - dataset_name="test_" + uniq_id(), credentials=PostgresCredentials() + PostgresClientConfiguration(credentials=PostgresCredentials())._bind_dataset_name( + dataset_name="test_" + uniq_id() ), ) @@ -97,10 +96,9 @@ def test_create_table_with_hints(client: PostgresClient) -> None: client = PostgresClient( client.schema, PostgresClientConfiguration( - dataset_name="test_" + uniq_id(), create_indexes=False, credentials=PostgresCredentials(), - ), + )._bind_dataset_name(dataset_name="test_" + uniq_id()), ) sql = client._get_table_update_sql("event_test_table", mod_update, False)[0] sqlfluff.parse(sql, dialect="postgres") diff --git a/tests/load/redshift/test_redshift_table_builder.py b/tests/load/redshift/test_redshift_table_builder.py index c6981e5553..bc132c7818 100644 --- a/tests/load/redshift/test_redshift_table_builder.py +++ b/tests/load/redshift/test_redshift_table_builder.py @@ -20,8 +20,8 @@ def client(empty_schema: Schema) -> RedshiftClient: # return client without opening connection return RedshiftClient( empty_schema, - RedshiftClientConfiguration( - dataset_name="test_" + uniq_id(), credentials=RedshiftCredentials() + RedshiftClientConfiguration(credentials=RedshiftCredentials())._bind_dataset_name( + dataset_name="test_" + uniq_id() ), ) diff --git a/tests/load/snowflake/test_snowflake_table_builder.py b/tests/load/snowflake/test_snowflake_table_builder.py index 1e80a61f1c..5d7108803e 100644 --- a/tests/load/snowflake/test_snowflake_table_builder.py +++ b/tests/load/snowflake/test_snowflake_table_builder.py @@ -21,7 +21,9 @@ def snowflake_client(empty_schema: Schema) -> SnowflakeClient: creds = SnowflakeCredentials() return SnowflakeClient( empty_schema, - SnowflakeClientConfiguration(dataset_name="test_" + uniq_id(), credentials=creds), + SnowflakeClientConfiguration(credentials=creds)._bind_dataset_name( + dataset_name="test_" + uniq_id() + ), ) diff --git a/tests/load/synapse/test_synapse_table_builder.py b/tests/load/synapse/test_synapse_table_builder.py index 871ceecf96..8575835820 100644 --- a/tests/load/synapse/test_synapse_table_builder.py +++ b/tests/load/synapse/test_synapse_table_builder.py @@ -25,8 +25,8 @@ def client(empty_schema: Schema) -> SynapseClient: # return client without opening connection client = SynapseClient( empty_schema, - SynapseClientConfiguration( - dataset_name="test_" + uniq_id(), credentials=SynapseCredentials() + SynapseClientConfiguration(credentials=SynapseCredentials())._bind_dataset_name( + dataset_name="test_" + uniq_id() ), ) assert client.config.create_indexes is False @@ -39,8 +39,8 @@ def client_with_indexes_enabled(empty_schema: Schema) -> SynapseClient: client = SynapseClient( empty_schema, SynapseClientConfiguration( - dataset_name="test_" + uniq_id(), credentials=SynapseCredentials(), create_indexes=True - ), + credentials=SynapseCredentials(), create_indexes=True + )._bind_dataset_name(dataset_name="test_" + uniq_id()), ) assert client.config.create_indexes is True return client diff --git a/tests/load/test_dummy_client.py b/tests/load/test_dummy_client.py index d7884abcf0..c5e4f874fc 100644 --- a/tests/load/test_dummy_client.py +++ b/tests/load/test_dummy_client.py @@ -6,8 +6,7 @@ from typing import List from dlt.common.exceptions import TerminalException, TerminalValueError -from dlt.common.schema.typing import TWriteDisposition -from dlt.common.storages import FileStorage, LoadStorage, PackageStorage, ParsedLoadJobFileName +from dlt.common.storages import FileStorage, PackageStorage, ParsedLoadJobFileName from dlt.common.storages.load_package import LoadJobInfo from dlt.common.storages.load_storage import JobWithUnsupportedWriterException from dlt.common.destination.reference import LoadJob, TDestination @@ -814,7 +813,9 @@ def setup_loader( if filesystem_staging: # do not accept jsonl to not conflict with filesystem destination client_config = client_config or DummyClientConfiguration(loader_file_format="reference") - staging_system_config = FilesystemDestinationClientConfiguration(dataset_name="dummy") + staging_system_config = FilesystemDestinationClientConfiguration()._bind_dataset_name( + dataset_name="dummy" + ) staging_system_config.as_staging = True os.makedirs(REMOTE_FILESYSTEM) staging = filesystem(bucket_url=REMOTE_FILESYSTEM) diff --git a/tests/pipeline/test_dlt_versions.py b/tests/pipeline/test_dlt_versions.py index 8906958e0c..4c7357cb06 100644 --- a/tests/pipeline/test_dlt_versions.py +++ b/tests/pipeline/test_dlt_versions.py @@ -24,8 +24,8 @@ from tests.utils import TEST_STORAGE_ROOT, test_storage -if sys.version_info > (3, 11): - pytest.skip("Does not run on Python 3.12 and later", allow_module_level=True) +# if sys.version_info > (3, 11): +# pytest.skip("Does not run on Python 3.12 and later", allow_module_level=True) GITHUB_PIPELINE_NAME = "dlt_github_pipeline" @@ -70,7 +70,7 @@ def test_pipeline_with_dlt_update(test_storage: FileStorage) -> None: } # check loads table without attaching to pipeline duckdb_cfg = resolve_configuration( - DuckDbClientConfiguration(dataset_name=GITHUB_DATASET), + DuckDbClientConfiguration()._bind_dataset_name(dataset_name=GITHUB_DATASET), sections=("destination", "duckdb"), ) with DuckDbSqlClient(GITHUB_DATASET, duckdb_cfg.credentials) as client: @@ -189,7 +189,7 @@ def test_load_package_with_dlt_update(test_storage: FileStorage) -> None: venv = Venv.restore_current() print(venv.run_script("../tests/pipeline/cases/github_pipeline/github_load.py")) duckdb_cfg = resolve_configuration( - DuckDbClientConfiguration(dataset_name=GITHUB_DATASET), + DuckDbClientConfiguration()._bind_dataset_name(dataset_name=GITHUB_DATASET), sections=("destination", "duckdb"), ) with DuckDbSqlClient(GITHUB_DATASET, duckdb_cfg.credentials) as client: diff --git a/tests/utils.py b/tests/utils.py index 924f44de73..00523486ea 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -19,7 +19,7 @@ ConfigProvidersContext, ) from dlt.common.pipeline import PipelineContext -from dlt.common.runtime.logger import init_logging +from dlt.common.runtime.init import init_logging from dlt.common.runtime.telemetry import start_telemetry, stop_telemetry from dlt.common.schema import Schema from dlt.common.storages import FileStorage