From 9364cacde8065d482e0151dd2a81f93e3ccf335a Mon Sep 17 00:00:00 2001 From: Jason Parraga Date: Fri, 22 Nov 2024 12:33:34 -0800 Subject: [PATCH] Actually refresh token and improve unit tests Signed-off-by: Jason Parraga --- flytekit/clients/auth/authenticator.py | 32 +++++++++++++ .../unit/clients/auth/test_authenticator.py | 48 ++++++++++++++++++- 2 files changed, 78 insertions(+), 2 deletions(-) diff --git a/flytekit/clients/auth/authenticator.py b/flytekit/clients/auth/authenticator.py index dc50f209ba..432718d1e7 100644 --- a/flytekit/clients/auth/authenticator.py +++ b/flytekit/clients/auth/authenticator.py @@ -272,6 +272,9 @@ def __init__( session: typing.Optional[requests.Session] = None, ): cfg = cfg_store.get_client_config() + self._auth_client = None + self._authorization_endpoint = cfg.authorization_endpoint + self._redirect_uri = cfg.redirect_uri self._audience = audience or cfg.audience self._client_id = cfg.client_id self._device_auth_endpoint = cfg.device_authorization_endpoint @@ -293,7 +296,36 @@ def __init__( verify=verify, ) + def _initialize_auth_client(self): + if not self._auth_client: + self._set_header_key(self._header_key) + self._auth_client = AuthorizationClient( + endpoint=self._endpoint, + redirect_uri=self._redirect_uri, + client_id=self._client_id, + audience=self._audience, + scopes=self._scopes, + auth_endpoint=self._authorization_endpoint, + token_endpoint=self._token_endpoint, + verify=self._verify, + session=self._session, + refresh_access_token_params={}, + ) + def refresh_credentials(self): + self._initialize_auth_client() + if self._creds: + """We have an access token so lets try to refresh it""" + try: + self._creds = self._auth_client.refresh_access_token(self._creds) + if self._creds: + KeyringStore.store(self._creds) + return + except AccessTokenNotFoundError: + logging.warning("Failed to refresh token. Kicking off a full authorization flow.") + KeyringStore.delete(self._endpoint) + + """Fall back to device flow""" resp = token_client.get_device_code( self._device_auth_endpoint, self._client_id, diff --git a/tests/flytekit/unit/clients/auth/test_authenticator.py b/tests/flytekit/unit/clients/auth/test_authenticator.py index f8e7559221..b7149d0cf5 100644 --- a/tests/flytekit/unit/clients/auth/test_authenticator.py +++ b/tests/flytekit/unit/clients/auth/test_authenticator.py @@ -12,7 +12,8 @@ PKCEAuthenticator, StaticClientConfigStore, ) -from flytekit.clients.auth.exceptions import AuthenticationError +from flytekit.clients.auth.exceptions import AuthenticationError, AccessTokenNotFoundError +from flytekit.clients.auth.keyring import Credentials from flytekit.clients.auth.token_client import DeviceCodeResponse ENDPOINT = "example.com" @@ -95,7 +96,8 @@ def test_client_creds_authenticator(mock_session): @patch("flytekit.clients.auth.authenticator.KeyringStore") @patch("flytekit.clients.auth.token_client.get_device_code") @patch("flytekit.clients.auth.token_client.poll_token_endpoint") -def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock, mock_keyring: MagicMock): +@patch("flytekit.clients.auth.auth_client.AuthorizationClient.refresh_access_token") +def test_device_flow_authenticator(mock_refresh: MagicMock, poll_mock: MagicMock, device_mock: MagicMock, mock_keyring: MagicMock): with pytest.raises(AuthenticationError): DeviceCodeAuthenticator(ENDPOINT, static_cfg_store, audience="x", verify=True) @@ -114,9 +116,51 @@ def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock, device_mock.return_value = DeviceCodeResponse("x", "y", "s", 1000, 0) poll_mock.return_value = ("access", "refresh", 100) + mock_refresh.side_effect = AccessTokenNotFoundError("test") # ensure refresh token fails + + authn.refresh_credentials() + assert authn._creds + + # assert calls made to mocks + poll_mock.assert_called() + device_mock.assert_called() + mock_refresh.assert_called() + +@patch("flytekit.clients.auth.authenticator.KeyringStore") +@patch("flytekit.clients.auth.token_client.get_device_code") +@patch("flytekit.clients.auth.token_client.poll_token_endpoint") +@patch("flytekit.clients.auth.auth_client.AuthorizationClient.refresh_access_token") +def test_device_flow_authenticator_refresh_token(mock_refresh: MagicMock, poll_mock: MagicMock, device_mock: MagicMock, mock_keyring: MagicMock): + with pytest.raises(AuthenticationError): + DeviceCodeAuthenticator(ENDPOINT, static_cfg_store, audience="x", verify=True) + + cfg_store = StaticClientConfigStore( + ClientConfig( + token_endpoint="token_endpoint", + authorization_endpoint="auth_endpoint", + redirect_uri="redirect_uri", + client_id="client", + device_authorization_endpoint="dev", + ) + ) + authn = DeviceCodeAuthenticator( + ENDPOINT, cfg_store, audience="x", http_proxy_url="http://my-proxy:9000", verify=False + ) + + mock_refresh.return_value = Credentials( + access_token="access", refresh_token="refresh", expires_in=100 + ) + authn.refresh_credentials() assert authn._creds + # Full login flow should not happen + poll_mock.assert_not_called() + device_mock.assert_not_called() + + # assert calls made to mocks + mock_refresh.assert_called() + @patch("flytekit.clients.auth.token_client.requests.Session") def test_client_creds_authenticator_with_custom_scopes(mock_session):