Skip to content

Commit

Permalink
[BUGFIX] Fix Databricks SQL Regex and Like based Expectations (#10406)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Kilo59 and pre-commit-ci[bot] authored Sep 17, 2024
1 parent 0b35117 commit c49d728
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 25 deletions.
2 changes: 1 addition & 1 deletion great_expectations/compatibility/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
)

try:
from databricks import connect # type: ignore[import-untyped]
from databricks import connect
except ImportError:
connect = DATABRICKS_CONNECT_NOT_IMPORTED
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ def __init__( # noqa: C901, PLR0912, PLR0913, PLR0915
self.dialect_module = import_library_module(
module_name="clickhouse_sqlalchemy.drivers.base"
)
elif self.dialect_name == GXSqlDialect.DATABRICKS:
self.dialect_module = import_library_module("databricks.sqlalchemy")
else:
self.dialect_module = None

Expand Down
71 changes: 56 additions & 15 deletions great_expectations/expectations/metrics/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
import logging
import re
from collections import UserDict
from types import ModuleType
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
overload,
)

Expand Down Expand Up @@ -63,6 +66,11 @@
except ImportError:
clickhouse_sqlalchemy = None

try:
import databricks.sqlalchemy as sqla_databricks
except ImportError:
sqla_databricks = None # type: ignore[assignment]

_BIGQUERY_MODULE_NAME = "sqlalchemy_bigquery"

from great_expectations.compatibility import bigquery as sqla_bigquery
Expand All @@ -79,12 +87,33 @@
teradatatypes = None


def _is_databricks_dialect(dialect: ModuleType | sa.Dialect | Type[sa.Dialect]) -> bool:
"""
Check if the Databricks dialect is being provided.
"""
if not sqla_databricks:
return False
try:
if isinstance(dialect, sqla_databricks.DatabricksDialect):
return True
if hasattr(dialect, "DatabricksDialect"):
return True
if issubclass(dialect, sqla_databricks.DatabricksDialect): # type: ignore[arg-type]
return True
except Exception:
pass
return False


