Skip to content

Commit

Permalink
[Identity] Validate identity config for MICredential (Azure#36950)
Browse files Browse the repository at this point in the history
ManagedIdentityCredential now validates the inputs for client_id and identity_config
to ensure no mutually exclusive values are given.


Signed-off-by: Paul Van Eck <[email protected]>
  • Loading branch information
pvaneck authored Aug 21, 2024
1 parent 3ae4c8d commit 513989e
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 46 deletions.
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

### Other Changes

- Added identity config validation to `ManagedIdentityCredential` to avoid non-deterministic states (e.g. both `resource_id` and `object_id` are specified). ([#36950](https://github.com/Azure/azure-sdk-for-python/pull/36950))

## 1.18.0b2 (2024-08-09)

### Features Added
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# ------------------------------------
import logging
import os
from typing import Optional, TYPE_CHECKING, Any
from typing import Optional, TYPE_CHECKING, Any, Mapping

from azure.core.credentials import AccessToken
from .. import CredentialUnavailableError
Expand All @@ -17,6 +17,22 @@
_LOGGER = logging.getLogger(__name__)


def validate_identity_config(client_id: Optional[str], identity_config: Optional[Mapping[str, str]]) -> None:
if identity_config:
if client_id:
if any(key in identity_config for key in ("object_id", "resource_id", "client_id")):
raise ValueError(
"identity_config must not contain 'object_id', 'resource_id', or 'client_id' when 'client_id' is "
"provided as a keyword argument."
)
# Only one of these keys should be present if one is present.
valid_keys = {"object_id", "resource_id", "client_id"}
if len(identity_config.keys() & valid_keys) > 1:
raise ValueError(
f"identity_config must not contain more than one of the following keys: {', '.join(valid_keys)}"
)


class ManagedIdentityCredential:
"""Authenticates with an Azure managed identity in any hosting environment which supports managed identities.
Expand All @@ -42,59 +58,66 @@ class ManagedIdentityCredential:
:caption: Create a ManagedIdentityCredential.
"""

def __init__(self, **kwargs: Any) -> None:
self._credential = None # type: Optional[TokenCredential]
def __init__(
self, *, client_id: Optional[str] = None, identity_config: Optional[Mapping[str, str]] = None, **kwargs: Any
) -> None:
validate_identity_config(client_id, identity_config)
self._credential: Optional[TokenCredential] = None
exclude_workload_identity = kwargs.pop("_exclude_workload_identity_credential", False)
if os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT):
if os.environ.get(EnvironmentVariables.IDENTITY_HEADER):
if os.environ.get(EnvironmentVariables.IDENTITY_SERVER_THUMBPRINT):
_LOGGER.info("%s will use Service Fabric managed identity", self.__class__.__name__)
from .service_fabric import ServiceFabricCredential

self._credential = ServiceFabricCredential(**kwargs)
self._credential = ServiceFabricCredential(
client_id=client_id, identity_config=identity_config, **kwargs
)
else:
_LOGGER.info("%s will use App Service managed identity", self.__class__.__name__)
from .app_service import AppServiceCredential

self._credential = AppServiceCredential(**kwargs)
self._credential = AppServiceCredential(
client_id=client_id, identity_config=identity_config, **kwargs
)
elif os.environ.get(EnvironmentVariables.IMDS_ENDPOINT):
_LOGGER.info("%s will use Azure Arc managed identity", self.__class__.__name__)
from .azure_arc import AzureArcCredential

self._credential = AzureArcCredential(**kwargs)
self._credential = AzureArcCredential(client_id=client_id, identity_config=identity_config, **kwargs)
elif os.environ.get(EnvironmentVariables.MSI_ENDPOINT):
if os.environ.get(EnvironmentVariables.MSI_SECRET):
_LOGGER.info("%s will use Azure ML managed identity", self.__class__.__name__)
from .azure_ml import AzureMLCredential

self._credential = AzureMLCredential(**kwargs)
self._credential = AzureMLCredential(client_id=client_id, identity_config=identity_config, **kwargs)
else:
_LOGGER.info("%s will use Cloud Shell managed identity", self.__class__.__name__)
from .cloud_shell import CloudShellCredential

self._credential = CloudShellCredential(**kwargs)
self._credential = CloudShellCredential(client_id=client_id, identity_config=identity_config, **kwargs)
elif (
all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS)
and not exclude_workload_identity
):
_LOGGER.info("%s will use workload identity", self.__class__.__name__)
from .workload_identity import WorkloadIdentityCredential

client_id = kwargs.pop("client_id", None) or os.environ.get(EnvironmentVariables.AZURE_CLIENT_ID)
if not client_id:
workload_client_id = client_id or os.environ.get(EnvironmentVariables.AZURE_CLIENT_ID)
if not workload_client_id:
raise ValueError('Configure the environment with a client ID or pass a value for "client_id" argument')

self._credential = WorkloadIdentityCredential(
tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID],
client_id=client_id,
client_id=workload_client_id,
file=os.environ[EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE],
**kwargs
**kwargs,
)
else:
from .imds import ImdsCredential

