Skip to content

Commit

Permalink
Re-authenticate requests that failed authentication
Browse files Browse the repository at this point in the history
  • Loading branch information
stefannica committed Dec 9, 2024
1 parent ee3e962 commit d15e791
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 58 deletions.
38 changes: 19 additions & 19 deletions src/zenml/cli/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,25 +587,6 @@ def server_list(verbose: bool = False, all: bool = False) -> None:
accessible_pro_servers = client.tenant.list(member_only=not all)
except AuthorizationException as e:
cli_utils.warning(f"ZenML Pro authorization error: {e}")
else:
if not all:
accessible_pro_servers = [
s
for s in accessible_pro_servers
if s.status == TenantStatus.AVAILABLE
]

if not accessible_pro_servers:
cli_utils.declare(
"No ZenML Pro servers that are accessible to the current "
"user could be found."
)
if not all:
cli_utils.declare(
"Hint: use the `--all` flag to show all ZenML servers, "
"including those that the client is not currently "
"authorized to access or are not running."
)

# We update the list of stored ZenML Pro servers with the ones that the
# client is a member of
Expand Down Expand Up @@ -633,6 +614,25 @@ def server_list(verbose: bool = False, all: bool = False) -> None:
stored_server.update_server_info(accessible_server)
pro_servers.append(stored_server)

if not all:
accessible_pro_servers = [
s
for s in accessible_pro_servers
if s.status == TenantStatus.AVAILABLE
]

if not accessible_pro_servers:
cli_utils.declare(
"No ZenML Pro servers that are accessible to the current "
"user could be found."
)
if not all:
cli_utils.declare(
"Hint: use the `--all` flag to show all ZenML servers, "
"including those that the client is not currently "
"authorized to access or are not running."
)

elif pro_servers:
cli_utils.warning(
"The ZenML Pro authentication has expired. Please re-login "
Expand Down
106 changes: 67 additions & 39 deletions src/zenml/zen_stores/rest_zen_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4349,46 +4349,74 @@ def _request(
{source_context.name: source_context.get().value}
)

try:
return self._handle_response(
self.session.request(
method,
url,
params=params,
verify=self.config.verify_ssl,
timeout=timeout or self.config.http_timeout,
**kwargs,
)
)
except CredentialsNotValid:
# NOTE: CredentialsNotValid is raised only when the server
# explicitly indicates that the credentials are not valid and they
# can be thrown away.

# We authenticate or re-authenticate here and then try the request
# again, this time with a valid API token in the header.
self.authenticate(
# If the last request was authenticated with an API token,
# we force a re-authentication to get a fresh token.
force=self._api_token is not None
)

try:
return self._handle_response(
self.session.request(
method,
url,
params=params,
verify=self.config.verify_ssl,
timeout=self.config.http_timeout,
**kwargs,
# If the server replies with a credentials validation (401 Unauthorized)
# error, we (re-)authenticate and retry the request here in the
# following cases:
#
# 1. initial authentication: the last request was not authenticated
# with an API token.
# 2. re-authentication: the last request was authenticated with an API
# token that was rejected by the server. This is to cover the case
# of expired tokens that can be refreshed by the client automatically
# without user intervention from other sources (e.g. API keys).
#
# NOTE: it can happen that the same request is retried here for up to
# two times: once after initial authentication and once after
# re-authentication.
re_authenticated = False
while True:
try:
return self._handle_response(
self.session.request(
method,
url,
params=params,
verify=self.config.verify_ssl,
timeout=timeout or self.config.http_timeout,
**kwargs,
)
)
)
except CredentialsNotValid as e:
raise CredentialsNotValid(
"The current credentials are no longer valid. Please log in "
"again using 'zenml login'."
) from e
except CredentialsNotValid as e:
# NOTE: CredentialsNotValid is raised only when the server
# explicitly indicates that the credentials are not valid and
# they can be thrown away or when the request is not
# authenticated at all.

if self._api_token is None:
# The last request was not authenticated with an API
# token at all. We authenticate here and then try the
# request again, this time with a valid API token in the
# header.
logger.debug(
f"The last request was not authenticated: {e}\n"
"Re-authenticating and retrying..."
)
self.authenticate()
elif not re_authenticated:
# The last request was authenticated with an API token
# that was rejected by the server. We attempt a
# re-authentication here and then retry the request.
logger.debug(
"The last request was authenticated with an API token "
f"that was rejected by the server: {e}\n"
"Re-authenticating and retrying..."
)
re_authenticated = True
self.authenticate(
# Ignore the current token and force a re-authentication
force=True
)
else:
# The last request was made after re-authenticating but
# still failed. Bailing out.
logger.debug(
f"The last request failed after re-authenticating: {e}\n"
"Bailing out..."
)
raise CredentialsNotValid(
"The current credentials are no longer valid. Please "
"log in again using 'zenml login'."
) from e

def get(
self,
Expand Down

0 comments on commit d15e791

Please sign in to comment.