def get_dialect_regex_expression( # noqa: C901, PLR0911, PLR0912, PLR0915
column, regex, dialect, positive=True
):
column: sa.Column,
regex: str,
dialect: ModuleType | Type[sa.Dialect] | sa.Dialect,
positive: bool = True,
) -> sa.SQLColumnExpression | None:
try:
# postgres
if issubclass(dialect.dialect, sa.dialects.postgresql.dialect):
if issubclass(dialect.dialect, sa.dialects.postgresql.dialect): # type: ignore[union-attr]
if positive:
return sqlalchemy.BinaryExpression(
column, sqlalchemy.literal(regex), sqlalchemy.custom_op("~")
Expand All @@ -96,11 +125,18 @@ def get_dialect_regex_expression( # noqa: C901, PLR0911, PLR0912, PLR0915
except AttributeError:
pass

# databricks sql
if _is_databricks_dialect(dialect):
if positive:
return sa.func.regexp_like(column, sqlalchemy.literal(regex))
else:
return sa.not_(sa.func.regexp_like(column, sqlalchemy.literal(regex)))

# redshift
# noinspection PyUnresolvedReferences
try:
if hasattr(dialect, "RedshiftDialect") or (
aws.redshiftdialect and issubclass(dialect.dialect, aws.redshiftdialect.RedshiftDialect)
aws.redshiftdialect and issubclass(dialect.dialect, aws.redshiftdialect.RedshiftDialect) # type: ignore[union-attr]
):
if positive:
return sqlalchemy.BinaryExpression(
Expand All @@ -117,7 +153,7 @@ def get_dialect_regex_expression( # noqa: C901, PLR0911, PLR0912, PLR0915

try:
# MySQL
if issubclass(dialect.dialect, sa.dialects.mysql.dialect):
if issubclass(dialect.dialect, sa.dialects.mysql.dialect): # type: ignore[union-attr]
if positive:
return sqlalchemy.BinaryExpression(
column, sqlalchemy.literal(regex), sqlalchemy.custom_op("REGEXP")
Expand All @@ -134,7 +170,7 @@ def get_dialect_regex_expression( # noqa: C901, PLR0911, PLR0912, PLR0915
try:
# Snowflake
if issubclass(
dialect.dialect,
dialect.dialect, # type: ignore[union-attr]
snowflake.sqlalchemy.snowdialect.SnowflakeDialect,
):
if positive:
Expand Down Expand Up @@ -216,7 +252,7 @@ def get_dialect_regex_expression( # noqa: C901, PLR0911, PLR0912, PLR0915

try:
# Teradata
if issubclass(dialect.dialect, teradatasqlalchemy.dialect.TeradataDialect):
if issubclass(dialect.dialect, teradatasqlalchemy.dialect.TeradataDialect): # type: ignore[union-attr]
if positive:
return (
sa.func.REGEXP_SIMILAR(
Expand All @@ -237,7 +273,7 @@ def get_dialect_regex_expression( # noqa: C901, PLR0911, PLR0912, PLR0915
try:
# sqlite
# regex_match for sqlite introduced in sqlalchemy v1.4
if issubclass(dialect.dialect, sa.dialects.sqlite.dialect) and version.parse(
if issubclass(dialect.dialect, sa.dialects.sqlite.dialect) and version.parse( # type: ignore[union-attr]
sa.__version__
) >= version.parse("1.4"):
if positive:
Expand All @@ -256,7 +292,9 @@ def get_dialect_regex_expression( # noqa: C901, PLR0911, PLR0912, PLR0915
return None


def _get_dialect_type_module(dialect=None):
def _get_dialect_type_module(
dialect: ModuleType | Type[sa.Dialect] | sa.Dialect | None = None,
) -> ModuleType | Type[sa.Dialect] | sa.Dialect:
if dialect is None:
logger.warning("No sqlalchemy dialect found; relying in top-level sqlalchemy types.")
return sa
Expand All @@ -274,7 +312,7 @@ def _get_dialect_type_module(dialect=None):
if (
isinstance(
dialect,
sqla_bigquery.BigQueryDialect,
sqla_bigquery.BigQueryDialect, # type: ignore[attr-defined]
)
and bigquery_types_tuple is not None
):
Expand All @@ -286,7 +324,7 @@ def _get_dialect_type_module(dialect=None):
try:
if (
issubclass(
dialect,
dialect, # type: ignore[arg-type]
teradatasqlalchemy.dialect.TeradataDialect,
)
and teradatatypes is not None
Expand Down Expand Up @@ -836,14 +874,14 @@ def _get_normalized_column_name_mapping_if_exists(
return None if verify_only else normalized_batch_columns_mappings


def parse_value_set(value_set):
def parse_value_set(value_set: Iterable) -> list:
parsed_value_set = [parse(value) if isinstance(value, str) else value for value in value_set]
return parsed_value_set


def get_dialect_like_pattern_expression( # noqa: C901, PLR0912
column, dialect, like_pattern, positive=True
):
def get_dialect_like_pattern_expression( # noqa: C901, PLR0912, PLR0915
column: sa.Column, dialect: ModuleType, like_pattern: str, positive: bool = True
) -> sa.BinaryExpression | None:
dialect_supported: bool = False

try:
Expand All @@ -868,6 +906,9 @@ def get_dialect_like_pattern_expression( # noqa: C901, PLR0912
):
dialect_supported = True

if _is_databricks_dialect(dialect):
dialect_supported = True

try:
if hasattr(dialect, "RedshiftDialect"):
dialect_supported = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ def _sqlalchemy(cls, column, _dialect, **kwargs):
regex_expression = get_dialect_regex_expression(column, cls.regex, _dialect)

if regex_expression is None:
logger.warning(f"Regex is not supported for dialect {_dialect.dialect.name!s}")
raise NotImplementedError
msg = f"Regex is not supported for dialect {_dialect.dialect.name!s}"
logger.warning(msg)
raise NotImplementedError(msg)

return regex_expression

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ module = [
"boto3.*",
"botocore.*",
"clickhouse_sqlalchemy.*",
"databricks.*",
"google.*",
"great_expectations.compatibility.pydantic.*",
"ipywidgets.*",
Expand Down
28 changes: 21 additions & 7 deletions tests/datasource/fluent/integration/test_sql_datasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pprint import pformat as pf
from typing import (
TYPE_CHECKING,
Any,
Final,
Generator,
Literal,
Expand Down Expand Up @@ -719,7 +720,7 @@ def _fails_expectation(param_id: str) -> bool:
This does not mean that it SHOULD fail, but that it currently does.
"""
column_name: ColNameParamId
dialect, column_name, _ = param_id.split("-") # type: ignore[assignment]
dialect, column_name, *_ = param_id.split("-") # type: ignore[assignment]
dialects_need_fixes: list[DatabaseType] = FAILS_EXPECTATION.get(column_name, [])
return dialect in dialects_need_fixes

Expand Down Expand Up @@ -765,15 +766,25 @@ def _raw_query_check_column_exists(


_EXPECTATION_TYPES: Final[tuple[ParameterSet, ...]] = (
param("expect_column_to_exist"),
param("expect_column_values_to_not_be_null"),
param("expect_column_to_exist", {}, id="expect_column_to_exist"),
param("expect_column_values_to_not_be_null", {}, id="expect_column_values_to_not_be_null"),
param(
"expect_column_values_to_match_regex",
{"regex": r".*"},
id="expect_column_values_to_match_regex",
),
param(
"expect_column_values_to_match_like_pattern",
{"like_pattern": r"%"},
id="expect_column_values_to_match_like_pattern",
),
)


@pytest.mark.filterwarnings(
"once::DeprecationWarning"
) # snowflake `add_table_asset` raises warning on passing a schema
@pytest.mark.parametrize("expectation_type", _EXPECTATION_TYPES)
@pytest.mark.parametrize("expectation_type, extra_exp_kwargs", _EXPECTATION_TYPES)
class TestColumnExpectations:
@pytest.mark.parametrize(
"column_name",
Expand Down Expand Up @@ -806,6 +817,7 @@ def test_unquoted_params(
table_factory: TableFactory,
column_name: str | quoted_name,
expectation_type: str,
extra_exp_kwargs: dict[str, Any],
request: pytest.FixtureRequest,
):
"""
Expand Down Expand Up @@ -862,7 +874,7 @@ def test_unquoted_params(
suite = context.suites.add(ExpectationSuite(name=f"{datasource.name}-{asset.name}"))
suite.add_expectation_configuration(
expectation_configuration=ExpectationConfiguration(
type=expectation_type, kwargs={"column": column_name}
type=expectation_type, kwargs={"column": column_name, **extra_exp_kwargs}
)
)
suite.save()
Expand Down Expand Up @@ -908,6 +920,7 @@ def test_quoted_params(
table_factory: TableFactory,
column_name: str | quoted_name,
expectation_type: str,
extra_exp_kwargs: dict[str, Any],
request: pytest.FixtureRequest,
):
"""
Expand Down Expand Up @@ -966,7 +979,7 @@ def test_quoted_params(
suite = context.suites.add(ExpectationSuite(name=f"{datasource.name}-{asset.name}"))
suite.add_expectation_configuration(
expectation_configuration=ExpectationConfiguration(
type=expectation_type, kwargs={"column": column_name}
type=expectation_type, kwargs={"column": column_name, **extra_exp_kwargs}
)
)
suite.save()
Expand Down Expand Up @@ -1028,6 +1041,7 @@ def test_desired_state(
table_factory: TableFactory,
column_name: str | quoted_name,
expectation_type: str,
extra_exp_kwargs: dict[str, Any],
request: pytest.FixtureRequest,
):
"""
Expand Down Expand Up @@ -1096,7 +1110,7 @@ def test_desired_state(
suite = context.suites.add(ExpectationSuite(name=f"{datasource.name}-{asset.name}"))
suite.add_expectation_configuration(
expectation_configuration=ExpectationConfiguration(
type=expectation_type, kwargs={"column": column_name}
type=expectation_type, kwargs={"column": column_name, **extra_exp_kwargs}
)
)
suite.save()
Expand Down

0 comments on commit c49d728

Please sign in to comment.