Skip to content

Commit

Permalink
supports snowflake container default credentials
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Mar 19, 2024
1 parent 7e30318 commit 7a43103
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 15 deletions.
7 changes: 2 additions & 5 deletions dlt/common/configuration/specs/aws_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,12 @@ def on_partial(self) -> None:
session = self._to_botocore_session()
if self._from_session(session) and not self.is_partial():
self.resolve()
self._set_default_credentials(session.get_credentials())

def to_session_credentials(self) -> Dict[str, str]:
"""Return configured or new aws session token"""
if self.aws_session_token and self.aws_access_key_id and self.aws_secret_access_key:
return dict(
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
)
return super().to_session_credentials()
sess = self._to_botocore_session()
client = sess.create_client("sts")
token = client.get_session_token()
Expand Down
96 changes: 87 additions & 9 deletions dlt/destinations/impl/snowflake/configuration.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
import base64
import binascii

from typing import Final, Optional, Any, Dict, ClassVar, List, TYPE_CHECKING

from dlt.common.configuration.specs import CredentialsWithDefault
from dlt.common.libs.sql_alchemy import URL

from dlt import version
from dlt.common import logger
from dlt.common.exceptions import MissingDependencyException
from dlt.common.typing import TSecretStrValue
from dlt.common.configuration.specs import ConnectionStringCredentials
Expand Down Expand Up @@ -52,41 +53,67 @@ def _read_private_key(private_key: str, password: Optional[str] = None) -> bytes


@configspec
class SnowflakeCredentials(ConnectionStringCredentials):
class SnowflakeCredentialsWithoutDefaults(ConnectionStringCredentials):
drivername: Final[str] = "snowflake" # type: ignore[misc]
password: Optional[TSecretStrValue] = None
token: Optional[TSecretStrValue] = None
username: Optional[str] = None
host: str = None
"""Snowflake account name"""
database: str = None
warehouse: Optional[str] = None
role: Optional[str] = None
authenticator: Optional[str] = None
private_key: Optional[TSecretStrValue] = None
private_key_passphrase: Optional[TSecretStrValue] = None
_hostname: Optional[str] = None
"Snowflake host, present in Snowflake Container Services"

__config_gen_annotations__: ClassVar[List[str]] = ["password", "warehouse", "role"]

def is_partial(self) -> bool:
return super().is_partial() or (
not self.password and not self.token and not self.private_key
)

def parse_native_representation(self, native_value: Any) -> None:
super().parse_native_representation(native_value)
self.warehouse = self.query.get("warehouse")
self.role = self.query.get("role")
self.token = self.query.get("token") # type: ignore
self.authenticator = self.query.get("authenticator")
self.private_key = self.query.get("private_key") # type: ignore
self.private_key_passphrase = self.query.get("private_key_passphrase") # type: ignore
if not self.is_partial() and (self.password or self.private_key):
if not self.is_partial():
self.resolve()

def on_resolved(self) -> None:
if not self.password and not self.private_key:
if not self.password and not self.private_key and not self.token:
raise ConfigurationValueError(
"Please specify password or private_key or oauth token. SnowflakeCredentials"
" supports password and private key authentication and one of those must be"
" specified. It also recognizes a login token if passed via `token` query parameter"
" and with `authenticator` set to `oauth`."
)
if (self.password or self.private_key) and not self.username:
raise ConfigurationValueError(
"Please specify password or private_key. SnowflakeCredentials supports password and"
" private key authentication and one of those must be specified."
"Please provide username when using password / private key authentication."
)
if self.token and self.authenticator != "oauth":
logger.warning(
"Login token was specified but authenticator is not set to oauth so it will be"
" ignored"
)

def to_url(self) -> URL:
query = dict(self.query or {})
if self.warehouse and "warehouse" not in query:
if self.warehouse:
query["warehouse"] = self.warehouse
if self.role and "role" not in query:
if self.role:
query["role"] = self.role
if self.authenticator:
query["authenticator"] = self.authenticator

return URL.create(
self.drivername,
self.username,
Expand All @@ -113,8 +140,59 @@ def to_connector_params(self) -> Dict[str, Any]:
)
if self.authenticator:
conn_params["authenticator"] = self.authenticator
if self.authenticator == "oauth":
conn_params["token"] = self.login_token()
conn_params["host"] = self._hostname

return conn_params

def login_token(self) -> str:
"""A method returning fresh login token"""
return self.token


@configspec
class SnowflakeCredentials(SnowflakeCredentialsWithoutDefaults, CredentialsWithDefault):
LOGIN_TOKEN_PATH: ClassVar[str] = "/snowflake/session/token"

def on_partial(self) -> None:
logger.info(
"Snowflake credentials could not be resolved, looking for login token and env variables"
)
token = self._from_token()
if token and not self.is_partial():
logger.info("Login token and env variable found, switching to oauth authenticator")
self.resolve()
self._set_default_credentials(token)

def _from_token(self) -> str:
token = self.login_token()
if token:
if self.private_key or self.password:
logger.warning(
"Password or private key were provided to SnowflakeCredentials and oauth2"
" token was found so they will be ignored."
)
self.private_key = None
self.password = None
# set authenticator
self.host, self._hostname = os.getenv("SNOWFLAKE_ACCOUNT"), os.getenv("SNOWFLAKE_HOST")
self.authenticator = "oauth"
# NOTE: set token here so configuration resolves. still we'll be reading it each time connection params are created
self.token = token # type: ignore[assignment]
return token

def login_token(self) -> str:
"""
Read the login token supplied automatically by Snowflake. These tokens
are short lived and should always be read right before creating any new connection.
"""
try:
with open(SnowflakeCredentials.LOGIN_TOKEN_PATH, "r", encoding="utf-8") as f:
return f.read()
except Exception:
return self.token