_LOGGER.info("%s will use IMDS", self.__class__.__name__)
self._credential = ImdsCredential(**kwargs)
self._credential = ImdsCredential(client_id=client_id, identity_config=identity_config, **kwargs)

def __enter__(self) -> "ManagedIdentityCredential":
if self._credential:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
from typing import Any, Optional, Dict, cast, Union
from typing import Any, Optional, Dict, cast, Union, Mapping
import abc
import time
import logging
Expand All @@ -23,10 +23,12 @@ class MsalManagedIdentityClient(abc.ABC): # pylint:disable=client-accepts-api-v
"""Base class for managed identity client wrapping MSAL ManagedIdentityClient."""

# pylint:disable=missing-client-constructor-parameter-credential
def __init__(self, **kwargs: Any) -> None:
self._settings = kwargs
def __init__(
self, *, client_id: Optional[str] = None, identity_config: Optional[Mapping[str, str]] = None, **kwargs: Any
) -> None:
self._settings = {"client_id": client_id, "identity_config": identity_config or {}}
self._client = MsalClient(**kwargs)
managed_identity = self.get_managed_identity(**kwargs)
managed_identity = self.get_managed_identity()
self._msal_client = msal.ManagedIdentityClient(managed_identity, http_client=self._client)

def __enter__(self) -> "MsalManagedIdentityClient":
Expand Down Expand Up @@ -56,20 +58,17 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # pylint:
error_message = self.get_unavailable_message(error_desc)
raise CredentialUnavailableError(error_message)

def get_managed_identity(
self, **kwargs: Any
) -> Union[msal.UserAssignedManagedIdentity, msal.SystemAssignedManagedIdentity]:
def get_managed_identity(self) -> Union[msal.UserAssignedManagedIdentity, msal.SystemAssignedManagedIdentity]:
"""
Get the managed identity configuration.
:keyword str client_id: The client ID of the user-assigned managed identity.
:keyword dict identity_config: The identity configuration.
:rtype: msal.UserAssignedManagedIdentity or msal.SystemAssignedManagedIdentity
:return: The managed identity configuration.
"""
if "client_id" in kwargs and kwargs["client_id"]:
return msal.UserAssignedManagedIdentity(client_id=kwargs["client_id"])
identity_config = kwargs.pop("identity_config", None) or {}

if "client_id" in self._settings and self._settings["client_id"]:
return msal.UserAssignedManagedIdentity(client_id=self._settings["client_id"])
identity_config = cast(Dict, self._settings.get("identity_config")) or {}
if "client_id" in identity_config and identity_config["client_id"]:
return msal.UserAssignedManagedIdentity(client_id=identity_config["client_id"])
if "resource_id" in identity_config and identity_config["resource_id"]:
Expand Down Expand Up @@ -154,5 +153,5 @@ def __getstate__(self) -> Dict[str, Any]: # pylint:disable=client-method-name-n
def __setstate__(self, state: Dict[str, Any]) -> None: # pylint:disable=client-method-name-no-double-underscore
self.__dict__.update(state)
# Re-create the unpickable entries
managed_identity = self.get_managed_identity(**self._settings)
managed_identity = self.get_managed_identity()
self._msal_client = msal.ManagedIdentityClient(managed_identity, http_client=self._client)
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
# ------------------------------------
import logging
import os
from typing import TYPE_CHECKING, Optional, Any
from typing import TYPE_CHECKING, Optional, Any, Mapping

