Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RESTClient: implement AuthConfigBase.__bool__ + update docs #1413

Merged
merged 4 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion dlt/sources/helpers/rest_client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ class AuthConfigBase(AuthBase, CredentialsConfiguration):
configurable via env variables or toml files
"""

pass
def __bool__(self) -> bool:
# This is needed to avoid AuthConfigBase-derived classes
# which do not implement CredentialsConfiguration interface
# to be evaluated as False in requests.sessions.Session.prepare_request()
return True


@configspec
Expand Down
14 changes: 8 additions & 6 deletions dlt/sources/helpers/rest_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
Any,
TypeVar,
Iterable,
Union,
cast,
)
import copy
from urllib.parse import urlparse
from requests import Session as BaseSession # noqa: I251
from requests import Response, Request
from requests.auth import AuthBase

from dlt.common import jsonpath, logger

Expand Down Expand Up @@ -41,7 +43,7 @@ def __init__(
request: Request,
response: Response,
paginator: BasePaginator,
auth: AuthConfigBase,
auth: AuthBase,
):
super().__init__(__iterable)
self.request = request
Expand All @@ -57,7 +59,7 @@ class RESTClient:
Args:
base_url (str): The base URL of the API to make requests to.
headers (Optional[Dict[str, str]]): Default headers to include in all requests.
auth (Optional[AuthConfigBase]): Authentication configuration for all requests.
auth (Optional[AuthBase]): Authentication configuration for all requests.
burnash marked this conversation as resolved.
Show resolved Hide resolved
paginator (Optional[BasePaginator]): Default paginator for handling paginated responses.
data_selector (Optional[jsonpath.TJsonPath]): JSONPath selector for extracting data from responses.
session (BaseSession): HTTP session for making requests.
Expand All @@ -69,7 +71,7 @@ def __init__(
self,
base_url: str,
headers: Optional[Dict[str, str]] = None,
auth: Optional[AuthConfigBase] = None,
auth: Optional[AuthBase] = None,
paginator: Optional[BasePaginator] = None,
data_selector: Optional[jsonpath.TJsonPath] = None,
session: BaseSession = None,
Expand Down Expand Up @@ -105,7 +107,7 @@ def _create_request(
method: HTTPMethod,
params: Dict[str, Any],
json: Optional[Dict[str, Any]] = None,
auth: Optional[AuthConfigBase] = None,
auth: Optional[AuthBase] = None,
hooks: Optional[Hooks] = None,
) -> Request:
parsed_url = urlparse(path)
Expand Down Expand Up @@ -154,7 +156,7 @@ def paginate(
method: HTTPMethodBasic = "GET",
params: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
auth: Optional[AuthConfigBase] = None,
auth: Optional[AuthBase] = None,
paginator: Optional[BasePaginator] = None,
data_selector: Optional[jsonpath.TJsonPath] = None,
hooks: Optional[Hooks] = None,
Expand All @@ -166,7 +168,7 @@ def paginate(
method (HTTPMethodBasic): HTTP method for the request, defaults to 'get'.
params (Optional[Dict[str, Any]]): URL parameters for the request.
json (Optional[Dict[str, Any]]): JSON payload for the request.
auth (Optional[AuthConfigBase]): Authentication configuration for the request.
auth (Optional[AuthBase): Authentication configuration for the request.
paginator (Optional[BasePaginator]): Paginator instance for handling
pagination logic.
data_selector (Optional[jsonpath.TJsonPath]): JSONPath selector for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def resource(url: str):
# dlt biquery custom destination
# we can use the dlt provided credentials class
# to retrieve the gcp credentials from the secrets
@dlt.destination(name="bigquery", loader_file_format="parquet", batch_size=0)
@dlt.destination(
name="bigquery", loader_file_format="parquet", batch_size=0, naming_convention="snake_case"
)
def bigquery_insert(
items, table, credentials: GcpServiceAccountCredentials = dlt.secrets.value
) -> None:
Expand Down
8 changes: 4 additions & 4 deletions docs/website/docs/general-usage/http/rest-client.md
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ The available authentication methods are defined in the `dlt.sources.helpers.res
- [APIKeyAuth](#api-key-authentication)
- [HttpBasicAuth](#http-basic-authentication)

For specific use cases, you can [implement custom authentication](#implementing-custom-authentication) by subclassing the `AuthConfigBase` class.
For specific use cases, you can [implement custom authentication](#implementing-custom-authentication) by subclassing the `AuthBase` class from the Requests library.

### Bearer token authentication

Expand Down Expand Up @@ -479,12 +479,12 @@ response = client.get("/protected/resource")

### Implementing custom authentication

You can implement custom authentication by subclassing the `AuthConfigBase` class and implementing the `__call__` method:
You can implement custom authentication by subclassing the `AuthBase` class and implementing the `__call__` method:

```py
from dlt.sources.helpers.rest_client.auth import AuthConfigBase
from requests.auth import AuthBase

class CustomAuth(AuthConfigBase):
class CustomAuth(AuthBase):
def __init__(self, token):
self.token = token

Expand Down
46 changes: 44 additions & 2 deletions tests/sources/helpers/rest_client/test_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import pytest
from typing import Any, cast
from requests import PreparedRequest, Request
from requests.auth import AuthBase
from dlt.common.typing import TSecretStrValue
from dlt.sources.helpers.requests import Response, Request
from dlt.sources.helpers.requests import Response
from dlt.sources.helpers.rest_client import RESTClient
from dlt.sources.helpers.rest_client.client import Hooks
from dlt.sources.helpers.rest_client.paginators import JSONResponsePaginator
Expand Down Expand Up @@ -57,7 +59,6 @@ def test_page_context(self, rest_client: RESTClient) -> None:
for page in rest_client.paginate(
"/posts",
paginator=JSONResponsePaginator(next_url_path="next_page"),
auth=AuthConfigBase(),
):
# response that produced data
assert isinstance(page.response, Response)
Expand Down Expand Up @@ -183,3 +184,44 @@ def test_oauth_jwt_auth_success(self, rest_client: RESTClient):
)

assert_pagination(list(pages_iter))

def test_custom_auth_success(self, rest_client: RESTClient):
class CustomAuthConfigBase(AuthConfigBase):
def __init__(self, token: str):
self.token = token

def __call__(self, request: PreparedRequest) -> PreparedRequest:
request.headers["Authorization"] = f"Bearer {self.token}"
return request

class CustomAuthAuthBase(AuthBase):
def __init__(self, token: str):
self.token = token

def __call__(self, request: PreparedRequest) -> PreparedRequest:
request.headers["Authorization"] = f"Bearer {self.token}"
return request

auth_list = [
CustomAuthConfigBase("test-token"),
CustomAuthAuthBase("test-token"),
]

for auth in auth_list:
response = rest_client.get(
"/protected/posts/bearer-token",
auth=auth,
)

assert response.status_code == 200
assert response.json()["data"][0] == {"id": 0, "title": "Post 0"}

pages_iter = rest_client.paginate(
"/protected/posts/bearer-token",
auth=auth,
)

pages_list = list(pages_iter)
assert_pagination(pages_list)

assert pages_list[0].response.request.headers["Authorization"] == "Bearer test-token"