@configspec
class SnowflakeClientConfiguration(DestinationClientDwhWithStagingConfiguration):
Expand Down
13 changes: 13 additions & 0 deletions docs/website/docs/dlt-ecosystem/destinations/snowflake.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ Snowflake destination accepts three authentication types:
- password authentication
- [key pair authentication](https://docs.snowflake.com/en/user-guide/key-pair-auth)
- external authentication
- default (Snowflake Container Services) authentication

The **password authentication** is not any different from other databases like Postgres or Redshift. `dlt` follows the same syntax as the [SQLAlchemy dialect](https://docs.snowflake.com/en/developer-guide/python-connector/sqlalchemy#required-parameters).

Expand Down Expand Up @@ -112,6 +113,18 @@ token="..."
or in the connection string as query parameters.
Refer to Snowflake [OAuth](https://docs.snowflake.com/en/user-guide/oauth-intro) for more details.

In **Snowflake Container Services (default) authentication**, `dlt` will use oauth token present in the container and environment variables to
create a valid connection. You must specify database name ie.:
```toml
[destination.snowflake.credentials]
database = "dlt_data"
```
or
```toml
destination.snowflake.credentials="snowflake:///dlt_data"
```
Find more details [here](https://docs.snowflake.com/en/developer-guide/snowpark-container-services/additional-considerations-services-jobs).

## Write disposition
All write dispositions are supported.

Expand Down
88 changes: 87 additions & 1 deletion tests/load/snowflake/test_snowflake_configuration.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
import os
from unittest.mock import patch
import pytest
from pathlib import Path
from dlt.common.libs.sql_alchemy import make_url
from dlt.common.storages.file_storage import FileStorage

pytest.importorskip("snowflake")

from dlt.common.configuration.resolve import resolve_configuration
from dlt.common.configuration.exceptions import ConfigurationValueError
from dlt.common.configuration.exceptions import ConfigFieldMissingException, ConfigurationValueError
from dlt.common.utils import digest128

from dlt.destinations.impl.snowflake.configuration import (
SnowflakeClientConfiguration,
SnowflakeCredentials,
)

from tests.utils import autouse_test_storage, TEST_STORAGE_ROOT, test_storage
from tests.common.configuration.utils import environment


Expand All @@ -38,6 +41,41 @@ def test_connection_string_with_all_params() -> None:
assert make_url(creds.to_native_representation()) == expected


def test_setting_authenticator() -> None:
# no password and user are allowed
url = "snowflake://host1/db1?authenticator=oauth&role=role1&token=A&warehouse=warehouse1"
creds = SnowflakeCredentials()
creds.parse_native_representation(url)
assert creds.authenticator == "oauth"
assert creds.token == "A"
assert creds.to_native_representation() == url

params = creds.to_connector_params()
assert params["authenticator"] == "oauth"
assert params["token"] == "A"

# change token
creds.token = "B" # type: ignore[assignment]
params = creds.to_connector_params()
assert params["token"] == "B"


def test_connection_value_errors() -> None:
# user must be present when password/key are present
url = "snowflake://:pass1@host1/db1?warehouse=warehouse1&role=role1&private_key=cGs%3D&private_key_passphrase=paphr"
creds = SnowflakeCredentials()
with pytest.raises(ConfigurationValueError) as c_ex:
creds.parse_native_representation(url)
assert "user" in str(c_ex)

# password / key / token must be present
url = "snowflake://user@host1/db1?warehouse=warehouse1&role=role1"
creds = SnowflakeCredentials()
creds.parse_native_representation(url)
# this is partial
assert creds.is_partial()


def test_to_connector_params() -> None:
# PEM key
pkey_str = Path("./tests/common/cases/secrets/encrypted-private-key").read_text("utf8")
Expand Down Expand Up @@ -125,6 +163,54 @@ def test_snowflake_credentials_native_value(environment) -> None:
assert c.private_key == "pk"


def test_snowflake_default_credentials(environment, test_storage: FileStorage) -> None:
with pytest.raises(ConfigFieldMissingException):
resolve_configuration(SnowflakeCredentials(), explicit_value="snowflake:///db1")

token_path = os.path.join(TEST_STORAGE_ROOT, "token")
test_storage.save("token", "TOK1")

with patch.object(SnowflakeCredentials, "LOGIN_TOKEN_PATH", token_path):
with pytest.raises(ConfigFieldMissingException) as mi_ex:
resolve_configuration(SnowflakeCredentials(), explicit_value="snowflake:///db1")
# only host missing
assert mi_ex.value.fields == ["host"]

# set missing env
os.environ["SNOWFLAKE_ACCOUNT"] = "accnt"
os.environ["SNOWFLAKE_HOST"] = "hostname"
c = resolve_configuration(SnowflakeCredentials(), explicit_value={"database": "db1"})
assert c.host == "accnt"
assert c._hostname == "hostname"
assert c.database == "db1"
assert c.token == "TOK1"
assert c.authenticator == "oauth"
assert not c.is_partial()
assert c.has_default_credentials()

# check conn url
params = c.to_connector_params()
# no password, user etc. are set
assert params == {
"user": None,
"password": None,
"account": "accnt",
"database": "db1",
"warehouse": None,
"role": None,
"private_key": None,
"authenticator": "oauth",
"token": "TOK1",
"host": "hostname",
}

# change token file
test_storage.save("token", "TOK2")
params = c.to_connector_params()
# token updated
assert params["token"] == "TOK2"


def test_snowflake_configuration() -> None:
# def empty fingerprint
assert SnowflakeClientConfiguration().fingerprint() == ""
Expand Down

0 comments on commit 7a43103

Please sign in to comment.