Skip to content

Commit

Permalink
[BUGFIX] Databricks Fix Type Translation - `ExpectColumnValuesToBeI…
Browse files Browse the repository at this point in the history
…nTypeList` and `ExpectColumnValuesToBeInType` (#10791)
  • Loading branch information
Shinnnyshinshin authored Dec 18, 2024
1 parent 47cc505 commit 5e11e7c
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def _validate_pandas( # noqa: C901, PLR0912
def _validate_sqlalchemy(self, actual_column_type, expected_types_list, execution_engine):
if expected_types_list is None:
success = True
elif execution_engine.dialect_name == GXSqlDialect.SNOWFLAKE:
elif execution_engine.dialect_name in [GXSqlDialect.SNOWFLAKE, GXSqlDialect.DATABRICKS]:
success = isinstance(actual_column_type, str) and any(
actual_column_type.lower() == expected_type.lower()
for expected_type in expected_types_list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def _validate_sqlalchemy(self, actual_column_type, expected_type, execution_engi

if expected_type is None:
success = True
elif execution_engine.dialect_name == GXSqlDialect.SNOWFLAKE:
elif execution_engine.dialect_name in [GXSqlDialect.SNOWFLAKE, GXSqlDialect.DATABRICKS]:
success = (
isinstance(actual_column_type, str)
and actual_column_type.lower() == expected_type.lower()
Expand Down
15 changes: 9 additions & 6 deletions great_expectations/expectations/metrics/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,16 +414,19 @@ def get_sqlalchemy_column_metadata( # noqa: C901
)

dialect_name = execution_engine.dialect.name
if dialect_name == GXSqlDialect.SNOWFLAKE:
if dialect_name in [GXSqlDialect.SNOWFLAKE, GXSqlDialect.DATABRICKS]:
# WARNING: Do not alter columns in place, as they are cached on the inspector
columns_copy = [column.copy() for column in columns]
for column in columns_copy:
column["type"] = column["type"].compile(dialect=execution_engine.dialect)
return [
# TODO: SmartColumn should know the dialect and do lookups based on that
CaseInsensitiveNameDict(column)
for column in columns_copy
]
if dialect_name == GXSqlDialect.SNOWFLAKE:
return [
# TODO: SmartColumn should know the dialect and do lookups based on that
CaseInsensitiveNameDict(column)
for column in columns_copy
]
else:
return columns_copy

return columns
except AttributeError as e:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import pandas as pd
import pytest
import sqlalchemy.types as sqltypes
from packaging import version

import great_expectations.expectations as gxe
from great_expectations.compatibility.databricks import DATABRICKS_TYPES
from great_expectations.compatibility.snowflake import SNOWFLAKE_TYPES
from great_expectations.compatibility.sqlalchemy import (
sqlalchemy as sa,
)
from great_expectations.core.result_format import ResultFormat
from great_expectations.datasource.fluent.interfaces import Batch
from tests.integration.conftest import parameterize_batch_for_data_sources
Expand Down Expand Up @@ -379,3 +384,210 @@ def test_success_complete_snowflake(
assert isinstance(result_dict["observed_value"], str)
assert isinstance(expectation.type_list, list)
assert result_dict["observed_value"] in expectation.type_list


@pytest.mark.parametrize(
"expectation",
[
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(column="STRING", type_list=["STRING"]),
id="STRING",
),
# SqlA Text gets converted to Databricks STRING
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(column="TEXT", type_list=["STRING"]),
id="TEXT",
),
# SqlA UNICODE gets converted to Databricks STRING
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(column="UNICODE", type_list=["STRING"]),
id="UNICODE",
),
# SqlA UNICODE_TEXT gets converted to Databricks STRING
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(column="UNICODE_TEXT", type_list=["STRING"]),
id="UNICODE_TEXT",
),
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(column="BOOLEAN", type_list=["BOOLEAN"]),
id="BOOLEAN",
),
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(
column="DECIMAL", type_list=["DECIMAL", "DECIMAL(10, 0)"]
),
id="DECIMAL",
),
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(column="DATE", type_list=["DATE"]),
id="DATE",
),
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(column="TIMESTAMP", type_list=["TIMESTAMP"]),
id="TIMESTAMP",
),
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(
column="TIMESTAMP_NTZ", type_list=["TIMESTAMP_NTZ"]
),
id="TIMESTAMP_NTZ",
),
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(column="FLOAT", type_list=["FLOAT"]),
id="FLOAT",
),
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(column="INT", type_list=["INT"]),
id="INT",
),
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(column="TINYINT", type_list=["TINYINT"]),
id="TINYINT",
),
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(
column="DECIMAL", type_list=["DECIMAL", "DECIMAL(10, 0)"]
),
id="DECIMAL",
),
# SqlA Time gets converted to Databricks STRING,
# but is not supported by our testing framework
# pytest.param(
# gxe.ExpectColumnValuesToBeInTypeList(column="TIME", type_list=["STRING"]),
# id="TIME",
# ),
# SqlA UUID gets converted to Databricks STRING,
# but is not supported by our testing framework.
# pytest.param(
# gxe.ExpectColumnValuesToBeInTypeList(column="UUID", type_list=["STRING"]),
# id="UUID",
# )
],
)
@parameterize_batch_for_data_sources(
data_source_configs=[
DatabricksDatasourceTestConfig(
column_types={
"STRING": DATABRICKS_TYPES.STRING,
"TEXT": sqltypes.Text,
"UNICODE": sqltypes.Unicode,
"UNICODE_TEXT": sqltypes.UnicodeText,
"BIGINT": sqltypes.BigInteger,
"BOOLEAN": sqltypes.BOOLEAN,
"DATE": sqltypes.DATE,
"TIMESTAMP_NTZ": DATABRICKS_TYPES.TIMESTAMP_NTZ,
"TIMESTAMP": DATABRICKS_TYPES.TIMESTAMP,
"FLOAT": sqltypes.Float,
"INT": sqltypes.Integer,
"DECIMAL": sqltypes.Numeric,
"SMALLINT": sqltypes.SmallInteger,
"TINYINT": DATABRICKS_TYPES.TINYINT,
# "TIME": sqltypes.Time,
# "UUID": sqltypes.UUID,
}
)
],
data=pd.DataFrame(
{
"STRING": ["a", "b", "c"],
"TEXT": ["a", "b", "c"],
"UNICODE": ["\u00e9", "\u00e9", "\u00e9"],
"UNICODE_TEXT": ["a", "b", "c"],
"BIGINT": [1111, 2222, 3333],
"BOOLEAN": [True, True, False],
"DATE": [
"2021-01-01",
"2021-01-02",
"2021-01-03",
],
"TIMESTAMP_NTZ": [
"2021-01-01 00:00:00",
"2021-01-02 00:00:00",
"2021-01-03 00:00:00",
],
"TIMESTAMP": [
"2021-01-01 00:00:00",
"2021-01-02 00:00:00",
"2021-01-03 00:00:00",
],
"DOUBLE": [1.0, 2.0, 3.0],
"FLOAT": [1.0, 2.0, 3.0],
"INT": [1, 2, 3],
"DECIMAL": [1.1, 2.2, 3.3],
"SMALLINT": [1, 2, 3],
# "TIME": [
# sa.Time("22:17:33.123456"),
# sa.Time("22:17:33.123456"),
# sa.Time("22:17:33.123456"),
# ],
# "UUID": [
# uuid.UUID("905993ea-f50e-4284-bea0-5be3f0ed7031"),
# uuid.UUID("9406b631-fa2f-41cf-b666-f9a2ac3118c1"),
# uuid.UUID("47538f05-32e3-4594-80e2-0b3b33257ae7")
# ],
},
dtype="object",
),
)
def test_success_complete_databricks(
batch_for_datasource: Batch, expectation: gxe.ExpectColumnValuesToBeInTypeList
) -> None:
result = batch_for_datasource.validate(expectation, result_format=ResultFormat.COMPLETE)
result_dict = result.to_json_dict()["result"]

