Skip to content

Commit

Permalink
Fix: ensure custom session can be provided to rest client (#1396)
Browse files Browse the repository at this point in the history
* fix: ensure custom session can be provided to rest client

* fix: move request client retry to correct central req method used in all codepaths

* chore: use adapter mock to replicate production code path more accurately

* chore: rename session warn func and add docstring

* fix: linting err

* creates explicit session in rest client tests

* allows custom sessions in oauth2 jwt of rest client

* adds NotResolved type annotations that excludes type from resolving in configspec

* fixes weaviate test

---------

Co-authored-by: Marcin Rudolf <[email protected]>
  • Loading branch information
z3z1ma and rudolfix authored May 27, 2024
1 parent b94c807 commit 4415988
Show file tree
Hide file tree
Showing 14 changed files with 216 additions and 81 deletions.
9 changes: 8 additions & 1 deletion dlt/common/configuration/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from .specs.base_configuration import configspec, is_valid_hint, is_secret_hint, resolve_type
from .specs.base_configuration import (
configspec,
is_valid_hint,
is_secret_hint,
resolve_type,
NotResolved,
)
from .specs import known_sections
from .resolve import resolve_configuration, inject_section
from .inject import with_config, last_config, get_fun_spec, create_resolved_partial
Expand All @@ -15,6 +21,7 @@
"configspec",
"is_valid_hint",
"is_secret_hint",
"NotResolved",
"resolve_type",
"known_sections",
"resolve_configuration",
Expand Down
6 changes: 3 additions & 3 deletions dlt/common/configuration/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
StrAny,
TSecretValue,
get_all_types_of_class_in_union,
is_final_type,
is_optional_type,
is_union_type,
)
Expand All @@ -21,6 +20,7 @@
is_context_inner_hint,
is_base_configuration_inner_hint,
is_valid_hint,
is_hint_not_resolved,
)
from dlt.common.configuration.specs.config_section_context import ConfigSectionContext
from dlt.common.configuration.specs.exceptions import NativeValueError
Expand Down Expand Up @@ -194,7 +194,7 @@ def _resolve_config_fields(
if explicit_values:
explicit_value = explicit_values.get(key)
else:
if is_final_type(hint):
if is_hint_not_resolved(hint):
# for final fields default value is like explicit
explicit_value = default_value
else:
Expand Down Expand Up @@ -258,7 +258,7 @@ def _resolve_config_fields(
unresolved_fields[key] = traces
# set resolved value in config
if default_value != current_value:
if not is_final_type(hint):
if not is_hint_not_resolved(hint):
# ignore final types
setattr(config, key, current_value)

Expand Down
38 changes: 37 additions & 1 deletion dlt/common/configuration/specs/base_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ClassVar,
TypeVar,
)
from typing_extensions import get_args, get_origin, dataclass_transform
from typing_extensions import get_args, get_origin, dataclass_transform, Annotated, TypeAlias
from functools import wraps

if TYPE_CHECKING:
Expand All @@ -29,8 +29,11 @@
TDtcField = dataclasses.Field

from dlt.common.typing import (
AnyType,
TAnyClass,
extract_inner_type,
is_annotated,
is_final_type,
is_optional_type,
is_union_type,
)
Expand All @@ -48,6 +51,34 @@
_C = TypeVar("_C", bound="CredentialsConfiguration")


class NotResolved:
"""Used in type annotations to indicate types that should not be resolved."""

def __init__(self, not_resolved: bool = True):
self.not_resolved = not_resolved

def __bool__(self) -> bool:
return self.not_resolved


def is_hint_not_resolved(hint: AnyType) -> bool:
"""Checks if hint should NOT be resolved. Final and types annotated like
>>> Annotated[str, NotResolved()]
are not resolved.
"""
if is_final_type(hint):
return True

if is_annotated(hint):
_, *a_m = get_args(hint)
for annotation in a_m:
if isinstance(annotation, NotResolved):
return bool(annotation)
return False


def is_base_configuration_inner_hint(inner_hint: Type[Any]) -> bool:
return inspect.isclass(inner_hint) and issubclass(inner_hint, BaseConfiguration)

Expand All @@ -70,6 +101,11 @@ def is_valid_hint(hint: Type[Any]) -> bool:
if get_origin(hint) is ClassVar:
# class vars are skipped by dataclass
return True

if is_hint_not_resolved(hint):
# all hints that are not resolved are valid
return True

hint = extract_inner_type(hint)
hint = get_config_if_union_hint(hint) or hint
hint = get_origin(hint) or hint
Expand Down
16 changes: 8 additions & 8 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
Any,
TypeVar,
Generic,
Final,
)
from typing_extensions import Annotated
import datetime # noqa: 251
from copy import deepcopy
import inspect
Expand All @@ -35,7 +35,7 @@
has_column_with_prop,
get_first_column_name_with_prop,
)
from dlt.common.configuration import configspec, resolve_configuration, known_sections
from dlt.common.configuration import configspec, resolve_configuration, known_sections, NotResolved
from dlt.common.configuration.specs import BaseConfiguration, CredentialsConfiguration
from dlt.common.configuration.accessors import config
from dlt.common.destination.capabilities import DestinationCapabilitiesContext
Expand Down Expand Up @@ -78,7 +78,7 @@ class StateInfo(NamedTuple):

