Skip to content

Commit

Permalink
[BUGFIX] Support Spark connect dataframes (#10420)
Browse files Browse the repository at this point in the history
  • Loading branch information
tyler-hoffman authored Sep 19, 2024
1 parent dec5dce commit 5b2a969
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 6 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ jobs:
- postgresql
- snowflake
- spark
- spark_connect
- trino
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
exclude:
Expand Down
7 changes: 7 additions & 0 deletions assets/docker/spark/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,10 @@ services:
ports:
- "9090:8080"
- "7077:7077"

spark-connect:
image: ${ECR_PULL_THROUGH_REPOSITORY_URL}bitnami/spark:3.5.2
ports:
- "15002:15002"
# See https://spark.apache.org/docs/latest/spark-connect-overview.html#download-and-start-spark-server-with-spark-connect
command: ./sbin/start-connect-server.sh --packages org.apache.spark:spark-connect_2.12:3.5.2
5 changes: 5 additions & 0 deletions great_expectations/compatibility/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@
except (ImportError, AttributeError):
Column = SPARK_NOT_IMPORTED # type: ignore[assignment,misc]

try:
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
except (ImportError, AttributeError):
ConnectDataFrame = SPARK_NOT_IMPORTED # type: ignore[assignment,misc]

try:
from pyspark.sql import DataFrame
except (ImportError, AttributeError):
Expand Down
19 changes: 14 additions & 5 deletions great_expectations/datasource/fluent/spark_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pprint import pformat as pf
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Generic,
Expand All @@ -27,7 +28,7 @@
StrictInt,
StrictStr,
)
from great_expectations.compatibility.pyspark import DataFrame, pyspark
from great_expectations.compatibility.pyspark import ConnectDataFrame, DataFrame, pyspark
from great_expectations.compatibility.typing_extensions import override
from great_expectations.core import IDDict
from great_expectations.core.batch import LegacyBatchDefinition
Expand All @@ -47,7 +48,7 @@
from great_expectations.exceptions.exceptions import BuildBatchRequestError

if TYPE_CHECKING:
from typing_extensions import TypeAlias
from typing_extensions import TypeAlias, TypeGuard

