Skip to content

Commit

Permalink
fix(python): Propagate tenant_id to CredentialProviderAzure if gi…
Browse files Browse the repository at this point in the history
…ven (#20583)
  • Loading branch information
nameexhaustion authored Jan 7, 2025
1 parent f104170 commit 3f74df4
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 42 deletions.
24 changes: 18 additions & 6 deletions crates/polars-io/src/cloud/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ impl CloudOptions {
pub fn build_azure(&self, url: &str) -> PolarsResult<impl object_store::ObjectStore> {
use super::credential_provider::IntoCredentialProvider;

let verbose = polars_core::config::verbose();
let mut storage_account: Option<polars_utils::pl_str::PlSmallStr> = None;

// The credential provider `self.credentials` is prioritized if it is set. We also need
Expand All @@ -430,19 +431,30 @@ impl CloudOptions {
.with_url(url)
.with_retry(get_retry_config(self.max_retries));

// Prefer the one embedded in the path
storage_account = extract_adls_uri_storage_account(url)
.map(|x| x.into())
.or(storage_account);

let builder = if let Some(v) = self.credential_provider.clone() {
if verbose {
eprintln!(
"[CloudOptions::build_azure]: Using credential provider {:?}",
&v
);
}
builder.with_credentials(v.into_azure_provider())
} else if let Some(v) = storage_account
} else if let Some(v) = extract_adls_uri_storage_account(url) // Prefer the one embedded in the path
.map(|x| x.into())
.or(storage_account)
.as_deref()
.and_then(get_azure_storage_account_key)
{
if verbose {
eprintln!("[CloudOptions::build_azure]: Retrieved account key from Azure CLI")
}
builder.with_access_key(v)
} else {
if verbose {
eprintln!(
"[CloudOptions::build_azure]: Could not retrieve account key from Azure CLI"
)
}
builder
};

Expand Down
116 changes: 81 additions & 35 deletions py-polars/polars/io/cloud/credential_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ def __init__(
*,
scopes: list[str] | None = None,
storage_account: str | None = None,
tenant_id: str | None = None,
_verbose: bool = False,
) -> None:
"""
Expand All @@ -173,13 +174,16 @@ def __init__(
for this account using the Azure CLI. If this is successful, the
account keys will be used instead of
`DefaultAzureCredential.get_token()`
tenant_id
Azure tenant ID.
"""
msg = "`CredentialProviderAzure` functionality is considered unstable"
issue_unstable_warning(msg)

self._check_module_availability()

self.account_name = storage_account
self.tenant_id = tenant_id
# Done like this to bypass mypy, we don't have stubs for azure.identity
self.credential = importlib.import_module("azure.identity").__dict__[
"DefaultAzureCredential"
Expand All @@ -192,6 +196,7 @@ def __init__(
(
"CredentialProviderAzure "
f"{self.account_name = } "
f"{self.tenant_id = } "
f"{self.scopes = } "
),
file=sys.stderr,
Expand All @@ -209,19 +214,19 @@ def __call__(self) -> CredentialProviderFunctionReturn:

if self._verbose:
print(
"[CredentialProviderAzure]: retrieved account keys from Azure CLI",
"[CredentialProviderAzure]: Retrieved account key from Azure CLI",
file=sys.stderr,
)
except Exception as e:
if self._verbose:
print(
f"[CredentialProviderAzure]: failed to retrieve account keys from Azure CLI: {e}",
f"[CredentialProviderAzure]: Could not retrieve account key from Azure CLI: {e}",
file=sys.stderr,
)
else:
return creds, None # type: ignore[return-value]

token = self.credential.get_token(*self.scopes)
token = self.credential.get_token(*self.scopes, tenant_id=self.tenant_id)

return {
"bearer_token": token.token,
Expand All @@ -248,22 +253,8 @@ def _extract_adls_uri_storage_account(uri: str) -> str | None:
except IndexError:
return None

@staticmethod
def _get_azure_storage_account_key_az_cli(account_name: str) -> str:
az_cmd = [
"az",
"storage",
"account",
"keys",
"list",
"--output",
"json",
"--account-name",
account_name,
]

cmd = az_cmd if sys.platform != "win32" else ["cmd", "/C", *az_cmd]

@classmethod
def _get_azure_storage_account_key_az_cli(cls, account_name: str) -> str:
# [
# {
# "creationTime": "1970-01-01T00:00:00.000000+00:00",
Expand All @@ -279,7 +270,31 @@ def _get_azure_storage_account_key_az_cli(account_name: str) -> str:
# }
# ]

return json.loads(subprocess.check_output(cmd))[0]["value"]
return json.loads(
cls._azcli(
"storage",
"account",
"keys",
"list",
"--output",
"json",
"--account-name",
account_name,
)
)[0]["value"]

@classmethod
def _azcli_version(cls) -> str | None:
try:
return json.loads(cls._azcli("version"))["azure-cli"]
except Exception:
return None

@staticmethod
def _azcli(*args: str) -> bytes:
return subprocess.check_output(
["az", *args] if sys.platform != "win32" else ["cmd", "/C", "az", *args]
)


class CredentialProviderGCP(CredentialProvider):
Expand Down Expand Up @@ -392,9 +407,6 @@ def _maybe_init_credential_provider(
if credential_provider != "auto":
return credential_provider

if storage_options is not None:
return None

verbose = os.getenv("POLARS_VERBOSE") == "1"

if (path := _first_scan_path(source)) is None:
Expand All @@ -406,20 +418,54 @@ def _maybe_init_credential_provider(
provider = None

try:
provider = (
CredentialProviderAWS()
if _is_aws_cloud(scheme)
else CredentialProviderAzure(
storage_account=(
CredentialProviderAzure._extract_adls_uri_storage_account(str(path))
),
# For Azure we dispatch to `azure.identity` as much as possible
if _is_azure_cloud(scheme):
tenant_id = None
storage_account = None

if storage_options is not None:
for k, v in storage_options.items():
# https://docs.rs/object_store/latest/object_store/azure/enum.AzureConfigKey.html
if k in {
"azure_storage_tenant_id",
"azure_storage_authority_id",
"azure_tenant_id",
"azure_authority_id",
"tenant_id",
"authority_id",
}:
tenant_id = v
elif k in {"azure_storage_account_name", "account_name"}:
storage_account = v
elif k in {"azure_use_azure_cli", "use_azure_cli"}:
continue
else:
# We assume some sort of access key was given, so we
# just dispatch to the rust side.
return None

storage_account = (
# Prefer the one embedded in the path
CredentialProviderAzure._extract_adls_uri_storage_account(str(path))
or storage_account
)

provider = CredentialProviderAzure(
storage_account=storage_account,
tenant_id=tenant_id,
_verbose=verbose,
)
if _is_azure_cloud(scheme)
else CredentialProviderGCP()
if _is_gcp_cloud(scheme)
else None
)
elif storage_options is not None:
return None
else:
provider = (
CredentialProviderAWS() # type: ignore[assignment]
if _is_aws_cloud(scheme)
else CredentialProviderGCP()
if _is_gcp_cloud(scheme)
else None
)

except ImportError as e:
if verbose:
msg = f"Unable to auto-select credential provider: {e}"
Expand Down
8 changes: 7 additions & 1 deletion py-polars/polars/meta/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def show_versions() -> None:

deps = _get_dependency_list()
core_properties = ("Polars", "Index type", "Platform", "Python", "LTS CPU")
keylen = max(len(x) for x in [*core_properties, *deps]) + 1
keylen = max(len(x) for x in [*core_properties, "Azure CLI", *deps]) + 1

print("--------Version info---------")
print(f"{'Polars:':{keylen}s} {get_polars_version()}")
Expand All @@ -58,6 +58,12 @@ def show_versions() -> None:
print(f"{'LTS CPU:':{keylen}s} {get_lts_cpu()}")

print("\n----Optional dependencies----")

from polars.io.cloud.credential_provider import CredentialProviderAzure

print(f"{'Azure CLI':{keylen}s} ", end="", flush=True)
print(CredentialProviderAzure._azcli_version() or "<not installed>")

for name in deps:
print(f"{name:{keylen}s} ", end="", flush=True)
print(_get_dependency_version(name))
Expand Down

0 comments on commit 3f74df4

Please sign in to comment.