Skip to content

Commit

Permalink
RESTClient: implement AuthConfigBase.__bool__ + update docs (#1398)
Browse files Browse the repository at this point in the history
* Fix AuthConfigBase so its instances always evaluate to True in bool context; change docs to suggest direct inheritance from AuthBase

* Add tests
  • Loading branch information
burnash authored May 24, 2024
1 parent 7c07c67 commit ceb229d
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 6 deletions.
6 changes: 5 additions & 1 deletion dlt/sources/helpers/rest_client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ class AuthConfigBase(AuthBase, CredentialsConfiguration):
configurable via env variables or toml files
"""

pass
def __bool__(self) -> bool:
# This is needed to avoid AuthConfigBase-derived classes
# which do not implement CredentialsConfiguration interface
# to be evaluated as False in requests.sessions.Session.prepare_request()
return True


@configspec
Expand Down
8 changes: 4 additions & 4 deletions docs/website/docs/general-usage/http/rest-client.md
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ The available authentication methods are defined in the `dlt.sources.helpers.res
- [APIKeyAuth](#api-key-authentication)
- [HttpBasicAuth](#http-basic-authentication)

For specific use cases, you can [implement custom authentication](#implementing-custom-authentication) by subclassing the `AuthConfigBase` class.
For specific use cases, you can [implement custom authentication](#implementing-custom-authentication) by subclassing the `AuthBase` class from the Requests library.

### Bearer token authentication

Expand Down Expand Up @@ -479,12 +479,12 @@ response = client.get("/protected/resource")

### Implementing custom authentication

You can implement custom authentication by subclassing the `AuthConfigBase` class and implementing the `__call__` method:
You can implement custom authentication by subclassing the `AuthBase` class and implementing the `__call__` method:

```py
from dlt.sources.helpers.rest_client.auth import AuthConfigBase
from requests.auth import AuthBase

class CustomAuth(AuthConfigBase):
class CustomAuth(AuthBase):
def __init__(self, token):
self.token = token

Expand Down
43 changes: 42 additions & 1 deletion tests/sources/helpers/rest_client/test_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import pytest
from typing import Any, cast
from requests.auth import AuthBase
from dlt.common.typing import TSecretStrValue
from dlt.sources.helpers.requests import Response, Request
from dlt.sources.helpers.rest_client import RESTClient
Expand Down Expand Up @@ -57,7 +58,6 @@ def test_page_context(self, rest_client: RESTClient) -> None:
for page in rest_client.paginate(
"/posts",
paginator=JSONResponsePaginator(next_url_path="next_page"),
auth=AuthConfigBase(),
):
# response that produced data
assert isinstance(page.response, Response)
Expand Down Expand Up @@ -183,3 +183,44 @@ def test_oauth_jwt_auth_success(self, rest_client: RESTClient):
)

assert_pagination(list(pages_iter))

def test_custom_auth_success(self, rest_client: RESTClient):
class CustomAuthConfigBase(AuthConfigBase):
def __init__(self, token: str):
self.token = token

def __call__(self, request: Request) -> Request:
request.headers["Authorization"] = f"Bearer {self.token}"
return request

class CustomAuthAuthBase(AuthBase):
def __init__(self, token: str):
self.token = token

def __call__(self, request: Request) -> Request:
request.headers["Authorization"] = f"Bearer {self.token}"
return request

auth_list = [
CustomAuthConfigBase("test-token"),
CustomAuthAuthBase("test-token"),
]

for auth in auth_list:
response = rest_client.get(
"/protected/posts/bearer-token",
auth=auth,
)

assert response.status_code == 200
assert response.json()["data"][0] == {"id": 0, "title": "Post 0"}

pages_iter = rest_client.paginate(
"/protected/posts/bearer-token",
auth=auth,
)

pages_list = list(pages_iter)
assert_pagination(pages_list)

assert pages_list[0].response.request.headers["Authorization"] == "Bearer test-token"

0 comments on commit ceb229d

Please sign in to comment.