diff --git a/dlt/common/configuration/specs/aws_credentials.py b/dlt/common/configuration/specs/aws_credentials.py index ee7360e2cb..e416213cd2 100644 --- a/dlt/common/configuration/specs/aws_credentials.py +++ b/dlt/common/configuration/specs/aws_credentials.py @@ -53,15 +53,12 @@ def on_partial(self) -> None: session = self._to_botocore_session() if self._from_session(session) and not self.is_partial(): self.resolve() + self._set_default_credentials(session.get_credentials()) def to_session_credentials(self) -> Dict[str, str]: """Return configured or new aws session token""" if self.aws_session_token and self.aws_access_key_id and self.aws_secret_access_key: - return dict( - aws_access_key_id=self.aws_access_key_id, - aws_secret_access_key=self.aws_secret_access_key, - aws_session_token=self.aws_session_token, - ) + return super().to_session_credentials() sess = self._to_botocore_session() client = sess.create_client("sts") token = client.get_session_token() diff --git a/dlt/destinations/impl/snowflake/configuration.py b/dlt/destinations/impl/snowflake/configuration.py index 4f97f08700..1ef2de5559 100644 --- a/dlt/destinations/impl/snowflake/configuration.py +++ b/dlt/destinations/impl/snowflake/configuration.py @@ -1,11 +1,12 @@ +import os import base64 -import binascii - from typing import Final, Optional, Any, Dict, ClassVar, List, TYPE_CHECKING +from dlt.common.configuration.specs import CredentialsWithDefault from dlt.common.libs.sql_alchemy import URL from dlt import version +from dlt.common import logger from dlt.common.exceptions import MissingDependencyException from dlt.common.typing import TSecretStrValue from dlt.common.configuration.specs import ConnectionStringCredentials @@ -52,41 +53,67 @@ def _read_private_key(private_key: str, password: Optional[str] = None) -> bytes @configspec -class SnowflakeCredentials(ConnectionStringCredentials): +class SnowflakeCredentialsWithoutDefaults(ConnectionStringCredentials): drivername: Final[str] = "snowflake" # type: ignore[misc] password: Optional[TSecretStrValue] = None + token: Optional[TSecretStrValue] = None + username: Optional[str] = None host: str = None + """Snowflake account name""" database: str = None warehouse: Optional[str] = None role: Optional[str] = None authenticator: Optional[str] = None private_key: Optional[TSecretStrValue] = None private_key_passphrase: Optional[TSecretStrValue] = None + _hostname: Optional[str] = None + "Snowflake host, present in Snowflake Container Services" __config_gen_annotations__: ClassVar[List[str]] = ["password", "warehouse", "role"] + def is_partial(self) -> bool: + return super().is_partial() or ( + not self.password and not self.token and not self.private_key + ) + def parse_native_representation(self, native_value: Any) -> None: super().parse_native_representation(native_value) self.warehouse = self.query.get("warehouse") self.role = self.query.get("role") + self.token = self.query.get("token") # type: ignore + self.authenticator = self.query.get("authenticator") self.private_key = self.query.get("private_key") # type: ignore self.private_key_passphrase = self.query.get("private_key_passphrase") # type: ignore - if not self.is_partial() and (self.password or self.private_key): + if not self.is_partial(): self.resolve() def on_resolved(self) -> None: - if not self.password and not self.private_key: + if not self.password and not self.private_key and not self.token: + raise ConfigurationValueError( + "Please specify password or private_key or oauth token. SnowflakeCredentials" + " supports password and private key authentication and one of those must be" + " specified. It also recognizes a login token if passed via `token` query parameter" + " and with `authenticator` set to `oauth`." + ) + if (self.password or self.private_key) and not self.username: raise ConfigurationValueError( - "Please specify password or private_key. SnowflakeCredentials supports password and" - " private key authentication and one of those must be specified." + "Please provide username when using password / private key authentication." + ) + if self.token and self.authenticator != "oauth": + logger.warning( + "Login token was specified but authenticator is not set to oauth so it will be" + " ignored" ) def to_url(self) -> URL: query = dict(self.query or {}) - if self.warehouse and "warehouse" not in query: + if self.warehouse: query["warehouse"] = self.warehouse - if self.role and "role" not in query: + if self.role: query["role"] = self.role + if self.authenticator: + query["authenticator"] = self.authenticator + return URL.create( self.drivername, self.username, @@ -113,8 +140,59 @@ def to_connector_params(self) -> Dict[str, Any]: ) if self.authenticator: conn_params["authenticator"] = self.authenticator + if self.authenticator == "oauth": + conn_params["token"] = self.login_token() + conn_params["host"] = self._hostname + return conn_params + def login_token(self) -> str: + """A method returning fresh login token""" + return self.token + + +@configspec +class SnowflakeCredentials(SnowflakeCredentialsWithoutDefaults, CredentialsWithDefault): + LOGIN_TOKEN_PATH: ClassVar[str] = "/snowflake/session/token" + + def on_partial(self) -> None: + logger.info( + "Snowflake credentials could not be resolved, looking for login token and env variables" + ) + token = self._from_token() + if token and not self.is_partial(): + logger.info("Login token and env variable found, switching to oauth authenticator") + self.resolve() + self._set_default_credentials(token) + + def _from_token(self) -> str: + token = self.login_token() + if token: + if self.private_key or self.password: + logger.warning( + "Password or private key were provided to SnowflakeCredentials and oauth2" + " token was found so they will be ignored." + ) + self.private_key = None + self.password = None + # set authenticator + self.host, self._hostname = os.getenv("SNOWFLAKE_ACCOUNT"), os.getenv("SNOWFLAKE_HOST") + self.authenticator = "oauth" + # NOTE: set token here so configuration resolves. still we'll be reading it each time connection params are created + self.token = token # type: ignore[assignment] + return token + + def login_token(self) -> str: + """ + Read the login token supplied automatically by Snowflake. These tokens + are short lived and should always be read right before creating any new connection. + """ + try: + with open(SnowflakeCredentials.LOGIN_TOKEN_PATH, "r", encoding="utf-8") as f: + return f.read() + except Exception: + return self.token + @configspec class SnowflakeClientConfiguration(DestinationClientDwhWithStagingConfiguration): diff --git a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md index a65eaec267..b2ece0aa4a 100644 --- a/docs/website/docs/dlt-ecosystem/destinations/snowflake.md +++ b/docs/website/docs/dlt-ecosystem/destinations/snowflake.md @@ -75,6 +75,7 @@ Snowflake destination accepts three authentication types: - password authentication - [key pair authentication](https://docs.snowflake.com/en/user-guide/key-pair-auth) - external authentication +- default (Snowflake Container Services) authentication The **password authentication** is not any different from other databases like Postgres or Redshift. `dlt` follows the same syntax as the [SQLAlchemy dialect](https://docs.snowflake.com/en/developer-guide/python-connector/sqlalchemy#required-parameters). @@ -112,6 +113,18 @@ token="..." or in the connection string as query parameters. Refer to Snowflake [OAuth](https://docs.snowflake.com/en/user-guide/oauth-intro) for more details. +In **Snowflake Container Services (default) authentication**, `dlt` will use oauth token present in the container and environment variables to +create a valid connection. You must specify database name ie.: + ```toml +[destination.snowflake.credentials] +database = "dlt_data" +``` +or +```toml +destination.snowflake.credentials="snowflake:///dlt_data" +``` +Find more details [here](https://docs.snowflake.com/en/developer-guide/snowpark-container-services/additional-considerations-services-jobs). + ## Write disposition All write dispositions are supported. diff --git a/tests/load/snowflake/test_snowflake_configuration.py b/tests/load/snowflake/test_snowflake_configuration.py index d0ca4de41b..138702b56a 100644 --- a/tests/load/snowflake/test_snowflake_configuration.py +++ b/tests/load/snowflake/test_snowflake_configuration.py @@ -1,12 +1,14 @@ import os +from unittest.mock import patch import pytest from pathlib import Path from dlt.common.libs.sql_alchemy import make_url +from dlt.common.storages.file_storage import FileStorage pytest.importorskip("snowflake") from dlt.common.configuration.resolve import resolve_configuration -from dlt.common.configuration.exceptions import ConfigurationValueError +from dlt.common.configuration.exceptions import ConfigFieldMissingException, ConfigurationValueError from dlt.common.utils import digest128 from dlt.destinations.impl.snowflake.configuration import ( @@ -14,6 +16,7 @@ SnowflakeCredentials, ) +from tests.utils import autouse_test_storage, TEST_STORAGE_ROOT, test_storage from tests.common.configuration.utils import environment @@ -38,6 +41,41 @@ def test_connection_string_with_all_params() -> None: assert make_url(creds.to_native_representation()) == expected +def test_setting_authenticator() -> None: + # no password and user are allowed + url = "snowflake://host1/db1?authenticator=oauth&role=role1&token=A&warehouse=warehouse1" + creds = SnowflakeCredentials() + creds.parse_native_representation(url) + assert creds.authenticator == "oauth" + assert creds.token == "A" + assert creds.to_native_representation() == url + + params = creds.to_connector_params() + assert params["authenticator"] == "oauth" + assert params["token"] == "A" + + # change token + creds.token = "B" # type: ignore[assignment] + params = creds.to_connector_params() + assert params["token"] == "B" + + +def test_connection_value_errors() -> None: + # user must be present when password/key are present + url = "snowflake://:pass1@host1/db1?warehouse=warehouse1&role=role1&private_key=cGs%3D&private_key_passphrase=paphr" + creds = SnowflakeCredentials() + with pytest.raises(ConfigurationValueError) as c_ex: + creds.parse_native_representation(url) + assert "user" in str(c_ex) + + # password / key / token must be present + url = "snowflake://user@host1/db1?warehouse=warehouse1&role=role1" + creds = SnowflakeCredentials() + creds.parse_native_representation(url) + # this is partial + assert creds.is_partial() + + def test_to_connector_params() -> None: # PEM key pkey_str = Path("./tests/common/cases/secrets/encrypted-private-key").read_text("utf8") @@ -125,6 +163,54 @@ def test_snowflake_credentials_native_value(environment) -> None: assert c.private_key == "pk" +def test_snowflake_default_credentials(environment, test_storage: FileStorage) -> None: + with pytest.raises(ConfigFieldMissingException): + resolve_configuration(SnowflakeCredentials(), explicit_value="snowflake:///db1") + + token_path = os.path.join(TEST_STORAGE_ROOT, "token") + test_storage.save("token", "TOK1") + + with patch.object(SnowflakeCredentials, "LOGIN_TOKEN_PATH", token_path): + with pytest.raises(ConfigFieldMissingException) as mi_ex: + resolve_configuration(SnowflakeCredentials(), explicit_value="snowflake:///db1") + # only host missing + assert mi_ex.value.fields == ["host"] + + # set missing env + os.environ["SNOWFLAKE_ACCOUNT"] = "accnt" + os.environ["SNOWFLAKE_HOST"] = "hostname" + c = resolve_configuration(SnowflakeCredentials(), explicit_value={"database": "db1"}) + assert c.host == "accnt" + assert c._hostname == "hostname" + assert c.database == "db1" + assert c.token == "TOK1" + assert c.authenticator == "oauth" + assert not c.is_partial() + assert c.has_default_credentials() + + # check conn url + params = c.to_connector_params() + # no password, user etc. are set + assert params == { + "user": None, + "password": None, + "account": "accnt", + "database": "db1", + "warehouse": None, + "role": None, + "private_key": None, + "authenticator": "oauth", + "token": "TOK1", + "host": "hostname", + } + + # change token file + test_storage.save("token", "TOK2") + params = c.to_connector_params() + # token updated + assert params["token"] == "TOK2" + + def test_snowflake_configuration() -> None: # def empty fingerprint assert SnowflakeClientConfiguration().fingerprint() == ""