from azure.core.credentials import AccessToken
from .._internal import AsyncContextManager
from .._internal.decorators import log_get_token_async
from ... import CredentialUnavailableError
from ..._constants import EnvironmentVariables
from ..._credentials.managed_identity import validate_identity_config

if TYPE_CHECKING:
from azure.core.credentials_async import AsyncTokenCredential
Expand Down Expand Up @@ -43,8 +44,11 @@ class ManagedIdentityCredential(AsyncContextManager):
:caption: Create a ManagedIdentityCredential.
"""

def __init__(self, **kwargs: Any) -> None:
self._credential = None # type: Optional[AsyncTokenCredential]
def __init__(
self, *, client_id: Optional[str] = None, identity_config: Optional[Mapping[str, str]] = None, **kwargs: Any
) -> None:
validate_identity_config(client_id, identity_config)
self._credential: Optional[AsyncTokenCredential] = None
exclude_workload_identity = kwargs.pop("_exclude_workload_identity_credential", False)

if os.environ.get(EnvironmentVariables.IDENTITY_ENDPOINT):
Expand All @@ -53,55 +57,54 @@ def __init__(self, **kwargs: Any) -> None:
_LOGGER.info("%s will use Service Fabric managed identity", self.__class__.__name__)
from .service_fabric import ServiceFabricCredential

self._credential = ServiceFabricCredential(**kwargs)
self._credential = ServiceFabricCredential(
client_id=client_id, identity_config=identity_config, **kwargs
)
else:
_LOGGER.info("%s will use App Service managed identity", self.__class__.__name__)
from .app_service import AppServiceCredential

self._credential = AppServiceCredential(**kwargs)
self._credential = AppServiceCredential(
client_id=client_id, identity_config=identity_config, **kwargs
)
elif os.environ.get(EnvironmentVariables.IMDS_ENDPOINT):
_LOGGER.info("%s will use Azure Arc managed identity", self.__class__.__name__)
from .azure_arc import AzureArcCredential

self._credential = AzureArcCredential(**kwargs)
else:
_LOGGER.info("%s will use Cloud Shell managed identity", self.__class__.__name__)
from .cloud_shell import CloudShellCredential

self._credential = CloudShellCredential(**kwargs)
self._credential = AzureArcCredential(client_id=client_id, identity_config=identity_config, **kwargs)
elif os.environ.get(EnvironmentVariables.MSI_ENDPOINT):
if os.environ.get(EnvironmentVariables.MSI_SECRET):
_LOGGER.info("%s will use Azure ML managed identity", self.__class__.__name__)
from .azure_ml import AzureMLCredential

self._credential = AzureMLCredential(**kwargs)
self._credential = AzureMLCredential(client_id=client_id, identity_config=identity_config, **kwargs)
else:
_LOGGER.info("%s will use Cloud Shell managed identity", self.__class__.__name__)
from .cloud_shell import CloudShellCredential

self._credential = CloudShellCredential(**kwargs)
self._credential = CloudShellCredential(client_id=client_id, identity_config=identity_config, **kwargs)
elif (
all(os.environ.get(var) for var in EnvironmentVariables.WORKLOAD_IDENTITY_VARS)
and not exclude_workload_identity
):
_LOGGER.info("%s will use workload identity", self.__class__.__name__)
from .workload_identity import WorkloadIdentityCredential

client_id = kwargs.pop("client_id", None) or os.environ.get(EnvironmentVariables.AZURE_CLIENT_ID)
if not client_id:
workload_client_id = client_id or os.environ.get(EnvironmentVariables.AZURE_CLIENT_ID)
if not workload_client_id:
raise ValueError('Configure the environment with a client ID or pass a value for "client_id" argument')

self._credential = WorkloadIdentityCredential(
tenant_id=os.environ[EnvironmentVariables.AZURE_TENANT_ID],
client_id=client_id,
client_id=workload_client_id,
file=os.environ[EnvironmentVariables.AZURE_FEDERATED_TOKEN_FILE],
**kwargs
)
else:
from .imds import ImdsCredential

_LOGGER.info("%s will use IMDS", self.__class__.__name__)
self._credential = ImdsCredential(**kwargs)
self._credential = ImdsCredential(client_id=client_id, identity_config=identity_config, **kwargs)

async def __aenter__(self) -> "ManagedIdentityCredential":
if self._credential:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
# Licensed under the MIT License.
# ------------------------------------
import abc
from typing import Any, cast, Optional, TypeVar
from types import TracebackType
from typing import Any, cast, Optional, TypeVar, Type

from azure.core.credentials import AccessToken
from . import AsyncContextManager
Expand Down Expand Up @@ -34,9 +35,14 @@ async def __aenter__(self: T) -> T:
await self._client.__aenter__()
return self

async def __aexit__(self, *args):
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_value: Optional[BaseException] = None,
traceback: Optional[TracebackType] = None,
) -> None:
if self._client:
await self._client.__aexit__(*args)
await self._client.__aexit__(exc_type, exc_value, traceback)

async def close(self) -> None:
await self.__aexit__()
Expand Down
21 changes: 21 additions & 0 deletions sdk/identity/azure-identity/tests/test_managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,3 +944,24 @@ def test_token_exchange_tenant_id(tmpdir):
credential = ManagedIdentityCredential(transport=transport)
token = credential.get_token(scope, tenant_id="tenant_id")
assert token.token == access_token


def test_validate_identity_config():
ManagedIdentityCredential()
ManagedIdentityCredential(client_id="foo")
ManagedIdentityCredential(identity_config={"foo": "bar"})
ManagedIdentityCredential(identity_config={"client_id": "foo"})
ManagedIdentityCredential(identity_config={"object_id": "foo"})
ManagedIdentityCredential(identity_config={"resource_id": "foo"})
ManagedIdentityCredential(identity_config={"foo": "bar"}, client_id="foo")

with pytest.raises(ValueError):
ManagedIdentityCredential(identity_config={"client_id": "foo"}, client_id="foo")
with pytest.raises(ValueError):
ManagedIdentityCredential(identity_config={"object_id": "bar"}, client_id="bar")
with pytest.raises(ValueError):
ManagedIdentityCredential(identity_config={"resource_id": "bar"}, client_id="bar")
with pytest.raises(ValueError):
ManagedIdentityCredential(identity_config={"object_id": "bar", "resource_id": "foo"})
with pytest.raises(ValueError):
ManagedIdentityCredential(identity_config={"object_id": "bar", "client_id": "foo"})
21 changes: 21 additions & 0 deletions sdk/identity/azure-identity/tests/test_managed_identity_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,3 +1213,24 @@ async def test_token_exchange_tenant_id(tmpdir):
credential = ManagedIdentityCredential(transport=transport)
token = await credential.get_token(scope, tenant_id="tenant_id")
assert token.token == access_token


def test_validate_identity_config():
ManagedIdentityCredential()
ManagedIdentityCredential(client_id="foo")
ManagedIdentityCredential(identity_config={"foo": "bar"})
ManagedIdentityCredential(identity_config={"client_id": "foo"})
ManagedIdentityCredential(identity_config={"object_id": "foo"})
ManagedIdentityCredential(identity_config={"resource_id": "foo"})
ManagedIdentityCredential(identity_config={"foo": "bar"}, client_id="foo")

with pytest.raises(ValueError):
ManagedIdentityCredential(identity_config={"client_id": "foo"}, client_id="foo")
with pytest.raises(ValueError):
ManagedIdentityCredential(identity_config={"object_id": "bar"}, client_id="bar")
with pytest.raises(ValueError):
ManagedIdentityCredential(identity_config={"resource_id": "bar"}, client_id="bar")
with pytest.raises(ValueError):
ManagedIdentityCredential(identity_config={"object_id": "bar", "resource_id": "foo"})
with pytest.raises(ValueError):
ManagedIdentityCredential(identity_config={"object_id": "bar", "client_id": "foo"})

0 comments on commit 513989e

Please sign in to comment.