From 158aa20bacf52f49d294a1ad904b10e531d83df5 Mon Sep 17 00:00:00 2001 From: Tyler Hoffman Date: Wed, 11 Sep 2024 14:49:39 -0400 Subject: [PATCH] Remove a return type :( --- tests/conftest.py | 2 +- tests/integration/spark/test_spark_connect.py | 44 +++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 tests/integration/spark/test_spark_connect.py diff --git a/tests/conftest.py b/tests/conftest.py index 431bc0cfb708..2fd93b1995a3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -493,7 +493,7 @@ def spark_session(test_backends) -> pyspark.SparkSession: @pytest.fixture -def spark_connect_session(test_backends) -> pyspark.SparkConnectSession: +def spark_connect_session(test_backends): from great_expectations.compatibility import pyspark if pyspark.SparkConnectSession: # type: ignore[truthy-function] diff --git a/tests/integration/spark/test_spark_connect.py b/tests/integration/spark/test_spark_connect.py new file mode 100644 index 000000000000..9ddb01461150 --- /dev/null +++ b/tests/integration/spark/test_spark_connect.py @@ -0,0 +1,44 @@ +import logging + +import pytest +from pyspark.sql import Row + +import great_expectations as gx + +logger = logging.getLogger(__name__) + + +pytestmark = pytest.mark.spark + + +def test_spark_connect(spark_connect_session, ephemeral_context_with_defaults): + context = ephemeral_context_with_defaults + df = spark_connect_session.createDataFrame( + [ + Row(column=1), + Row(column=2), + Row(column=5), + ] + ) + + bd = ( + context.data_sources.add_spark(name="spark-connect-ds") + .add_dataframe_asset(name="spark-connect-asset") + .add_batch_definition_whole_dataframe(name="spark-connect-bd") + ) + suite = context.suites.add( + gx.ExpectationSuite( + name="spark-connect-suite", + expectations=[ + gx.expectations.ExpectColumnValuesToBeInSet(column="column", value_set=[1, 2, 5]), + ], + ) + ) + + vd = context.validation_definitions.add( + gx.ValidationDefinition(name="spark-connect-vd", suite=suite, data=bd) + ) + + results = vd.run(batch_parameters={"dataframe": df}) + + assert results.success