@configspec
class DestinationClientConfiguration(BaseConfiguration):
destination_type: Final[str] = dataclasses.field(
destination_type: Annotated[str, NotResolved()] = dataclasses.field(
default=None, init=False, repr=False, compare=False
) # which destination to load data to
credentials: Optional[CredentialsConfiguration] = None
Expand All @@ -103,11 +103,11 @@ def on_resolved(self) -> None:
class DestinationClientDwhConfiguration(DestinationClientConfiguration):
"""Configuration of a destination that supports datasets/schemas"""

dataset_name: Final[str] = dataclasses.field(
dataset_name: Annotated[str, NotResolved()] = dataclasses.field(
default=None, init=False, repr=False, compare=False
) # dataset must be final so it is not configurable
) # dataset cannot be resolved
"""dataset name in the destination to load data to, for schemas that are not default schema, it is used as dataset prefix"""
default_schema_name: Final[Optional[str]] = dataclasses.field(
default_schema_name: Annotated[Optional[str], NotResolved()] = dataclasses.field(
default=None, init=False, repr=False, compare=False
)
"""name of default schema to be used to name effective dataset to load data to"""
Expand All @@ -121,8 +121,8 @@ def _bind_dataset_name(
This method is intended to be used internally.
"""
self.dataset_name = dataset_name # type: ignore[misc]
self.default_schema_name = default_schema_name # type: ignore[misc]
self.dataset_name = dataset_name
self.default_schema_name = default_schema_name
return self

def normalize_dataset_name(self, schema: Schema) -> str:
Expand Down
7 changes: 5 additions & 2 deletions dlt/destinations/impl/qdrant/configuration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import dataclasses
from typing import Optional, Final
from typing_extensions import Annotated

from dlt.common.configuration import configspec
from dlt.common.configuration import configspec, NotResolved
from dlt.common.configuration.specs.base_configuration import (
BaseConfiguration,
CredentialsConfiguration,
Expand Down Expand Up @@ -55,7 +56,9 @@ class QdrantClientConfiguration(DestinationClientDwhConfiguration):
dataset_separator: str = "_"

# make it optional so empty dataset is allowed
dataset_name: Final[Optional[str]] = dataclasses.field(default=None, init=False, repr=False, compare=False) # type: ignore[misc]
dataset_name: Annotated[Optional[str], NotResolved()] = dataclasses.field(
default=None, init=False, repr=False, compare=False
)

# Batch size for generating embeddings
embedding_batch_size: int = 32
Expand Down
7 changes: 5 additions & 2 deletions dlt/destinations/impl/weaviate/configuration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import dataclasses
from typing import Dict, Literal, Optional, Final
from typing_extensions import Annotated
from urllib.parse import urlparse

from dlt.common.configuration import configspec
from dlt.common.configuration import configspec, NotResolved
from dlt.common.configuration.specs.base_configuration import CredentialsConfiguration
from dlt.common.destination.reference import DestinationClientDwhConfiguration
from dlt.common.utils import digest128
Expand All @@ -26,7 +27,9 @@ def __str__(self) -> str:
class WeaviateClientConfiguration(DestinationClientDwhConfiguration):
destination_type: Final[str] = dataclasses.field(default="weaviate", init=False, repr=False, compare=False) # type: ignore
# make it optional so empty dataset is allowed
dataset_name: Optional[str] = None # type: ignore[misc]
dataset_name: Annotated[Optional[str], NotResolved()] = dataclasses.field(
default=None, init=False, repr=False, compare=False
)

batch_size: int = 100
batch_workers: int = 1
Expand Down
2 changes: 1 addition & 1 deletion dlt/sources/helpers/requests/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def _make_session(self) -> Session:
session.mount("http://", self._adapter)
session.mount("https://", self._adapter)
retry = _make_retry(**self._retry_kwargs)
session.request = retry.wraps(session.request) # type: ignore[method-assign]
session.send = retry.wraps(session.send) # type: ignore[method-assign]
return session

@property
Expand Down
16 changes: 12 additions & 4 deletions dlt/sources/helpers/rest_client/auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from base64 import b64encode
import dataclasses
import math
from typing import (
List,
Expand All @@ -12,12 +13,13 @@
Iterable,
TYPE_CHECKING,
)
from typing_extensions import Annotated
from requests.auth import AuthBase
from requests import PreparedRequest # noqa: I251
from requests import PreparedRequest, Session as BaseSession # noqa: I251

from dlt.common import logger
from dlt.common.exceptions import MissingDependencyException
from dlt.common.configuration.specs.base_configuration import configspec
from dlt.common.configuration.specs.base_configuration import configspec, NotResolved
from dlt.common.configuration.specs import CredentialsConfiguration
from dlt.common.configuration.specs.exceptions import NativeValueError
from dlt.common.pendulum import pendulum
Expand Down Expand Up @@ -146,19 +148,25 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest:
class OAuthJWTAuth(BearerTokenAuth):
"""This is a form of Bearer auth, actually there's not standard way to declare it in openAPI"""

format: Final[Literal["JWT"]] = "JWT" # noqa: A003
format: Final[Literal["JWT"]] = dataclasses.field( # noqa: A003
default="JWT", init=False, repr=False, compare=False
)
client_id: str = None
private_key: TSecretStrValue = None
auth_endpoint: str = None
scopes: Optional[Union[str, List[str]]] = None
headers: Optional[Dict[str, str]] = None
private_key_passphrase: Optional[TSecretStrValue] = None
default_token_expiration: int = 3600
session: Annotated[BaseSession, NotResolved()] = None

def __post_init__(self) -> None:
self.scopes = self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes)
self.token = None
self.token_expiry: Optional[pendulum.DateTime] = None
# use default system session is not specified
if self.session is None:
self.session = requests.client.session

def __call__(self, r: PreparedRequest) -> PreparedRequest:
if self.token is None or self.is_token_expired():
Expand All @@ -183,7 +191,7 @@ def obtain_token(self) -> None:

logger.debug(f"Obtaining token from {self.auth_endpoint}")

response = requests.post(self.auth_endpoint, headers=self.headers, data=data)
response = self.session.post(self.auth_endpoint, headers=self.headers, data=data)
response.raise_for_status()

token_response = response.json()
Expand Down
23 changes: 12 additions & 11 deletions dlt/sources/helpers/rest_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ def __init__(
self.auth = auth

if session:
self._validate_session_raise_for_status(session)
self.session = session
# dlt.sources.helpers.requests.session.Session
# has raise_for_status=True by default
self.session = _warn_if_raise_for_status_and_return(session)
else:
self.session = Client(raise_for_status=False).session

Expand All @@ -92,15 +93,6 @@ def __init__(

self.data_selector = data_selector

def _validate_session_raise_for_status(self, session: BaseSession) -> None:
# dlt.sources.helpers.requests.session.Session
# has raise_for_status=True by default
if getattr(self.session, "raise_for_status", False):
logger.warning(
"The session provided has raise_for_status enabled. "
"This may cause unexpected behavior."
)

def _create_request(
self,
path: str,
Expand Down Expand Up @@ -298,3 +290,12 @@ def detect_paginator(self, response: Response, data: Any) -> BasePaginator:
" instance of the paginator as some settings may not be guessed correctly."
)
return paginator


def _warn_if_raise_for_status_and_return(session: BaseSession) -> BaseSession:
"""A generic function to warn if the session has raise_for_status enabled."""
if getattr(session, "raise_for_status", False):
logger.warning(
"The session provided has raise_for_status enabled. This may cause unexpected behavior."
)
return session
55 changes: 54 additions & 1 deletion tests/common/configuration/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
Optional,
Type,
Union,
TYPE_CHECKING,
)
from typing_extensions import Annotated

from dlt.common import json, pendulum, Decimal, Wei
from dlt.common.configuration.providers.provider import ConfigProvider
from dlt.common.configuration.specs.base_configuration import NotResolved, is_hint_not_resolved
from dlt.common.configuration.specs.gcp_credentials import (
GcpServiceAccountCredentialsWithoutDefaults,
)
Expand Down Expand Up @@ -917,6 +918,58 @@ def test_is_valid_hint() -> None:
assert is_valid_hint(Wei) is True
# any class type, except deriving from BaseConfiguration is wrong type
assert is_valid_hint(ConfigFieldMissingException) is False
# but final and annotated types are not ok because they are not resolved
assert is_valid_hint(Final[ConfigFieldMissingException]) is True # type: ignore[arg-type]
assert is_valid_hint(Annotated[ConfigFieldMissingException, NotResolved()]) is True # type: ignore[arg-type]
assert is_valid_hint(Annotated[ConfigFieldMissingException, "REQ"]) is False # type: ignore[arg-type]


def test_is_not_resolved_hint() -> None:
assert is_hint_not_resolved(Final[ConfigFieldMissingException]) is True
assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, NotResolved()]) is True
assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, NotResolved(True)]) is True
assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, NotResolved(False)]) is False
assert is_hint_not_resolved(Annotated[ConfigFieldMissingException, "REQ"]) is False
assert is_hint_not_resolved(str) is False


