Skip to content

Commit

Permalink
Actually refresh token and improve unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Jason Parraga <[email protected]>
  • Loading branch information
Sovietaced committed Nov 22, 2024
1 parent 5e76d3b commit 9364cac
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 2 deletions.
32 changes: 32 additions & 0 deletions flytekit/clients/auth/authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,9 @@ def __init__(
session: typing.Optional[requests.Session] = None,
):
cfg = cfg_store.get_client_config()
self._auth_client = None
self._authorization_endpoint = cfg.authorization_endpoint
self._redirect_uri = cfg.redirect_uri
self._audience = audience or cfg.audience
self._client_id = cfg.client_id
self._device_auth_endpoint = cfg.device_authorization_endpoint
Expand All @@ -293,7 +296,36 @@ def __init__(
verify=verify,
)

def _initialize_auth_client(self):
if not self._auth_client:
self._set_header_key(self._header_key)
self._auth_client = AuthorizationClient(
endpoint=self._endpoint,
redirect_uri=self._redirect_uri,
client_id=self._client_id,
audience=self._audience,
scopes=self._scopes,
auth_endpoint=self._authorization_endpoint,
token_endpoint=self._token_endpoint,
verify=self._verify,
session=self._session,
refresh_access_token_params={},
)

def refresh_credentials(self):
self._initialize_auth_client()
if self._creds:
"""We have an access token so lets try to refresh it"""
try:
self._creds = self._auth_client.refresh_access_token(self._creds)
if self._creds:
KeyringStore.store(self._creds)
return
except AccessTokenNotFoundError:
logging.warning("Failed to refresh token. Kicking off a full authorization flow.")
KeyringStore.delete(self._endpoint)

"""Fall back to device flow"""
resp = token_client.get_device_code(
self._device_auth_endpoint,
self._client_id,
Expand Down
48 changes: 46 additions & 2 deletions tests/flytekit/unit/clients/auth/test_authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
PKCEAuthenticator,
StaticClientConfigStore,
)
from flytekit.clients.auth.exceptions import AuthenticationError
from flytekit.clients.auth.exceptions import AuthenticationError, AccessTokenNotFoundError
from flytekit.clients.auth.keyring import Credentials
from flytekit.clients.auth.token_client import DeviceCodeResponse

ENDPOINT = "example.com"
Expand Down Expand Up @@ -95,7 +96,8 @@ def test_client_creds_authenticator(mock_session):
@patch("flytekit.clients.auth.authenticator.KeyringStore")
@patch("flytekit.clients.auth.token_client.get_device_code")
@patch("flytekit.clients.auth.token_client.poll_token_endpoint")
def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock, mock_keyring: MagicMock):
@patch("flytekit.clients.auth.auth_client.AuthorizationClient.refresh_access_token")
def test_device_flow_authenticator(mock_refresh: MagicMock, poll_mock: MagicMock, device_mock: MagicMock, mock_keyring: MagicMock):
with pytest.raises(AuthenticationError):
DeviceCodeAuthenticator(ENDPOINT, static_cfg_store, audience="x", verify=True)

Expand All @@ -114,9 +116,51 @@ def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock,

device_mock.return_value = DeviceCodeResponse("x", "y", "s", 1000, 0)
poll_mock.return_value = ("access", "refresh", 100)
mock_refresh.side_effect = AccessTokenNotFoundError("test") # ensure refresh token fails

authn.refresh_credentials()
assert authn._creds

# assert calls made to mocks
poll_mock.assert_called()
device_mock.assert_called()
mock_refresh.assert_called()

@patch("flytekit.clients.auth.authenticator.KeyringStore")
@patch("flytekit.clients.auth.token_client.get_device_code")
@patch("flytekit.clients.auth.token_client.poll_token_endpoint")
@patch("flytekit.clients.auth.auth_client.AuthorizationClient.refresh_access_token")
def test_device_flow_authenticator_refresh_token(mock_refresh: MagicMock, poll_mock: MagicMock, device_mock: MagicMock, mock_keyring: MagicMock):
with pytest.raises(AuthenticationError):
DeviceCodeAuthenticator(ENDPOINT, static_cfg_store, audience="x", verify=True)

cfg_store = StaticClientConfigStore(
ClientConfig(
token_endpoint="token_endpoint",
authorization_endpoint="auth_endpoint",
redirect_uri="redirect_uri",
client_id="client",
device_authorization_endpoint="dev",
)
)
authn = DeviceCodeAuthenticator(
ENDPOINT, cfg_store, audience="x", http_proxy_url="http://my-proxy:9000", verify=False
)

mock_refresh.return_value = Credentials(
access_token="access", refresh_token="refresh", expires_in=100
)

authn.refresh_credentials()
assert authn._creds

# Full login flow should not happen
poll_mock.assert_not_called()
device_mock.assert_not_called()

# assert calls made to mocks
mock_refresh.assert_called()


@patch("flytekit.clients.auth.token_client.requests.Session")
def test_client_creds_authenticator_with_custom_scopes(mock_session):
Expand Down

0 comments on commit 9364cac

Please sign in to comment.