Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUGFIX] 0.18.x - Apply QueryAsset splitting fix #9160

Merged
merged 1 commit into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions great_expectations/execution_engine/sqlalchemy_batch_data.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Optional, Tuple, overload
from typing import Literal, Optional, Tuple, overload

from great_expectations.compatibility import sqlalchemy
from great_expectations.compatibility.sqlalchemy import Selectable
from great_expectations.compatibility.sqlalchemy import (
sqlalchemy as sa,
)
from great_expectations.core.batch import BatchData
from great_expectations.execution_engine.sqlalchemy_dialect import GXSqlDialect
from great_expectations.util import generate_temporary_table_name

if TYPE_CHECKING:
from great_expectations.compatibility.sqlalchemy import Selectable

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -96,7 +94,7 @@ def __init__( # noqa: PLR0913
source_table_name: Optional[str] = None,
) -> None:
"""A Constructor used to initialize and SqlAlchemy Batch, create an id for it, and verify that all necessary
parameters have been provided. If a Query is given, also builds a temporary table for this query
parameters have been provided. Builds a temporary table for the `query` if `create_temp_table=True`.

Args:
engine (SqlAlchemy Engine): \
Expand Down Expand Up @@ -174,7 +172,7 @@ def __init__( # noqa: PLR0913
)
)
elif query:
self._selectable = self._generate_selectable_from_query(
self._selectable = self._generate_selectable_from_query( # type: ignore[call-overload] # https://github.com/python/mypy/issues/14764
query, dialect, create_temp_table, temp_table_schema_name
)
else:
Expand Down Expand Up @@ -208,7 +206,10 @@ def use_quoted_name(self):
return self._use_quoted_name

def _create_temporary_table( # noqa: C901, PLR0912, PLR0915
self, dialect, query, temp_table_schema_name=None
self,
dialect: GXSqlDialect,
query: str,
temp_table_schema_name: str | None = None,
) -> Tuple[str, str]:
"""
Create Temporary table based on sql query. This will be used as a basis for executing expectations.
Expand Down Expand Up @@ -342,23 +343,43 @@ def _generate_selectable_from_schema_name_and_table_name(
schema=schema_name,
)

@overload
def _generate_selectable_from_query(
self,
query: str,
dialect: GXSqlDialect,
create_temp_table: Literal[True],
temp_table_schema_name: Optional[str] = ...,
) -> sqlalchemy.Table:
...

@overload
def _generate_selectable_from_query(
self,
query: str,
dialect: GXSqlDialect,
create_temp_table: Literal[False],
temp_table_schema_name: Optional[str] = ...,
) -> sqlalchemy.TextClause:
...

