Skip to content

Commit

Permalink
Merge branch 'develop' into kml/sourcesError
Browse files Browse the repository at this point in the history
  • Loading branch information
klavavej authored Dec 18, 2024
2 parents 7ce516f + 5e11e7c commit 9f9794f
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 13 deletions.
38 changes: 35 additions & 3 deletions great_expectations/compatibility/databricks.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,42 @@
from __future__ import annotations

from great_expectations.compatibility.not_imported import NotImported

DATABRICKS_CONNECT_NOT_IMPORTED = NotImported(
"databricks-connect is not installed, please 'pip install databricks-connect'"
)

# The following types are modeled after the following documentation that is part
# of the databricks package.
# tldr: SQLAlchemy application should (mostly) "just work" with Databricks,
# other than the exceptions below
# https://github.com/databricks/databricks-sql-python/blob/main/src/databricks/sqlalchemy/README.sqlalchemy.md

try:
from databricks.sqlalchemy._types import TIMESTAMP_NTZ as TIMESTAMP_NTZ # noqa: PLC0414, RUF100
except (ImportError, AttributeError):
TIMESTAMP_NTZ = DATABRICKS_CONNECT_NOT_IMPORTED # type: ignore[misc, assignment]

try:
from databricks.sqlalchemy._types import DatabricksStringType as STRING # noqa: PLC0414, RUF100
except (ImportError, AttributeError):
STRING = DATABRICKS_CONNECT_NOT_IMPORTED # type: ignore[misc, assignment]

try:
from databricks import connect
except ImportError:
connect = DATABRICKS_CONNECT_NOT_IMPORTED
from databricks.sqlalchemy._types import TIMESTAMP as TIMESTAMP # noqa: PLC0414, RUF100
except (ImportError, AttributeError):
TIMESTAMP = DATABRICKS_CONNECT_NOT_IMPORTED # type: ignore[misc, assignment]

try:
from databricks.sqlalchemy._types import TINYINT as TINYINT # noqa: PLC0414, RUF100
except (ImportError, AttributeError):
TINYINT = DATABRICKS_CONNECT_NOT_IMPORTED # type: ignore[misc, assignment]


class DATABRICKS_TYPES:
"""Namespace for Databricks dialect types"""

