Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

supports snowflake container default credentials #1115

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions dlt/common/configuration/specs/aws_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,12 @@ def on_partial(self) -> None:
session = self._to_botocore_session()
if self._from_session(session) and not self.is_partial():
self.resolve()
self._set_default_credentials(session.get_credentials())

def to_session_credentials(self) -> Dict[str, str]:
"""Return configured or new aws session token"""
if self.aws_session_token and self.aws_access_key_id and self.aws_secret_access_key:
return dict(
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
)
return super().to_session_credentials()
sess = self._to_botocore_session()
client = sess.create_client("sts")
token = client.get_session_token()
Expand Down
96 changes: 87 additions & 9 deletions dlt/destinations/impl/snowflake/configuration.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
import base64
import binascii

from typing import Final, Optional, Any, Dict, ClassVar, List, TYPE_CHECKING

from dlt.common.configuration.specs import CredentialsWithDefault
from dlt.common.libs.sql_alchemy import URL

from dlt import version
from dlt.common import logger
from dlt.common.exceptions import MissingDependencyException
from dlt.common.typing import TSecretStrValue
from dlt.common.configuration.specs import ConnectionStringCredentials
Expand Down Expand Up @@ -52,41 +53,67 @@ def _read_private_key(private_key: str, password: Optional[str] = None) -> bytes


@configspec
class SnowflakeCredentials(ConnectionStringCredentials):
class SnowflakeCredentialsWithoutDefaults(ConnectionStringCredentials):
drivername: Final[str] = "snowflake" # type: ignore[misc]
password: Optional[TSecretStrValue] = None
token: Optional[TSecretStrValue] = None
username: Optional[str] = None
host: str = None
"""Snowflake account name"""
database: str = None
warehouse: Optional[str] = None
role: Optional[str] = None
authenticator: Optional[str] = None
private_key: Optional[TSecretStrValue] = None
private_key_passphrase: Optional[TSecretStrValue] = None
_hostname: Optional[str] = None
"Snowflake host, present in Snowflake Container Services"

__config_gen_annotations__: ClassVar[List[str]] = ["password", "warehouse", "role"]

def is_partial(self) -> bool:
return super().is_partial() or (
not self.password and not self.token and not self.private_key
)

def parse_native_representation(self, native_value: Any) -> None:
super().parse_native_representation(native_value)
self.warehouse = self.query.get("warehouse")
self.role = self.query.get("role")
self.token = self.query.get("token") # type: ignore
self.authenticator = self.query.get("authenticator")
self.private_key = self.query.get("private_key") # type: ignore
self.private_key_passphrase = self.query.get("private_key_passphrase") # type: ignore
if not self.is_partial() and (self.password or self.private_key):
if not self.is_partial():
self.resolve()

def on_resolved(self) -> None:
if not self.password and not self.private_key:
if not self.password and not self.private_key and not self.token:
raise ConfigurationValueError(
"Please specify password or private_key or oauth token. SnowflakeCredentials"
" supports password and private key authentication and one of those must be"
" specified. It also recognizes a login token if passed via `token` query parameter"
" and with `authenticator` set to `oauth`."
)
if (self.password or self.private_key) and not self.username:
raise ConfigurationValueError(
"Please specify password or private_key. SnowflakeCredentials supports password and"
" private key authentication and one of those must be specified."
"Please provide username when using password / private key authentication."
)
if self.token and self.authenticator != "oauth":
logger.warning(
"Login token was specified but authenticator is not set to oauth so it will be"
" ignored"
)

def to_url(self) -> URL:
query = dict(self.query or {})
if self.warehouse and "warehouse" not in query:
if self.warehouse:
query["warehouse"] = self.warehouse
if self.role and "role" not in query:
if self.role:
query["role"] = self.role
if self.authenticator:
query["authenticator"] = self.authenticator

return URL.create(
self.drivername,
self.username,
Expand All @@ -113,8 +140,59 @@ def to_connector_params(self) -> Dict[str, Any]:
)
if self.authenticator:
conn_params["authenticator"] = self.authenticator
if self.authenticator == "oauth":
conn_params["token"] = self.login_token()
conn_params["host"] = self._hostname

return conn_params

def login_token(self) -> str:
"""A method returning fresh login token"""
return self.token


@configspec
class SnowflakeCredentials(SnowflakeCredentialsWithoutDefaults, CredentialsWithDefault):
LOGIN_TOKEN_PATH: ClassVar[str] = "/snowflake/session/token"

def on_partial(self) -> None:
logger.info(
"Snowflake credentials could not be resolved, looking for login token and env variables"
)
token = self._from_token()
if token and not self.is_partial():
logger.info("Login token and env variable found, switching to oauth authenticator")
self.resolve()
self._set_default_credentials(token)

def _from_token(self) -> str:
token = self.login_token()
if token:
if self.private_key or self.password:
logger.warning(
"Password or private key were provided to SnowflakeCredentials and oauth2"
" token was found so they will be ignored."
)
self.private_key = None
self.password = None
# set authenticator
self.host, self._hostname = os.getenv("SNOWFLAKE_ACCOUNT"), os.getenv("SNOWFLAKE_HOST")
self.authenticator = "oauth"
# NOTE: set token here so configuration resolves. still we'll be reading it each time connection params are created
self.token = token # type: ignore[assignment]
return token

def login_token(self) -> str:
"""
Read the login token supplied automatically by Snowflake. These tokens
are short lived and should always be read right before creating any new connection.
"""
try:
with open(SnowflakeCredentials.LOGIN_TOKEN_PATH, "r", encoding="utf-8") as f:
return f.read()
except Exception:
return self.token