def _generate_selectable_from_query(
self,
query: str,
dialect: GXSqlDialect,
create_temp_table: bool,
temp_table_schema_name: Optional[str] = None,
) -> sqlalchemy.Table:
) -> sqlalchemy.Table | sqlalchemy.TextClause:
"""Helper method to generate Selectable from query string.

Args:
query (str): query passed in as RuntimeBatchRequest.
dialect (GXSqlDialect): Needed for _create_temporary_table, since different backends name temp_tables differently.
create_temp_table (bool): Should we create a temp_table?
create_temp_table (bool): Should we create a temp_table? If not a `TextClause` will be returned instead of a Table.
temp_table_schema_name (Optional[str], optional): Optional string for temp_table schema. Defaults to None.

Returns:
sqlalchemy.Table: SqlAlchemy Table that is Selectable.
sqlalchemy.Table: SqlAlchemy Table that is Selectable or a TextClause.
"""
if not create_temp_table:
return sa.text(query)
Expand Down
25 changes: 16 additions & 9 deletions great_expectations/execution_engine/sqlalchemy_execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,7 +1248,7 @@ def get_data_for_batch_identifiers(

def _build_selectable_from_batch_spec(
self, batch_spec: BatchSpec
) -> Union[sqlalchemy.Selectable, str]:
) -> sqlalchemy.Selectable:
if (
batch_spec.get("query") is not None
and batch_spec.get("sampling_method") is not None
Expand Down Expand Up @@ -1315,14 +1315,16 @@ def _subselectable(self, batch_spec: BatchSpec) -> sqlalchemy.Selectable:
if not isinstance(query, str):
raise ValueError(f"SQL query should be a str but got {query}")
# Query is a valid SELECT query that begins with r"\w+select\w"
selectable = sa.select(sa.text(query.lstrip()[6:].lstrip())).subquery()
selectable = sa.select(
sa.text(query.lstrip()[6:].strip().rstrip(";").rstrip())
).subquery()

return selectable

@override
def get_batch_data_and_markers(
self, batch_spec: BatchSpec
) -> Tuple[Any, BatchMarkers]:
) -> Tuple[SqlAlchemyBatchData, BatchMarkers]:
if not isinstance(
batch_spec, (SqlAlchemyDatasourceBatchSpec, RuntimeQueryBatchSpec)
):
Expand Down Expand Up @@ -1360,22 +1362,27 @@ def get_batch_data_and_markers(
create_temp_table: bool = batch_spec.get(
"create_temp_table", self._create_temp_table
)
# this is where splitter components are added to the selectable
selectable: sqlalchemy.Selectable = self._build_selectable_from_batch_spec(
batch_spec=batch_spec
)
# NOTE: what's being checked here is the presence of a `query` attribute, we could check this directly
# instead of doing an instance check
if isinstance(batch_spec, RuntimeQueryBatchSpec):
# query != None is already checked when RuntimeQueryBatchSpec is instantiated
query: str = batch_spec.query

# re-compile the query to include any new parameters
compiled_query = selectable.compile(
dialect=self.engine.dialect,
compile_kwargs={"literal_binds": True},
)
query_str = str(compiled_query)
batch_data = SqlAlchemyBatchData(
execution_engine=self,
query=query,
query=query_str,
temp_table_schema_name=temp_table_schema_name,
create_temp_table=create_temp_table,
)
elif isinstance(batch_spec, SqlAlchemyDatasourceBatchSpec):
selectable: Union[
sqlalchemy.Selectable, str
] = self._build_selectable_from_batch_spec(batch_spec=batch_spec)
batch_data = SqlAlchemyBatchData(
execution_engine=self,
selectable=selectable,
Expand Down
3 changes: 1 addition & 2 deletions tests/datasource/fluent/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ def pandas_data(


def sqlite_datasource(
context: AbstractDataContext,
db_filename: str,
context: AbstractDataContext, db_filename: str | pathlib.Path
) -> SqliteDatasource:
relative_path = pathlib.Path(
"..",
Expand Down
60 changes: 40 additions & 20 deletions tests/datasource/fluent/integration/test_integration_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,27 +110,47 @@ def test_batch_head(


@pytest.mark.sqlite
def test_sql_query_data_asset(empty_data_context):
context = empty_data_context
datasource = sqlite_datasource(context, "yellow_tripdata.db")
passenger_count_value = 5
asset = (
datasource.add_query_asset(
name="query_asset",
query=f" SELECT * from yellow_tripdata_sample_2019_02 WHERE passenger_count = {passenger_count_value}",
class TestQueryAssets:
def test_success_with_splitters(self, empty_data_context):
context = empty_data_context
datasource = sqlite_datasource(context, "yellow_tripdata.db")
passenger_count_value = 5
asset = (
datasource.add_query_asset(
name="query_asset",
query=f" SELECT * from yellow_tripdata_sample_2019_02 WHERE passenger_count = {passenger_count_value}",
)
.add_splitter_year_and_month(column_name="pickup_datetime")
.add_sorters(["year"])
)
.add_splitter_year_and_month(column_name="pickup_datetime")
.add_sorters(["year"])
)
validator = context.get_validator(
batch_request=asset.build_batch_request({"year": 2019})
)
result = validator.expect_column_distinct_values_to_equal_set(
column="passenger_count",
value_set=[passenger_count_value],
result_format={"result_format": "BOOLEAN_ONLY"},
)
assert result.success
validator = context.get_validator(
batch_request=asset.build_batch_request({"year": 2019})
)
result = validator.expect_column_distinct_values_to_equal_set(
column="passenger_count",
value_set=[passenger_count_value],
result_format={"result_format": "BOOLEAN_ONLY"},
)
assert result.success

def test_splitter_filtering(self, empty_data_context):
context = empty_data_context
datasource = sqlite_datasource(
context, "../../test_cases_for_sql_data_connector.db"
)

asset = datasource.add_query_asset(
name="trip_asset_split_by_event_type",
query="SELECT * FROM table_partitioned_by_date_column__A",
).add_splitter_column_value("event_type")
batch_request = asset.build_batch_request({"event_type": "start"})
validator = context.get_validator(batch_request=batch_request)

# All rows returned by head have the start event_type.
result = validator.execution_engine.batch_manager.active_batch.head(n_rows=50)
unique_event_types = set(result.data["event_type"].unique())
print(f"{unique_event_types=}")
assert unique_event_types == {"start"}


@pytest.mark.filesystem
Expand Down
Loading