from great_expectations.compatibility.pyspark import SparkSession
from great_expectations.core.batch_definition import BatchDefinition
Expand Down Expand Up @@ -231,9 +232,9 @@ def build_batch_request(
if not (options is not None and "dataframe" in options and len(options) == 1):
raise BuildBatchRequestError(message="options must contain exactly 1 key, 'dataframe'.")

if not isinstance(options["dataframe"], DataFrame):
if not self.is_spark_data_frame(options["dataframe"]):
raise BuildBatchRequestError(
message="Can not build batch request for dataframe asset " "without a dataframe."
message="Cannot build batch request without a Spark DataFrame."
)

return BatchRequest(
Expand All @@ -255,7 +256,7 @@ def _validate_batch_request(self, batch_request: BatchRequest) -> None:
and batch_request.options
and len(batch_request.options) == 1
and "dataframe" in batch_request.options
and isinstance(batch_request.options["dataframe"], DataFrame)
and self.is_spark_data_frame(batch_request.options["dataframe"])
):
expect_batch_request_form = BatchRequest[None](
datasource_name=self.datasource.name,
Expand Down Expand Up @@ -314,6 +315,14 @@ def add_batch_definition_whole_dataframe(self, name: str) -> BatchDefinition:
partitioner=None,
)

@staticmethod
def is_spark_data_frame(df: Any) -> TypeGuard[Union[DataFrame, ConnectDataFrame]]:
"""Check that a given object is a Spark DataFrame.
This could either be a regular Spark DataFrame or a Spark Connect DataFrame.
"""
data_frame_types = [DataFrame, ConnectDataFrame]
return any((cls and isinstance(df, cls)) for cls in data_frame_types)


@public_api
class SparkDatasource(_SparkDatasource):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ markers = [
"sqlite: mark test requiring sqlite",
"slow: mark tests taking longer than 1 second.",
"spark: mark a test as Spark-dependent.",
"spark_connect: mark a test as Spark Connect-dependent.",
"trino: mark a test as trino-dependent.",
"unit: mark a test as a unit test.",
"v2_api: mark test as specific to the v2 api (e.g. pre Data Connectors).",
Expand Down
8 changes: 8 additions & 0 deletions reqs/requirements-dev-spark-connect.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# NOTES:
# Spark connect's requirements are here: https://github.com/apache/spark/blob/ed3a9b1aa92957015592b399167a960b68b73beb/dev/requirements.txt#L60
# grpcio and grpcio-status should be bumped up to match that, but that conflicts with our constraints.txt file.
# TODO: Fix in V1-532

googleapis-common-protos>=1.56.4
grpcio>=1.48.1
grpcio-status>=1.48.1
8 changes: 8 additions & 0 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,14 @@ class TestDependencies(NamedTuple):
services=("spark",),
extra_pytest_args=("--spark",),
),
"spark_connect": TestDependencies(
requirement_files=(
"reqs/requirements-dev-spark.txt",
"reqs/requirements-dev-spark-connect.txt",
),
services=("spark",),
extra_pytest_args=("--spark_connect",),
),
"trino": TestDependencies(
("reqs/requirements-dev-trino.txt",),
services=("trino",),
Expand Down
22 changes: 21 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import great_expectations as gx
from great_expectations.analytics.config import ENV_CONFIG
from great_expectations.compatibility import pyspark
from great_expectations.compatibility.sqlalchemy_compatibility_wrappers import (
add_dataframe_to_db,
)
Expand Down Expand Up @@ -95,7 +96,6 @@

from pytest_mock import MockerFixture

from great_expectations.compatibility import pyspark
from great_expectations.compatibility.sqlalchemy import Engine

yaml = YAMLHandler()
Expand Down Expand Up @@ -130,6 +130,7 @@
"pyarrow",
"snowflake",
"spark",
"spark_connect",
"sqlite",
"trino",
"unit",
Expand Down Expand Up @@ -196,6 +197,11 @@ def pytest_addoption(parser):
action="store_true",
help="If set, execute tests against the spark test suite",
)
parser.addoption(
"--spark_connect",
action="store_true",
help="If set, execute tests against the spark-connect test suite",
)
parser.addoption(
"--no-sqlalchemy",
action="store_true",
Expand Down Expand Up @@ -492,6 +498,20 @@ def spark_session(test_backends) -> pyspark.SparkSession:
raise ValueError("spark tests are requested, but pyspark is not installed")


@pytest.fixture
def spark_connect_session(test_backends):
from great_expectations.compatibility import pyspark

if pyspark.SparkConnectSession: # type: ignore[truthy-function]
spark_connect_session = pyspark.SparkSession.builder.remote(
"sc://localhost:15002"
).getOrCreate()
assert isinstance(spark_connect_session, pyspark.SparkConnectSession)
return spark_connect_session

raise ValueError("spark tests are requested, but pyspark is not installed")


@pytest.fixture
def basic_spark_df_execution_engine(spark_session):
from great_expectations.execution_engine import SparkDFExecutionEngine
Expand Down
86 changes: 86 additions & 0 deletions tests/integration/spark/test_spark_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import logging
from typing import Any

import pytest

import great_expectations as gx
from great_expectations.compatibility.pyspark import ConnectDataFrame, Row, SparkConnectSession
from great_expectations.core.validation_definition import ValidationDefinition
from great_expectations.data_context.data_context.abstract_data_context import AbstractDataContext
from great_expectations.exceptions.exceptions import BuildBatchRequestError

logger = logging.getLogger(__name__)


pytestmark = pytest.mark.spark_connect

DATAFRAME_VALUES = [1, 2, 3]


@pytest.fixture
def spark_validation_definition(
ephemeral_context_with_defaults: AbstractDataContext,
) -> ValidationDefinition:
context = ephemeral_context_with_defaults
bd = (
context.data_sources.add_spark(name="spark-connect-ds")
.add_dataframe_asset(name="spark-connect-asset")
.add_batch_definition_whole_dataframe(name="spark-connect-bd")
)
suite = context.suites.add(
gx.ExpectationSuite(
name="spark-connect-suite",
expectations=[
gx.expectations.ExpectColumnValuesToBeInSet(
column="column", value_set=DATAFRAME_VALUES
),
],
)
)
return context.validation_definitions.add(
gx.ValidationDefinition(name="spark-connect-vd", suite=suite, data=bd)
)


def test_spark_connect(
spark_connect_session: SparkConnectSession,
spark_validation_definition: ValidationDefinition,
):
df = spark_connect_session.createDataFrame(
[Row(column=x) for x in DATAFRAME_VALUES],
)
assert isinstance(df, ConnectDataFrame)

results = spark_validation_definition.run(batch_parameters={"dataframe": df})

assert results.success


@pytest.mark.parametrize("not_a_dataframe", [None, 1, "string", 1.0, True])
def test_error_messages_if_we_get_an_invalid_dataframe(
not_a_dataframe: Any,
spark_validation_definition: ValidationDefinition,
):
with pytest.raises(
BuildBatchRequestError, match="Cannot build batch request without a Spark DataFrame."
):
spark_validation_definition.run(batch_parameters={"dataframe": not_a_dataframe})


def test_spark_connect_with_spark_connect_session_factory_method(
spark_validation_definition: ValidationDefinition,
):
"""This test demonstrates that SparkConnectionSession can be used to create a session.
This test is being added because in some scenarios, this appeared to fail, but it was
the result of other active spark sessions.
"""
spark_connect_session = SparkConnectSession.builder.remote("sc://localhost:15002").getOrCreate()
assert isinstance(spark_connect_session, SparkConnectSession)
df = spark_connect_session.createDataFrame(
[Row(column=x) for x in DATAFRAME_VALUES],
)

results = spark_validation_definition.run(batch_parameters={"dataframe": df})

assert results.success

0 comments on commit 5b2a969

Please sign in to comment.