assert result.success
assert isinstance(result_dict, dict)
assert isinstance(result_dict["observed_value"], str)
assert isinstance(expectation.type_list, list)
assert result_dict["observed_value"] in expectation.type_list


if version.parse(sa.__version__) >= version.parse("2.0.0"):
# Note: why not use pytest.skip?
# the import of `sqltypes.Double` is only possible in sqlalchemy >= 2.0.0
# the import is done as part of the instantiation of the test, which includes
# processing the pytest.skip() statement. This way, we skip the instantiation
# of the test entirely.
@pytest.mark.parametrize(
"expectation",
[
pytest.param(
gxe.ExpectColumnValuesToBeInTypeList(
column="DOUBLE", type_list=["DOUBLE", "FLOAT"]
),
id="DOUBLE",
)
],
)
@parameterize_batch_for_data_sources(
data_source_configs=[
DatabricksDatasourceTestConfig(
column_types={
"DOUBLE": sqltypes.Double,
}
)
],
data=pd.DataFrame(
{
"DOUBLE": [1.0, 2.0, 3.0],
},
dtype="object",
),
)
def test_success_complete_databricks_double_type_only(
batch_for_datasource: Batch, expectation: gxe.ExpectColumnValuesToBeInTypeList
) -> None:
"""What does this test and why?
Databricks mostly uses SqlA types directly, but the double type is
only available after sqlalchemy 2.0. We therefore split up the test
into 2 parts, with this test being skipped if the SA version is too low.
"""
result = batch_for_datasource.validate(expectation, result_format=ResultFormat.COMPLETE)
result_dict = result.to_json_dict()["result"]

assert result.success
assert isinstance(result_dict, dict)
assert isinstance(result_dict["observed_value"], str)
assert isinstance(expectation.type_list, list)
assert result_dict["observed_value"] in expectation.type_list
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,12 @@ def test_success_for_type__INTEGER(batch_for_datasource: Batch) -> None:
assert result.success


@pytest.mark.xfail
@parameterize_batch_for_data_sources(
data_source_configs=[DatabricksDatasourceTestConfig()],
data=DATA,
)
def test_success_for_type__Integer(batch_for_datasource: Batch) -> None:
expectation = gxe.ExpectColumnValuesToBeOfType(column=INTEGER_COLUMN, type_="Integer")
expectation = gxe.ExpectColumnValuesToBeOfType(column=INTEGER_COLUMN, type_="INT")
result = batch_for_datasource.validate(expectation)
assert result.success

Expand Down

0 comments on commit 5e11e7c

Please sign in to comment.