def test_not_resolved_hint() -> None:
class SentinelClass:
pass

@configspec
class OptionalNotResolveConfiguration(BaseConfiguration):
trace: Final[Optional[SentinelClass]] = None
traces: Annotated[Optional[List[SentinelClass]], NotResolved()] = None

c = resolve.resolve_configuration(OptionalNotResolveConfiguration())
assert c.trace is None
assert c.traces is None

s1 = SentinelClass()
s2 = SentinelClass()

c = resolve.resolve_configuration(OptionalNotResolveConfiguration(s1, [s2]))
assert c.trace is s1
assert c.traces[0] is s2

@configspec
class NotResolveConfiguration(BaseConfiguration):
trace: Final[SentinelClass] = None
traces: Annotated[List[SentinelClass], NotResolved()] = None

with pytest.raises(ConfigFieldMissingException):
resolve.resolve_configuration(NotResolveConfiguration())

with pytest.raises(ConfigFieldMissingException):
resolve.resolve_configuration(NotResolveConfiguration(trace=s1))

with pytest.raises(ConfigFieldMissingException):
resolve.resolve_configuration(NotResolveConfiguration(traces=[s2]))

c2 = resolve.resolve_configuration(NotResolveConfiguration(s1, [s2]))
assert c2.trace is s1
assert c2.traces[0] is s2


def test_configspec_auto_base_config_derivation() -> None:
Expand Down
Loading

0 comments on commit 4415988

Please sign in to comment.