diff --git a/great_expectations/datasource/fluent/__init__.py b/great_expectations/datasource/fluent/__init__.py index edbeef86d536..4052427ef418 100644 --- a/great_expectations/datasource/fluent/__init__.py +++ b/great_expectations/datasource/fluent/__init__.py @@ -8,6 +8,7 @@ Sorter, BatchMetadata, GxDatasourceWarning, + GxContextWarning, TestConnectionError, ) from great_expectations.datasource.fluent.invalid_datasource import ( diff --git a/great_expectations/datasource/fluent/config_str.py b/great_expectations/datasource/fluent/config_str.py index 7eea69c447c4..2f9784128bc9 100644 --- a/great_expectations/datasource/fluent/config_str.py +++ b/great_expectations/datasource/fluent/config_str.py @@ -2,13 +2,22 @@ import logging import warnings -from typing import TYPE_CHECKING, Mapping - -from great_expectations.compatibility.pydantic import SecretStr +from typing import ( + TYPE_CHECKING, + ClassVar, + Literal, + Mapping, + Optional, + TypedDict, +) + +from great_expectations.compatibility.pydantic import AnyUrl, SecretStr, parse_obj_as from great_expectations.compatibility.typing_extensions import override from great_expectations.core.config_substitutor import TEMPLATE_STR_REGEX if TYPE_CHECKING: + from typing_extensions import Self, TypeAlias + from great_expectations.core.config_provider import _ConfigurationProvider from great_expectations.datasource.fluent import Datasource @@ -54,8 +63,15 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({self._display()!r})" @classmethod - def _validate_template_str_format(cls, v): - if TEMPLATE_STR_REGEX.search(v): + def str_contains_config_template(cls, v: str) -> bool: + """ + Returns True if the input string contains a config template string. + """ + return TEMPLATE_STR_REGEX.search(v) is not None + + @classmethod + def _validate_template_str_format(cls, v: str) -> str | None: + if cls.str_contains_config_template(v): return v raise ValueError( cls.__name__ @@ -72,6 +88,125 @@ def __get_validators__(cls): yield cls.validate +UriParts: TypeAlias = Literal[ # https://docs.pydantic.dev/1.10/usage/types/#url-properties + "scheme", "host", "user", "password", "port", "path", "query", "fragment", "tld" +] + + +class UriPartsDict(TypedDict, total=False): + scheme: str + user: str | None + password: str | None + ipv4: str | None + ipv6: str | None + domain: str | None + port: str | None + path: str | None + query: str | None + fragment: str | None + + +class ConfigUri(AnyUrl, ConfigStr): # type: ignore[misc] # Mixin "validate" signature mismatch + """ + Special type that enables great_expectation config variable substitution for the + `user` and `password` section of a URI. + + Example: + ``` + "snowflake://${MY_USER}:${MY_PASSWORD}@account/database/schema/table" + ``` + + Note: this type is meant to used as part of pydantic model. + To use this outside of a model see the pydantic docs below. + https://docs.pydantic.dev/usage/models/#parsing-data-into-a-specified-type + """ + + ALLOWED_SUBSTITUTIONS: ClassVar[set[UriParts]] = {"user", "password"} + + min_length: int = 1 + max_length: int = 2**16 + + def __init__( # noqa: PLR0913 # for compatibility with AnyUrl + self, + template_str: str, + *, + scheme: str, + user: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + tld: Optional[str] = None, + host_type: str = "domain", + port: Optional[str] = None, + path: Optional[str] = None, + query: Optional[str] = None, + fragment: Optional[str] = None, + ) -> None: + if template_str: # may have already been set in __new__ + self.template_str: str = template_str + self._secret_value = template_str # for compatibility with SecretStr + super().__init__( + template_str, + scheme=scheme, + user=user, + password=password, + host=host, + tld=tld, + host_type=host_type, + port=port, + path=path, + query=query, + fragment=fragment, + ) + + def __new__(cls: type[Self], template_str: Optional[str], **kwargs) -> Self: + """custom __new__ for compatibility with pydantic.parse_obj_as()""" + built_url = cls.build(**kwargs) if template_str is None else template_str + instance = str.__new__(cls, built_url) + instance.template_str = str(instance) + return instance + + @classmethod + @override + def validate_parts(cls, parts: UriPartsDict, validate_port: bool = True) -> UriPartsDict: + """ + Ensure that only the `user` and `password` parts have config template strings. + Also validate that all parts of the URI are valid. + """ + allowed_substitutions = sorted(cls.ALLOWED_SUBSTITUTIONS) + + for name, part in parts.items(): + if not part: + continue + if ( + cls.str_contains_config_template(part) # type: ignore[arg-type] # is str + and name not in cls.ALLOWED_SUBSTITUTIONS + ): + raise ValueError( # noqa: TRY003 + f"Only {', '.join(allowed_substitutions)} may use config substitution; '{name}'" + " substitution not allowed" + ) + return AnyUrl.validate_parts(parts, validate_port) + + @override + def get_config_value(self, config_provider: _ConfigurationProvider) -> AnyUrl: + """ + Resolve the config template string to its string value according to the passed + _ConfigurationProvider. + Parse the resolved URI string into an `AnyUrl` object. + """ + LOGGER.info(f"Substituting '{self}'") + raw_value = config_provider.substitute_config(self.template_str) + return parse_obj_as(AnyUrl, raw_value) + + @classmethod + def __get_validators__(cls): + # one or more validators may be yielded which will be called in the + # order to validate the input, each validator will receive as an input + # the value returned from the previous validator + yield ConfigStr._validate_template_str_format + yield cls.validate # equivalent to AnyUrl.validate + + def _check_config_substitutions_needed( datasource: Datasource, options: Mapping, diff --git a/great_expectations/datasource/fluent/interfaces.py b/great_expectations/datasource/fluent/interfaces.py index 4e3e5a949361..45288340446a 100644 --- a/great_expectations/datasource/fluent/interfaces.py +++ b/great_expectations/datasource/fluent/interfaces.py @@ -181,6 +181,14 @@ class GxDatasourceWarning(UserWarning): """ +class GxContextWarning(GxDatasourceWarning): + """ + Warning related to a Datasource with a missing context. + Usually because the Datasource was created directly rather than using a + `context.sources` factory method. + """ + + class GxSerializationWarning(GxDatasourceWarning): pass diff --git a/great_expectations/datasource/fluent/schemas/SnowflakeDatasource.json b/great_expectations/datasource/fluent/schemas/SnowflakeDatasource.json index afa246a16277..2b5100c51471 100644 --- a/great_expectations/datasource/fluent/schemas/SnowflakeDatasource.json +++ b/great_expectations/datasource/fluent/schemas/SnowflakeDatasource.json @@ -51,8 +51,9 @@ }, { "type": "string", - "writeOnly": true, - "format": "password" + "minLength": 1, + "maxLength": 65536, + "format": "uri" }, { "type": "string", @@ -538,7 +539,7 @@ }, "ConnectionDetails": { "title": "ConnectionDetails", - "description": "Information needed to connect to a Snowflake database.\nAlternative to a connection string.", + "description": "Information needed to connect to a Snowflake database.\nAlternative to a connection string.\n\nhttps://docs.snowflake.com/en/developer-guide/python-connector/sqlalchemy#additional-connection-parameters", "type": "object", "properties": { "account": { @@ -564,10 +565,12 @@ }, "database": { "title": "Database", + "description": "`database` that the Datasource is mapped to.", "type": "string" }, "schema": { "title": "Schema", + "description": "`schema` that the Datasource is mapped to.", "type": "string" }, "warehouse": { @@ -587,7 +590,9 @@ "required": [ "account", "user", - "password" + "password", + "database", + "schema" ], "additionalProperties": false } diff --git a/great_expectations/datasource/fluent/snowflake_datasource.py b/great_expectations/datasource/fluent/snowflake_datasource.py index 73bd26e261fc..89327eda697a 100644 --- a/great_expectations/datasource/fluent/snowflake_datasource.py +++ b/great_expectations/datasource/fluent/snowflake_datasource.py @@ -1,31 +1,95 @@ from __future__ import annotations +import functools import logging -from typing import TYPE_CHECKING, Final, Literal, Optional, Type, Union +import urllib.parse +import warnings +from typing import ( + TYPE_CHECKING, + Any, + Final, + Iterable, + Literal, + Optional, + Type, + Union, +) -from great_expectations._docs_decorators import public_api +from great_expectations._docs_decorators import deprecated_method_or_class, public_api from great_expectations.compatibility import pydantic from great_expectations.compatibility.pydantic import AnyUrl, errors from great_expectations.compatibility.snowflake import URL from great_expectations.compatibility.sqlalchemy import sqlalchemy as sa from great_expectations.compatibility.typing_extensions import override +from great_expectations.datasource.fluent import GxContextWarning, GxDatasourceWarning from great_expectations.datasource.fluent.config_str import ( ConfigStr, + ConfigUri, _check_config_substitutions_needed, ) from great_expectations.datasource.fluent.sql_datasource import ( FluentBaseModel, SQLDatasource, SQLDatasourceError, + TableAsset, ) if TYPE_CHECKING: from great_expectations.compatibility import sqlalchemy from great_expectations.compatibility.pydantic.networks import Parts + from great_expectations.datasource.fluent.interfaces import ( + BatchMetadata, + SortersDefinition, + ) from great_expectations.execution_engine import SqlAlchemyExecutionEngine LOGGER: Final[logging.Logger] = logging.getLogger(__name__) +REQUIRED_QUERY_PARAMS: Final[Iterable[str]] = { # errors will be thrown if any of these are missing + # TODO: require warehouse and role + # "warehouse", + # "role", +} + +MISSING: Final = object() # sentinel value to indicate missing values + + +@functools.lru_cache(maxsize=4) +def _extract_path_sections(url_path: str) -> dict[str, str]: + """ + Extracts the database and schema from the path of a URL. + + Raises UrlPathError if the path is missing database/schema. + + snowflake://user:password@account/database/schema + """ + try: + _, database, schema, *_ = url_path.split("/") + except (ValueError, AttributeError) as e: + LOGGER.info(f"Unable to split path - {e!r}") + raise UrlPathError() from e + if not database: + raise UrlPathError(msg="missing database") + if not schema: + raise UrlPathError(msg="missing schema") + return {"database": database, "schema": schema} + + +def _get_config_substituted_connection_string( + datasource: SnowflakeDatasource, + warning_msg: str = "Unable to perform config substitution", +) -> AnyUrl | None: + if not isinstance(datasource.connection_string, ConfigUri): + raise TypeError("Config substitution is only supported for `ConfigUri`") # noqa: TRY003 + if not datasource._data_context: + warnings.warn( + f"{warning_msg} for {datasource.connection_string.template_str}." + " Likely missing a context.", + category=GxContextWarning, + ) + return None + return datasource.connection_string.get_config_value(datasource._data_context.config_provider) + class _UrlPasswordError(pydantic.UrlError): """ @@ -45,6 +109,30 @@ class _UrlDomainError(pydantic.UrlError): msg_template = "URL domain invalid" +class UrlPathError(pydantic.UrlError): + """ + Custom Pydantic error for missing path in SnowflakeDsn. + """ + + code = "url.path" + msg_template = "URL path missing database/schema" + + def __init__(self, **ctx: Any) -> None: + super().__init__(**ctx) + + +class _UrlMissingQueryError(pydantic.UrlError): + """ + Custom Pydantic error for missing query parameters in SnowflakeDsn. + """ + + def __init__(self, **ctx: Any) -> None: + super().__init__(**ctx) + + code = "url.query" + msg_template = "URL query param missing" + + class SnowflakeDsn(AnyUrl): allowed_schemes = { "snowflake", @@ -68,21 +156,63 @@ def validate_parts(cls, parts: Parts, validate_port: bool = True) -> Parts: if domain is None: raise _UrlDomainError() - return AnyUrl.validate_parts(parts=parts, validate_port=validate_port) + validated_parts = AnyUrl.validate_parts(parts=parts, validate_port=validate_port) + + path: str = parts["path"] + # raises UrlPathError if path is missing database/schema + _extract_path_sections(path) + + return validated_parts + + @property + def params(self) -> dict[str, list[str]]: + """The query parameters as a dictionary.""" + if not self.query: + return {} + return urllib.parse.parse_qs(self.query) + + @property + def account_identifier(self) -> str: + """Alias for host.""" + assert self.host + return self.host + + @property + def database(self) -> str: + assert self.path + return self.path.split("/")[1] + + @property + def schema_(self) -> str: + assert self.path + return self.path.split("/")[2] + + @property + def warehouse(self) -> str | None: + return self.params.get("warehouse", [None])[0] + + @property + def role(self) -> str | None: + return self.params.get("role", [None])[0] class ConnectionDetails(FluentBaseModel): """ Information needed to connect to a Snowflake database. Alternative to a connection string. + + https://docs.snowflake.com/en/developer-guide/python-connector/sqlalchemy#additional-connection-parameters """ account: str user: str password: Union[ConfigStr, str] - database: Optional[str] = None - schema_: Optional[str] = pydantic.Field( - None, alias="schema" + database: str = pydantic.Field( + ..., + description="`database` that the Datasource is mapped to.", + ) + schema_: str = pydantic.Field( + ..., alias="schema", description="`schema` that the Datasource is mapped to." ) # schema is a reserved attr in BaseModel warehouse: Optional[str] = None role: Optional[str] = None @@ -103,10 +233,95 @@ class SnowflakeDatasource(SQLDatasource): type: Literal["snowflake"] = "snowflake" # type: ignore[assignment] # TODO: rename this to `connection` for v1? - connection_string: Union[ConnectionDetails, ConfigStr, SnowflakeDsn] # type: ignore[assignment] # Deviation from parent class as individual args are supported for connection + connection_string: Union[ConnectionDetails, ConfigUri, SnowflakeDsn] # type: ignore[assignment] # Deviation from parent class as individual args are supported for connection # TODO: add props for account, user, password, etc? + @property + def schema_(self) -> str | None: + """ + Convenience property to get the `schema` regardless of the connection string format. + + `schema_` to avoid conflict with Pydantic models schema property. + """ + if isinstance(self.connection_string, (ConnectionDetails, SnowflakeDsn)): + return self.connection_string.schema_ + + subbed_str: str | None = _get_config_substituted_connection_string( + self, warning_msg="Unable to determine schema" + ) + if not subbed_str: + return None + url_path: str = urllib.parse.urlparse(subbed_str).path + return _extract_path_sections(url_path)["schema"] + + @property + def database(self) -> str | None: + """Convenience property to get the `database` regardless of the connection string format.""" + if isinstance(self.connection_string, (ConnectionDetails, SnowflakeDsn)): + return self.connection_string.database + + subbed_str: str | None = _get_config_substituted_connection_string( + self, warning_msg="Unable to determine database" + ) + if not subbed_str: + return None + url_path: str = urllib.parse.urlparse(subbed_str).path + return _extract_path_sections(url_path)["database"] + + @deprecated_method_or_class( + version="1.0.0a4", + message="`schema_name` is deprecated." " The schema now comes from the datasource.", + ) + @public_api + @override + def add_table_asset( # noqa: PLR0913 + self, + name: str, + table_name: str = "", + schema_name: Optional[str] = MISSING, # type: ignore[assignment] # sentinel value + order_by: Optional[SortersDefinition] = None, + batch_metadata: Optional[BatchMetadata] = None, + ) -> TableAsset: + """Adds a table asset to this datasource. + + Args: + name: The name of this table asset. + table_name: The table where the data resides. + schema_name: The schema that holds the table. Will use the datasource schema if not + provided. + order_by: A list of Sorters or Sorter strings. + batch_metadata: BatchMetadata we want to associate with this DataAsset and all batches + derived from it. + + Returns: + The table asset that is added to the datasource. + The type of this object will match the necessary type for this datasource. + """ + if schema_name is MISSING: + # using MISSING to indicate that the user did not provide a value + schema_name = self.schema_ + else: + # deprecated-v0.18.16 + warnings.warn( + "The `schema_name argument` is deprecated and will be removed in a future release." + " The schema now comes from the datasource.", + category=DeprecationWarning, + ) + if schema_name != self.schema_: + warnings.warn( + f"schema_name {schema_name} does not match datasource schema {self.schema_}", + category=GxDatasourceWarning, + ) + + return super().add_table_asset( + name=name, + table_name=table_name, + schema_name=schema_name, + order_by=order_by, + batch_metadata=batch_metadata, + ) + @pydantic.root_validator(pre=True) def _convert_root_connection_detail_fields(cls, values: dict) -> dict: """ @@ -122,18 +337,38 @@ def _convert_root_connection_detail_fields(cls, values: dict) -> dict: *ConnectionDetails.__fields__.keys(), } + connection_string: Any | None = values.get("connection_string") + provided_fields = tuple(values.keys()) + connection_details = {} - for field_name in tuple(values.keys()): + for field_name in provided_fields: if field_name in connection_detail_fields: + if connection_string: + raise ValueError( # noqa: TRY003 + "Provided both connection detail keyword args and `connection_string`." + ) connection_details[field_name] = values.pop(field_name) if connection_details: values["connection_string"] = connection_details return values + @pydantic.validator("connection_string", pre=True) + def _check_config_template(cls, connection_string: Any) -> Any: + """ + If connection_string has a config template, parse it as a ConfigUri, ignore other errors. + """ + if isinstance(connection_string, str): + if ConfigUri.str_contains_config_template(connection_string): + LOGGER.debug("`connection_string` contains config template") + return pydantic.parse_obj_as(ConfigUri, connection_string) + return connection_string + @pydantic.root_validator def _check_xor_input_args(cls, values: dict) -> dict: # keeping this validator isn't strictly necessary, but it provides a better error message - connection_string: str | ConnectionDetails | None = values.get("connection_string") + connection_string: str | ConfigUri | ConnectionDetails | None = values.get( + "connection_string" + ) if connection_string: # Method 1 - connection string is_connection_string: bool = isinstance( @@ -151,6 +386,32 @@ def _check_xor_input_args(cls, values: dict) -> dict: "Must provide either a connection string or a combination of account, user, and password." # noqa: E501 ) + @pydantic.validator("connection_string") + def _check_for_required_query_params( + cls, connection_string: ConnectionDetails | SnowflakeDsn | ConfigUri + ) -> ConnectionDetails | SnowflakeDsn | ConfigUri: + """ + If connection_string is a SnowflakeDsn, + check for required query parameters according to `REQUIRED_QUERY_PARAMS`. + """ + if not isinstance(connection_string, (SnowflakeDsn, ConfigUri)): + return connection_string + + missing_keys: set[str] = set(REQUIRED_QUERY_PARAMS) + + if connection_string.query: + query_params: dict[str, list[str]] = urllib.parse.parse_qs(connection_string.query) + + for key in REQUIRED_QUERY_PARAMS: + if key in query_params: + missing_keys.remove(key) + + if missing_keys: + raise _UrlMissingQueryError( + msg=f"missing {', '.join(sorted(missing_keys))}", + ) + return connection_string + class Config: @staticmethod def schema_extra(schema: dict, model: type[SnowflakeDatasource]) -> None: diff --git a/great_expectations/datasource/fluent/sources.pyi b/great_expectations/datasource/fluent/sources.pyi index 49b93080d02a..5c044cb6053e 100644 --- a/great_expectations/datasource/fluent/sources.pyi +++ b/great_expectations/datasource/fluent/sources.pyi @@ -609,8 +609,8 @@ class _SourceFactories: account: str = ..., user: str = ..., password: Union[ConfigStr, str] = ..., - database: Optional[str] = ..., - schema: Optional[str] = ..., + database: str = ..., + schema: str = ..., warehouse: Optional[str] = ..., role: Optional[str] = ..., numpy: bool = ..., @@ -647,8 +647,8 @@ class _SourceFactories: account: str = ..., user: str = ..., password: Union[ConfigStr, str] = ..., - database: Optional[str] = ..., - schema: Optional[str] = ..., + database: str = ..., + schema: str = ..., warehouse: Optional[str] = ..., role: Optional[str] = ..., numpy: bool = ..., @@ -685,8 +685,8 @@ class _SourceFactories: account: str = ..., user: str = ..., password: Union[ConfigStr, str] = ..., - database: Optional[str] = ..., - schema: Optional[str] = ..., + database: str = ..., + schema: str = ..., warehouse: Optional[str] = ..., role: Optional[str] = ..., numpy: bool = ..., diff --git a/pyproject.toml b/pyproject.toml index a8225d797e9b..f903e57e6ca1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -459,6 +459,10 @@ filterwarnings = [ "ignore:stream argument is deprecated. Use stream parameter in request directly:DeprecationWarning", # We likely won't be updating to `marhsmallow` 4, these errors should be filtered out "error::marshmallow.warnings.RemovedInMarshmallow4Warning", + # pkg_resources is deprecated as an API, but third party libraries still use it + "ignore: pkg_resources is deprecated as an API.:DeprecationWarning", + 'ignore: Deprecated call to `pkg_resources.declare_namespace\(.*\)`', + # --------------------------------------- Great Expectations Warnings ---------------------------------- # This warning is for configuring the result_format parameter at the Validator-level, which will not be persisted, diff --git a/tests/conftest.py b/tests/conftest.py index da3962cfd248..181ba3212b44 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2205,3 +2205,20 @@ def filter_gx_datasource_warnings() -> Generator[None, None, None]: with warnings.catch_warnings(): warnings.simplefilter("ignore", category=GxDatasourceWarning) yield + + +@pytest.fixture(scope="function") +def param_id(request: pytest.FixtureRequest) -> str: + """Return the parameter id of the current test. + + Example: + + ```python + @pytest.mark.parametrize("my_param", ["a", "b", "c"], ids=lambda x: x.upper()) + def test_something(param_id: str, my_param: str): + assert my_param != param_id + assert my_param.upper() == param_id + ``` + """ + raw_name: str = request.node.name + return raw_name.split("[")[1].split("]")[0] diff --git a/tests/datasource/fluent/great_expectations.yml b/tests/datasource/fluent/great_expectations.yml index b16f65a0377f..77947860c19f 100644 --- a/tests/datasource/fluent/great_expectations.yml +++ b/tests/datasource/fluent/great_expectations.yml @@ -198,7 +198,7 @@ fluent_datasources: abs_container: "this_is_always_required" my_snowflake_ds: type: snowflake - connection_string: "snowflake://user_login_name:password@account_identifier" + connection_string: "snowflake://user_login_name:password@account_identifier/database/public" assets: my_table_asset_wo_partitioners: id: d8b22f50-d3f9-4d04-9b4c-cfed86b157ff diff --git a/tests/datasource/fluent/integration/test_connections.py b/tests/datasource/fluent/integration/test_connections.py index 40820416262d..803057c36fea 100644 --- a/tests/datasource/fluent/integration/test_connections.py +++ b/tests/datasource/fluent/integration/test_connections.py @@ -24,28 +24,20 @@ class TestSnowflake: "connection_string", [ param( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?warehouse=ci", + "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@oca29081.us-east-1/ci/public?warehouse=ci", id="missing role", ), param( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?warehouse=ci&role=ci_no_select", + "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@oca29081.us-east-1/ci/public?warehouse=ci&role=ci_no_select", id="role wo select", ), - param( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}?warehouse=ci&role=ci", - id="missing database + schema", - ), - param( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci?warehouse=ci&role=ci", - id="missing schema", - ), ], ) def test_un_queryable_asset_should_raise_error( self, context: DataContext, connection_string: str ): """ - A SnowflakeDatasource can successfully connect even if things like database, schema, warehouse, and role are omitted. + A SnowflakeDatasource can successfully connect even if things like warehouse, and role are omitted. However, if we try to add an asset that is not queryable with the current datasource connection details, then we should expect a TestConnectionError. https://docs.snowflake.com/en/developer-guide/python-connector/sqlalchemy#connection-parameters @@ -84,11 +76,11 @@ def test_un_queryable_asset_should_raise_error( "connection_string", [ param( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?warehouse=ci&role=ci", + "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@oca29081.us-east-1/ci/public?warehouse=ci&role=ci", id="full connection string", ), param( - "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?role=ci", + "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@oca29081.us-east-1/ci/public?role=ci", id="missing warehouse", ), ], diff --git a/tests/datasource/fluent/integration/test_sql_datasources.py b/tests/datasource/fluent/integration/test_sql_datasources.py index 1a2980ec14f3..c3ae0299de04 100644 --- a/tests/datasource/fluent/integration/test_sql_datasources.py +++ b/tests/datasource/fluent/integration/test_sql_datasources.py @@ -6,6 +6,7 @@ import shutil import sys import uuid +import warnings from pprint import pformat as pf from typing import ( TYPE_CHECKING, @@ -333,9 +334,10 @@ def snowflake_ds( pytest.skip("no snowflake credentials") ds = context.data_sources.add_snowflake( "snowflake", - connection_string="snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?warehouse=ci&role=ci", + connection_string="snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@oca29081.us-east-1/ci" + f"/{RAND_SCHEMA}?warehouse=ci&role=ci", # NOTE: uncomment this and set SNOWFLAKE_USER to run tests against your own snowflake account # noqa: E501 - # connection_string="snowflake://${SNOWFLAKE_USER}@${SNOWFLAKE_CI_ACCOUNT}/DEMO_DB/RESTAURANTS?warehouse=COMPUTE_WH&role=PUBLIC&authenticator=externalbrowser", + # connection_string="snowflake://${SNOWFLAKE_USER}@oca29081.us-east-1/DEMO_DB/RESTAURANTS?warehouse=COMPUTE_WH&role=PUBLIC&authenticator=externalbrowser", ) return ds @@ -451,7 +453,7 @@ def test_snowflake( if not snowflake_ds: pytest.skip("no snowflake datasource") # create table - schema = get_random_identifier_name() + schema = RAND_SCHEMA table_factory( gx_engine=snowflake_ds.get_execution_engine(), table_names={table_name}, @@ -461,7 +463,7 @@ def test_snowflake( table_names: list[str] = inspect(snowflake_ds.get_engine()).get_table_names(schema=schema) print(f"snowflake tables:\n{pf(table_names)}))") - snowflake_ds.add_table_asset(asset_name, table_name=table_name, schema_name=schema) + snowflake_ds.add_table_asset(asset_name, table_name=table_name) @pytest.mark.sqlite def test_sqlite( @@ -482,6 +484,9 @@ def test_sqlite( sqlite_ds.add_table_asset(asset_name, table_name=table_name) + @pytest.mark.filterwarnings( # snowflake `add_table_asset` raises warning on passing a schema + "once::great_expectations.datasource.fluent.GxDatasourceWarning" + ) @pytest.mark.parametrize( "datasource_type,schema", [ @@ -518,8 +523,13 @@ def test_checkpoint_run( schema=schema, ) - asset = datasource.add_table_asset(asset_name, table_name=table_name, schema_name=schema) - batch_definition = asset.add_batch_definition_whole_table("whole table!") + with warnings.catch_warnings(): + # passing a schema to snowflake tables is deprecated + warnings.simplefilter("once", DeprecationWarning) + asset = datasource.add_table_asset( + asset_name, table_name=table_name, schema_name=schema + ) + batch_definition = asset.add_batch_definition_whole_table("whole table!") suite = context.suites.add(ExpectationSuite(name=f"{datasource.name}-{asset.name}")) suite.add_expectation(gxe.ExpectColumnValuesToNotBeNull(column="name", mostly=1)) @@ -711,6 +721,9 @@ def _raw_query_check_column_exists( return True +@pytest.mark.filterwarnings( + "once::DeprecationWarning" +) # snowflake `add_table_asset` raises warning on passing a schema @pytest.mark.parametrize( "column_name", [ diff --git a/tests/datasource/fluent/test_config_str.py b/tests/datasource/fluent/test_config_str.py index 77f9f8cc74ef..29bd97a1a918 100644 --- a/tests/datasource/fluent/test_config_str.py +++ b/tests/datasource/fluent/test_config_str.py @@ -10,7 +10,11 @@ _ConfigurationProvider, _EnvironmentConfigurationProvider, ) -from great_expectations.datasource.fluent.config_str import ConfigStr, SecretStr +from great_expectations.datasource.fluent.config_str import ( + ConfigStr, + ConfigUri, + SecretStr, +) from great_expectations.datasource.fluent.fluent_base_model import FluentBaseModel from great_expectations.exceptions import MissingConfigVariableError @@ -245,3 +249,119 @@ def test_serialization( assert "my_secret" not in dumped_str assert "dont_serialize_me" not in dumped_str assert r"${MY_SECRET}" in dumped_str + + +@pytest.mark.parametrize( + "uri", + [ + "http://my_user:${MY_PW}@example.com:8000/the/path/?query=here#fragment=is;this=bit", + "http://${MY_USER}:${MY_PW}@example.com:8000/the/path/?query=here#fragment=is;this=bit", + "snowflake://my_user:${MY_PW}@account/db", + "snowflake://${MY_USER}:${MY_PW}@account/db", + "postgresql+psycopg2://my_user:${MY_PW}@host/db", + "postgresql+psycopg2://${MY_USER}:${MY_PW}@host/db", + ], +) +class TestConfigUri: + def test_parts( + self, + env_config_provider: _ConfigurationProvider, + monkeypatch: MonkeyPatch, + uri: str, + ): + monkeypatch.setenv("MY_USER", "my_user") + monkeypatch.setenv("MY_PW", "super_secret") + + parsed = pydantic.parse_obj_as(ConfigUri, uri) + + # ensure attributes are set + assert parsed.scheme + assert parsed.host + assert parsed.path + # ensure no attribute errors + _ = parsed.query + _ = parsed.fragment + + # ensure that the password (and user) are not substituted + assert parsed.password == "${MY_PW}" + assert parsed.user in ["${MY_USER}", "my_user"] + + def test_substitution( + self, + env_config_provider: _ConfigurationProvider, + monkeypatch: MonkeyPatch, + uri: str, + ): + monkeypatch.setenv("MY_USER", "my_user") + monkeypatch.setenv("MY_PW", "super_secret") + + parsed = pydantic.parse_obj_as(ConfigUri, uri) + + substituted = parsed.get_config_value(env_config_provider) + + # ensure attributes are set + assert substituted.scheme + assert substituted.host + assert substituted.path + # ensure no attribute errors + _ = substituted.query + _ = substituted.fragment + + # ensure that the password (and user) are not substituted + assert substituted.password == "super_secret" + assert substituted.user == "my_user" + + def test_leakage( + self, + env_config_provider: _ConfigurationProvider, + monkeypatch: MonkeyPatch, + uri: str, + ): + """ + Ensure the config values are not leaked in the repr or str of the object + or the component parts. + """ + monkeypatch.setenv("MY_USER", "my_user") + monkeypatch.setenv("MY_PW", "super_secret") + + parsed = pydantic.parse_obj_as(ConfigUri, uri) + assert "super_secret" not in str(parsed) + assert "super_secret" not in repr(parsed) + assert parsed.password + assert "super_secret" not in parsed.password + + if "my_user" not in uri: + assert "my_user" not in str(parsed) + assert "my_user" not in repr(parsed) + assert parsed.user + assert "my_user" not in parsed.user + + +class TestConfigUriInvalid: + @pytest.mark.parametrize("uri", ["invalid_uri", "http:/example.com"]) + def test_invalid_uri(self, uri: str): + with pytest.raises(pydantic.ValidationError): + _ = pydantic.parse_obj_as(ConfigUri, uri) + + @pytest.mark.parametrize( + "uri", + [ + "${MY_SCHEME}://me:secret@account/db/schema", + "snowflake://me:secret@${MY_ACCOUNT}/db/schema", + "snowflake://me:secret@account/${MY_DB}/schema", + "snowflake://me:secret@account/db/${MY_SCHEMA}", + "snowflake://me:secret@account/db/my_schema?${MY_QUERY_PARAMS}", + "snowflake://me:secret@account/db/my_schema?role=${MY_ROLE}", + ], + ) + def test_disallowed_substitution(self, uri: str): + with pytest.raises(pydantic.ValidationError): + _ = pydantic.parse_obj_as(ConfigUri, uri) + + def test_no_template_str(self): + with pytest.raises(pydantic.ValidationError): + _ = pydantic.parse_obj_as(ConfigUri, "snowflake://me:password@account/db") + + +if __name__ == "__main__": + pytest.main([__file__, "-vv"]) diff --git a/tests/datasource/fluent/test_snowflake_datasource.py b/tests/datasource/fluent/test_snowflake_datasource.py index 40d02119a551..17e46d4f2d5a 100644 --- a/tests/datasource/fluent/test_snowflake_datasource.py +++ b/tests/datasource/fluent/test_snowflake_datasource.py @@ -1,5 +1,9 @@ from __future__ import annotations +from pprint import pformat as pf +from sys import version_info as python_version +from typing import TYPE_CHECKING, Final, Sequence + import pytest import sqlalchemy as sa from pytest import param @@ -7,6 +11,7 @@ from great_expectations.compatibility import pydantic from great_expectations.compatibility.snowflake import snowflake from great_expectations.data_context import AbstractDataContext +from great_expectations.datasource.fluent import GxContextWarning from great_expectations.datasource.fluent.config_str import ConfigStr from great_expectations.datasource.fluent.snowflake_datasource import ( SnowflakeDatasource, @@ -14,52 +19,96 @@ ) from great_expectations.execution_engine import SqlAlchemyExecutionEngine +if TYPE_CHECKING: + from pytest.mark.structures import ParameterSet # type: ignore[import-not-found] + +VALID_DS_CONFIG_PARAMS: Final[Sequence[ParameterSet]] = [ + param( + { + "connection_string": "snowflake://my_user:password@my_account/d_public/s_public?numpy=True" + }, + id="connection_string str", + ), + param( + {"connection_string": "snowflake://my_user:${MY_PASSWORD}@my_account/d_public/s_public"}, + id="connection_string ConfigStr - password sub", + ), + param( + {"connection_string": "snowflake://${MY_USER}:${MY_PASSWORD}@my_account/d_public/s_public"}, + id="connection_string ConfigStr - user + password sub", + ), + param( + { + "connection_string": { + "user": "my_user", + "password": "password", + "account": "my_account", + "schema": "s_public", + "database": "d_public", + } + }, + id="connection_string dict", + ), + param( + { + "connection_string": { + "user": "my_user", + "password": "${MY_PASSWORD}", + "account": "my_account", + "schema": "s_public", + "database": "d_public", + } + }, + id="connection_string dict with password ConfigStr", + ), +] + @pytest.fixture def seed_env_vars(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.setenv("MY_CONN_STR", "snowflake://my_user:password@my_account") + monkeypatch.setenv("MY_USER", "my_user") monkeypatch.setenv("MY_PASSWORD", "my_password") +@pytest.mark.unit +def test_snowflake_dsn(): + dsn = pydantic.parse_obj_as( + SnowflakeDsn, + "snowflake://my_user:password@my_account/my_db/my_schema?role=my_role&warehouse=my_wh", + ) + assert dsn.user == "my_user" + assert dsn.password == "password" + assert dsn.account_identifier == "my_account" + assert dsn.database == "my_db" + assert dsn.schema_ == "my_schema" + assert dsn.role == "my_role" + assert dsn.warehouse == "my_wh" + + @pytest.mark.snowflake # TODO: make this a unit test @pytest.mark.parametrize( "config_kwargs", [ - param( - {"connection_string": "snowflake://my_user:password@my_account"}, - id="connection_string str", - ), - param({"connection_string": "${MY_CONN_STR}"}, id="connection_string ConfigStr"), + *VALID_DS_CONFIG_PARAMS, param( { - "connection_string": { - "user": "my_user", - "password": "password", - "account": "my_account", - } + "user": "my_user", + "password": "password", + "account": "my_account", + "schema": "s_public", + "database": "d_public", }, - id="connection_string dict", - ), - param( - { - "connection_string": { - "user": "my_user", - "password": "${MY_PASSWORD}", - "account": "my_account", - } - }, - id="connection_string dict with password ConfigStr", - ), - param( - {"user": "my_user", "password": "password", "account": "my_account"}, id="old config format - top level keys", ), ], ) def test_valid_config( - empty_file_context: AbstractDataContext, seed_env_vars: None, config_kwargs: dict + empty_file_context: AbstractDataContext, + seed_env_vars: None, + config_kwargs: dict, + param_id: str, ): - my_sf_ds_1 = SnowflakeDatasource(name="my_sf_ds_1", **config_kwargs) + my_sf_ds_1 = SnowflakeDatasource(name=f"my_sf {param_id}", **config_kwargs) assert my_sf_ds_1 my_sf_ds_1._data_context = empty_file_context # attach to enable config substitution @@ -72,31 +121,360 @@ def test_valid_config( @pytest.mark.unit @pytest.mark.parametrize( - "connection_string, connect_args", + ["connection_string", "expected_errors"], [ pytest.param( - "snowflake://:@", - {"account": "my_account", "user": "my_user", "password": "123456"}, + "${MY_CONFIG_VAR}", + [ + { + "loc": ("connection_string", "__root__"), + "msg": "Only password, user may use config substitution;" + " 'domain' substitution not allowed", + "type": "value_error", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of account, " + "user, and password.", + "type": "value_error", + }, + ], + id="illegal config substitution - full connection string", + ), + pytest.param( + "snowflake://my_user:password@${MY_CONFIG_VAR}/db/schema", + [ + { + "loc": ("connection_string", "__root__"), + "msg": "Only password, user may use config substitution;" + " 'domain' substitution not allowed", + "type": "value_error", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of account, " + "user, and password.", + "type": "value_error", + }, + ], + id="illegal config substitution - account (domain)", + ), + pytest.param( + "snowflake://my_user:password@account/${MY_CONFIG_VAR}/schema", + [ + { + "loc": ("connection_string", "__root__"), + "msg": "Only password, user may use config substitution;" + " 'path' substitution not allowed", + "type": "value_error", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of account, " + "user, and password.", + "type": "value_error", + }, + ], + id="illegal config substitution - database (path)", + ), + pytest.param( + "snowflake://my_user:password@account/db/${MY_CONFIG_VAR}", + [ + { + "loc": ("connection_string", "__root__"), + "msg": "Only password, user may use config substitution;" + " 'path' substitution not allowed", + "type": "value_error", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of account, " + "user, and password.", + "type": "value_error", + }, + ], + id="illegal config substitution - schema (path)", + ), + pytest.param( + "snowflake://my_user:password@my_account", + [ + { + "loc": ("connection_string",), + "msg": "value is not a valid dict", + "type": "type_error.dict", + }, + { + "loc": ("connection_string",), + "msg": "ConfigStr - contains no config template strings in the format " + "'${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", + "type": "value_error", + }, + { + "loc": ("connection_string",), + "msg": "URL path missing database/schema", + "type": "value_error.url.path", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of account, " + "user, and password.", + "type": "value_error", + }, + ], + id="missing path", + ), + pytest.param( + "snowflake://my_user:password@my_account//", + [ + { + "loc": ("connection_string",), + "msg": "value is not a valid dict", + "type": "type_error.dict", + }, + { + "loc": ("connection_string",), + "msg": "ConfigStr - contains no config template strings in the format " + "'${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", + "type": "value_error", + }, + { + "ctx": {"msg": "missing database"}, + "loc": ("connection_string",), + "msg": "URL path missing database/schema", + "type": "value_error.url.path", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of account, " + "user, and password.", + "type": "value_error", + }, + ], + id="missing database + schema", + ), + pytest.param( + "snowflake://my_user:password@my_account/my_db", + [ + { + "loc": ("connection_string",), + "msg": "value is not a valid dict", + "type": "type_error.dict", + }, + { + "loc": ("connection_string",), + "msg": "ConfigStr - contains no config template strings in the format " + "'${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", + "type": "value_error", + }, + { + "loc": ("connection_string",), + "msg": "URL path missing database/schema", + "type": "value_error.url.path", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of account, " + "user, and password.", + "type": "value_error", + }, + ], + id="missing schema", + ), + pytest.param( + "snowflake://my_user:password@my_account/my_db/", + [ + { + "loc": ("connection_string",), + "msg": "value is not a valid dict", + "type": "type_error.dict", + }, + { + "loc": ("connection_string",), + "msg": "ConfigStr - contains no config template strings in the format " + "'${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", + "type": "value_error", + }, + { + "ctx": {"msg": "missing schema"}, + "loc": ("connection_string",), + "msg": "URL path missing database/schema", + "type": "value_error.url.path", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of account, " + "user, and password.", + "type": "value_error", + }, + ], + id="missing schema 2", + ), + pytest.param( + "snowflake://my_user:password@my_account//my_schema", + [ + { + "loc": ("connection_string",), + "msg": "value is not a valid dict", + "type": "type_error.dict", + }, + { + "loc": ("connection_string",), + "msg": "ConfigStr - contains no config template strings in the format " + "'${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", + "type": "value_error", + }, + { + "ctx": {"msg": "missing database"}, + "loc": ("connection_string",), + "msg": "URL path missing database/schema", + "type": "value_error.url.path", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of account, " + "user, and password.", + "type": "value_error", + }, + ], + id="missing database", + ), + ], +) +def test_missing_required_params( + connection_string: str, + expected_errors: list[dict], # TODO: use pydantic error dict +): + with pytest.raises(pydantic.ValidationError) as exc_info: + ds = SnowflakeDatasource( + name="my_sf_ds", + connection_string=connection_string, + ) + print(f"{ds!r}") + + print(f"\n\tErrors:\n{pf(exc_info.value.errors())}") + assert exc_info.value.errors() == expected_errors + + +@pytest.mark.unit +@pytest.mark.parametrize( + "connection_string, connect_args, expected_errors", + [ + pytest.param( + "snowflake://my_user:password@my_account/foo/bar?numpy=True", + { + "account": "my_account", + "user": "my_user", + "password": "123456", + "schema": "foo", + "database": "bar", + }, + [ + { + "loc": ("__root__",), + "msg": "Provided both connection detail keyword args and `connection_string`.", + "type": "value_error", + } + ], id="both connection_string and connect_args", ), - pytest.param(None, {}, id="neither connection_string nor connect_args"), pytest.param( None, - {"account": "my_account", "user": "my_user"}, + {}, + [ + { + "loc": ("connection_string",), + "msg": "none is not an allowed value", + "type": "type_error.none.not_allowed", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of" + " account, user, and password.", + "type": "value_error", + }, + ], + id="neither connection_string nor connect_args", + ), + pytest.param( + None, + { + "account": "my_account", + "user": "my_user", + "schema": "foo", + "database": "bar", + }, + [ + { + "loc": ("connection_string", "password"), + "msg": "field required", + "type": "value_error.missing", + }, + { + "loc": ("connection_string",), + "msg": "expected string or bytes-like object" + f"""{"" if python_version < (3, 11) else ", got 'dict'"}""", + "type": "type_error", + }, + { + "loc": ("connection_string",), + "msg": "str type expected", + "type": "type_error.str", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of" + " account, user, and password.", + "type": "value_error", + }, + ], id="incomplete connect_args", ), pytest.param( - {"connection_string": {"account": "my_account", "user": "my_user"}}, + { + "account": "my_account", + "user": "my_user", + "schema": "foo", + "database": "bar", + }, {}, + [ + { + "loc": ("connection_string", "password"), + "msg": "field required", + "type": "value_error.missing", + }, + { + "loc": ("connection_string",), + "msg": "expected string or bytes-like object" + f"""{"" if python_version < (3, 11) else ", got 'dict'"}""", + "type": "type_error", + }, + { + "loc": ("connection_string",), + "msg": "str type expected", + "type": "type_error.str", + }, + { + "loc": ("__root__",), + "msg": "Must provide either a connection string or a combination of account, " + "user, and password.", + "type": "value_error", + }, + ], id="incomplete connection_string dict connect_args", ), ], ) def test_conflicting_connection_string_and_args_raises_error( - connection_string: ConfigStr | SnowflakeDsn | None | dict, connect_args: dict + connection_string: ConfigStr | SnowflakeDsn | None | dict, + connect_args: dict, + expected_errors: list[dict], ): - with pytest.raises(ValueError): - _ = SnowflakeDatasource(connection_string=connection_string, **connect_args) + with pytest.raises(pydantic.ValidationError) as exc_info: + _ = SnowflakeDatasource( + name="my_sf_ds", connection_string=connection_string, **connect_args + ) + assert exc_info.value.errors() == expected_errors @pytest.mark.unit @@ -113,7 +491,8 @@ def test_conflicting_connection_string_and_args_raises_error( }, { "loc": ("connection_string",), - "msg": "ConfigStr - contains no config template strings in the format '${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", # noqa: E501 + "msg": "ConfigStr - contains no config template strings in the format" + " '${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", "type": "value_error", }, { @@ -123,7 +502,8 @@ def test_conflicting_connection_string_and_args_raises_error( }, { "loc": ("__root__",), - "msg": "Must provide either a connection string or a combination of account, user, and password.", # noqa: E501 + "msg": "Must provide either a connection string or a combination of" + " account, user, and password.", "type": "value_error", }, ], @@ -139,7 +519,8 @@ def test_conflicting_connection_string_and_args_raises_error( }, { "loc": ("connection_string",), - "msg": "ConfigStr - contains no config template strings in the format '${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", # noqa: E501 + "msg": "ConfigStr - contains no config template strings in the format" + " '${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", "type": "value_error", }, { @@ -149,7 +530,8 @@ def test_conflicting_connection_string_and_args_raises_error( }, { "loc": ("__root__",), - "msg": "Must provide either a connection string or a combination of account, user, and password.", # noqa: E501 + "msg": "Must provide either a connection string or a combination of" + " account, user, and password.", "type": "value_error", }, ], @@ -165,7 +547,8 @@ def test_conflicting_connection_string_and_args_raises_error( }, { "loc": ("connection_string",), - "msg": "ConfigStr - contains no config template strings in the format '${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", # noqa: E501 + "msg": "ConfigStr - contains no config template strings in the format" + " '${MY_CONFIG_VAR}' or '$MY_CONFIG_VAR'", "type": "value_error", }, { @@ -175,7 +558,8 @@ def test_conflicting_connection_string_and_args_raises_error( }, { "loc": ("__root__",), - "msg": "Must provide either a connection string or a combination of account, user, and password.", # noqa: E501 + "msg": "Must provide either a connection string or a combination of" + " account, user, and password.", "type": "value_error", }, ], @@ -196,11 +580,53 @@ def test_invalid_connection_string_raises_dsn_error( @pytest.mark.skipif(True if not snowflake else False, reason="snowflake is not installed") @pytest.mark.unit def test_get_execution_engine_succeeds(): - connection_string = "snowflake://my_user:password@my_account" + connection_string = "snowflake://my_user:password@my_account/my_db/my_schema" datasource = SnowflakeDatasource(name="my_snowflake", connection_string=connection_string) # testing that this doesn't raise an exception datasource.get_execution_engine() +@pytest.mark.snowflake +@pytest.mark.parametrize("ds_config", VALID_DS_CONFIG_PARAMS) +class TestConvenienceProperties: + def test_schema( + self, + ds_config: dict, + seed_env_vars: None, + param_id: str, + ephemeral_context_with_defaults: AbstractDataContext, + ): + datasource = SnowflakeDatasource(name=param_id, **ds_config) + if isinstance(datasource.connection_string, ConfigStr): + # expect a warning if connection string is a ConfigStr + with pytest.warns(GxContextWarning): + assert ( + not datasource.schema_ + ), "Don't expect schema to be available without config_provider" + # attach context to enable config substitution + datasource._data_context = ephemeral_context_with_defaults + + assert datasource.schema_ + + def test_database( + self, + ds_config: dict, + seed_env_vars: None, + param_id: str, + ephemeral_context_with_defaults: AbstractDataContext, + ): + datasource = SnowflakeDatasource(name=param_id, **ds_config) + if isinstance(datasource.connection_string, ConfigStr): + # expect a warning if connection string is a ConfigStr + with pytest.warns(GxContextWarning): + assert ( + not datasource.database + ), "Don't expect schema to be available without config_provider" + # attach context to enable config substitution + datasource._data_context = ephemeral_context_with_defaults + + assert datasource.database + + if __name__ == "__main__": pytest.main([__file__, "-vv"]) diff --git a/tests/integration/cloud/end_to_end/test_snowflake_datasource.py b/tests/integration/cloud/end_to_end/test_snowflake_datasource.py index adae99dcf60f..978f1a7230b0 100644 --- a/tests/integration/cloud/end_to_end/test_snowflake_datasource.py +++ b/tests/integration/cloud/end_to_end/test_snowflake_datasource.py @@ -2,7 +2,7 @@ import os import uuid -from typing import TYPE_CHECKING, Iterator +from typing import TYPE_CHECKING, Final, Iterator import pytest @@ -23,13 +23,21 @@ from great_expectations.validator.validator import Validator from tests.integration.cloud.end_to_end.conftest import TableFactory +RANDOM_SCHEMA: Final[str] = f"i{uuid.uuid4().hex}" + @pytest.fixture(scope="module") def connection_string() -> str: if os.getenv("SNOWFLAKE_CI_USER_PASSWORD") and os.getenv("SNOWFLAKE_CI_ACCOUNT"): - return "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@${SNOWFLAKE_CI_ACCOUNT}/ci/public?warehouse=ci&role=ci" + return ( + "snowflake://ci:${SNOWFLAKE_CI_USER_PASSWORD}@oca29081.us-east-1/ci" + f"/{RANDOM_SCHEMA}?warehouse=ci&role=ci" + ) elif os.getenv("SNOWFLAKE_USER") and os.getenv("SNOWFLAKE_CI_ACCOUNT"): - return "snowflake://${SNOWFLAKE_USER}@${SNOWFLAKE_CI_ACCOUNT}/DEMO_DB?warehouse=COMPUTE_WH&role=PUBLIC&authenticator=externalbrowser" + return ( + "snowflake://${SNOWFLAKE_USER}@oca29081.us-east-1/DEMO_DB" + f"/{RANDOM_SCHEMA}?warehouse=COMPUTE_WH&role=PUBLIC&authenticator=externalbrowser" + ) else: pytest.skip("no snowflake credentials") @@ -89,16 +97,15 @@ def table_asset( asset_name: str, table_factory: TableFactory, ) -> TableAsset: - schema_name = f"i{uuid.uuid4().hex}" table_name = f"i{uuid.uuid4().hex}" table_factory( gx_engine=datasource.get_execution_engine(), table_names={table_name}, - schema_name=schema_name, + schema_name=RANDOM_SCHEMA, ) return datasource.add_table_asset( name=asset_name, - schema_name=schema_name, + schema_name=RANDOM_SCHEMA, table_name=table_name, ) diff --git a/tests/integration/db/taxi_data_utils.py b/tests/integration/db/taxi_data_utils.py index a4e77e108d41..0ff05769e47e 100644 --- a/tests/integration/db/taxi_data_utils.py +++ b/tests/integration/db/taxi_data_utils.py @@ -1,3 +1,6 @@ +from __future__ import annotations + +import warnings from contextlib import contextmanager from typing import TYPE_CHECKING, List @@ -5,6 +8,7 @@ import great_expectations as gx from great_expectations.core.batch_definition import BatchDefinition +from great_expectations.datasource.fluent import GxDatasourceWarning from great_expectations.execution_engine.sqlalchemy_batch_data import ( SqlAlchemyBatchData, ) @@ -114,7 +118,23 @@ def _execute_taxi_partitioning_test_cases( datasource = add_datasource( context, name=datasource_name, connection_string=connection_string ) - asset = datasource.add_table_asset(data_asset_name, table_name=table_name) + + with warnings.catch_warnings(): + warnings.filterwarnings( + "once", + message="The `schema_name argument` is deprecated", + category=DeprecationWarning, + ) + warnings.filterwarnings( + "once", + message="schema_name None does not match datasource schema", + category=GxDatasourceWarning, + ) + + asset = datasource.add_table_asset( + data_asset_name, table_name=table_name, schema_name=None + ) + add_batch_definition_method = getattr( asset, test_case.add_batch_definition_method_name or "MAKE THIS REQUIRED" )