diff --git a/src/azure-cli-core/azure/cli/core/_profile.py b/src/azure-cli-core/azure/cli/core/_profile.py index c83fa9e22c6..7b5cc6c750d 100644 --- a/src/azure-cli-core/azure/cli/core/_profile.py +++ b/src/azure-cli-core/azure/cli/core/_profile.py @@ -862,7 +862,12 @@ def _create_identity_instance(cli_ctx, *args, **kwargs): # Only enable encryption for Windows (for now). fallback = sys.platform.startswith('win32') + + # EXPERIMENTAL: Use core.encrypt_token_cache=False to turn off token cache encryption. # encrypt_token_cache affects both MSAL token cache and service principal entries. encrypt = cli_ctx.config.getboolean('core', 'encrypt_token_cache', fallback=fallback) - return Identity(*args, encrypt=encrypt, **kwargs) + # EXPERIMENTAL: Use core.use_msal_http_cache=False to turn off MSAL HTTP cache. + use_msal_http_cache = cli_ctx.config.getboolean('core', 'use_msal_http_cache', fallback=True) + + return Identity(*args, encrypt=encrypt, use_msal_http_cache=use_msal_http_cache, **kwargs) diff --git a/src/azure-cli-core/azure/cli/core/auth/identity.py b/src/azure-cli-core/azure/cli/core/auth/identity.py index ad230ca25cf..ac6ee9c262c 100644 --- a/src/azure-cli-core/azure/cli/core/auth/identity.py +++ b/src/azure-cli-core/azure/cli/core/auth/identity.py @@ -5,9 +5,11 @@ import json import os +import pickle import re from azure.cli.core._environment import get_config_dir +from azure.cli.core.decorators import retry from msal import PublicClientApplication from knack.log import get_logger @@ -45,7 +47,7 @@ class Identity: # pylint: disable=too-many-instance-attributes # It follows singleton pattern so that _secret_file is read only once. _service_principal_store_instance = None - def __init__(self, authority, tenant_id=None, client_id=None, encrypt=False): + def __init__(self, authority, tenant_id=None, client_id=None, encrypt=False, use_msal_http_cache=True): """ :param authority: Authentication authority endpoint. For example, - AAD: https://login.microsoftonline.com @@ -58,7 +60,8 @@ def __init__(self, authority, tenant_id=None, client_id=None, encrypt=False): self.authority = authority self.tenant_id = tenant_id self.client_id = client_id or AZURE_CLI_CLIENT_ID - self.encrypt = encrypt + self._encrypt = encrypt + self._use_msal_http_cache = use_msal_http_cache # Build the authority in MSAL style self._msal_authority, self._is_adfs = _get_authority_url(authority, tenant_id) @@ -80,7 +83,7 @@ def _msal_app_kwargs(self): if not Identity._msal_token_cache: Identity._msal_token_cache = self._load_msal_token_cache() - if not Identity._msal_http_cache: + if self._use_msal_http_cache and not Identity._msal_http_cache: Identity._msal_http_cache = self._load_msal_http_cache() return { @@ -100,25 +103,46 @@ def _msal_app(self): def _load_msal_token_cache(self): # Store for user token persistence - cache = load_persisted_token_cache(self._token_cache_file, self.encrypt) + cache = load_persisted_token_cache(self._token_cache_file, self._encrypt) return cache + @retry() + def __load_msal_http_cache(self): + """Load MSAL HTTP cache with retry. If it still fails at last, raise the original exception as-is.""" + logger.debug("__load_msal_http_cache: %s", self._http_cache_file) + try: + with open(self._http_cache_file, 'rb') as f: + return pickle.load(f) + except FileNotFoundError: + # The cache file has not been created. This is expected. + logger.debug("%s not found. Using a fresh one.", self._http_cache_file) + return {} + + def _dump_msal_http_cache(self): + logger.debug("_dump_msal_http_cache: %s", self._http_cache_file) + with open(self._http_cache_file, 'wb') as f: + # At this point, an empty cache file will be created. Loading this cache file will + # trigger EOFError. This can be simulated by adding time.sleep(30) here. + # So, during loading, EOFError is ignored. + pickle.dump(self._msal_http_cache, f) + def _load_msal_http_cache(self): import atexit - import pickle logger.debug("_load_msal_http_cache: %s", self._http_cache_file) try: - with open(self._http_cache_file, 'rb') as f: - persisted_http_cache = pickle.load(f) - except (pickle.UnpicklingError, FileNotFoundError) as ex: - logger.debug("Failed to load MSAL HTTP cache: %s", ex) + persisted_http_cache = self.__load_msal_http_cache() + except (pickle.UnpicklingError, EOFError) as ex: + # We still get exception after retry: + # - pickle.UnpicklingError is caused by corrupted cache file, perhaps due to concurrent writes. + # - EOFError is caused by empty cache file created by other az instance, but hasn't been filled yet. + logger.debug("Failed to load MSAL HTTP cache: %s. Using a fresh one.", ex) persisted_http_cache = {} # Ignore a non-exist or corrupted http_cache - atexit.register(lambda: pickle.dump( - # When exit, flush it back to the file. - # If 2 processes write at the same time, the cache will be corrupted, - # but that is fine. Subsequent runs would reach eventual consistency. - persisted_http_cache, open(self._http_cache_file, 'wb'))) + + # When exiting, flush it back to the file. + # If 2 processes write at the same time, the cache will be corrupted, + # but that is fine. Subsequent runs would reach eventual consistency. + atexit.register(self._dump_msal_http_cache) return persisted_http_cache @@ -128,7 +152,7 @@ def _service_principal_store(self): The instance is lazily created. """ if not Identity._service_principal_store_instance: - store = load_secret_store(self._secret_file, self.encrypt) + store = load_secret_store(self._secret_file, self._encrypt) Identity._service_principal_store_instance = ServicePrincipalStore(store) return Identity._service_principal_store_instance diff --git a/src/azure-cli-core/azure/cli/core/auth/persistence.py b/src/azure-cli-core/azure/cli/core/auth/persistence.py index c22a2124f97..eb51a82660c 100644 --- a/src/azure-cli-core/azure/cli/core/auth/persistence.py +++ b/src/azure-cli-core/azure/cli/core/auth/persistence.py @@ -8,13 +8,14 @@ import json import sys -import time from msal_extensions import (FilePersistenceWithDataProtection, KeychainPersistence, LibsecretPersistence, FilePersistence, PersistedTokenCache, CrossPlatLock) from msal_extensions.persistence import PersistenceNotFound from knack.log import get_logger +from azure.cli.core.decorators import retry + logger = get_logger(__name__) @@ -60,27 +61,9 @@ def save(self, content): with CrossPlatLock(self._lock_file): self._persistence.save(json.dumps(content, indent=4)) - def _load(self): + @retry() + def load(self): try: return json.loads(self._persistence.load()) except PersistenceNotFound: return [] - - def load(self): - # Use optimistic locking rather than CrossPlatLock, so that multiple processes can - # read the same file at the same time. - retry = 3 - for attempt in range(1, retry + 1): - try: - return self._load() - except Exception: # pylint: disable=broad-except - # Presumably other processes are writing the file, causing dirty read - if attempt < retry: - logger.debug("Unable to load secret store in No. %d attempt", attempt) - import traceback - logger.debug(traceback.format_exc()) - time.sleep(0.5) - else: - raise # End of retry. Re-raise the exception as-is. - - return [] # Not really reachable here. Just to keep pylint happy. diff --git a/src/azure-cli-core/azure/cli/core/decorators.py b/src/azure-cli-core/azure/cli/core/decorators.py index 55a1d88e11e..d59d0316912 100644 --- a/src/azure-cli-core/azure/cli/core/decorators.py +++ b/src/azure-cli-core/azure/cli/core/decorators.py @@ -16,6 +16,9 @@ from knack.log import get_logger +logger = get_logger(__name__) + + # pylint: disable=too-few-public-methods class Completer: @@ -81,3 +84,30 @@ def _wrapped_func(*args, **kwargs): return fallback_return return _wrapped_func return _decorator + + +def retry(retry_times=3, interval=0.5, exceptions=Exception): + """Use optimistic locking to call a function, so that multiple processes can + access the same resource (such as a file) at the same time. + + :param retry_times: Times to retry. + :param interval: Interval between retries. + :param exceptions: Exceptions that can be ignored. Use a tuple if multiple exceptions should be ignored. + """ + def _decorator(func): + @wraps(func) + def _wrapped_func(*args, **kwargs): + for attempt in range(1, retry_times + 1): + try: + return func(*args, **kwargs) + except exceptions: # pylint: disable=broad-except + if attempt < retry_times: + logger.debug("%s failed in No. %d attempt", func, attempt) + import traceback + import time + logger.debug(traceback.format_exc()) + time.sleep(interval) + else: + raise # End of retry. Re-raise the exception as-is. + return _wrapped_func + return _decorator