Skip to content

Commit

Permalink
fix: few fixes on token exchanges & first tests - wip
Browse files Browse the repository at this point in the history
  • Loading branch information
julien-pinchelimouroux committed Dec 18, 2024
1 parent c22b5e0 commit 149ff12
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 43 deletions.
135 changes: 135 additions & 0 deletions tests/http_api/test_authentication_configs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from datetime import datetime, timedelta
from typing import Any

import pytest
from pytest_mock import MockFixture

from toucan_connectors.http_api.authentication_configs import HttpOauth2SecretsKeeper, OAuth2Config, OAuth2SecretData


@pytest.fixture
def client_id() -> str:
return "client_id_xxx_123"


@pytest.fixture
def client_secret() -> str:
return "client_secret_yyy_345"


@pytest.fixture
def authentication_url() -> str:
return "https//oauth2.backend.server/authorize"


@pytest.fixture
def token_url() -> str:
return "https//oauth2.backend.server/fetch/token"


@pytest.fixture
def scope() -> str:
return "ACCESS::READ, ACCESS::WRITE"


@pytest.fixture
def oauth2_authentication_config(
client_id: str, client_secret: str, authentication_url: str, token_url: str, scope: str
) -> OAuth2Config:
return OAuth2Config(
kind="OAuth2Config",
client_id=client_id,
client_secret=client_secret,
authentication_url=authentication_url,
token_url=token_url,
scope=scope,
)


class KeyAlreadyExistsInDatabase(Exception):
pass


_SAVED_CONTENT = {}


def _fake_saver(key: str, value: dict[str, Any]) -> None:
if key in _SAVED_CONTENT:
raise KeyAlreadyExistsInDatabase(f"key={key}")
_SAVED_CONTENT[key] = value


def _fake_loader(key: str) -> dict[str, Any] | None:
return _SAVED_CONTENT.get(key, None)


def _fake_remover(key: str) -> None:
if key in _SAVED_CONTENT:
_SAVED_CONTENT.pop(key)


@pytest.fixture
def secret_keeper() -> HttpOauth2SecretsKeeper:
return HttpOauth2SecretsKeeper(save_callback=_fake_saver, delete_callback=_fake_remover, load_callback=_fake_loader)


def test_secret_names(oauth2_authentication_config: OAuth2Config) -> None:
assert oauth2_authentication_config.secrets_names() == ["client_secret"]


def test_authenticate_session_with_valid_access_token(
oauth2_authentication_config: OAuth2Config, secret_keeper: HttpOauth2SecretsKeeper
) -> None:
auth_flow_id = "my_secret_key"
oauth2_authentication_config.set_secret_keeper(secret_keeper=secret_keeper)
oauth2_authentication_config.set_auth_flow_id(auth_flow_id=auth_flow_id)

# Create and save fake secret data
secrets = OAuth2SecretData(
access_token="my_awesome_token",
refresh_token="my_awesome_refresh_token",
expires_at=(datetime.now() + timedelta(0, 3600)).timestamp(),
).model_dump()
secret_keeper.save(auth_flow_id, secrets)

retrieved_session = oauth2_authentication_config.authenticate_session()
assert retrieved_session.headers["Authorization"] == "Bearer my_awesome_token"


def test_authenticate_session_with_expired_access_token(
oauth2_authentication_config: OAuth2Config,
secret_keeper: HttpOauth2SecretsKeeper,
client_id: str,
client_secret: str,
token_url: str,
mocker: MockFixture,
) -> None:
auth_flow_id = "my_secret_key"
oauth2_authentication_config.set_secret_keeper(secret_keeper=secret_keeper)
oauth2_authentication_config.set_auth_flow_id(auth_flow_id=auth_flow_id)

# Create and save fake secret data
secrets = OAuth2SecretData(
access_token="my_awesome_token",
refresh_token="my_awesome_refresh_token",
expires_at=(datetime.now() - timedelta(0, 3600)).timestamp(),
).model_dump()
secret_keeper.save(auth_flow_id, secrets)

# Expects a call to refresh the expired access token
mocked_client = mocker.MagicMock(name="mocked_client")
mocked_client.refresh_token.return_value = {
"access_token": "new_access_token",
"expires_in": 3920,
"scope": "ACCESS::READ, ACCESS::WRITE",
"token_type": "Bearer",
}
mock = mocker.patch("toucan_connectors.http_api.authentication_configs.oauth_client", return_value=mocked_client)
retrieved_session = oauth2_authentication_config.authenticate_session()
assert mock.call_count == 1
assert mock.call_args[1] == {"client_id": client_id, "client_secret": client_secret}
assert mocked_client.refresh_token.call_count == 1
assert mocked_client.refresh_token.call_args[0][0] == token_url
assert mocked_client.refresh_token.call_args[1]["refresh_token"] == "my_awesome_refresh_token"

