Skip to content

Commit

Permalink
Add RESTClient: (#1141)
Browse files Browse the repository at this point in the history
* Add RESTClient and tests

* Add PyJWT

* Add initial version of `rest_client.paginate()`

* Export `rest_client.paginate` to `helpers.requests` module

* Fix the typing error

* Use dlt.common.json

* Add dependency checks for PyJWT and cryptography in auth module

* Remove unused imports and check_connection function from rest_client utils

* Refactor pagination assertion into a standalone function

* Move `paginate` function test to new file `test_requests_paginate.py`

* Remove PyJWT from deps

* Remove explicit initializers and meta fields from configspec classes

* Implement lazy loading for jwt and cryptography in auth

* Set username default to None

* Add PyJWT to dev dependencies
  • Loading branch information
burnash authored Mar 25, 2024
1 parent 28434b6 commit b8fb7fd
Show file tree
Hide file tree
Showing 19 changed files with 1,857 additions and 33 deletions.
4 changes: 3 additions & 1 deletion dlt/sources/helpers/requests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from requests.exceptions import ChunkedEncodingError
from dlt.sources.helpers.requests.retry import Client
from dlt.sources.helpers.requests.session import Session
from dlt.sources.helpers.rest_client import paginate
from dlt.common.configuration.specs import RunConfiguration

client = Client()

get, post, put, patch, delete, options, head, request = (
get, post, put, patch, delete, options, head, request, paginate = (
client.get,
client.post,
client.put,
Expand All @@ -28,6 +29,7 @@
client.options,
client.head,
client.request,
paginate,
)


Expand Down
46 changes: 46 additions & 0 deletions dlt/sources/helpers/rest_client/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import Optional, Dict, Iterator, Union, Any

from dlt.common import jsonpath

from .client import RESTClient # noqa: F401
from .client import PageData
from .auth import AuthConfigBase
from .paginators import BasePaginator
from .typing import HTTPMethodBasic, Hooks


def paginate(
url: str,
method: HTTPMethodBasic = "GET",
headers: Optional[Dict[str, str]] = None,
params: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
auth: AuthConfigBase = None,
paginator: Optional[BasePaginator] = None,
data_selector: Optional[jsonpath.TJsonPath] = None,
hooks: Optional[Hooks] = None,
) -> Iterator[PageData[Any]]:
"""
Paginate over a REST API endpoint.
Args:
url: URL to paginate over.
**kwargs: Keyword arguments to pass to `RESTClient.paginate`.
Returns:
Iterator[Page]: Iterator over pages.
"""
client = RESTClient(
base_url=url,
headers=headers,
)
return client.paginate(
path="",
method=method,
params=params,
json=json,
auth=auth,
paginator=paginator,
data_selector=data_selector,
hooks=hooks,
)
215 changes: 215 additions & 0 deletions dlt/sources/helpers/rest_client/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
from base64 import b64encode
import math
from typing import (
List,
Dict,
Final,
Literal,
Optional,
Union,
Any,
cast,
Iterable,
TYPE_CHECKING,
)
from dlt.sources.helpers import requests
from requests.auth import AuthBase
from requests import PreparedRequest # noqa: I251
import pendulum

from dlt.common.exceptions import MissingDependencyException

from dlt.common import logger
from dlt.common.configuration.specs.base_configuration import configspec
from dlt.common.configuration.specs import CredentialsConfiguration
from dlt.common.configuration.specs.exceptions import NativeValueError
from dlt.common.typing import TSecretStrValue

if TYPE_CHECKING:
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
else:
PrivateKeyTypes = Any

TApiKeyLocation = Literal[
"header", "cookie", "query", "param"
] # Alias for scheme "in" field


class AuthConfigBase(AuthBase, CredentialsConfiguration):
"""Authenticator base which is both `requests` friendly AuthBase and dlt SPEC
configurable via env variables or toml files
"""

pass


@configspec
class BearerTokenAuth(AuthConfigBase):
token: TSecretStrValue = None

def parse_native_representation(self, value: Any) -> None:
if isinstance(value, str):
self.token = cast(TSecretStrValue, value)
else:
raise NativeValueError(
type(self),
value,
f"BearerTokenAuth token must be a string, got {type(value)}",
)

def __call__(self, request: PreparedRequest) -> PreparedRequest:
request.headers["Authorization"] = f"Bearer {self.token}"
return request


@configspec
class APIKeyAuth(AuthConfigBase):
name: str = "Authorization"
api_key: TSecretStrValue = None
location: TApiKeyLocation = "header"

def parse_native_representation(self, value: Any) -> None:
if isinstance(value, str):
self.api_key = cast(TSecretStrValue, value)
else:
raise NativeValueError(
type(self),
value,
f"APIKeyAuth api_key must be a string, got {type(value)}",
)

def __call__(self, request: PreparedRequest) -> PreparedRequest:
if self.location == "header":
request.headers[self.name] = self.api_key
elif self.location in ["query", "param"]:
request.prepare_url(request.url, {self.name: self.api_key})
elif self.location == "cookie":
raise NotImplementedError()
return request


@configspec
class HttpBasicAuth(AuthConfigBase):
username: str = None
password: TSecretStrValue = None

def parse_native_representation(self, value: Any) -> None:
if isinstance(value, Iterable) and not isinstance(value, str):
value = list(value)
if len(value) == 2:
self.username, self.password = value
return
raise NativeValueError(
type(self),
value,
f"HttpBasicAuth username and password must be a tuple of two strings, got {type(value)}",
)

def __call__(self, request: PreparedRequest) -> PreparedRequest:
encoded = b64encode(f"{self.username}:{self.password}".encode()).decode()
request.headers["Authorization"] = f"Basic {encoded}"
return request


@configspec
class OAuth2AuthBase(AuthConfigBase):
"""Base class for oauth2 authenticators. requires access_token"""

# TODO: Separate class for flows (implicit, authorization_code, client_credentials, etc)
access_token: TSecretStrValue = None

def parse_native_representation(self, value: Any) -> None:
if isinstance(value, str):
self.access_token = cast(TSecretStrValue, value)
else:
raise NativeValueError(
type(self),
value,
f"OAuth2AuthBase access_token must be a string, got {type(value)}",
)

def __call__(self, request: PreparedRequest) -> PreparedRequest:
request.headers["Authorization"] = f"Bearer {self.access_token}"
return request


@configspec
class OAuthJWTAuth(BearerTokenAuth):
"""This is a form of Bearer auth, actually there's not standard way to declare it in openAPI"""

format: Final[Literal["JWT"]] = "JWT" # noqa: A003
client_id: str = None
private_key: TSecretStrValue = None
auth_endpoint: str = None
scopes: Optional[Union[str, List[str]]] = None
headers: Optional[Dict[str, str]] = None
private_key_passphrase: Optional[TSecretStrValue] = None
default_token_expiration: int = 3600

def __post_init__(self) -> None:
self.scopes = (
self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes)
)
self.token = None
self.token_expiry: Optional[pendulum.DateTime] = None

def __call__(self, r: PreparedRequest) -> PreparedRequest:
if self.token is None or self.is_token_expired():
self.obtain_token()
r.headers["Authorization"] = f"Bearer {self.token}"
return r

def is_token_expired(self) -> bool:
return not self.token_expiry or pendulum.now() >= self.token_expiry

def obtain_token(self) -> None:
try:
import jwt
except ModuleNotFoundError:
raise MissingDependencyException("dlt OAuth helpers", ["PyJWT"])

payload = self.create_jwt_payload()
data = {
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
"assertion": jwt.encode(
payload, self.load_private_key(), algorithm="RS256"
),
}

logger.debug(f"Obtaining token from {self.auth_endpoint}")

response = requests.post(self.auth_endpoint, headers=self.headers, data=data)
response.raise_for_status()

token_response = response.json()
self.token = token_response["access_token"]
self.token_expiry = pendulum.now().add(
seconds=token_response.get("expires_in", self.default_token_expiration)
)

def create_jwt_payload(self) -> Dict[str, Union[str, int]]:
now = pendulum.now()
return {
"iss": self.client_id,
"sub": self.client_id,
"aud": self.auth_endpoint,
"exp": math.floor((now.add(hours=1)).timestamp()),
"iat": math.floor(now.timestamp()),
"scope": cast(str, self.scopes),
}

def load_private_key(self) -> "PrivateKeyTypes":
try:
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
except ModuleNotFoundError:
raise MissingDependencyException("dlt OAuth helpers", ["cryptography"])

private_key_bytes = self.private_key.encode("utf-8")
return serialization.load_pem_private_key(
private_key_bytes,
password=self.private_key_passphrase.encode("utf-8")
if self.private_key_passphrase
else None,
backend=default_backend(),
)
Loading

0 comments on commit b8fb7fd

Please sign in to comment.