@configspec
class SnowflakeClientConfiguration(DestinationClientDwhWithStagingConfiguration):
Expand Down
13 changes: 13 additions & 0 deletions docs/website/docs/dlt-ecosystem/destinations/snowflake.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ Snowflake destination accepts three authentication types:
- password authentication
- [key pair authentication](https://docs.snowflake.com/en/user-guide/key-pair-auth)
- external authentication
- default (Snowflake Container Services) authentication

The **password authentication** is not any different from other databases like Postgres or Redshift. `dlt` follows the same syntax as the [SQLAlchemy dialect](https://docs.snowflake.com/en/developer-guide/python-connector/sqlalchemy#required-parameters).

Expand Down Expand Up @@ -112,6 +113,18 @@ token="..."
or in the connection string as query parameters.
Refer to Snowflake [OAuth](https://docs.snowflake.com/en/user-guide/oauth-intro) for more details.

In **Snowflake Container Services (default) authentication**, `dlt` will use oauth token present in the container and environment variables to
create a valid connection. You must specify database name ie.:
```toml
[destination.snowflake.credentials]
database = "dlt_data"
```
or
```toml
destination.snowflake.credentials="snowflake:///dlt_data"
```
Find more details [here](https://docs.snowflake.com/en/developer-guide/snowpark-container-services/additional-considerations-services-jobs).

## Write disposition
All write dispositions are supported.

Expand Down
88 changes: 87 additions & 1 deletion tests/load/snowflake/test_snowflake_configuration.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
import os
from unittest.mock import patch
import pytest
from pathlib import Path
from dlt.common.libs.sql_alchemy import make_url
from dlt.common.storages.file_storage import FileStorage

pytest.importorskip("snowflake")

from dlt.common.configuration.resolve import resolve_configuration
from dlt.common.configuration.exceptions import ConfigurationValueError
from dlt.common.configuration.exceptions import ConfigFieldMissingException, ConfigurationValueError
from dlt.common.utils import digest128

from dlt.destinations.impl.snowflake.configuration import (
SnowflakeClientConfiguration,
SnowflakeCredentials,
)

from tests.utils import autouse_test_storage, TEST_STORAGE_ROOT, test_storage
from tests.common.configuration.utils import environment


Expand All @@ -38,6 +41,41 @@ def test_connection_string_with_all_params() -> None:
assert make_url(creds.to_native_representation()) == expected


def test_setting_authenticator() -> None:
# no password and user are allowed
url = "snowflake://host1/db1?authenticator=oauth&role=role1&token=A&warehouse=warehouse1"
creds = SnowflakeCredentials()
creds.parse_native_representation(url)
assert creds.authenticator == "oauth"
assert creds.token == "A"
assert creds.to_native_representation() == url

params = creds.to_connector_params()
assert params["authenticator"] == "oauth"
assert params["token"] == "A"

# change token
creds.token = "B" # type: ignore[assignment]
params = creds.to_connector_params()
assert params["token"] == "B"


def test_connection_value_errors() -> None:
# user must be present when password/key are present
url = "snowflake://:pass1@host1/db1?warehouse=warehouse1&role=role1&private_key=cGs%3D&private_key_passphrase=paphr"
creds = SnowflakeCredentials()
with pytest.raises(ConfigurationValueError) as c_ex:
creds.parse_native_representation(url)
assert "user" in str(c_ex)

# password / key / token must be present
url = "snowflake://user@host1/db1?warehouse=warehouse1&role=role1"
creds = SnowflakeCredentials()
creds.parse_native_representation(url)
# this is partial
assert creds.is_partial()


def test_to_connector_params() -> None:
# PEM key
pkey_str = Path("./tests/common/cases/secrets/encrypted-private-key").read_text("utf8")
Expand Down Expand Up @@ -125,6 +163,54 @@ def test_snowflake_credentials_native_value(environment) -> None:
assert c.private_key == "pk"


def test_snowflake_default_credentials(environment, test_storage: FileStorage) -> None:
with pytest.raises(ConfigFieldMissingException):
resolve_configuration(SnowflakeCredentials(), explicit_value="snowflake:///db1")

token_path = os.path.join(TEST_STORAGE_ROOT, "token")
test_storage.save("token", "TOK1")

with patch.object(SnowflakeCredentials, "LOGIN_TOKEN_PATH", token_path):
with pytest.raises(ConfigFieldMissingException) as mi_ex:
resolve_configuration(SnowflakeCredentials(), explicit_value="snowflake:///db1")
# only host missing
assert mi_ex.value.fields == ["host"]

# set missing env
os.environ["SNOWFLAKE_ACCOUNT"] = "accnt"
os.environ["SNOWFLAKE_HOST"] = "hostname"
c = resolve_configuration(SnowflakeCredentials(), explicit_value={"database": "db1"})
assert c.host == "accnt"
assert c._hostname == "hostname"
assert c.database == "db1"
assert c.token == "TOK1"
assert c.authenticator == "oauth"
assert not c.is_partial()
assert c.has_default_credentials()

# check conn url
params = c.to_connector_params()
# no password, user etc. are set
assert params == {
"user": None,
"password": None,
"account": "accnt",
"database": "db1",
"warehouse": None,
"role": None,
"private_key": None,
"authenticator": "oauth",
"token": "TOK1",
"host": "hostname",
}

# change token file
test_storage.save("token", "TOK2")
params = c.to_connector_params()
# token updated
assert params["token"] == "TOK2"


def test_snowflake_configuration() -> None:
# def empty fingerprint
assert SnowflakeClientConfiguration().fingerprint() == ""
Expand Down
Loading