assert retrieved_session.headers["Authorization"] == "Bearer new_access_token"
92 changes: 56 additions & 36 deletions toucan_connectors/http_api/authentication_configs.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import logging
from abc import ABC, abstractmethod
from typing import Any, Literal, Callable

from pydantic import BaseModel, Field, SecretStr, ValidationError
from datetime import datetime, timedelta
from typing import Any, Callable, Literal
from urllib import parse as url_parse

from toucan_connectors.common import UI_HIDDEN
from pydantic import BaseModel, Field, SecretStr, ValidationError
from requests import Session

from toucan_connectors.common import UI_HIDDEN
from toucan_connectors.json_wrapper import JsonWrapper
from toucan_connectors.oauth2_connector.oauth2connector import oauth_client

Expand All @@ -32,20 +32,15 @@ def _extract_expires_at_from_token_response(token_response: dict[str, Any], defa
try:
return datetime.strptime(token_response["expires_at"], "%FT%T")
except ValueError as exc:
_LOGGER.error(
"Can't parse the oauth2 token expiration date",
exc_info=exc
)
_LOGGER.error("Can't parse the oauth2 token expiration date", exc_info=exc)
raise
return datetime.fromtimestamp(int(token_response["expires_at"]))
else:
return (
datetime.now() + timedelta(0, int(token_response.get("expires_in", default_lifetime)))
)
return datetime.now() + timedelta(0, int(token_response.get("expires_in", default_lifetime)))


class OAuth2SecretData(BaseModel):
workflow_token: str
workflow_token: str | None = None
access_token: str
refresh_token: str
expires_at: float
Expand All @@ -57,15 +52,14 @@ class HttpOauth2SecretsKeeper(BaseModel):
delete_callback: Callable[[str], None]
load_callback: Callable[[str], dict[str, Any] | None]

def save(self, key: str, value: dict) -> None:
def save(self, key: str, value: dict[str, Any]) -> None:
"""Save secrets in a secrets repository"""
expires_at = _extract_expires_at_from_token_response(
token_response=value,
default_lifetime=self.token_lifetime_seconds
)
_LOGGER.info(f"EXPIRES AT: {expires_at}")
value["expires_at"] = _extract_expires_at_from_token_response(
token_response=value, default_lifetime=self.default_token_lifetime_seconds
).timestamp()
try:
secret_data = OAuth2SecretData(**value, expires_at=expires_at.timestamp())
_LOGGER.info(f"DATA = {value}")
secret_data = OAuth2SecretData(**value)
except ValidationError as exc:
_LOGGER.error(f"Can't instantiate oauth secret data with value_keys={list(value.keys())}", exc_info=exc)
raise
Expand All @@ -75,9 +69,9 @@ def save(self, key: str, value: dict) -> None:
# save new secrets
self.save_callback(key, secret_data.model_dump())

def load(self, key: str) -> Any:
def load(self, key: str) -> dict[str, Any] | None:
"""Load secrets from the secrets repository"""
self.load_callback(key)
return self.load_callback(key)


class AuthenticationConfig(BaseModel, ABC):
Expand All @@ -95,10 +89,17 @@ def is_oauth_config() -> bool:

class BaseOAuth2Config(AuthenticationConfig, ABC):
"""Base class for OAuth2 authentication configs"""

client_id: str
authentication_url: str
token_url: str
scope: str
additional_auth_params: dict = Field(
default_factory=dict,
title="Additional authentication params",
description="A JSON object that represents additional arguments that must be passed as query params"
" to the Oauth2 backend during token exchanges",
)

# Mandatory hidden fields for oauth2 dance which must be setup by the backend and not by the end-user
_auth_flow_id: str | None = None
Expand All @@ -121,12 +122,23 @@ def retrieve_token(self, response_params: dict[str, Any]):
def secrets_names(self) -> list[str]:
pass

def set_secret_keeper(self, secret_keeper: HttpOauth2SecretsKeeper):
self._secrets_keeper = secret_keeper

def set_redirect_uri(self, redirect_uri: str):
self._redirect_uri = redirect_uri

def set_auth_flow_id(self, auth_flow_id: str):
self._auth_flow_id = auth_flow_id


class OAuth2Config(BaseOAuth2Config):
"""Authorization code configuration type"""

kind: Literal["OAuth2Config"] = Field(..., **UI_HIDDEN)

# Allows to instantiate authentication config without secrets
client_secret: SecretStr | None
client_secret: SecretStr | None = None

def secrets_names(self) -> list[str]:
return ["client_secret"]
Expand All @@ -150,11 +162,14 @@ def build_authorization_uri(self, **kwargs):
redirect_uri=self._redirect_uri,
scope=self.scope,
)
state = {"workflow_token": generate_token(), **kwargs}
uri, state = client.create_authorization_url(self.authorization_url, state=JsonWrapper.dumps(state))
_LOGGER.info(f"AUTHORIZATION URL: {uri}")
workflow_token = generate_token()
state = {"workflow_token": workflow_token, **kwargs}
uri, state = client.create_authorization_url(
self.authentication_url, state=JsonWrapper.dumps(state), **self.additional_auth_params
)