TIMESTAMP_NTZ = TIMESTAMP_NTZ
STRING = STRING
TINYINT = TINYINT
TIMESTAMP = TIMESTAMP
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def _validate_pandas( # noqa: C901, PLR0912
def _validate_sqlalchemy(self, actual_column_type, expected_types_list, execution_engine):
if expected_types_list is None:
success = True
elif execution_engine.dialect_name == GXSqlDialect.SNOWFLAKE:
elif execution_engine.dialect_name in [GXSqlDialect.SNOWFLAKE, GXSqlDialect.DATABRICKS]:
success = isinstance(actual_column_type, str) and any(
actual_column_type.lower() == expected_type.lower()
for expected_type in expected_types_list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def _validate_sqlalchemy(self, actual_column_type, expected_type, execution_engi

if expected_type is None:
success = True
elif execution_engine.dialect_name == GXSqlDialect.SNOWFLAKE:
elif execution_engine.dialect_name in [GXSqlDialect.SNOWFLAKE, GXSqlDialect.DATABRICKS]:
success = (
isinstance(actual_column_type, str)
and actual_column_type.lower() == expected_type.lower()
Expand Down
15 changes: 9 additions & 6 deletions great_expectations/expectations/metrics/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,16 +414,19 @@ def get_sqlalchemy_column_metadata( # noqa: C901
)

dialect_name = execution_engine.dialect.name
if dialect_name == GXSqlDialect.SNOWFLAKE:
if dialect_name in [GXSqlDialect.SNOWFLAKE, GXSqlDialect.DATABRICKS]:
# WARNING: Do not alter columns in place, as they are cached on the inspector
columns_copy = [column.copy() for column in columns]
for column in columns_copy:
column["type"] = column["type"].compile(dialect=execution_engine.dialect)
return [
# TODO: SmartColumn should know the dialect and do lookups based on that
CaseInsensitiveNameDict(column)
for column in columns_copy
]
if dialect_name == GXSqlDialect.SNOWFLAKE:
return [
# TODO: SmartColumn should know the dialect and do lookups based on that
CaseInsensitiveNameDict(column)
for column in columns_copy
]
else:
return columns_copy

return columns
except AttributeError as e:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import pandas as pd
import pytest
import sqlalchemy.types as sqltypes
from packaging import version

import great_expectations.expectations as gxe
from great_expectations.compatibility.databricks import DATABRICKS_TYPES
from great_expectations.compatibility.snowflake import SNOWFLAKE_TYPES
from great_expectations.compatibility.sqlalchemy import (
sqlalchemy as sa,
)
from great_expectations.core.result_format import ResultFormat
from great_expectations.datasource.fluent.interfaces import Batch
from tests.integration.conftest import parameterize_batch_for_data_sources
Expand Down Expand Up @@ -379,3 +384,210 @@ def test_success_complete_snowflake(
assert isinstance(result_dict["observed_value"], str)
assert isinstance(expectation.type_list, list)
assert result_dict["observed_value"] in expectation.type_list


@pytest.mark.parametrize(
"expectation",
[
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(column="STRING", type_list=["STRING"]),
id="STRING",
),
# SqlA Text gets converted to Databricks STRING
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(column="TEXT", type_list=["STRING"]),
id="TEXT",
),
# SqlA UNICODE gets converted to Databricks STRING
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(column="UNICODE", type_list=["STRING"]),
id="UNICODE",
),
# SqlA UNICODE_TEXT gets converted to Databricks STRING
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(column="UNICODE_TEXT", type_list=["STRING"]),
id="UNICODE_TEXT",
),
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(column="BOOLEAN", type_list=["BOOLEAN"]),
id="BOOLEAN",
),
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(
column="DECIMAL", type_list=["DECIMAL", "DECIMAL(10, 0)"]
),
id="DECIMAL",
),
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(column="DATE", type_list=["DATE"]),
id="DATE",
),
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(column="TIMESTAMP", type_list=["TIMESTAMP"]),
id="TIMESTAMP",
),
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(
column="TIMESTAMP_NTZ", type_list=["TIMESTAMP_NTZ"]
),
id="TIMESTAMP_NTZ",
),
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(column="FLOAT", type_list=["FLOAT"]),
id="FLOAT",
),
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(column="INT", type_list=["INT"]),
id="INT",
),
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(column="TINYINT", type_list=["TINYINT"]),
id="TINYINT",
),
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(
column="DECIMAL", type_list=["DECIMAL", "DECIMAL(10, 0)"]
),
id="DECIMAL",
),
# SqlA Time gets converted to Databricks STRING,
# but is not supported by our testing framework
# pytest.param(
# gxe.ExpectColumnValuesToBeInTypeList(column="TIME", type_list=["STRING"]),
# id="TIME",
# ),
# SqlA UUID gets converted to Databricks STRING,
# but is not supported by our testing framework.
# pytest.param(
# gxe.ExpectColumnValuesToBeInTypeList(column="UUID", type_list=["STRING"]),
# id="UUID",
# )
],
)
@parameterize_batch_for_data_sources(
data_source_configs=[
DatabricksDatasourceTestConfig(
column_types={
"STRING": DATABRICKS_TYPES.STRING,
"TEXT": sqltypes.Text,
"UNICODE": sqltypes.Unicode,
"UNICODE_TEXT": sqltypes.UnicodeText,
"BIGINT": sqltypes.BigInteger,
"BOOLEAN": sqltypes.BOOLEAN,
"DATE": sqltypes.DATE,
"TIMESTAMP_NTZ": DATABRICKS_TYPES.TIMESTAMP_NTZ,
"TIMESTAMP": DATABRICKS_TYPES.TIMESTAMP,
"FLOAT": sqltypes.Float,
"INT": sqltypes.Integer,
"DECIMAL": sqltypes.Numeric,
"SMALLINT": sqltypes.SmallInteger,
"TINYINT": DATABRICKS_TYPES.TINYINT,
# "TIME": sqltypes.Time,
# "UUID": sqltypes.UUID,
}
)
],
data=pd.DataFrame(
{
"STRING": ["a", "b", "c"],
"TEXT": ["a", "b", "c"],
"UNICODE": ["\u00e9", "\u00e9", "\u00e9"],
"UNICODE_TEXT": ["a", "b", "c"],
"BIGINT": [1111, 2222, 3333],
"BOOLEAN": [True, True, False],
"DATE": [
"2021-01-01",
"2021-01-02",
"2021-01-03",
],
"TIMESTAMP_NTZ": [
"2021-01-01 00:00:00",
"2021-01-02 00:00:00",
"2021-01-03 00:00:00",
],
"TIMESTAMP": [
"2021-01-01 00:00:00",
"2021-01-02 00:00:00",
"2021-01-03 00:00:00",
],
"DOUBLE": [1.0, 2.0, 3.0],
"FLOAT": [1.0, 2.0, 3.0],
"INT": [1, 2, 3],
"DECIMAL": [1.1, 2.2, 3.3],
"SMALLINT": [1, 2, 3],
# "TIME": [
# sa.Time("22:17:33.123456"),
# sa.Time("22:17:33.123456"),
# sa.Time("22:17:33.123456"),
# ],
# "UUID": [
# uuid.UUID("905993ea-f50e-4284-bea0-5be3f0ed7031"),
# uuid.UUID("9406b631-fa2f-41cf-b666-f9a2ac3118c1"),
# uuid.UUID("47538f05-32e3-4594-80e2-0b3b33257ae7")
# ],
},
dtype="object",
),
)
def test_success_complete_databricks(
batch_for_datasource: Batch, expectation: gxe.ExpectColumnValuesToBeInTypeList
) -> None:
result = batch_for_datasource.validate(expectation, result_format=ResultFormat.COMPLETE)
result_dict = result.to_json_dict()["result"]

