From 5b8421e7227059583a02bee5086f6e8b55c5e0f5 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 9 Dec 2024 17:49:31 +0100 Subject: [PATCH 1/3] Fix validation error in for custom auth classes --- dlt/common/typing.py | 15 +++++++++++++++ dlt/sources/rest_api/config_setup.py | 6 +++++- .../configurations/test_custom_auth_config.py | 17 ++++++++++++++++- 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/dlt/common/typing.py b/dlt/common/typing.py index a3364d1b07..c8080b548d 100644 --- a/dlt/common/typing.py +++ b/dlt/common/typing.py @@ -484,3 +484,18 @@ def decorator( return func return decorator + + +def add_value_to_literal(literal: Any, value: Any) -> None: + """Extends a Literal at runtime with a new value. + + Args: + literal (Type[Any]): Literal to extend + value (Any): Value to add + + """ + type_args = get_args(literal) + + if value not in type_args: + type_args += (value,) + literal.__args__ = type_args diff --git a/dlt/sources/rest_api/config_setup.py b/dlt/sources/rest_api/config_setup.py index d03a4fd59b..3ce81f3aa7 100644 --- a/dlt/sources/rest_api/config_setup.py +++ b/dlt/sources/rest_api/config_setup.py @@ -20,6 +20,7 @@ from dlt.common.configuration import resolve_configuration from dlt.common.schema.utils import merge_columns from dlt.common.utils import update_dict_nested, exclude_keys +from dlt.common.typing import add_value_to_literal from dlt.common import jsonpath from dlt.extract.incremental import Incremental @@ -64,6 +65,7 @@ ResponseActionDict, Endpoint, EndpointResource, + AuthType, ) @@ -153,6 +155,8 @@ def register_auth( ) AUTH_MAP[auth_name] = auth_class + add_value_to_literal(AuthType, auth_name) + def get_auth_class(auth_type: str) -> Type[AuthConfigBase]: try: @@ -285,7 +289,7 @@ def build_resource_dependency_graph( resolved_param_map[resource_name] = None break assert isinstance(endpoint_resource["endpoint"], dict) - # connect transformers to resources via resolved params + # find resolved parameters to connect dependent resources resolved_params = _find_resolved_params(endpoint_resource["endpoint"]) # set of resources in resolved params diff --git a/tests/sources/rest_api/configurations/test_custom_auth_config.py b/tests/sources/rest_api/configurations/test_custom_auth_config.py index 1a5a2e58a3..132bd67e88 100644 --- a/tests/sources/rest_api/configurations/test_custom_auth_config.py +++ b/tests/sources/rest_api/configurations/test_custom_auth_config.py @@ -5,7 +5,7 @@ from dlt.sources import rest_api from dlt.sources.helpers.rest_client.auth import APIKeyAuth, OAuth2ClientCredentials -from dlt.sources.rest_api.typing import ApiKeyAuthConfig, AuthConfig +from dlt.sources.rest_api.typing import ApiKeyAuthConfig, AuthConfig, RESTAPIConfig class CustomOAuth2(OAuth2ClientCredentials): @@ -77,3 +77,18 @@ class NotAuthConfigBase: "not_an_auth_config_base", NotAuthConfigBase # type: ignore ) assert e.match("Invalid auth: NotAuthConfigBase.") + + def test_validate_config_raises_no_error(self, custom_auth_config: AuthConfig) -> None: + rest_api.config_setup.register_auth("custom_oauth_2", CustomOAuth2) + + valid_config: RESTAPIConfig = { + "client": { + "base_url": "https://example.com", + "auth": custom_auth_config, + }, + "resources": ["test"], + } + + rest_api.rest_api_source(valid_config) + + del rest_api.config_setup.AUTH_MAP["custom_oauth_2"] From 4a6351e3b75875f10fb25b7da0cc28ad51c8ff9c Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Mon, 9 Dec 2024 18:28:05 +0100 Subject: [PATCH 2/3] Add paginator type support --- dlt/sources/rest_api/config_setup.py | 2 ++ .../configurations/test_custom_auth_config.py | 2 +- .../configurations/test_custom_paginator_config.py | 12 +++++++++++- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/dlt/sources/rest_api/config_setup.py b/dlt/sources/rest_api/config_setup.py index 3ce81f3aa7..bf62c6c4f7 100644 --- a/dlt/sources/rest_api/config_setup.py +++ b/dlt/sources/rest_api/config_setup.py @@ -66,6 +66,7 @@ Endpoint, EndpointResource, AuthType, + PaginatorType, ) @@ -105,6 +106,7 @@ def register_paginator( "Your custom paginator has to be a subclass of BasePaginator" ) PAGINATOR_MAP[paginator_name] = paginator_class + add_value_to_literal(PaginatorType, paginator_name) def get_paginator_class(paginator_name: str) -> Type[BasePaginator]: diff --git a/tests/sources/rest_api/configurations/test_custom_auth_config.py b/tests/sources/rest_api/configurations/test_custom_auth_config.py index 132bd67e88..52cdb95735 100644 --- a/tests/sources/rest_api/configurations/test_custom_auth_config.py +++ b/tests/sources/rest_api/configurations/test_custom_auth_config.py @@ -78,7 +78,7 @@ class NotAuthConfigBase: ) assert e.match("Invalid auth: NotAuthConfigBase.") - def test_validate_config_raises_no_error(self, custom_auth_config: AuthConfig) -> None: + def test_valid_config_raises_no_error(self, custom_auth_config: AuthConfig) -> None: rest_api.config_setup.register_auth("custom_oauth_2", CustomOAuth2) valid_config: RESTAPIConfig = { diff --git a/tests/sources/rest_api/configurations/test_custom_paginator_config.py b/tests/sources/rest_api/configurations/test_custom_paginator_config.py index f8ac060218..975ab10176 100644 --- a/tests/sources/rest_api/configurations/test_custom_paginator_config.py +++ b/tests/sources/rest_api/configurations/test_custom_paginator_config.py @@ -4,7 +4,7 @@ from dlt.sources import rest_api from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator -from dlt.sources.rest_api.typing import PaginatorConfig +from dlt.sources.rest_api.typing import PaginatorConfig, RESTAPIConfig class CustomPaginator(JSONLinkPaginator): @@ -67,3 +67,13 @@ class NotAPaginator: with pytest.raises(ValueError) as e: rest_api.config_setup.register_paginator("not_a_paginator", NotAPaginator) # type: ignore[arg-type] assert e.match("Invalid paginator: NotAPaginator.") + + def test_test_valid_config_raises_no_error(self, custom_paginator_config) -> None: + rest_api.config_setup.register_paginator("custom_paginator", CustomPaginator) + + valid_config: RESTAPIConfig = { + "client": {"base_url": "https://example.com", "paginator": custom_paginator_config}, + "resources": ["test"], + } + + rest_api.rest_api_source(valid_config) From e03801d7ff00a8fc6c870e10180b1f44e9765e71 Mon Sep 17 00:00:00 2001 From: Anton Burnashev Date: Thu, 12 Dec 2024 11:03:14 +0100 Subject: [PATCH 3/3] Add test for add_value_to_literal --- tests/common/test_typing.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/common/test_typing.py b/tests/common/test_typing.py index 2749e3ebb1..e81c3e7fa2 100644 --- a/tests/common/test_typing.py +++ b/tests/common/test_typing.py @@ -43,6 +43,7 @@ is_union_type, is_annotated, is_callable_type, + add_value_to_literal, ) @@ -293,3 +294,19 @@ def test_secret_type() -> None: assert TSecretStrValue("x_str") == "x_str" assert TSecretStrValue({}) == "{}" + + +def test_add_value_to_literal() -> None: + TestLiteral = Literal["red", "blue"] + + add_value_to_literal(TestLiteral, "green") + + assert get_args(TestLiteral) == ("red", "blue", "green") + + add_value_to_literal(TestLiteral, "red") + assert get_args(TestLiteral) == ("red", "blue", "green") + + TestSingleLiteral = Literal["red"] + add_value_to_literal(TestSingleLiteral, "green") + add_value_to_literal(TestSingleLiteral, "blue") + assert get_args(TestSingleLiteral) == ("red", "green", "blue")