tmp_oauth_secrets = {
"workflow_token": state["workflow_token"],
"workflow_token": workflow_token,
"access_token": "__UNKNOWN__",
"refresh_token": "__UNKNOWN__",
}
Expand All @@ -172,32 +187,37 @@ def _get_access_token(self):
expires_at = oauth_token_info["expires_at"]
if datetime.fromtimestamp(expires_at) < datetime.now():
client = oauth_client(
client_id=self.config.client_id,
client_id=self.client_id,
client_secret=self.client_secret.get_secret_value(),
)
new_token = client.refresh_token(self.token_url, refresh_token=oauth_token_info["refresh_token"])
self._secrets_keeper.save(self._auth_flow_id, new_token)
self._secrets_keeper.save(
self._auth_flow_id,
# refresh call doesn't always contain the refresh_token
new_token | {"refresh_token": oauth_token_info["refresh_token"]},
)
return self._secrets_keeper.load(self._auth_flow_id)["access_token"]

def retrieve_token(self, response_params: dict[str, Any]):
def retrieve_token(self, authorization_response: str):
url = url_parse.urlparse(authorization_response)
url_params = url_parse.parse_qs(url.query)
client = oauth_client(
client_id=self.config.client_id,
client_id=self.client_id,
client_secret=self.client_secret.get_secret_value(),
redirect_uri=self._redirect_uri,
)
saved_flow = self._secrets_keeper.load(self._auth_flow_id)
if saved_flow is None:
raise MissingOauthWorkflowError()

# Verify the oauth2 workflow token
assert JsonWrapper.loads(saved_flow["state"])["token"] == JsonWrapper.loads(
response_params["state"][0]
)["token"]
assert saved_flow["workflow_token"] == JsonWrapper.loads(url_params["state"][0])["workflow_token"]

token = client.fetch_token(
self.token_url,
client_id=self.config.client_id,
client_secret=self.client_secret.get_secret_value(),
authorization_response=authorization_response,
# Some oauth applications needs redirect_uri in fetch_token params.
# authorization_response does not carry it natively.
body=url_parse.urlencode({"redirect_uri": self._redirect_uri}),
)
self._secrets_keeper.save(self._auth_flow_id, token)

Expand Down
6 changes: 2 additions & 4 deletions toucan_connectors/http_api/http_api_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from pydantic import AnyHttpUrl, BaseModel, Field, FilePath

from toucan_connectors.http_api.authentication_configs import AuthenticationConfig
from toucan_connectors.http_api.http_api_data_source import HttpAPIDataSource, apply_pagination_to_data_source

try:
Expand All @@ -16,13 +15,12 @@
from requests.exceptions import HTTPError
from xmltodict import parse

from toucan_connectors.http_api.authentication_configs import HttpAuthenticationConfig
from toucan_connectors.http_api.pagination_configs import (
NoopPaginationConfig,
extract_pagination_info_from_result,
)

from toucan_connectors.http_api.authentication_configs import HttpAuthenticationConfig

CONNECTOR_OK = True
except ImportError as exc: # pragma: no cover
getLogger(__name__).warning(f"Missing dependencies for {__name__}: {exc}")
Expand Down Expand Up @@ -86,7 +84,7 @@ class HttpAPIConnector(ToucanConnector, data_source_model=HttpAPIDataSource):
None,
title="Authentication type",
deprecated=True,
description="Deprecated authentication config. Please use 'Authentication' section."
description="Deprecated authentication config. Please use 'Authentication' section.",
) # Deprecated

authentication: HttpAuthenticationConfig | None = Field(
Expand Down
7 changes: 6 additions & 1 deletion toucan_connectors/json_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import json
import logging
from enum import Enum

from pydantic import SecretStr
from pydantic import AnyHttpUrl, SecretStr


def custom_json_serializer(obj):
if isinstance(obj, SecretStr):
return obj.get_secret_value()
if isinstance(obj, Enum):
return obj.value
if isinstance(obj, AnyHttpUrl):
return str(obj)
return obj


Expand Down
4 changes: 2 additions & 2 deletions toucan_connectors/oauth2_connector/oauth2connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def get_access_token(self) -> str:
if is_expired:
if "refresh_token" not in token:
raise NoOAuth2RefreshToken
client = _client(
client = oauth_client(
client_id=self.config.client_id,
client_secret=self.config.client_secret.get_secret_value(),
)
Expand All @@ -148,7 +148,7 @@ def get_access_data(self):
if "instance_url" not in access_data:
raise NoInstanceUrl

client = _client(
client = oauth_client(
client_id=self.config.client_id,
client_secret=self.config.client_secret.get_secret_value(),
)
Expand Down

0 comments on commit 149ff12

Please sign in to comment.