assert result.success
assert isinstance(result_dict, dict)
assert isinstance(result_dict["observed_value"], str)
assert isinstance(expectation.type_list, list)
assert result_dict["observed_value"] in expectation.type_list


if version.parse(sa.__version__) >= version.parse("2.0.0"):
# Note: why not use pytest.skip?
# the import of `sqltypes.Double` is only possible in sqlalchemy >= 2.0.0
# the import is done as part of the instantiation of the test, which includes
# processing the pytest.skip() statement. This way, we skip the instantiation
# of the test entirely.
@pytest.mark.parametrize(
"expectation",
[
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(
column="DOUBLE", type_list=["DOUBLE", "FLOAT"]
),
id="DOUBLE",
)
],
)
@parameterize_batch_for_data_sources(
data_source_configs=[
DatabricksDatasourceTestConfig(
column_types={
"DOUBLE": sqltypes.Double,
}
)
],
data=pd.DataFrame(
{
"DOUBLE": [1.0, 2.0, 3.0],
},
dtype="object",
),
)
def test_success_complete_databricks_double_type_only(
batch_for_datasource: Batch, expectation: gxe.ExpectColumnValuesToBeInTypeList
) -> None:
"""What does this test and why?
Databricks mostly uses SqlA types directly, but the double type is
only available after sqlalchemy 2.0. We therefore split up the test
into 2 parts, with this test being skipped if the SA version is too low.
"""
result = batch_for_datasource.validate(expectation, result_format=ResultFormat.COMPLETE)
result_dict = result.to_json_dict()["result"]

assert result.success
assert isinstance(result_dict, dict)
assert isinstance(result_dict["observed_value"], str)
assert isinstance(expectation.type_list, list)
assert result_dict["observed_value"] in expectation.type_list
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,12 @@ def test_success_for_type__INTEGER(batch_for_datasource: Batch) -> None:
assert result.success


@pytest.mark.xfail
@parameterize_batch_for_data_sources(
data_source_configs=[DatabricksDatasourceTestConfig()],
data=DATA,
)
def test_success_for_type__Integer(batch_for_datasource: Batch) -> None:
expectation = gxe.ExpectColumnValuesToBeOfType(column=INTEGER_COLUMN, type_="Integer")
expectation = gxe.ExpectColumnValuesToBeOfType(column=INTEGER_COLUMN, type_="INT")
result = batch_for_datasource.validate(expectation)
assert result.success

Expand Down

0 comments on commit 9f9794f

Please sign in to comment.