From 947d74d5aef96d4c9b4bcb52b73e323cc7ffc8b4 Mon Sep 17 00:00:00 2001 From: Gabriel Gore Date: Wed, 5 Jun 2024 15:18:20 -0400 Subject: [PATCH 01/20] Require `database` & `schema` --- .../fluent/schemas/SnowflakeDatasource.json | 4 +- .../datasource/fluent/snowflake_datasource.py | 100 ++++++- .../datasource/fluent/sources.pyi | 12 +- pyproject.toml | 4 + .../datasource/fluent/great_expectations.yml | 2 +- .../fluent/integration/test_connections.py | 18 +- .../integration/test_sql_datasources.py | 5 +- .../fluent/test_snowflake_datasource.py | 277 +++++++++++++++++- .../end_to_end/test_snowflake_datasource.py | 27 +- 9 files changed, 403 insertions(+), 46 deletions(-) diff --git a/great_expectations/datasource/fluent/schemas/SnowflakeDatasource.json b/great_expectations/datasource/fluent/schemas/SnowflakeDatasource.json index afa246a16277..b3a5269d113d 100644 --- a/great_expectations/datasource/fluent/schemas/SnowflakeDatasource.json +++ b/great_expectations/datasource/fluent/schemas/SnowflakeDatasource.json @@ -587,7 +587,9 @@ "required": [ "account", "user", - "password" + "password", + "database", + "schema" ], "additionalProperties": false } diff --git a/great_expectations/datasource/fluent/snowflake_datasource.py b/great_expectations/datasource/fluent/snowflake_datasource.py index 73bd26e261fc..95bdb1165d00 100644 --- a/great_expectations/datasource/fluent/snowflake_datasource.py +++ b/great_expectations/datasource/fluent/snowflake_datasource.py @@ -1,7 +1,17 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Final, Literal, Optional, Type, Union +import urllib.parse +from typing import ( + TYPE_CHECKING, + Any, + Final, + Iterable, + Literal, + Optional, + Type, + Union, +) from great_expectations._docs_decorators import public_api from great_expectations.compatibility import pydantic @@ -26,6 +36,11 @@ LOGGER: Final[logging.Logger] = logging.getLogger(__name__) +REQUIRED_QUERY_PARAMS: Final[Iterable[str]] = { # errors will be thrown if any of these are missing + "database", + "schema", +} + class _UrlPasswordError(pydantic.UrlError): """ @@ -45,6 +60,18 @@ class _UrlDomainError(pydantic.UrlError): msg_template = "URL domain invalid" +class _UrlMissingQueryError(pydantic.UrlError): + """ + Custom Pydantic error for missing query parameters in SnowflakeDsn. + """ + + def __init__(self, **ctx: Any) -> None: + super().__init__(**ctx) + + code = "url.query" + msg_template = "URL query param missing" + + class SnowflakeDsn(AnyUrl): allowed_schemes = { "snowflake", @@ -80,10 +107,8 @@ class ConnectionDetails(FluentBaseModel): account: str user: str password: Union[ConfigStr, str] - database: Optional[str] = None - schema_: Optional[str] = pydantic.Field( - None, alias="schema" - ) # schema is a reserved attr in BaseModel + database: str + schema_: str = pydantic.Field(..., alias="schema") # schema is a reserved attr in BaseModel warehouse: Optional[str] = None role: Optional[str] = None numpy: bool = False @@ -107,6 +132,36 @@ class SnowflakeDatasource(SQLDatasource): # TODO: add props for account, user, password, etc? + @property + def schema_(self) -> str | None: + """ + Convenience property to get the `schema` regardless of the connection string format. + + `schema_` to avoid conflict with Pydantic models schema property. + """ + if isinstance(self.connection_string, ConnectionDetails): + return self.connection_string.schema_ + elif isinstance(self.connection_string, SnowflakeDsn): + # extra database and schema query parameters for the url + for key, value in self.connection_string.query_params(): + if key.lower() == "schema": + return value + # TODO: attempt to parse schema from a ConfigStr + return None + + @property + def database(self) -> str | None: + """Convenience property to get the `database` regardless of the connection string format.""" + if isinstance(self.connection_string, ConnectionDetails): + return self.connection_string.database + elif isinstance(self.connection_string, SnowflakeDsn): + # extra database and schema query parameters for the url + for key, value in self.connection_string.query_params(): + if key.lower() == "database": + return value + # TODO: attempt to parse database from a ConfigStr + return None + @pydantic.root_validator(pre=True) def _convert_root_connection_detail_fields(cls, values: dict) -> dict: """ @@ -151,6 +206,41 @@ def _check_xor_input_args(cls, values: dict) -> dict: "Must provide either a connection string or a combination of account, user, and password." # noqa: E501 ) + @pydantic.validator("connection_string") + def _check_for_required_query_params( + cls, connection_string: ConnectionDetails | SnowflakeDsn | ConfigStr + ) -> ConnectionDetails | SnowflakeDsn | ConfigStr: + """ + If connection_string is a SnowflakeDsn, + check for required query parameters according to `REQUIRED_QUERY_PARAMS`. + """ + if not isinstance(connection_string, (SnowflakeDsn, ConfigStr)): + return connection_string + + missing_keys: set[str] = set(REQUIRED_QUERY_PARAMS) + if isinstance(connection_string, ConfigStr): + query_str = connection_string.template_str.partition("?")[2] + # best effort: query could be part of the config substitution. + # Have to check this when adding assets. + if not query_str: + LOGGER.info(f"Unable to validate query parameters for {connection_string}") + return connection_string + else: + query_str = connection_string.query + + if query_str: + query_params: dict[str, list[str]] = urllib.parse.parse_qs(query_str) + + for key in REQUIRED_QUERY_PARAMS: + if key in query_params: + missing_keys.remove(key) + + if missing_keys: + raise _UrlMissingQueryError( + msg=f"missing {', '.join(sorted(missing_keys))}", + ) + return connection_string + class Config: @staticmethod def schema_extra(schema: dict, model: type[SnowflakeDatasource]) -> None: diff --git a/great_expectations/datasource/fluent/sources.pyi b/great_expectations/datasource/fluent/sources.pyi index 49b93080d02a..5c044cb6053e 100644 --- a/great_expectations/datasource/fluent/sources.pyi +++ b/great_expectations/datasource/fluent/sources.pyi @@ -609,8 +609,8 @@ class _SourceFactories: account: str = ..., user: str = ..., password: Union[ConfigStr, str] = ..., - database: Optional[str] = ..., - schema: Optional[str] = ..., + database: str = ..., + schema: str = ..., warehouse: Optional[str] = ..., role: Optional[str] = ..., numpy: bool = ..., @@ -647,8 +647,8 @@ class _SourceFactories: account: str = ..., user: str = ..., password: Union[ConfigStr, str] = ..., - database: Optional[str] = ..., - schema: Optional[str] = ..., + database: str = ..., + schema: str = ..., warehouse: Optional[str] = ..., role: Optional[str] = ..., numpy: bool = ..., @@ -685,8 +685,8 @@ class _SourceFactories: account: str = ..., user: str = ..., password: Union[ConfigStr, str] = ..., - database: Optional[str] = ..., - schema: Optional[str] = ..., + database: str = ..., + schema: str = ..., warehouse: Optional[str] = ..., role: Optional[str] = ..., numpy: bool = ..., diff --git a/pyproject.toml b/pyproject.toml index a8225d797e9b..73c466349706 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -459,6 +459,10 @@ filterwarnings = [ "ignore:stream argument is deprecated. Use stream parameter in request directly:DeprecationWarning", # We likely won't be updating to `marhsmallow` 4, these errors should be filtered out "error::marshmallow.warnings.RemovedInMarshmallow4Warning", + # pkg_resources is deprecated as an API, but third party libraries still use it + "once: pkg_resources is deprecated as an API.:DeprecationWarning", + 'once: Deprecated call to `pkg_resources.declare_namespace\(.*\)`', + # --------------------------------------- Great Expectations Warnings ---------------------------------- # This warning is for configuring the result_format parameter at the Validator-level, which will not be persisted, diff --git a/tests/datasource/fluent/great_expectations.yml b/tests/datasource/fluent/great_expectations.yml index b16f65a0377f..ba1b6a031a7e 100644 --- a/tests/datasource/fluent/great_expectations.yml +++ b/tests/datasource/fluent/great_expectations.yml @@ -198,7 +198,7 @@ fluent_datasources: abs_container: "this_is_always_required" my_snowflake_ds: type: snowflake - connection_string: "snowflake://user_login_name:password@account_identifier" + connection_string: "snowflake://user_login_name:password@account_identifier?database=testdb&schema=public" assets: my_table_asset_wo_partitioners: id: d8b22f50-d3f9-4d04-9b4c-cfed86b157ff diff --git a/tests/datasource/fluent/integration/test_connections.py b/tests/datasource/fluent/integration/test_connections.py index 40820416262d..202942b1802d 100644 --- a/tests/datasource/fluent/integration/test_connections.py +++ b/tests/datasource/fluent/integration/test_connections.py @@ -24,28 +24,20 @@ class TestSnowflake: "connection_string", [ param( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?warehouse=ci", + "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?warehouse=ci&database=ci&schema=public", id="missing role", ), param( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?warehouse=ci&role=ci_no_select", + "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?warehouse=ci&role=ci_no_select&database=ci&schema=public", id="role wo select", ), - param( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}?warehouse=ci&role=ci", - id="missing database + schema", - ), - param( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci?warehouse=ci&role=ci", - id="missing schema", - ), ], ) def test_un_queryable_asset_should_raise_error( self, context: DataContext, connection_string: str ): """ - A SnowflakeDatasource can successfully connect even if things like database, schema, warehouse, and role are omitted. + A SnowflakeDatasource can successfully connect even if things like warehouse, and role are omitted. However, if we try to add an asset that is not queryable with the current datasource connection details, then we should expect a TestConnectionError. https://docs.snowflake.com/en/developer-guide/python-connector/sqlalchemy#connection-parameters @@ -84,11 +76,11 @@ def test_un_queryable_asset_should_raise_error( "connection_string", [ param( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?warehouse=ci&role=ci", + "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?warehouse=ci&role=ci&database=ci&schema=public", id="full connection string", ), param( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?role=ci", + "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?role=ci&database=ci&schema=public", id="missing warehouse", ), ], diff --git a/tests/datasource/fluent/integration/test_sql_datasources.py b/tests/datasource/fluent/integration/test_sql_datasources.py index 1a2980ec14f3..fdf38fe1f4bc 100644 --- a/tests/datasource/fluent/integration/test_sql_datasources.py +++ b/tests/datasource/fluent/integration/test_sql_datasources.py @@ -333,8 +333,9 @@ def snowflake_ds( pytest.skip("no snowflake credentials") ds = context.data_sources.add_snowflake( "snowflake", - connection_string="snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?warehouse=ci&role=ci", - # NOTE: uncomment this and set SNOWFLAKE_USER to run tests against your own snowflake account # noqa: E501 + connection_string="snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public" + f"?warehouse=ci&role=ci&database=ci&schema={RAND_SCHEMA}", + # NOTE: uncomment this and set SNOWFLAKE_USER to run tests against your own snowflake account # connection_string="snowflake://${SNOWFLAKE_USER}@${SNOWFLAKE_CI_ACCOUNT}/DEMO_DB/RESTAURANTS?warehouse=COMPUTE_WH&role=PUBLIC&authenticator=externalbrowser", ) return ds diff --git a/tests/datasource/fluent/test_snowflake_datasource.py b/tests/datasource/fluent/test_snowflake_datasource.py index 40d02119a551..d717f8f5de16 100644 --- a/tests/datasource/fluent/test_snowflake_datasource.py +++ b/tests/datasource/fluent/test_snowflake_datasource.py @@ -26,16 +26,28 @@ def seed_env_vars(monkeypatch: pytest.MonkeyPatch) -> None: "config_kwargs", [ param( - {"connection_string": "snowflake://my_user:password@my_account"}, + { + "connection_string": "snowflake://my_user:password@my_account?numpy=True&schema=s_public&database=d_public" + }, id="connection_string str", ), param({"connection_string": "${MY_CONN_STR}"}, id="connection_string ConfigStr"), + param( + {"connection_string": "${MY_CONN_STR}"}, + id="connection_string ConfigStr missing query params", + ), + param( + {"connection_string": "${MY_CONN_STR}?database=my_db&schema=my_schema"}, + id="connection_string ConfigStr with required query params", + ), param( { "connection_string": { "user": "my_user", "password": "password", "account": "my_account", + "schema": "s_public", + "database": "d_public", } }, id="connection_string dict", @@ -46,12 +58,20 @@ def seed_env_vars(monkeypatch: pytest.MonkeyPatch) -> None: "user": "my_user", "password": "${MY_PASSWORD}", "account": "my_account", + "schema": "s_public", + "database": "d_public", } }, id="connection_string dict with password ConfigStr", ), param( - {"user": "my_user", "password": "password", "account": "my_account"}, + { + "user": "my_user", + "password": "password", + "account": "my_account", + "schema": "s_public", + "database": "d_public", + }, id="old config format - top level keys", ), ], @@ -72,21 +92,212 @@ def test_valid_config( @pytest.mark.unit @pytest.mark.parametrize( - "connection_string, connect_args", + ["connection_string", "expected_errors"], + [ + pytest.param( + "snowflake://my_user:password@my_account", + [ + { + "ctx": {"msg": "missing database, schema"}, + "loc": ("connection_string",), + "msg": "URL query param missing", + "type": "value_error.url.query", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of" + " account, user, and password.", + "type": "value_error", + }, + ], + id="missing database + schema", + ), + pytest.param( + "snowflake://${my_user}:${password}@my_account?numpy=True", + [ + { + "ctx": {"msg": "missing database, schema"}, + "loc": ("connection_string",), + "msg": "URL query param missing", + "type": "value_error.url.query", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of account," + " user, and password.", + "type": "value_error", + }, + ], + id="ConfigStr missing database + schema", + ), + pytest.param( + "snowflake://my_user:password@my_account?database=my_db", + [ + { + "ctx": {"msg": "missing schema"}, + "loc": ("connection_string",), + "msg": "URL query param missing", + "type": "value_error.url.query", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of" + " account, user, and password.", + "type": "value_error", + }, + ], + id="missing schema", + ), + pytest.param( + "snowflake://${my_user}:${password}@my_account?database=my_db", + [ + { + "ctx": {"msg": "missing schema"}, + "loc": ("connection_string",), + "msg": "URL query param missing", + "type": "value_error.url.query", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of" + " account, user, and password.", + "type": "value_error", + }, + ], + id="ConfigStr missing schema", + ), + pytest.param( + "snowflake://my_user:password@my_account?schema=my_schema", + [ + { + "ctx": {"msg": "missing database"}, + "loc": ("connection_string",), + "msg": "URL query param missing", + "type": "value_error.url.query", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of" + " account, user, and password.", + "type": "value_error", + }, + ], + id="missing database", + ), + pytest.param( + "snowflake://my_user:${password}@my_account?schema=my_schema", + [ + { + "ctx": {"msg": "missing database"}, + "loc": ("connection_string",), + "msg": "URL query param missing", + "type": "value_error.url.query", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of" + " account, user, and password.", + "type": "value_error", + }, + ], + id="ConfigStr missing database", + ), + ], +) +def test_missing_required_query_params( + connection_string: str, + expected_errors: list[dict], # TODO: use pydantic error dict +): + with pytest.raises(pydantic.ValidationError) as exc_info: + _ = SnowflakeDatasource( + name="my_sf_ds", + connection_string=connection_string, + ) + assert exc_info.value.errors() == expected_errors + + +@pytest.mark.unit +@pytest.mark.parametrize( + "connection_string, connect_args, expected_errors", [ pytest.param( - "snowflake://:@", - {"account": "my_account", "user": "my_user", "password": "123456"}, + "snowflake://my_user:password@my_account?numpy=True&schema=foo&database=bar", + { + "account": "my_account", + "user": "my_user", + "password": "123456", + "schema": "foo", + "database": "bar", + }, + [ + { + "loc": ("__root__",), + "msg": "Cannot provide both a connection string and a combination of" + " account, user, and password.", + "type": "value_error", + } + ], id="both connection_string and connect_args", ), pytest.param(None, {}, id="neither connection_string nor connect_args"), pytest.param( None, - {"account": "my_account", "user": "my_user"}, + {}, + [ + { + "loc": ("connection_string",), + "msg": "none is not an allowed value", + "type": "type_error.none.not_allowed", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of" + " account, user, and password.", + "type": "value_error", + }, + ], + id="neither connection_string nor connect_args", + ), + pytest.param( + None, + { + "account": "my_account", + "user": "my_user", + "schema": "foo", + "database": "bar", + }, + [ + { + "loc": ("connection_string", "password"), + "msg": "field required", + "type": "value_error.missing", + }, + { + "loc": ("connection_string",), + "msg": f"""expected string or bytes-like object{"" if python_version < (3, 11) else ", got 'dict'"}""", + "type": "type_error", + }, + { + "loc": ("connection_string",), + "msg": "str type expected", + "type": "type_error.str", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of" + " account, user, and password.", + "type": "value_error", + }, + ], id="incomplete connect_args", ), pytest.param( - {"connection_string": {"account": "my_account", "user": "my_user"}}, + { + "account": "my_account", + "user": "my_user", + "schema": "foo", + "database": "bar", + }, {}, id="incomplete connection_string dict connect_args", ), @@ -196,11 +407,61 @@ def test_invalid_connection_string_raises_dsn_error( @pytest.mark.skipif(True if not snowflake else False, reason="snowflake is not installed") @pytest.mark.unit def test_get_execution_engine_succeeds(): - connection_string = "snowflake://my_user:password@my_account" + connection_string = "snowflake://my_user:password@my_account?database=foo&schema=bar" datasource = SnowflakeDatasource(name="my_snowflake", connection_string=connection_string) # testing that this doesn't raise an exception datasource.get_execution_engine() +@pytest.mark.snowflake +@pytest.mark.parametrize( + "connection_string", + [ + param( + "snowflake://my_user:password@my_account?numpy=True&database=foo&schema=bar", + id="connection_string str", + ), + param( + { + "user": "my_user", + "password": "password", + "account": "my_account", + "database": "foo", + "schema": "bar", + }, + id="connection_string dict", + ), + ], +) +@pytest.mark.parametrize( + "context_fixture_name,expected_query_param", + [ + param( + "empty_file_context", + "great_expectations_core", + id="file context", + ), + param( + "empty_cloud_context_fluent", + "great_expectations_platform", + id="cloud context", + ), + ], +) +def test_get_engine_correctly_sets_application_query_param( + request, + context_fixture_name: str, + expected_query_param: str, + connection_string: str | dict, +): + context = request.getfixturevalue(context_fixture_name) + my_sf_ds = SnowflakeDatasource(name="my_sf_ds", connection_string=connection_string) + my_sf_ds._data_context = context + + sql_engine = my_sf_ds.get_engine() + application_query_param = sql_engine.url.query.get("application") + assert application_query_param == expected_query_param + + if __name__ == "__main__": pytest.main([__file__, "-vv"]) diff --git a/tests/integration/cloud/end_to_end/test_snowflake_datasource.py b/tests/integration/cloud/end_to_end/test_snowflake_datasource.py index adae99dcf60f..caa4b60cdb14 100644 --- a/tests/integration/cloud/end_to_end/test_snowflake_datasource.py +++ b/tests/integration/cloud/end_to_end/test_snowflake_datasource.py @@ -2,7 +2,7 @@ import os import uuid -from typing import TYPE_CHECKING, Iterator +from typing import TYPE_CHECKING, Final, Iterator import pytest @@ -23,13 +23,21 @@ from great_expectations.validator.validator import Validator from tests.integration.cloud.end_to_end.conftest import TableFactory +RANDOM_SCHEMA: Final[str] = f"i{uuid.uuid4().hex}" + @pytest.fixture(scope="module") def connection_string() -> str: if os.getenv("SNOWFLAKE_CI_USER_PASSWORD") and os.getenv("SNOWFLAKE_CI_ACCOUNT"): - return "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?warehouse=ci&role=ci" + return ( + "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public" + f"?database=ci&schema={RANDOM_SCHEMA}&warehouse=ci&role=ci" + ) elif os.getenv("SNOWFLAKE_USER") and os.getenv("SNOWFLAKE_CI_ACCOUNT"): - return "snowflake://${SNOWFLAKE_USER}@${SNOWFLAKE_CI_ACCOUNT}/DEMO_DB?warehouse=COMPUTE_WH&role=PUBLIC&authenticator=externalbrowser" + return ( + "snowflake://${SNOWFLAKE_USER}@${SNOWFLAKE_CI_ACCOUNT}/DEMO_DB" + f"?database=ci&schema={RANDOM_SCHEMA}&warehouse=COMPUTE_WH&role=PUBLIC&authenticator=externalbrowser" + ) else: pytest.skip("no snowflake credentials") @@ -88,18 +96,17 @@ def table_asset( datasource: SnowflakeDatasource, asset_name: str, table_factory: TableFactory, -) -> TableAsset: - schema_name = f"i{uuid.uuid4().hex}" + get_missing_data_asset_error_type: type[Exception], +) -> Iterator[TableAsset]: table_name = f"i{uuid.uuid4().hex}" table_factory( gx_engine=datasource.get_execution_engine(), table_names={table_name}, - schema_name=schema_name, + schema_name=RANDOM_SCHEMA, ) - return datasource.add_table_asset( - name=asset_name, - schema_name=schema_name, - table_name=table_name, + asset_name = f"i{uuid.uuid4().hex}" + _ = datasource.add_table_asset( + name=asset_name, table_name=table_name, schema_name=RANDOM_SCHEMA ) From fe408169a9756643afe890b6f83f5cd412419b05 Mon Sep 17 00:00:00 2001 From: Gabriel Gore Date: Wed, 5 Jun 2024 15:24:41 -0400 Subject: [PATCH 02/20] use `schema` when creating `TableAsset` --- .../datasource/fluent/__init__.py | 1 + .../datasource/fluent/config_str.py | 11 +- .../datasource/fluent/interfaces.py | 8 + .../fluent/schemas/SnowflakeDatasource.json | 4 +- .../datasource/fluent/snowflake_datasource.py | 220 +++++++++++-- tests/conftest.py | 17 + .../datasource/fluent/great_expectations.yml | 2 +- .../fluent/integration/test_connections.py | 4 +- .../integration/test_sql_datasources.py | 23 +- .../fluent/test_snowflake_datasource.py | 308 ++++++++++++------ .../end_to_end/test_snowflake_datasource.py | 6 +- 11 files changed, 472 insertions(+), 132 deletions(-) diff --git a/great_expectations/datasource/fluent/__init__.py b/great_expectations/datasource/fluent/__init__.py index edbeef86d536..4052427ef418 100644 --- a/great_expectations/datasource/fluent/__init__.py +++ b/great_expectations/datasource/fluent/__init__.py @@ -8,6 +8,7 @@ Sorter, BatchMetadata, GxDatasourceWarning, + GxContextWarning, TestConnectionError, ) from great_expectations.datasource.fluent.invalid_datasource import ( diff --git a/great_expectations/datasource/fluent/config_str.py b/great_expectations/datasource/fluent/config_str.py index 7eea69c447c4..1714910c0966 100644 --- a/great_expectations/datasource/fluent/config_str.py +++ b/great_expectations/datasource/fluent/config_str.py @@ -54,8 +54,15 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({self._display()!r})" @classmethod - def _validate_template_str_format(cls, v): - if TEMPLATE_STR_REGEX.search(v): + def str_contains_config_template(cls, v: str) -> bool: + """ + Returns True if the input string contains a config template string. + """ + return TEMPLATE_STR_REGEX.search(v) is not None + + @classmethod + def _validate_template_str_format(cls, v: str) -> str | None: + if cls.str_contains_config_template(v): return v raise ValueError( cls.__name__ diff --git a/great_expectations/datasource/fluent/interfaces.py b/great_expectations/datasource/fluent/interfaces.py index 4e3e5a949361..ed8c57029b4e 100644 --- a/great_expectations/datasource/fluent/interfaces.py +++ b/great_expectations/datasource/fluent/interfaces.py @@ -181,6 +181,14 @@ class GxDatasourceWarning(UserWarning): """ +class GxContextWarning(GxDatasourceWarning): + """ + Warning related to a Datasource that with a missing context. + Usually because the Datasource was created directly rather than using a + `context.sources` factory method. + """ + + class GxSerializationWarning(GxDatasourceWarning): pass diff --git a/great_expectations/datasource/fluent/schemas/SnowflakeDatasource.json b/great_expectations/datasource/fluent/schemas/SnowflakeDatasource.json index b3a5269d113d..c65e6e8e57ac 100644 --- a/great_expectations/datasource/fluent/schemas/SnowflakeDatasource.json +++ b/great_expectations/datasource/fluent/schemas/SnowflakeDatasource.json @@ -538,7 +538,7 @@ }, "ConnectionDetails": { "title": "ConnectionDetails", - "description": "Information needed to connect to a Snowflake database.\nAlternative to a connection string.", + "description": "Information needed to connect to a Snowflake database.\nAlternative to a connection string.\n\nhttps://docs.snowflake.com/en/developer-guide/python-connector/sqlalchemy#additional-connection-parameters", "type": "object", "properties": { "account": { @@ -564,10 +564,12 @@ }, "database": { "title": "Database", + "description": "`database` that the Datasource is mapped to.", "type": "string" }, "schema": { "title": "Schema", + "description": "`schema` that the Datasource is mapped to.", "type": "string" }, "warehouse": { diff --git a/great_expectations/datasource/fluent/snowflake_datasource.py b/great_expectations/datasource/fluent/snowflake_datasource.py index 95bdb1165d00..7c8651806253 100644 --- a/great_expectations/datasource/fluent/snowflake_datasource.py +++ b/great_expectations/datasource/fluent/snowflake_datasource.py @@ -1,7 +1,9 @@ from __future__ import annotations +import functools import logging import urllib.parse +import warnings from typing import ( TYPE_CHECKING, Any, @@ -13,12 +15,13 @@ Union, ) -from great_expectations._docs_decorators import public_api +from great_expectations._docs_decorators import deprecated_method_or_class, public_api from great_expectations.compatibility import pydantic from great_expectations.compatibility.pydantic import AnyUrl, errors from great_expectations.compatibility.snowflake import URL from great_expectations.compatibility.sqlalchemy import sqlalchemy as sa from great_expectations.compatibility.typing_extensions import override +from great_expectations.datasource.fluent import GxContextWarning, GxDatasourceWarning from great_expectations.datasource.fluent.config_str import ( ConfigStr, _check_config_substitutions_needed, @@ -27,20 +30,74 @@ FluentBaseModel, SQLDatasource, SQLDatasourceError, + TableAsset, ) if TYPE_CHECKING: from great_expectations.compatibility import sqlalchemy from great_expectations.compatibility.pydantic.networks import Parts + from great_expectations.datasource.fluent.interfaces import ( + BatchMetadata, + SortersDefinition, + ) from great_expectations.execution_engine import SqlAlchemyExecutionEngine LOGGER: Final[logging.Logger] = logging.getLogger(__name__) REQUIRED_QUERY_PARAMS: Final[Iterable[str]] = { # errors will be thrown if any of these are missing - "database", - "schema", + # TODO: require warehouse and role + # "warehouse", + # "role", } +MISSING: Final = object() # sentinel value to indicate missing values + + +def _extract_query_section(url: str) -> str | None: + """ + Extracts the query section of a URL if it exists. + + snowflake://user:password@account?warehouse=warehouse&role=role + """ + return url.partition("?")[2] + + +@functools.lru_cache(maxsize=4) +def _extract_path_sections(url_path: str) -> dict[str, str]: + """ + Extracts the database and schema from the path of a URL. + + Raises UrlPathError if the path is missing database/schema. + + snowflake://user:password@account/database/schema + """ + try: + _, database, schema, *_ = url_path.split("/") + except (ValueError, AttributeError) as e: + LOGGER.info(f"Unable to split path - {e!r}") + raise UrlPathError() from e + if not database: + raise UrlPathError(msg="missing database") + if not schema: + raise UrlPathError(msg="missing schema") + return {"database": database, "schema": schema} + + +def _get_config_substituted_connection_string( + datasource: SnowflakeDatasource, + warning_msg: str = "Unable to perform config substitution", +) -> str | None: + if not isinstance(datasource.connection_string, ConfigStr): + raise TypeError("Config substitution is only supported for `ConfigStr`") + if not datasource._data_context: + warnings.warn( + f"{warning_msg} for {datasource.connection_string.template_str}." + " Likely missing a context.", + category=GxContextWarning, + ) + return None + return datasource.connection_string.get_config_value(datasource._data_context.config_provider) + class _UrlPasswordError(pydantic.UrlError): """ @@ -60,6 +117,18 @@ class _UrlDomainError(pydantic.UrlError): msg_template = "URL domain invalid" +class UrlPathError(pydantic.UrlError): + """ + Custom Pydantic error for missing path in SnowflakeDsn. + """ + + code = "url.path" + msg_template = "URL path missing database/schema" + + def __init__(self, **ctx: Any) -> None: + super().__init__(**ctx) + + class _UrlMissingQueryError(pydantic.UrlError): """ Custom Pydantic error for missing query parameters in SnowflakeDsn. @@ -95,20 +164,64 @@ def validate_parts(cls, parts: Parts, validate_port: bool = True) -> Parts: if domain is None: raise _UrlDomainError() - return AnyUrl.validate_parts(parts=parts, validate_port=validate_port) + validated_parts = AnyUrl.validate_parts(parts=parts, validate_port=validate_port) + + path: str = parts["path"] + # raises UrlPathError if path is missing database/schema + _extract_path_sections(path) + + return validated_parts + + @property + def params(self) -> dict[str, list[str]]: + """The query parameters as a dictionary.""" + if not self.query: + return {} + return urllib.parse.parse_qs(self.query) + + @property + def account_identifier(self) -> str: + """Alias for host.""" + assert self.host + return self.host + + @property + def database(self) -> str: + assert self.path + return self.path.split("/")[1] + + @property + def schema_(self) -> str: + assert self.path + return self.path.split("/")[2] + + @property + def warehouse(self) -> str | None: + return self.params.get("warehouse", [None])[0] + + @property + def role(self) -> str | None: + return self.params.get("role", [None])[0] class ConnectionDetails(FluentBaseModel): """ Information needed to connect to a Snowflake database. Alternative to a connection string. + + https://docs.snowflake.com/en/developer-guide/python-connector/sqlalchemy#additional-connection-parameters """ account: str user: str password: Union[ConfigStr, str] - database: str - schema_: str = pydantic.Field(..., alias="schema") # schema is a reserved attr in BaseModel + database: str = pydantic.Field( + ..., + description="`database` that the Datasource is mapped to.", + ) + schema_: str = pydantic.Field( + ..., alias="schema", description="`schema` that the Datasource is mapped to." + ) # schema is a reserved attr in BaseModel warehouse: Optional[str] = None role: Optional[str] = None numpy: bool = False @@ -139,28 +252,83 @@ def schema_(self) -> str | None: `schema_` to avoid conflict with Pydantic models schema property. """ - if isinstance(self.connection_string, ConnectionDetails): + if isinstance(self.connection_string, (ConnectionDetails, SnowflakeDsn)): return self.connection_string.schema_ - elif isinstance(self.connection_string, SnowflakeDsn): - # extra database and schema query parameters for the url - for key, value in self.connection_string.query_params(): - if key.lower() == "schema": - return value - # TODO: attempt to parse schema from a ConfigStr - return None + + subbed_str: str | None = _get_config_substituted_connection_string( + self, warning_msg="Unable to determine schema" + ) + if not subbed_str: + return None + url_path: str = urllib.parse.urlparse(subbed_str).path + return _extract_path_sections(url_path)["schema"] @property def database(self) -> str | None: """Convenience property to get the `database` regardless of the connection string format.""" - if isinstance(self.connection_string, ConnectionDetails): + if isinstance(self.connection_string, (ConnectionDetails, SnowflakeDsn)): return self.connection_string.database - elif isinstance(self.connection_string, SnowflakeDsn): - # extra database and schema query parameters for the url - for key, value in self.connection_string.query_params(): - if key.lower() == "database": - return value - # TODO: attempt to parse database from a ConfigStr - return None + + subbed_str: str | None = _get_config_substituted_connection_string( + self, warning_msg="Unable to determine database" + ) + if not subbed_str: + return None + url_path: str = urllib.parse.urlparse(subbed_str).path + return _extract_path_sections(url_path)["database"] + + @deprecated_method_or_class( + version="0.18.16", + message="`schema_name` is deprecated." " The schema now comes from the datasource.", + ) + @public_api + @override + def add_table_asset( # noqa: PLR0913 + self, + name: str, + table_name: str = "", + schema_name: Optional[str] = MISSING, # type: ignore[assignment] # sentinel value + order_by: Optional[SortersDefinition] = None, + batch_metadata: Optional[BatchMetadata] = None, + ) -> TableAsset: + """Adds a table asset to this datasource. + + Args: + name: The name of this table asset. + table_name: The table where the data resides. + schema_name: The schema that holds the table. Will use the datasource schema if not + provided. + order_by: A list of Sorters or Sorter strings. + batch_metadata: BatchMetadata we want to associate with this DataAsset and all batches + derived from it. + + Returns: + The table asset that is added to the datasource. + The type of this object will match the necessary type for this datasource. + """ + if schema_name is MISSING: + # using MISSING to indicate that the user did not provide a value + schema_name = self.schema_ + else: + # deprecated-v0.18.16 + warnings.warn( + "The `schema_name argument` is deprecated and will be removed in a future release." + " The schema now comes from the datasource.", + category=DeprecationWarning, + ) + if schema_name != self.schema_: + warnings.warn( + f"schema_name {schema_name} does not match datasource schema {self.schema_}", + category=GxDatasourceWarning, + ) + + return super().add_table_asset( + name=name, + table_name=table_name, + schema_name=schema_name, + order_by=order_by, + batch_metadata=batch_metadata, + ) @pydantic.root_validator(pre=True) def _convert_root_connection_detail_fields(cls, values: dict) -> dict: @@ -219,10 +387,10 @@ def _check_for_required_query_params( missing_keys: set[str] = set(REQUIRED_QUERY_PARAMS) if isinstance(connection_string, ConfigStr): - query_str = connection_string.template_str.partition("?")[2] - # best effort: query could be part of the config substitution. - # Have to check this when adding assets. - if not query_str: + query_str = _extract_query_section(connection_string.template_str) + # best effort: query could be part of the config substitution. Have to check this when + # adding assets. + if not query_str or ConfigStr.str_contains_config_template(query_str): LOGGER.info(f"Unable to validate query parameters for {connection_string}") return connection_string else: diff --git a/tests/conftest.py b/tests/conftest.py index c8f1e63206ff..14948c26edaf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2670,3 +2670,20 @@ def filter_gx_datasource_warnings() -> Generator[None, None, None]: with warnings.catch_warnings(): warnings.simplefilter("ignore", category=GxDatasourceWarning) yield + + +@pytest.fixture(scope="function") +def param_id(request: pytest.FixtureRequest) -> str: + """Return the parameter id of the current test. + + Example: + + ```python + @pytest.mark.parametrize("my_param", ["a", "b", "c"], ids=lambda x: x.upper()) + def test_something(param_id: str, my_param: str): + assert my_param != param_id + assert my_param.upper() == param_id + ``` + """ + raw_name: str = request.node.name + return raw_name.split("[")[1].split("]")[0] diff --git a/tests/datasource/fluent/great_expectations.yml b/tests/datasource/fluent/great_expectations.yml index ba1b6a031a7e..77947860c19f 100644 --- a/tests/datasource/fluent/great_expectations.yml +++ b/tests/datasource/fluent/great_expectations.yml @@ -198,7 +198,7 @@ fluent_datasources: abs_container: "this_is_always_required" my_snowflake_ds: type: snowflake - connection_string: "snowflake://user_login_name:password@account_identifier?database=testdb&schema=public" + connection_string: "snowflake://user_login_name:password@account_identifier/database/public" assets: my_table_asset_wo_partitioners: id: d8b22f50-d3f9-4d04-9b4c-cfed86b157ff diff --git a/tests/datasource/fluent/integration/test_connections.py b/tests/datasource/fluent/integration/test_connections.py index 202942b1802d..f3e246bbb7d0 100644 --- a/tests/datasource/fluent/integration/test_connections.py +++ b/tests/datasource/fluent/integration/test_connections.py @@ -24,11 +24,11 @@ class TestSnowflake: "connection_string", [ param( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?warehouse=ci&database=ci&schema=public", + "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?warehouse=ci", id="missing role", ), param( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?warehouse=ci&role=ci_no_select&database=ci&schema=public", + "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?warehouse=ci&role=ci_no_select", id="role wo select", ), ], diff --git a/tests/datasource/fluent/integration/test_sql_datasources.py b/tests/datasource/fluent/integration/test_sql_datasources.py index fdf38fe1f4bc..c7edeb0f1c36 100644 --- a/tests/datasource/fluent/integration/test_sql_datasources.py +++ b/tests/datasource/fluent/integration/test_sql_datasources.py @@ -6,6 +6,7 @@ import shutil import sys import uuid +import warnings from pprint import pformat as pf from typing import ( TYPE_CHECKING, @@ -333,8 +334,8 @@ def snowflake_ds( pytest.skip("no snowflake credentials") ds = context.data_sources.add_snowflake( "snowflake", - connection_string="snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public" - f"?warehouse=ci&role=ci&database=ci&schema={RAND_SCHEMA}", + connection_string="snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci" + f"/{RAND_SCHEMA}?warehouse=ci&role=ci", # NOTE: uncomment this and set SNOWFLAKE_USER to run tests against your own snowflake account # connection_string="snowflake://${SNOWFLAKE_USER}@${SNOWFLAKE_CI_ACCOUNT}/DEMO_DB/RESTAURANTS?warehouse=COMPUTE_WH&role=PUBLIC&authenticator=externalbrowser", ) @@ -452,7 +453,7 @@ def test_snowflake( if not snowflake_ds: pytest.skip("no snowflake datasource") # create table - schema = get_random_identifier_name() + schema = RAND_SCHEMA table_factory( gx_engine=snowflake_ds.get_execution_engine(), table_names={table_name}, @@ -462,7 +463,7 @@ def test_snowflake( table_names: list[str] = inspect(snowflake_ds.get_engine()).get_table_names(schema=schema) print(f"snowflake tables:\n{pf(table_names)}))") - snowflake_ds.add_table_asset(asset_name, table_name=table_name, schema_name=schema) + snowflake_ds.add_table_asset(asset_name, table_name=table_name) @pytest.mark.sqlite def test_sqlite( @@ -483,6 +484,9 @@ def test_sqlite( sqlite_ds.add_table_asset(asset_name, table_name=table_name) + @pytest.mark.filterwarnings( # snowflake `add_table_asset` raises warning on passing a schema + "once::great_expectations.datasource.fluent.GxDatasourceWarning" + ) @pytest.mark.parametrize( "datasource_type,schema", [ @@ -519,8 +523,12 @@ def test_checkpoint_run( schema=schema, ) - asset = datasource.add_table_asset(asset_name, table_name=table_name, schema_name=schema) - batch_definition = asset.add_batch_definition_whole_table("whole table!") + with warnings.catch_warnings(): + # passing a schema to snowflake tables is deprecated + warnings.simplefilter("once", DeprecationWarning) + asset = datasource.add_table_asset( + asset_name, table_name=table_name, schema_name=schema + ) suite = context.suites.add(ExpectationSuite(name=f"{datasource.name}-{asset.name}")) suite.add_expectation(gxe.ExpectColumnValuesToNotBeNull(column="name", mostly=1)) @@ -712,6 +720,9 @@ def _raw_query_check_column_exists( return True +@pytest.mark.filterwarnings( + "once::DeprecationWarning" +) # snowflake `add_table_asset` raises warning on passing a schema @pytest.mark.parametrize( "column_name", [ diff --git a/tests/datasource/fluent/test_snowflake_datasource.py b/tests/datasource/fluent/test_snowflake_datasource.py index d717f8f5de16..6b9836496f6d 100644 --- a/tests/datasource/fluent/test_snowflake_datasource.py +++ b/tests/datasource/fluent/test_snowflake_datasource.py @@ -1,5 +1,9 @@ from __future__ import annotations +from pprint import pformat as pf +from sys import version_info as python_version +from typing import TYPE_CHECKING, Final, Sequence + import pytest import sqlalchemy as sa from pytest import param @@ -7,6 +11,7 @@ from great_expectations.compatibility import pydantic from great_expectations.compatibility.snowflake import snowflake from great_expectations.data_context import AbstractDataContext +from great_expectations.datasource.fluent import GxContextWarning from great_expectations.datasource.fluent.config_str import ConfigStr from great_expectations.datasource.fluent.snowflake_datasource import ( SnowflakeDatasource, @@ -14,56 +19,88 @@ ) from great_expectations.execution_engine import SqlAlchemyExecutionEngine +if TYPE_CHECKING: + from pytest.mark.structures import ParameterSet + +VALID_DS_CONFIG_PARAMS: Final[Sequence[ParameterSet]] = [ + param( + { + "connection_string": "snowflake://my_user:password@my_account/d_public/s_public?numpy=True" + }, + id="connection_string str", + ), + param( + {"connection_string": "${MY_CONN_STR_PARTIAL}@${MY_PATH}?${MY_QUERY_PARAMS}"}, + id="connection_string ConfigStr with required query params", + ), + param( + {"connection_string": "${MY_CONN_STR_FULL}"}, + id="connection_string ConfigStr - required params part of sub", + ), + param( + {"connection_string": "${MY_CONN_STR_MIN}?${MY_QUERY_PARAMS}"}, + id="connection_string ConfigStr - dedicated query params sub", + ), + param( + { + "connection_string": { + "user": "my_user", + "password": "password", + "account": "my_account", + "schema": "s_public", + "database": "d_public", + } + }, + id="connection_string dict", + ), + param( + { + "connection_string": { + "user": "my_user", + "password": "${MY_PASSWORD}", + "account": "my_account", + "schema": "s_public", + "database": "d_public", + } + }, + id="connection_string dict with password ConfigStr", + ), +] + @pytest.fixture def seed_env_vars(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setenv("MY_CONN_STR", "snowflake://my_user:password@my_account") + monkeypatch.setenv("MY_CONN_STR_PARTIAL", "snowflake://my_user:password") + monkeypatch.setenv("MY_CONN_STR_MIN", "snowflake://my_user:password@my_account/my_db/my_schema") + monkeypatch.setenv( + "MY_CONN_STR_FULL", + "snowflake://my_user:password@my_account/my_db/my_schema?warehouse=my_wh&role=my_role", + ) + monkeypatch.setenv("MY_PATH", "my_account/my_db/my_schema") + monkeypatch.setenv("MY_QUERY_PARAMS", "warehouse=my_wh&role=my_role") monkeypatch.setenv("MY_PASSWORD", "my_password") +@pytest.mark.unit +def test_snowflake_dsn(): + dsn = pydantic.parse_obj_as( + SnowflakeDsn, + "snowflake://my_user:password@my_account/my_db/my_schema?role=my_role&warehouse=my_wh", + ) + assert dsn.user == "my_user" + assert dsn.password == "password" + assert dsn.account_identifier == "my_account" + assert dsn.database == "my_db" + assert dsn.schema_ == "my_schema" + assert dsn.role == "my_role" + assert dsn.warehouse == "my_wh" + + @pytest.mark.snowflake # TODO: make this a unit test @pytest.mark.parametrize( "config_kwargs", [ - param( - { - "connection_string": "snowflake://my_user:password@my_account?numpy=True&schema=s_public&database=d_public" - }, - id="connection_string str", - ), - param({"connection_string": "${MY_CONN_STR}"}, id="connection_string ConfigStr"), - param( - {"connection_string": "${MY_CONN_STR}"}, - id="connection_string ConfigStr missing query params", - ), - param( - {"connection_string": "${MY_CONN_STR}?database=my_db&schema=my_schema"}, - id="connection_string ConfigStr with required query params", - ), - param( - { - "connection_string": { - "user": "my_user", - "password": "password", - "account": "my_account", - "schema": "s_public", - "database": "d_public", - } - }, - id="connection_string dict", - ), - param( - { - "connection_string": { - "user": "my_user", - "password": "${MY_PASSWORD}", - "account": "my_account", - "schema": "s_public", - "database": "d_public", - } - }, - id="connection_string dict with password ConfigStr", - ), + *VALID_DS_CONFIG_PARAMS, param( { "user": "my_user", @@ -74,12 +111,19 @@ def seed_env_vars(monkeypatch: pytest.MonkeyPatch) -> None: }, id="old config format - top level keys", ), + param( + {"connection_string": "${MY_CONN_STR_MIN}"}, + id="connection_string ConfigStr missing query params", + ), ], ) def test_valid_config( - empty_file_context: AbstractDataContext, seed_env_vars: None, config_kwargs: dict + empty_file_context: AbstractDataContext, + seed_env_vars: None, + config_kwargs: dict, + param_id: str, ): - my_sf_ds_1 = SnowflakeDatasource(name="my_sf_ds_1", **config_kwargs) + my_sf_ds_1 = SnowflakeDatasource(name=f"my_sf {param_id}", **config_kwargs) assert my_sf_ds_1 my_sf_ds_1._data_context = empty_file_context # attach to enable config substitution @@ -98,121 +142,159 @@ def test_valid_config( "snowflake://my_user:password@my_account", [ { - "ctx": {"msg": "missing database, schema"}, "loc": ("connection_string",), - "msg": "URL query param missing", - "type": "value_error.url.query", + "msg": "value is not a valid dict", + "type": "type_error.dict", + }, + { + "loc": ("connection_string",), + "msg": "ConfigStr - contains no config template strings in the format " + "'${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", + "type": "value_error", + }, + { + "loc": ("connection_string",), + "msg": "URL path missing database/schema", + "type": "value_error.url.path", }, { "loc": ("__root__",), - "msg": "Must provide either a connection string or a combination of" - " account, user, and password.", + "msg": "Must provide either a connection string or a combination of account, " + "user, and password.", "type": "value_error", }, ], - id="missing database + schema", + id="missing path", ), pytest.param( - "snowflake://${my_user}:${password}@my_account?numpy=True", + "snowflake://my_user:password@my_account//", [ { - "ctx": {"msg": "missing database, schema"}, "loc": ("connection_string",), - "msg": "URL query param missing", - "type": "value_error.url.query", + "msg": "value is not a valid dict", + "type": "type_error.dict", + }, + { + "loc": ("connection_string",), + "msg": "ConfigStr - contains no config template strings in the format " + "'${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", + "type": "value_error", + }, + { + "ctx": {"msg": "missing database"}, + "loc": ("connection_string",), + "msg": "URL path missing database/schema", + "type": "value_error.url.path", }, { "loc": ("__root__",), - "msg": "Must provide either a connection string or a combination of account," - " user, and password.", + "msg": "Must provide either a connection string or a combination of account, " + "user, and password.", "type": "value_error", }, ], - id="ConfigStr missing database + schema", + id="missing database + schema", ), pytest.param( - "snowflake://my_user:password@my_account?database=my_db", + "snowflake://my_user:password@my_account/my_db", [ { - "ctx": {"msg": "missing schema"}, "loc": ("connection_string",), - "msg": "URL query param missing", - "type": "value_error.url.query", + "msg": "value is not a valid dict", + "type": "type_error.dict", + }, + { + "loc": ("connection_string",), + "msg": "ConfigStr - contains no config template strings in the format " + "'${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", + "type": "value_error", + }, + { + "loc": ("connection_string",), + "msg": "URL path missing database/schema", + "type": "value_error.url.path", }, { "loc": ("__root__",), - "msg": "Must provide either a connection string or a combination of" - " account, user, and password.", + "msg": "Must provide either a connection string or a combination of account, " + "user, and password.", "type": "value_error", }, ], id="missing schema", ), pytest.param( - "snowflake://${my_user}:${password}@my_account?database=my_db", + "snowflake://my_user:password@my_account/my_db/", [ + { + "loc": ("connection_string",), + "msg": "value is not a valid dict", + "type": "type_error.dict", + }, + { + "loc": ("connection_string",), + "msg": "ConfigStr - contains no config template strings in the format " + "'${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", + "type": "value_error", + }, { "ctx": {"msg": "missing schema"}, "loc": ("connection_string",), - "msg": "URL query param missing", - "type": "value_error.url.query", + "msg": "URL path missing database/schema", + "type": "value_error.url.path", }, { "loc": ("__root__",), - "msg": "Must provide either a connection string or a combination of" - " account, user, and password.", + "msg": "Must provide either a connection string or a combination of account, " + "user, and password.", "type": "value_error", }, ], - id="ConfigStr missing schema", + id="missing schema 2", ), pytest.param( - "snowflake://my_user:password@my_account?schema=my_schema", + "snowflake://my_user:password@my_account//my_schema", [ { - "ctx": {"msg": "missing database"}, "loc": ("connection_string",), - "msg": "URL query param missing", - "type": "value_error.url.query", + "msg": "value is not a valid dict", + "type": "type_error.dict", }, { - "loc": ("__root__",), - "msg": "Must provide either a connection string or a combination of" - " account, user, and password.", + "loc": ("connection_string",), + "msg": "ConfigStr - contains no config template strings in the format " + "'${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", "type": "value_error", }, - ], - id="missing database", - ), - pytest.param( - "snowflake://my_user:${password}@my_account?schema=my_schema", - [ { "ctx": {"msg": "missing database"}, "loc": ("connection_string",), - "msg": "URL query param missing", - "type": "value_error.url.query", + "msg": "URL path missing database/schema", + "type": "value_error.url.path", }, { "loc": ("__root__",), - "msg": "Must provide either a connection string or a combination of" - " account, user, and password.", + "msg": "Must provide either a connection string or a combination of account, " + "user, and password.", "type": "value_error", }, ], - id="ConfigStr missing database", + id="missing database", ), ], ) -def test_missing_required_query_params( +def test_missing_required_params( connection_string: str, expected_errors: list[dict], # TODO: use pydantic error dict ): with pytest.raises(pydantic.ValidationError) as exc_info: - _ = SnowflakeDatasource( + ds = SnowflakeDatasource( name="my_sf_ds", connection_string=connection_string, ) + print(f"{ds!r}") + + print(f"\n\tErrors:\n{pf(exc_info.value.errors())}") assert exc_info.value.errors() == expected_errors @@ -221,7 +303,7 @@ def test_missing_required_query_params( "connection_string, connect_args, expected_errors", [ pytest.param( - "snowflake://my_user:password@my_account?numpy=True&schema=foo&database=bar", + "snowflake://my_user:password@my_account/foo/bar?numpy=True", { "account": "my_account", "user": "my_user", @@ -407,7 +489,7 @@ def test_invalid_connection_string_raises_dsn_error( @pytest.mark.skipif(True if not snowflake else False, reason="snowflake is not installed") @pytest.mark.unit def test_get_execution_engine_succeeds(): - connection_string = "snowflake://my_user:password@my_account?database=foo&schema=bar" + connection_string = "snowflake://my_user:password@my_account/my_db/my_schema" datasource = SnowflakeDatasource(name="my_snowflake", connection_string=connection_string) # testing that this doesn't raise an exception datasource.get_execution_engine() @@ -418,7 +500,7 @@ def test_get_execution_engine_succeeds(): "connection_string", [ param( - "snowflake://my_user:password@my_account?numpy=True&database=foo&schema=bar", + "snowflake://my_user:password@my_account/my_db/my_schema?numpy=True", id="connection_string str", ), param( @@ -454,7 +536,9 @@ def test_get_engine_correctly_sets_application_query_param( expected_query_param: str, connection_string: str | dict, ): - context = request.getfixturevalue(context_fixture_name) + context = request.getfixturevalue( # TODO: fix this and make it a fixture in the root conftest + context_fixture_name + ) my_sf_ds = SnowflakeDatasource(name="my_sf_ds", connection_string=connection_string) my_sf_ds._data_context = context @@ -463,5 +547,47 @@ def test_get_engine_correctly_sets_application_query_param( assert application_query_param == expected_query_param +@pytest.mark.snowflake +@pytest.mark.parametrize("ds_config", VALID_DS_CONFIG_PARAMS) +class TestConvenienceProperties: + def test_schema( + self, + ds_config: dict, + seed_env_vars: None, + param_id: str, + ephemeral_context_with_defaults: AbstractDataContext, + ): + datasource = SnowflakeDatasource(name=param_id, **ds_config) + if isinstance(datasource.connection_string, ConfigStr): + # expect a warning if connection string is a ConfigStr + with pytest.warns(GxContextWarning): + assert ( + not datasource.schema_ + ), "Don't expect schema to be available without config_provider" + # attach context to enable config substitution + datasource._data_context = ephemeral_context_with_defaults + + assert datasource.schema_ + + def test_database( + self, + ds_config: dict, + seed_env_vars: None, + param_id: str, + ephemeral_context_with_defaults: AbstractDataContext, + ): + datasource = SnowflakeDatasource(name=param_id, **ds_config) + if isinstance(datasource.connection_string, ConfigStr): + # expect a warning if connection string is a ConfigStr + with pytest.warns(GxContextWarning): + assert ( + not datasource.database + ), "Don't expect schema to be available without config_provider" + # attach context to enable config substitution + datasource._data_context = ephemeral_context_with_defaults + + assert datasource.database + + if __name__ == "__main__": pytest.main([__file__, "-vv"]) diff --git a/tests/integration/cloud/end_to_end/test_snowflake_datasource.py b/tests/integration/cloud/end_to_end/test_snowflake_datasource.py index caa4b60cdb14..2103d277a9be 100644 --- a/tests/integration/cloud/end_to_end/test_snowflake_datasource.py +++ b/tests/integration/cloud/end_to_end/test_snowflake_datasource.py @@ -30,13 +30,13 @@ def connection_string() -> str: if os.getenv("SNOWFLAKE_CI_USER_PASSWORD") and os.getenv("SNOWFLAKE_CI_ACCOUNT"): return ( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public" - f"?database=ci&schema={RANDOM_SCHEMA}&warehouse=ci&role=ci" + "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci" + f"/{RANDOM_SCHEMA}?warehouse=ci&role=ci" ) elif os.getenv("SNOWFLAKE_USER") and os.getenv("SNOWFLAKE_CI_ACCOUNT"): return ( "snowflake://${SNOWFLAKE_USER}@${SNOWFLAKE_CI_ACCOUNT}/DEMO_DB" - f"?database=ci&schema={RANDOM_SCHEMA}&warehouse=COMPUTE_WH&role=PUBLIC&authenticator=externalbrowser" + f"/{RANDOM_SCHEMA}?warehouse=COMPUTE_WH&role=PUBLIC&authenticator=externalbrowser" ) else: pytest.skip("no snowflake credentials") From d39a82a0f6b40d5216671f94227f0cfa237ed456 Mon Sep 17 00:00:00 2001 From: Gabriel Date: Tue, 4 Jun 2024 14:06:37 -0400 Subject: [PATCH 03/20] [FEATURE] Add new `ConfigUri` type (#10000) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../datasource/fluent/config_str.py | 139 +++++++++++++++++- pyproject.toml | 4 +- tests/datasource/fluent/test_config_str.py | 119 ++++++++++++++- 3 files changed, 256 insertions(+), 6 deletions(-) diff --git a/great_expectations/datasource/fluent/config_str.py b/great_expectations/datasource/fluent/config_str.py index 1714910c0966..3038a9dd6578 100644 --- a/great_expectations/datasource/fluent/config_str.py +++ b/great_expectations/datasource/fluent/config_str.py @@ -2,13 +2,22 @@ import logging import warnings -from typing import TYPE_CHECKING, Mapping - -from great_expectations.compatibility.pydantic import SecretStr +from typing import ( + TYPE_CHECKING, + ClassVar, + Literal, + Mapping, + Optional, + TypedDict, +) + +from great_expectations.compatibility.pydantic import AnyUrl, SecretStr, parse_obj_as from great_expectations.compatibility.typing_extensions import override from great_expectations.core.config_substitutor import TEMPLATE_STR_REGEX if TYPE_CHECKING: + from typing_extensions import Self, TypeAlias + from great_expectations.core.config_provider import _ConfigurationProvider from great_expectations.datasource.fluent import Datasource @@ -79,6 +88,130 @@ def __get_validators__(cls): yield cls.validate +UriParts: TypeAlias = ( + Literal[ # https://docs.pydantic.dev/1.10/usage/types/#url-properties + "scheme", "host", "user", "password", "port", "path", "query", "fragment", "tld" + ] +) + + +class UriPartsDict(TypedDict, total=False): + scheme: str + user: str | None + password: str | None + ipv4: str | None + ipv6: str | None + domain: str | None + port: str | None + path: str | None + query: str | None + fragment: str | None + + +class ConfigUri(AnyUrl, ConfigStr): # type: ignore[misc] # Mixin "validate" signature mismatch + """ + Special type that enables great_expectation config variable substitution for the + `user` and `password` section of a URI. + + Example: + ``` + "snowflake://${MY_USER}:${MY_PASSWORD}@account/database/schema/table" + ``` + + Note: this type is meant to used as part of pydantic model. + To use this outside of a model see the pydantic docs below. + https://docs.pydantic.dev/usage/models/#parsing-data-into-a-specified-type + """ + + ALLOWED_SUBSTITUTIONS: ClassVar[set[UriParts]] = {"user", "password"} + + min_length: int = 1 + max_length: int = 2**16 + + def __init__( # noqa: PLR0913 # for compatibility with AnyUrl + self, + template_str: str, + *, + scheme: str, + user: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + tld: Optional[str] = None, + host_type: str = "domain", + port: Optional[str] = None, + path: Optional[str] = None, + query: Optional[str] = None, + fragment: Optional[str] = None, + ) -> None: + if template_str: # may have already been set in __new__ + self.template_str: str = template_str + self._secret_value = template_str # for compatibility with SecretStr + super().__init__( + template_str, + scheme=scheme, + user=user, + password=password, + host=host, + tld=tld, + host_type=host_type, + port=port, + path=path, + query=query, + fragment=fragment, + ) + + def __new__(cls: type[Self], template_str: Optional[str], **kwargs) -> Self: + """custom __new__ for compatibility with pydantic.parse_obj_as()""" + built_url = cls.build(**kwargs) if template_str is None else template_str + instance = str.__new__(cls, built_url) + instance.template_str = str(instance) + return instance + + @classmethod + @override + def validate_parts( + cls, parts: UriPartsDict, validate_port: bool = True + ) -> UriPartsDict: + """ + Ensure that only the `user` and `password` parts have config template strings. + Also validate that all parts of the URI are valid. + """ + validated_parts = AnyUrl.validate_parts(parts, validate_port) + + name: UriParts + for name, part in validated_parts.items(): + if not part: + continue + if ( + cls.str_contains_config_template(part) + and name not in cls.ALLOWED_SUBSTITUTIONS + ): + raise ValueError( + f"ConfigUri - '{name}' part of URI is not allowed to be substituted" + ) + + return validated_parts + + @override + def get_config_value(self, config_provider: _ConfigurationProvider) -> AnyUrl: + """ + Resolve the config template string to its string value according to the passed + _ConfigurationProvider. + Parse the resolved URI string into an `AnyUrl` object. + """ + LOGGER.info(f"Substituting '{self}'") + raw_value = config_provider.substitute_config(self.template_str) + return parse_obj_as(AnyUrl, raw_value) + + @classmethod + def __get_validators__(cls): + # one or more validators may be yielded which will be called in the + # order to validate the input, each validator will receive as an input + # the value returned from the previous validator + yield ConfigStr._validate_template_str_format + yield cls.validate # equivalent to AnyUrl.validate + + def _check_config_substitutions_needed( datasource: Datasource, options: Mapping, diff --git a/pyproject.toml b/pyproject.toml index 73c466349706..f903e57e6ca1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -460,8 +460,8 @@ filterwarnings = [ # We likely won't be updating to `marhsmallow` 4, these errors should be filtered out "error::marshmallow.warnings.RemovedInMarshmallow4Warning", # pkg_resources is deprecated as an API, but third party libraries still use it - "once: pkg_resources is deprecated as an API.:DeprecationWarning", - 'once: Deprecated call to `pkg_resources.declare_namespace\(.*\)`', + "ignore: pkg_resources is deprecated as an API.:DeprecationWarning", + 'ignore: Deprecated call to `pkg_resources.declare_namespace\(.*\)`', # --------------------------------------- Great Expectations Warnings ---------------------------------- diff --git a/tests/datasource/fluent/test_config_str.py b/tests/datasource/fluent/test_config_str.py index 77f9f8cc74ef..7c021fe70516 100644 --- a/tests/datasource/fluent/test_config_str.py +++ b/tests/datasource/fluent/test_config_str.py @@ -10,7 +10,11 @@ _ConfigurationProvider, _EnvironmentConfigurationProvider, ) -from great_expectations.datasource.fluent.config_str import ConfigStr, SecretStr +from great_expectations.datasource.fluent.config_str import ( + ConfigStr, + ConfigUri, + SecretStr, +) from great_expectations.datasource.fluent.fluent_base_model import FluentBaseModel from great_expectations.exceptions import MissingConfigVariableError @@ -245,3 +249,116 @@ def test_serialization( assert "my_secret" not in dumped_str assert "dont_serialize_me" not in dumped_str assert r"${MY_SECRET}" in dumped_str + + +@pytest.mark.parametrize( + "uri", + [ + "http://my_user:${MY_PW}@example.com:8000/the/path/?query=here#fragment=is;this=bit", + "http://${MY_USER}:${MY_PW}@example.com:8000/the/path/?query=here#fragment=is;this=bit", + "snowflake://my_user:${MY_PW}@account/db", + "snowflake://${MY_USER}:${MY_PW}@account/db", + "postgresql+psycopg2://my_user:${MY_PW}@host/db", + "postgresql+psycopg2://${MY_USER}:${MY_PW}@host/db", + ], +) +class TestConfigUri: + def test_parts( + self, + env_config_provider: _ConfigurationProvider, + monkeypatch: MonkeyPatch, + uri: str, + ): + monkeypatch.setenv("MY_USER", "my_user") + monkeypatch.setenv("MY_PW", "super_secret") + + parsed = pydantic.parse_obj_as(ConfigUri, uri) + + # ensure attributes are set + assert parsed.scheme + assert parsed.host + assert parsed.path + # ensure no attribute errors + _ = parsed.query + _ = parsed.fragment + + # ensure that the password (and user) are not substituted + assert parsed.password == "${MY_PW}" + assert parsed.user in ["${MY_USER}", "my_user"] + + def test_substitution( + self, + env_config_provider: _ConfigurationProvider, + monkeypatch: MonkeyPatch, + uri: str, + ): + monkeypatch.setenv("MY_USER", "my_user") + monkeypatch.setenv("MY_PW", "super_secret") + + parsed = pydantic.parse_obj_as(ConfigUri, uri) + + substituted = parsed.get_config_value(env_config_provider) + + # ensure attributes are set + assert substituted.scheme + assert substituted.host + assert substituted.path + # ensure no attribute errors + _ = substituted.query + _ = substituted.fragment + + # ensure that the password (and user) are not substituted + assert substituted.password == "super_secret" + assert substituted.user == "my_user" + + def test_leakage( + self, + env_config_provider: _ConfigurationProvider, + monkeypatch: MonkeyPatch, + uri: str, + ): + """Ensure the config values are not leaked in the repr or str of the object or the component parts.""" + monkeypatch.setenv("MY_USER", "my_user") + monkeypatch.setenv("MY_PW", "super_secret") + + parsed = pydantic.parse_obj_as(ConfigUri, uri) + assert "super_secret" not in str(parsed) + assert "super_secret" not in repr(parsed) + assert parsed.password + assert "super_secret" not in parsed.password + + if "my_user" not in uri: + assert "my_user" not in str(parsed) + assert "my_user" not in repr(parsed) + assert parsed.user + assert "my_user" not in parsed.user + + +class TestConfigUriInvalid: + @pytest.mark.parametrize("uri", ["invalid_uri", "http:/example.com"]) + def test_invalid_uri(self, uri: str): + with pytest.raises(pydantic.ValidationError): + _ = pydantic.parse_obj_as(ConfigUri, uri) + + @pytest.mark.parametrize( + "uri", + [ + "${MY_SCHEME}://me:secret@account/db/schema", + "snowflake://me:secret@${MY_ACCOUNT}/db/schema", + "snowflake://me:secret@account/${MY_DB}/schema", + "snowflake://me:secret@account/db/${MY_SCHEMA}", + "snowflake://me:secret@account/db/my_schema?${MY_QUERY_PARAMS}", + "snowflake://me:secret@account/db/my_schema?role=${MY_ROLE}", + ], + ) + def test_disallowed_substitution(self, uri: str): + with pytest.raises(pydantic.ValidationError): + _ = pydantic.parse_obj_as(ConfigUri, uri) + + def test_no_template_str(self): + with pytest.raises(pydantic.ValidationError): + _ = pydantic.parse_obj_as(ConfigUri, "snowflake://me:password@account/db") + + +if __name__ == "__main__": + pytest.main([__file__, "-vv"]) From 9353f158fe1a96815ac0189747705d55f15a237a Mon Sep 17 00:00:00 2001 From: Gabriel Gore Date: Wed, 5 Jun 2024 15:28:05 -0400 Subject: [PATCH 04/20] Restrict substitutable sections for Snowflake.connection_string --- .../datasource/fluent/config_str.py | 25 ++--- .../fluent/schemas/SnowflakeDatasource.json | 5 +- .../datasource/fluent/snowflake_datasource.py | 57 +++++------ .../fluent/integration/test_connections.py | 8 +- .../integration/test_sql_datasources.py | 4 +- .../fluent/test_snowflake_datasource.py | 97 +++++++++++++++---- .../end_to_end/test_snowflake_datasource.py | 4 +- 7 files changed, 124 insertions(+), 76 deletions(-) diff --git a/great_expectations/datasource/fluent/config_str.py b/great_expectations/datasource/fluent/config_str.py index 3038a9dd6578..a30e2a8a5482 100644 --- a/great_expectations/datasource/fluent/config_str.py +++ b/great_expectations/datasource/fluent/config_str.py @@ -88,11 +88,9 @@ def __get_validators__(cls): yield cls.validate -UriParts: TypeAlias = ( - Literal[ # https://docs.pydantic.dev/1.10/usage/types/#url-properties - "scheme", "host", "user", "password", "port", "path", "query", "fragment", "tld" - ] -) +UriParts: TypeAlias = Literal[ # https://docs.pydantic.dev/1.10/usage/types/#url-properties + "scheme", "host", "user", "password", "port", "path", "query", "fragment", "tld" +] class UriPartsDict(TypedDict, total=False): @@ -169,28 +167,25 @@ def __new__(cls: type[Self], template_str: Optional[str], **kwargs) -> Self: @classmethod @override - def validate_parts( - cls, parts: UriPartsDict, validate_port: bool = True - ) -> UriPartsDict: + def validate_parts(cls, parts: UriPartsDict, validate_port: bool = True) -> UriPartsDict: """ Ensure that only the `user` and `password` parts have config template strings. Also validate that all parts of the URI are valid. """ - validated_parts = AnyUrl.validate_parts(parts, validate_port) + allowed_substitutions = sorted(cls.ALLOWED_SUBSTITUTIONS) - name: UriParts - for name, part in validated_parts.items(): + for name, part in parts.items(): if not part: continue if ( - cls.str_contains_config_template(part) + cls.str_contains_config_template(part) # type: ignore[arg-type] # is str and name not in cls.ALLOWED_SUBSTITUTIONS ): raise ValueError( - f"ConfigUri - '{name}' part of URI is not allowed to be substituted" + f"Only {', '.join(allowed_substitutions)} may use config substitution; '{name}'" + " substitution not allowed" ) - - return validated_parts + return AnyUrl.validate_parts(parts, validate_port) @override def get_config_value(self, config_provider: _ConfigurationProvider) -> AnyUrl: diff --git a/great_expectations/datasource/fluent/schemas/SnowflakeDatasource.json b/great_expectations/datasource/fluent/schemas/SnowflakeDatasource.json index c65e6e8e57ac..2b5100c51471 100644 --- a/great_expectations/datasource/fluent/schemas/SnowflakeDatasource.json +++ b/great_expectations/datasource/fluent/schemas/SnowflakeDatasource.json @@ -51,8 +51,9 @@ }, { "type": "string", - "writeOnly": true, - "format": "password" + "minLength": 1, + "maxLength": 65536, + "format": "uri" }, { "type": "string", diff --git a/great_expectations/datasource/fluent/snowflake_datasource.py b/great_expectations/datasource/fluent/snowflake_datasource.py index 7c8651806253..ed275a047946 100644 --- a/great_expectations/datasource/fluent/snowflake_datasource.py +++ b/great_expectations/datasource/fluent/snowflake_datasource.py @@ -24,6 +24,7 @@ from great_expectations.datasource.fluent import GxContextWarning, GxDatasourceWarning from great_expectations.datasource.fluent.config_str import ( ConfigStr, + ConfigUri, _check_config_substitutions_needed, ) from great_expectations.datasource.fluent.sql_datasource import ( @@ -53,15 +54,6 @@ MISSING: Final = object() # sentinel value to indicate missing values -def _extract_query_section(url: str) -> str | None: - """ - Extracts the query section of a URL if it exists. - - snowflake://user:password@account?warehouse=warehouse&role=role - """ - return url.partition("?")[2] - - @functools.lru_cache(maxsize=4) def _extract_path_sections(url_path: str) -> dict[str, str]: """ @@ -86,9 +78,9 @@ def _extract_path_sections(url_path: str) -> dict[str, str]: def _get_config_substituted_connection_string( datasource: SnowflakeDatasource, warning_msg: str = "Unable to perform config substitution", -) -> str | None: - if not isinstance(datasource.connection_string, ConfigStr): - raise TypeError("Config substitution is only supported for `ConfigStr`") +) -> AnyUrl | None: + if not isinstance(datasource.connection_string, ConfigUri): + raise TypeError("Config substitution is only supported for `ConfigUri`") if not datasource._data_context: warnings.warn( f"{warning_msg} for {datasource.connection_string.template_str}." @@ -241,7 +233,7 @@ class SnowflakeDatasource(SQLDatasource): type: Literal["snowflake"] = "snowflake" # type: ignore[assignment] # TODO: rename this to `connection` for v1? - connection_string: Union[ConnectionDetails, ConfigStr, SnowflakeDsn] # type: ignore[assignment] # Deviation from parent class as individual args are supported for connection + connection_string: Union[ConnectionDetails, ConfigUri, SnowflakeDsn] # type: ignore[assignment] # Deviation from parent class as individual args are supported for connection # TODO: add props for account, user, password, etc? @@ -353,15 +345,27 @@ def _convert_root_connection_detail_fields(cls, values: dict) -> dict: values["connection_string"] = connection_details return values + @pydantic.validator("connection_string", pre=True) + def _check_config_template(cls, connection_string: Any) -> Any: + """ + If connection_string has a config template, parse it as a ConfigUri, ignore other errors. + """ + if isinstance(connection_string, str): + if ConfigUri.str_contains_config_template(connection_string): + LOGGER.debug("`connection_string` contains config template") + return pydantic.parse_obj_as(ConfigUri, connection_string) + return connection_string + @pydantic.root_validator def _check_xor_input_args(cls, values: dict) -> dict: # keeping this validator isn't strictly necessary, but it provides a better error message - connection_string: str | ConnectionDetails | None = values.get("connection_string") + connection_string: str | ConfigUri | ConnectionDetails | None = values.get( + "connection_string" + ) if connection_string: # Method 1 - connection string - is_connection_string: bool = isinstance( - connection_string, (str, ConfigStr, SnowflakeDsn) - ) + if isinstance(connection_string, (str, ConfigUri)): + return values # Method 2 - individual args (account, user, and password are bare minimum) has_min_connection_detail_values: bool = isinstance( connection_string, ConnectionDetails @@ -376,28 +380,19 @@ def _check_xor_input_args(cls, values: dict) -> dict: @pydantic.validator("connection_string") def _check_for_required_query_params( - cls, connection_string: ConnectionDetails | SnowflakeDsn | ConfigStr - ) -> ConnectionDetails | SnowflakeDsn | ConfigStr: + cls, connection_string: ConnectionDetails | SnowflakeDsn | ConfigUri + ) -> ConnectionDetails | SnowflakeDsn | ConfigUri: """ If connection_string is a SnowflakeDsn, check for required query parameters according to `REQUIRED_QUERY_PARAMS`. """ - if not isinstance(connection_string, (SnowflakeDsn, ConfigStr)): + if not isinstance(connection_string, (SnowflakeDsn, ConfigUri)): return connection_string missing_keys: set[str] = set(REQUIRED_QUERY_PARAMS) - if isinstance(connection_string, ConfigStr): - query_str = _extract_query_section(connection_string.template_str) - # best effort: query could be part of the config substitution. Have to check this when - # adding assets. - if not query_str or ConfigStr.str_contains_config_template(query_str): - LOGGER.info(f"Unable to validate query parameters for {connection_string}") - return connection_string - else: - query_str = connection_string.query - if query_str: - query_params: dict[str, list[str]] = urllib.parse.parse_qs(query_str) + if connection_string.query: + query_params: dict[str, list[str]] = urllib.parse.parse_qs(connection_string.query) for key in REQUIRED_QUERY_PARAMS: if key in query_params: diff --git a/tests/datasource/fluent/integration/test_connections.py b/tests/datasource/fluent/integration/test_connections.py index f3e246bbb7d0..556da4d6374a 100644 --- a/tests/datasource/fluent/integration/test_connections.py +++ b/tests/datasource/fluent/integration/test_connections.py @@ -24,11 +24,11 @@ class TestSnowflake: "connection_string", [ param( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?warehouse=ci", + "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@oca29081.us-east-1/ci/public?warehouse=ci", id="missing role", ), param( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?warehouse=ci&role=ci_no_select", + "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@oca29081.us-east-1/ci/public?warehouse=ci&role=ci_no_select", id="role wo select", ), ], @@ -76,11 +76,11 @@ def test_un_queryable_asset_should_raise_error( "connection_string", [ param( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?warehouse=ci&role=ci&database=ci&schema=public", + "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@oca29081.us-east-1/ci/public?warehouse=ci&role=ci&database=ci&schema=public", id="full connection string", ), param( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?role=ci&database=ci&schema=public", + "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@oca29081.us-east-1/ci/public?role=ci&database=ci&schema=public", id="missing warehouse", ), ], diff --git a/tests/datasource/fluent/integration/test_sql_datasources.py b/tests/datasource/fluent/integration/test_sql_datasources.py index c7edeb0f1c36..e7dd691613bc 100644 --- a/tests/datasource/fluent/integration/test_sql_datasources.py +++ b/tests/datasource/fluent/integration/test_sql_datasources.py @@ -334,10 +334,10 @@ def snowflake_ds( pytest.skip("no snowflake credentials") ds = context.data_sources.add_snowflake( "snowflake", - connection_string="snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci" + connection_string="snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@oca29081.us-east-1/ci" f"/{RAND_SCHEMA}?warehouse=ci&role=ci", # NOTE: uncomment this and set SNOWFLAKE_USER to run tests against your own snowflake account - # connection_string="snowflake://${SNOWFLAKE_USER}@${SNOWFLAKE_CI_ACCOUNT}/DEMO_DB/RESTAURANTS?warehouse=COMPUTE_WH&role=PUBLIC&authenticator=externalbrowser", + # connection_string="snowflake://${SNOWFLAKE_USER}@oca29081.us-east-1/DEMO_DB/RESTAURANTS?warehouse=COMPUTE_WH&role=PUBLIC&authenticator=externalbrowser", ) return ds diff --git a/tests/datasource/fluent/test_snowflake_datasource.py b/tests/datasource/fluent/test_snowflake_datasource.py index 6b9836496f6d..e2d2b24fe1f0 100644 --- a/tests/datasource/fluent/test_snowflake_datasource.py +++ b/tests/datasource/fluent/test_snowflake_datasource.py @@ -30,16 +30,12 @@ id="connection_string str", ), param( - {"connection_string": "${MY_CONN_STR_PARTIAL}@${MY_PATH}?${MY_QUERY_PARAMS}"}, - id="connection_string ConfigStr with required query params", + {"connection_string": "snowflake://my_user:${MY_PASSWORD}@my_account/d_public/s_public"}, + id="connection_string ConfigStr - password sub", ), param( - {"connection_string": "${MY_CONN_STR_FULL}"}, - id="connection_string ConfigStr - required params part of sub", - ), - param( - {"connection_string": "${MY_CONN_STR_MIN}?${MY_QUERY_PARAMS}"}, - id="connection_string ConfigStr - dedicated query params sub", + {"connection_string": "snowflake://${MY_USER}:${MY_PASSWORD}@my_account/d_public/s_public"}, + id="connection_string ConfigStr - user + password sub", ), param( { @@ -70,14 +66,7 @@ @pytest.fixture def seed_env_vars(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setenv("MY_CONN_STR_PARTIAL", "snowflake://my_user:password") - monkeypatch.setenv("MY_CONN_STR_MIN", "snowflake://my_user:password@my_account/my_db/my_schema") - monkeypatch.setenv( - "MY_CONN_STR_FULL", - "snowflake://my_user:password@my_account/my_db/my_schema?warehouse=my_wh&role=my_role", - ) - monkeypatch.setenv("MY_PATH", "my_account/my_db/my_schema") - monkeypatch.setenv("MY_QUERY_PARAMS", "warehouse=my_wh&role=my_role") + monkeypatch.setenv("MY_USER", "my_user") monkeypatch.setenv("MY_PASSWORD", "my_password") @@ -111,10 +100,6 @@ def test_snowflake_dsn(): }, id="old config format - top level keys", ), - param( - {"connection_string": "${MY_CONN_STR_MIN}"}, - id="connection_string ConfigStr missing query params", - ), ], ) def test_valid_config( @@ -138,6 +123,78 @@ def test_valid_config( @pytest.mark.parametrize( ["connection_string", "expected_errors"], [ + pytest.param( + "${MY_CONFIG_VAR}", + [ + { + "loc": ("connection_string", "__root__"), + "msg": "Only password, user may use config substitution;" + " 'domain' substitution not allowed", + "type": "value_error", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of account, " + "user, and password.", + "type": "value_error", + }, + ], + id="illegal config substitution - full connection string", + ), + pytest.param( + "snowflake://my_user:password@${MY_CONFIG_VAR}/db/schema", + [ + { + "loc": ("connection_string", "__root__"), + "msg": "Only password, user may use config substitution;" + " 'domain' substitution not allowed", + "type": "value_error", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of account, " + "user, and password.", + "type": "value_error", + }, + ], + id="illegal config substitution - account (domain)", + ), + pytest.param( + "snowflake://my_user:password@account/${MY_CONFIG_VAR}/schema", + [ + { + "loc": ("connection_string", "__root__"), + "msg": "Only password, user may use config substitution;" + " 'path' substitution not allowed", + "type": "value_error", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of account, " + "user, and password.", + "type": "value_error", + }, + ], + id="illegal config substitution - database (path)", + ), + pytest.param( + "snowflake://my_user:password@account/db/${MY_CONFIG_VAR}", + [ + { + "loc": ("connection_string", "__root__"), + "msg": "Only password, user may use config substitution;" + " 'path' substitution not allowed", + "type": "value_error", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of account, " + "user, and password.", + "type": "value_error", + }, + ], + id="illegal config substitution - schema (path)", + ), pytest.param( "snowflake://my_user:password@my_account", [ diff --git a/tests/integration/cloud/end_to_end/test_snowflake_datasource.py b/tests/integration/cloud/end_to_end/test_snowflake_datasource.py index 2103d277a9be..1f5fa3d34ff4 100644 --- a/tests/integration/cloud/end_to_end/test_snowflake_datasource.py +++ b/tests/integration/cloud/end_to_end/test_snowflake_datasource.py @@ -30,12 +30,12 @@ def connection_string() -> str: if os.getenv("SNOWFLAKE_CI_USER_PASSWORD") and os.getenv("SNOWFLAKE_CI_ACCOUNT"): return ( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci" + "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@oca29081.us-east-1/ci" f"/{RANDOM_SCHEMA}?warehouse=ci&role=ci" ) elif os.getenv("SNOWFLAKE_USER") and os.getenv("SNOWFLAKE_CI_ACCOUNT"): return ( - "snowflake://${SNOWFLAKE_USER}@${SNOWFLAKE_CI_ACCOUNT}/DEMO_DB" + "snowflake://${SNOWFLAKE_USER}@oca29081.us-east-1/DEMO_DB" f"/{RANDOM_SCHEMA}?warehouse=COMPUTE_WH&role=PUBLIC&authenticator=externalbrowser" ) else: From 8abc7218ef23b0f7709749df08028200646c552b Mon Sep 17 00:00:00 2001 From: Gabriel Date: Wed, 5 Jun 2024 15:36:27 -0400 Subject: [PATCH 05/20] revert change --- great_expectations/datasource/fluent/snowflake_datasource.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/great_expectations/datasource/fluent/snowflake_datasource.py b/great_expectations/datasource/fluent/snowflake_datasource.py index ed275a047946..c5ade9faf061 100644 --- a/great_expectations/datasource/fluent/snowflake_datasource.py +++ b/great_expectations/datasource/fluent/snowflake_datasource.py @@ -364,8 +364,8 @@ def _check_xor_input_args(cls, values: dict) -> dict: ) if connection_string: # Method 1 - connection string - if isinstance(connection_string, (str, ConfigUri)): - return values + is_connection_string: bool = isinstance( + connection_string, (str, ConfigStr, SnowflakeDsn) # Method 2 - individual args (account, user, and password are bare minimum) has_min_connection_detail_values: bool = isinstance( connection_string, ConnectionDetails From 3f31ec8208f7182b60fcf9866d589db787dd9baf Mon Sep 17 00:00:00 2001 From: Gabriel Gore Date: Wed, 5 Jun 2024 15:39:35 -0400 Subject: [PATCH 06/20] fix syntax --- great_expectations/datasource/fluent/snowflake_datasource.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/great_expectations/datasource/fluent/snowflake_datasource.py b/great_expectations/datasource/fluent/snowflake_datasource.py index c5ade9faf061..d8c962313b44 100644 --- a/great_expectations/datasource/fluent/snowflake_datasource.py +++ b/great_expectations/datasource/fluent/snowflake_datasource.py @@ -80,7 +80,7 @@ def _get_config_substituted_connection_string( warning_msg: str = "Unable to perform config substitution", ) -> AnyUrl | None: if not isinstance(datasource.connection_string, ConfigUri): - raise TypeError("Config substitution is only supported for `ConfigUri`") + raise TypeError("Config substitution is only supported for `ConfigUri`") # noqa: TRY003 if not datasource._data_context: warnings.warn( f"{warning_msg} for {datasource.connection_string.template_str}." @@ -366,6 +366,7 @@ def _check_xor_input_args(cls, values: dict) -> dict: # Method 1 - connection string is_connection_string: bool = isinstance( connection_string, (str, ConfigStr, SnowflakeDsn) + ) # Method 2 - individual args (account, user, and password are bare minimum) has_min_connection_detail_values: bool = isinstance( connection_string, ConnectionDetails From 34732f5af1a794dd65f5603a78f3d6a48e5f54d7 Mon Sep 17 00:00:00 2001 From: Gabriel Gore Date: Wed, 5 Jun 2024 16:09:57 -0400 Subject: [PATCH 07/20] fix test merge --- .../fluent/test_snowflake_datasource.py | 62 ++++++++++++++----- 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/tests/datasource/fluent/test_snowflake_datasource.py b/tests/datasource/fluent/test_snowflake_datasource.py index e2d2b24fe1f0..abcbe160ecf9 100644 --- a/tests/datasource/fluent/test_snowflake_datasource.py +++ b/tests/datasource/fluent/test_snowflake_datasource.py @@ -371,14 +371,12 @@ def test_missing_required_params( [ { "loc": ("__root__",), - "msg": "Cannot provide both a connection string and a combination of" - " account, user, and password.", + "msg": "Cannot provide both a connection string and a combination of account, user, and password.", "type": "value_error", } ], id="both connection_string and connect_args", ), - pytest.param(None, {}, id="neither connection_string nor connect_args"), pytest.param( None, {}, @@ -390,8 +388,7 @@ def test_missing_required_params( }, { "loc": ("__root__",), - "msg": "Must provide either a connection string or a combination of" - " account, user, and password.", + "msg": "Must provide either a connection string or a combination of account, user, and password.", "type": "value_error", }, ], @@ -423,8 +420,7 @@ def test_missing_required_params( }, { "loc": ("__root__",), - "msg": "Must provide either a connection string or a combination of" - " account, user, and password.", + "msg": "Must provide either a connection string or a combination of account, user, and password.", "type": "value_error", }, ], @@ -438,15 +434,43 @@ def test_missing_required_params( "database": "bar", }, {}, + [ + { + "loc": ("connection_string", "password"), + "msg": "field required", + "type": "value_error.missing", + }, + { + "loc": ("connection_string",), + "msg": f"""expected string or bytes-like object{"" if python_version < (3, 11) else ", got 'dict'"}""", + "type": "type_error", + }, + { + "loc": ("connection_string",), + "msg": "str type expected", + "type": "type_error.str", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of account, " + "user, and password.", + "type": "value_error", + }, + ], id="incomplete connection_string dict connect_args", ), ], ) def test_conflicting_connection_string_and_args_raises_error( - connection_string: ConfigStr | SnowflakeDsn | None | dict, connect_args: dict + connection_string: ConfigStr | SnowflakeDsn | None | dict, + connect_args: dict, + expected_errors: list[dict], ): - with pytest.raises(ValueError): - _ = SnowflakeDatasource(connection_string=connection_string, **connect_args) + with pytest.raises(pydantic.ValidationError) as exc_info: + _ = SnowflakeDatasource( + name="my_sf_ds", connection_string=connection_string, **connect_args + ) + assert exc_info.value.errors() == expected_errors @pytest.mark.unit @@ -463,7 +487,8 @@ def test_conflicting_connection_string_and_args_raises_error( }, { "loc": ("connection_string",), - "msg": "ConfigStr - contains no config template strings in the format '${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", # noqa: E501 + "msg": "ConfigStr - contains no config template strings in the format" + " '${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", "type": "value_error", }, { @@ -473,7 +498,8 @@ def test_conflicting_connection_string_and_args_raises_error( }, { "loc": ("__root__",), - "msg": "Must provide either a connection string or a combination of account, user, and password.", # noqa: E501 + "msg": "Must provide either a connection string or a combination of" + " account, user, and password.", "type": "value_error", }, ], @@ -489,7 +515,8 @@ def test_conflicting_connection_string_and_args_raises_error( }, { "loc": ("connection_string",), - "msg": "ConfigStr - contains no config template strings in the format '${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", # noqa: E501 + "msg": "ConfigStr - contains no config template strings in the format" + " '${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", "type": "value_error", }, { @@ -499,7 +526,8 @@ def test_conflicting_connection_string_and_args_raises_error( }, { "loc": ("__root__",), - "msg": "Must provide either a connection string or a combination of account, user, and password.", # noqa: E501 + "msg": "Must provide either a connection string or a combination of" + " account, user, and password.", "type": "value_error", }, ], @@ -515,7 +543,8 @@ def test_conflicting_connection_string_and_args_raises_error( }, { "loc": ("connection_string",), - "msg": "ConfigStr - contains no config template strings in the format '${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", # noqa: E501 + "msg": "ConfigStr - contains no config template strings in the format" + " '${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", "type": "value_error", }, { @@ -525,7 +554,8 @@ def test_conflicting_connection_string_and_args_raises_error( }, { "loc": ("__root__",), - "msg": "Must provide either a connection string or a combination of account, user, and password.", # noqa: E501 + "msg": "Must provide either a connection string or a combination of" + " account, user, and password.", "type": "value_error", }, ], From 20b583b5869ae2edb7689cc425cdfc50167e70a0 Mon Sep 17 00:00:00 2001 From: Gabriel Date: Wed, 5 Jun 2024 16:16:45 -0400 Subject: [PATCH 08/20] revert unneeded changes --- tests/datasource/fluent/integration/test_connections.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/datasource/fluent/integration/test_connections.py b/tests/datasource/fluent/integration/test_connections.py index 556da4d6374a..803057c36fea 100644 --- a/tests/datasource/fluent/integration/test_connections.py +++ b/tests/datasource/fluent/integration/test_connections.py @@ -76,11 +76,11 @@ def test_un_queryable_asset_should_raise_error( "connection_string", [ param( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@oca29081.us-east-1/ci/public?warehouse=ci&role=ci&database=ci&schema=public", + "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@oca29081.us-east-1/ci/public?warehouse=ci&role=ci", id="full connection string", ), param( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@oca29081.us-east-1/ci/public?role=ci&database=ci&schema=public", + "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@oca29081.us-east-1/ci/public?role=ci", id="missing warehouse", ), ], From 7b4e41e91aaa6e734b1a86137937f1a88e166eb7 Mon Sep 17 00:00:00 2001 From: Gabriel Gore Date: Wed, 5 Jun 2024 17:00:26 -0400 Subject: [PATCH 09/20] error on connection_string + kwargs --- .../datasource/fluent/snowflake_datasource.py | 9 ++++++++- tests/datasource/fluent/test_snowflake_datasource.py | 5 +++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/great_expectations/datasource/fluent/snowflake_datasource.py b/great_expectations/datasource/fluent/snowflake_datasource.py index d8c962313b44..9163b4480092 100644 --- a/great_expectations/datasource/fluent/snowflake_datasource.py +++ b/great_expectations/datasource/fluent/snowflake_datasource.py @@ -337,9 +337,16 @@ def _convert_root_connection_detail_fields(cls, values: dict) -> dict: *ConnectionDetails.__fields__.keys(), } + connection_string: Any | None = values.get("connection_string") + provided_fields = tuple(values.keys()) + connection_details = {} - for field_name in tuple(values.keys()): + for field_name in provided_fields: if field_name in connection_detail_fields: + if connection_string: + raise ValueError( # noqa: TRY003 + "Provided both connection detail keyword args and `connection_string`." + ) connection_details[field_name] = values.pop(field_name) if connection_details: values["connection_string"] = connection_details diff --git a/tests/datasource/fluent/test_snowflake_datasource.py b/tests/datasource/fluent/test_snowflake_datasource.py index abcbe160ecf9..353fd94904a8 100644 --- a/tests/datasource/fluent/test_snowflake_datasource.py +++ b/tests/datasource/fluent/test_snowflake_datasource.py @@ -371,7 +371,7 @@ def test_missing_required_params( [ { "loc": ("__root__",), - "msg": "Cannot provide both a connection string and a combination of account, user, and password.", + "msg": "Provided both connection detail keyword args and `connection_string`.", "type": "value_error", } ], @@ -388,7 +388,8 @@ def test_missing_required_params( }, { "loc": ("__root__",), - "msg": "Must provide either a connection string or a combination of account, user, and password.", + "msg": "Must provide either a connection string or a combination of" + " account, user, and password.", "type": "value_error", }, ], From 7ff0989e84cbe58993edc69af8a771535e2fc893 Mon Sep 17 00:00:00 2001 From: Gabriel Gore Date: Wed, 5 Jun 2024 16:56:39 -0400 Subject: [PATCH 10/20] linting ignores --- great_expectations/datasource/fluent/config_str.py | 2 +- tests/datasource/fluent/test_config_str.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/great_expectations/datasource/fluent/config_str.py b/great_expectations/datasource/fluent/config_str.py index a30e2a8a5482..2f9784128bc9 100644 --- a/great_expectations/datasource/fluent/config_str.py +++ b/great_expectations/datasource/fluent/config_str.py @@ -181,7 +181,7 @@ def validate_parts(cls, parts: UriPartsDict, validate_port: bool = True) -> UriP cls.str_contains_config_template(part) # type: ignore[arg-type] # is str and name not in cls.ALLOWED_SUBSTITUTIONS ): - raise ValueError( + raise ValueError( # noqa: TRY003 f"Only {', '.join(allowed_substitutions)} may use config substitution; '{name}'" " substitution not allowed" ) diff --git a/tests/datasource/fluent/test_config_str.py b/tests/datasource/fluent/test_config_str.py index 7c021fe70516..29bd97a1a918 100644 --- a/tests/datasource/fluent/test_config_str.py +++ b/tests/datasource/fluent/test_config_str.py @@ -317,7 +317,10 @@ def test_leakage( monkeypatch: MonkeyPatch, uri: str, ): - """Ensure the config values are not leaked in the repr or str of the object or the component parts.""" + """ + Ensure the config values are not leaked in the repr or str of the object + or the component parts. + """ monkeypatch.setenv("MY_USER", "my_user") monkeypatch.setenv("MY_PW", "super_secret") From 54ea2a61d64f011bdbe299c6950e8bf819c9bd6f Mon Sep 17 00:00:00 2001 From: Gabriel Gore Date: Wed, 5 Jun 2024 19:24:26 -0400 Subject: [PATCH 11/20] revert applicatioin query param tests --- .../fluent/test_snowflake_datasource.py | 52 ------------------- 1 file changed, 52 deletions(-) diff --git a/tests/datasource/fluent/test_snowflake_datasource.py b/tests/datasource/fluent/test_snowflake_datasource.py index 353fd94904a8..37ad4105956b 100644 --- a/tests/datasource/fluent/test_snowflake_datasource.py +++ b/tests/datasource/fluent/test_snowflake_datasource.py @@ -583,58 +583,6 @@ def test_get_execution_engine_succeeds(): datasource.get_execution_engine() -@pytest.mark.snowflake -@pytest.mark.parametrize( - "connection_string", - [ - param( - "snowflake://my_user:password@my_account/my_db/my_schema?numpy=True", - id="connection_string str", - ), - param( - { - "user": "my_user", - "password": "password", - "account": "my_account", - "database": "foo", - "schema": "bar", - }, - id="connection_string dict", - ), - ], -) -@pytest.mark.parametrize( - "context_fixture_name,expected_query_param", - [ - param( - "empty_file_context", - "great_expectations_core", - id="file context", - ), - param( - "empty_cloud_context_fluent", - "great_expectations_platform", - id="cloud context", - ), - ], -) -def test_get_engine_correctly_sets_application_query_param( - request, - context_fixture_name: str, - expected_query_param: str, - connection_string: str | dict, -): - context = request.getfixturevalue( # TODO: fix this and make it a fixture in the root conftest - context_fixture_name - ) - my_sf_ds = SnowflakeDatasource(name="my_sf_ds", connection_string=connection_string) - my_sf_ds._data_context = context - - sql_engine = my_sf_ds.get_engine() - application_query_param = sql_engine.url.query.get("application") - assert application_query_param == expected_query_param - - @pytest.mark.snowflake @pytest.mark.parametrize("ds_config", VALID_DS_CONFIG_PARAMS) class TestConvenienceProperties: From 6a52bf9a971abaae57654c8d9e6f03f4f6c601fc Mon Sep 17 00:00:00 2001 From: Gabriel Gore Date: Wed, 5 Jun 2024 19:33:01 -0400 Subject: [PATCH 12/20] fix bad merge --- great_expectations/datasource/fluent/snowflake_datasource.py | 2 +- tests/datasource/fluent/integration/test_sql_datasources.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/great_expectations/datasource/fluent/snowflake_datasource.py b/great_expectations/datasource/fluent/snowflake_datasource.py index 9163b4480092..89327eda697a 100644 --- a/great_expectations/datasource/fluent/snowflake_datasource.py +++ b/great_expectations/datasource/fluent/snowflake_datasource.py @@ -270,7 +270,7 @@ def database(self) -> str | None: return _extract_path_sections(url_path)["database"] @deprecated_method_or_class( - version="0.18.16", + version="1.0.0a4", message="`schema_name` is deprecated." " The schema now comes from the datasource.", ) @public_api diff --git a/tests/datasource/fluent/integration/test_sql_datasources.py b/tests/datasource/fluent/integration/test_sql_datasources.py index e7dd691613bc..8e654426e4e9 100644 --- a/tests/datasource/fluent/integration/test_sql_datasources.py +++ b/tests/datasource/fluent/integration/test_sql_datasources.py @@ -529,6 +529,7 @@ def test_checkpoint_run( asset = datasource.add_table_asset( asset_name, table_name=table_name, schema_name=schema ) + batch_definition = asset.add_batch_definition_whole_table("whole table!") suite = context.suites.add(ExpectationSuite(name=f"{datasource.name}-{asset.name}")) suite.add_expectation(gxe.ExpectColumnValuesToNotBeNull(column="name", mostly=1)) From 886602ebc24e9dee1d252a224da9d3e518223ff0 Mon Sep 17 00:00:00 2001 From: Gabriel Gore Date: Wed, 5 Jun 2024 19:35:43 -0400 Subject: [PATCH 13/20] linting fixes --- .../fluent/integration/test_sql_datasources.py | 2 +- tests/datasource/fluent/test_snowflake_datasource.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/datasource/fluent/integration/test_sql_datasources.py b/tests/datasource/fluent/integration/test_sql_datasources.py index 8e654426e4e9..c3ae0299de04 100644 --- a/tests/datasource/fluent/integration/test_sql_datasources.py +++ b/tests/datasource/fluent/integration/test_sql_datasources.py @@ -336,7 +336,7 @@ def snowflake_ds( "snowflake", connection_string="snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@oca29081.us-east-1/ci" f"/{RAND_SCHEMA}?warehouse=ci&role=ci", - # NOTE: uncomment this and set SNOWFLAKE_USER to run tests against your own snowflake account + # NOTE: uncomment this and set SNOWFLAKE_USER to run tests against your own snowflake account # noqa: E501 # connection_string="snowflake://${SNOWFLAKE_USER}@oca29081.us-east-1/DEMO_DB/RESTAURANTS?warehouse=COMPUTE_WH&role=PUBLIC&authenticator=externalbrowser", ) return ds diff --git a/tests/datasource/fluent/test_snowflake_datasource.py b/tests/datasource/fluent/test_snowflake_datasource.py index 37ad4105956b..c2c9a993eefa 100644 --- a/tests/datasource/fluent/test_snowflake_datasource.py +++ b/tests/datasource/fluent/test_snowflake_datasource.py @@ -411,7 +411,8 @@ def test_missing_required_params( }, { "loc": ("connection_string",), - "msg": f"""expected string or bytes-like object{"" if python_version < (3, 11) else ", got 'dict'"}""", + "msg": "expected string or bytes-like object" + f"""{"" if python_version < (3, 11) else ", got 'dict'"}""", "type": "type_error", }, { @@ -421,7 +422,8 @@ def test_missing_required_params( }, { "loc": ("__root__",), - "msg": "Must provide either a connection string or a combination of account, user, and password.", + "msg": "Must provide either a connection string or a combination of" + " account, user, and password.", "type": "value_error", }, ], @@ -443,7 +445,8 @@ def test_missing_required_params( }, { "loc": ("connection_string",), - "msg": f"""expected string or bytes-like object{"" if python_version < (3, 11) else ", got 'dict'"}""", + "msg": "expected string or bytes-like object" + f"""{"" if python_version < (3, 11) else ", got 'dict'"}""", "type": "type_error", }, { From 719bfed7f390af33490b2bf9f0f5fee21b422061 Mon Sep 17 00:00:00 2001 From: Gabriel Gore Date: Wed, 5 Jun 2024 19:46:42 -0400 Subject: [PATCH 14/20] fix cloud table_asset fixture --- .../cloud/end_to_end/test_snowflake_datasource.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/integration/cloud/end_to_end/test_snowflake_datasource.py b/tests/integration/cloud/end_to_end/test_snowflake_datasource.py index 1f5fa3d34ff4..978f1a7230b0 100644 --- a/tests/integration/cloud/end_to_end/test_snowflake_datasource.py +++ b/tests/integration/cloud/end_to_end/test_snowflake_datasource.py @@ -96,17 +96,17 @@ def table_asset( datasource: SnowflakeDatasource, asset_name: str, table_factory: TableFactory, - get_missing_data_asset_error_type: type[Exception], -) -> Iterator[TableAsset]: +) -> TableAsset: table_name = f"i{uuid.uuid4().hex}" table_factory( gx_engine=datasource.get_execution_engine(), table_names={table_name}, schema_name=RANDOM_SCHEMA, ) - asset_name = f"i{uuid.uuid4().hex}" - _ = datasource.add_table_asset( - name=asset_name, table_name=table_name, schema_name=RANDOM_SCHEMA + return datasource.add_table_asset( + name=asset_name, + schema_name=RANDOM_SCHEMA, + table_name=table_name, ) From 20e879ea766e0b6bd59b28394711bad0a5dedf02 Mon Sep 17 00:00:00 2001 From: Gabriel Gore Date: Wed, 5 Jun 2024 19:51:35 -0400 Subject: [PATCH 15/20] type ignores --- tests/datasource/fluent/test_snowflake_datasource.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasource/fluent/test_snowflake_datasource.py b/tests/datasource/fluent/test_snowflake_datasource.py index c2c9a993eefa..17e46d4f2d5a 100644 --- a/tests/datasource/fluent/test_snowflake_datasource.py +++ b/tests/datasource/fluent/test_snowflake_datasource.py @@ -20,7 +20,7 @@ from great_expectations.execution_engine import SqlAlchemyExecutionEngine if TYPE_CHECKING: - from pytest.mark.structures import ParameterSet + from pytest.mark.structures import ParameterSet # type: ignore[import-not-found] VALID_DS_CONFIG_PARAMS: Final[Sequence[ParameterSet]] = [ param( From 10e07d9dd91f1b536f08cb29f09d60a46a8c4209 Mon Sep 17 00:00:00 2001 From: Gabriel Gore Date: Thu, 6 Jun 2024 12:09:56 -0400 Subject: [PATCH 16/20] don't provide `schema` name --- tests/integration/db/taxi_data_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/db/taxi_data_utils.py b/tests/integration/db/taxi_data_utils.py index a4e77e108d41..7525e13e249d 100644 --- a/tests/integration/db/taxi_data_utils.py +++ b/tests/integration/db/taxi_data_utils.py @@ -114,7 +114,7 @@ def _execute_taxi_partitioning_test_cases( datasource = add_datasource( context, name=datasource_name, connection_string=connection_string ) - asset = datasource.add_table_asset(data_asset_name, table_name=table_name) + asset = datasource.add_table_asset(data_asset_name, table_name=table_name, schema_name=None) add_batch_definition_method = getattr( asset, test_case.add_batch_definition_method_name or "MAKE THIS REQUIRED" ) From 999d0d9a9ffbe93646b0a13d3d24e32f87f3b6db Mon Sep 17 00:00:00 2001 From: Gabriel Gore Date: Thu, 6 Jun 2024 12:29:36 -0400 Subject: [PATCH 17/20] catch deprecation warning --- tests/integration/db/taxi_data_utils.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/integration/db/taxi_data_utils.py b/tests/integration/db/taxi_data_utils.py index 7525e13e249d..52635844ff08 100644 --- a/tests/integration/db/taxi_data_utils.py +++ b/tests/integration/db/taxi_data_utils.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import warnings from contextlib import contextmanager from typing import TYPE_CHECKING, List @@ -114,7 +117,17 @@ def _execute_taxi_partitioning_test_cases( datasource = add_datasource( context, name=datasource_name, connection_string=connection_string ) - asset = datasource.add_table_asset(data_asset_name, table_name=table_name, schema_name=None) + + with warnings.catch_warnings(): + warnings.filterwarnings( + "once", + message="The `schema_name argument` is deprecated", + category=DeprecationWarning, + ) + asset = datasource.add_table_asset( + data_asset_name, table_name=table_name, schema_name=None + ) + add_batch_definition_method = getattr( asset, test_case.add_batch_definition_method_name or "MAKE THIS REQUIRED" ) From 5a430cd466a042a5f1f7593e89d05e6c40952ebd Mon Sep 17 00:00:00 2001 From: Gabriel Gore Date: Thu, 6 Jun 2024 13:40:24 -0400 Subject: [PATCH 18/20] catch GxDatasourceWarning --- tests/integration/db/taxi_data_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/integration/db/taxi_data_utils.py b/tests/integration/db/taxi_data_utils.py index 52635844ff08..6b59ae841ae6 100644 --- a/tests/integration/db/taxi_data_utils.py +++ b/tests/integration/db/taxi_data_utils.py @@ -8,6 +8,7 @@ import great_expectations as gx from great_expectations.core.batch_definition import BatchDefinition +from great_expectations.datasource.fluent import GxDatasourceWarning from great_expectations.execution_engine.sqlalchemy_batch_data import ( SqlAlchemyBatchData, ) @@ -124,6 +125,12 @@ def _execute_taxi_partitioning_test_cases( message="The `schema_name argument` is deprecated", category=DeprecationWarning, ) + warnings.filterwarnings( + "once", + message="schema_name None does not match datasource schema ***", + category=GxDatasourceWarning, + ) + asset = datasource.add_table_asset( data_asset_name, table_name=table_name, schema_name=None ) From e26410c3ead6627937a43c8435551101335d837f Mon Sep 17 00:00:00 2001 From: Gabriel Gore Date: Thu, 6 Jun 2024 19:20:00 -0400 Subject: [PATCH 19/20] fix filter regex --- tests/integration/db/taxi_data_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/db/taxi_data_utils.py b/tests/integration/db/taxi_data_utils.py index 6b59ae841ae6..0ff05769e47e 100644 --- a/tests/integration/db/taxi_data_utils.py +++ b/tests/integration/db/taxi_data_utils.py @@ -127,7 +127,7 @@ def _execute_taxi_partitioning_test_cases( ) warnings.filterwarnings( "once", - message="schema_name None does not match datasource schema ***", + message="schema_name None does not match datasource schema", category=GxDatasourceWarning, ) From ad6cfeee5147187906fdbb9d8a483a7565e5344d Mon Sep 17 00:00:00 2001 From: Gabriel Date: Fri, 7 Jun 2024 18:50:38 -0400 Subject: [PATCH 20/20] Update great_expectations/datasource/fluent/interfaces.py Co-authored-by: Rob Lim --- great_expectations/datasource/fluent/interfaces.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/great_expectations/datasource/fluent/interfaces.py b/great_expectations/datasource/fluent/interfaces.py index ed8c57029b4e..45288340446a 100644 --- a/great_expectations/datasource/fluent/interfaces.py +++ b/great_expectations/datasource/fluent/interfaces.py @@ -183,7 +183,7 @@ class GxDatasourceWarning(UserWarning): class GxContextWarning(GxDatasourceWarning): """ - Warning related to a Datasource that with a missing context. + Warning related to a Datasource with a missing context. Usually because the Datasource was created directly rather than using a `context.sources` factory method. """