Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RESTClient #1141

Merged
merged 15 commits into from
Mar 25, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 14 additions & 55 deletions dlt/sources/helpers/rest_client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,7 @@ class AuthConfigBase(AuthBase, CredentialsConfiguration):

@configspec
class BearerTokenAuth(AuthConfigBase):
type: Final[Literal["http"]] = "http" # noqa: A003
scheme: Literal["bearer"] = "bearer"
token: TSecretStrValue

def __init__(self, token: TSecretStrValue = secrets.value) -> None:
self.token = token
token: TSecretStrValue = None

def parse_native_representation(self, value: Any) -> None:
if isinstance(value, str):
Expand All @@ -67,21 +62,10 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest:

@configspec
class APIKeyAuth(AuthConfigBase):
type: Final[Literal["apiKey"]] = "apiKey" # noqa: A003
name: str = "Authorization"
api_key: TSecretStrValue
api_key: TSecretStrValue = None
location: TApiKeyLocation = "header"

def __init__(
self,
name: str = config.value,
api_key: TSecretStrValue = secrets.value,
location: TApiKeyLocation = "header",
) -> None:
self.name = name
self.api_key = api_key
self.location = location

def parse_native_representation(self, value: Any) -> None:
if isinstance(value, str):
self.api_key = cast(TSecretStrValue, value)
Expand All @@ -104,16 +88,8 @@ def __call__(self, request: PreparedRequest) -> PreparedRequest:

@configspec
class HttpBasicAuth(AuthConfigBase):
type: Final[Literal["http"]] = "http" # noqa: A003
scheme: Literal["basic"] = "basic"
username: str
password: TSecretStrValue

def __init__(
self, username: str = config.value, password: TSecretStrValue = secrets.value
) -> None:
self.username = username
self.password = password
username: str = ""
burnash marked this conversation as resolved.
Show resolved Hide resolved
password: TSecretStrValue = None

def parse_native_representation(self, value: Any) -> None:
if isinstance(value, Iterable) and not isinstance(value, str):
Expand All @@ -138,11 +114,7 @@ class OAuth2AuthBase(AuthConfigBase):
"""Base class for oauth2 authenticators. requires access_token"""

# TODO: Separate class for flows (implicit, authorization_code, client_credentials, etc)
type: Final[Literal["oauth2"]] = "oauth2" # noqa: A003
access_token: TSecretStrValue

def __init__(self, access_token: TSecretStrValue = secrets.value) -> None:
self.access_token = access_token
access_token: TSecretStrValue = None

def parse_native_representation(self, value: Any) -> None:
if isinstance(value, str):
Expand All @@ -164,33 +136,20 @@ class OAuthJWTAuth(BearerTokenAuth):
"""This is a form of Bearer auth, actually there's not standard way to declare it in openAPI"""

format: Final[Literal["JWT"]] = "JWT" # noqa: A003
client_id: str
private_key: TSecretStrValue
auth_endpoint: str
scopes: Optional[str] = None
client_id: str = None
private_key: TSecretStrValue = None
auth_endpoint: str = None
scopes: Optional[Union[str, List[str]]] = None
headers: Optional[Dict[str, str]] = None
private_key_passphrase: Optional[TSecretStrValue] = None
default_token_expiration: int = 3600

def __init__(
self,
client_id: str = config.value,
private_key: TSecretStrValue = secrets.value,
auth_endpoint: str = config.value,
scopes: Optional[Union[str, List[str]]] = None,
headers: Optional[Dict[str, str]] = None,
private_key_passphrase: Optional[TSecretStrValue] = None,
default_token_expiration: int = 3600,
):
self.client_id = client_id
self.private_key = private_key
self.private_key_passphrase = private_key_passphrase
self.auth_endpoint = auth_endpoint
self.scopes = scopes if isinstance(scopes, str) else " ".join(scopes)
self.headers = headers
def __post_init__(self) -> None:
self.scopes = (
self.scopes if isinstance(self.scopes, str) else " ".join(self.scopes)
)
self.token = None
self.token_expiry: Optional[pendulum.DateTime] = None
self.default_token_expiration = default_token_expiration

def __call__(self, r: PreparedRequest) -> PreparedRequest:
if self.token is None or self.is_token_expired():
Expand Down Expand Up @@ -229,7 +188,7 @@ def create_jwt_payload(self) -> Dict[str, Union[str, int]]:
"aud": self.auth_endpoint,
"exp": math.floor((now.add(hours=1)).timestamp()),
"iat": math.floor(now.timestamp()),
"scope": self.scopes,
"scope": cast(str, self.scopes),
}

def load_private_key(self) -> PrivateKeyTypes:
Expand Down
Loading