diff --git a/great_expectations/execution_engine/sqlalchemy_batch_data.py b/great_expectations/execution_engine/sqlalchemy_batch_data.py index 4362e1b232b4..3d176ce4fd5f 100644 --- a/great_expectations/execution_engine/sqlalchemy_batch_data.py +++ b/great_expectations/execution_engine/sqlalchemy_batch_data.py @@ -1,9 +1,10 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional, Tuple, overload +from typing import Literal, Optional, Tuple, overload from great_expectations.compatibility import sqlalchemy +from great_expectations.compatibility.sqlalchemy import Selectable from great_expectations.compatibility.sqlalchemy import ( sqlalchemy as sa, ) @@ -11,9 +12,6 @@ from great_expectations.execution_engine.sqlalchemy_dialect import GXSqlDialect from great_expectations.util import generate_temporary_table_name -if TYPE_CHECKING: - from great_expectations.compatibility.sqlalchemy import Selectable - logger = logging.getLogger(__name__) @@ -96,7 +94,7 @@ def __init__( # noqa: PLR0913 source_table_name: Optional[str] = None, ) -> None: """A Constructor used to initialize and SqlAlchemy Batch, create an id for it, and verify that all necessary - parameters have been provided. If a Query is given, also builds a temporary table for this query + parameters have been provided. Builds a temporary table for the `query` if `create_temp_table=True`. Args: engine (SqlAlchemy Engine): \ @@ -174,7 +172,7 @@ def __init__( # noqa: PLR0913 ) ) elif query: - self._selectable = self._generate_selectable_from_query( + self._selectable = self._generate_selectable_from_query( # type: ignore[call-overload] # https://github.com/python/mypy/issues/14764 query, dialect, create_temp_table, temp_table_schema_name ) else: @@ -208,7 +206,10 @@ def use_quoted_name(self): return self._use_quoted_name def _create_temporary_table( # noqa: C901, PLR0912, PLR0915 - self, dialect, query, temp_table_schema_name=None + self, + dialect: GXSqlDialect, + query: str, + temp_table_schema_name: str | None = None, ) -> Tuple[str, str]: """ Create Temporary table based on sql query. This will be used as a basis for executing expectations. @@ -342,23 +343,43 @@ def _generate_selectable_from_schema_name_and_table_name( schema=schema_name, ) + @overload + def _generate_selectable_from_query( + self, + query: str, + dialect: GXSqlDialect, + create_temp_table: Literal[True], + temp_table_schema_name: Optional[str] = ..., + ) -> sqlalchemy.Table: + ... + + @overload + def _generate_selectable_from_query( + self, + query: str, + dialect: GXSqlDialect, + create_temp_table: Literal[False], + temp_table_schema_name: Optional[str] = ..., + ) -> sqlalchemy.TextClause: + ... + def _generate_selectable_from_query( self, query: str, dialect: GXSqlDialect, create_temp_table: bool, temp_table_schema_name: Optional[str] = None, - ) -> sqlalchemy.Table: + ) -> sqlalchemy.Table | sqlalchemy.TextClause: """Helper method to generate Selectable from query string. Args: query (str): query passed in as RuntimeBatchRequest. dialect (GXSqlDialect): Needed for _create_temporary_table, since different backends name temp_tables differently. - create_temp_table (bool): Should we create a temp_table? + create_temp_table (bool): Should we create a temp_table? If not a `TextClause` will be returned instead of a Table. temp_table_schema_name (Optional[str], optional): Optional string for temp_table schema. Defaults to None. Returns: - sqlalchemy.Table: SqlAlchemy Table that is Selectable. + sqlalchemy.Table: SqlAlchemy Table that is Selectable or a TextClause. """ if not create_temp_table: return sa.text(query) diff --git a/great_expectations/execution_engine/sqlalchemy_execution_engine.py b/great_expectations/execution_engine/sqlalchemy_execution_engine.py index 938f8ffa4313..77d1e46d541c 100644 --- a/great_expectations/execution_engine/sqlalchemy_execution_engine.py +++ b/great_expectations/execution_engine/sqlalchemy_execution_engine.py @@ -1248,7 +1248,7 @@ def get_data_for_batch_identifiers( def _build_selectable_from_batch_spec( self, batch_spec: BatchSpec - ) -> Union[sqlalchemy.Selectable, str]: + ) -> sqlalchemy.Selectable: if ( batch_spec.get("query") is not None and batch_spec.get("sampling_method") is not None @@ -1315,14 +1315,16 @@ def _subselectable(self, batch_spec: BatchSpec) -> sqlalchemy.Selectable: if not isinstance(query, str): raise ValueError(f"SQL query should be a str but got {query}") # Query is a valid SELECT query that begins with r"\w+select\w" - selectable = sa.select(sa.text(query.lstrip()[6:].lstrip())).subquery() + selectable = sa.select( + sa.text(query.lstrip()[6:].strip().rstrip(";").rstrip()) + ).subquery() return selectable @override def get_batch_data_and_markers( self, batch_spec: BatchSpec - ) -> Tuple[Any, BatchMarkers]: + ) -> Tuple[SqlAlchemyBatchData, BatchMarkers]: if not isinstance( batch_spec, (SqlAlchemyDatasourceBatchSpec, RuntimeQueryBatchSpec) ): @@ -1360,22 +1362,27 @@ def get_batch_data_and_markers( create_temp_table: bool = batch_spec.get( "create_temp_table", self._create_temp_table ) + # this is where splitter components are added to the selectable + selectable: sqlalchemy.Selectable = self._build_selectable_from_batch_spec( + batch_spec=batch_spec + ) # NOTE: what's being checked here is the presence of a `query` attribute, we could check this directly # instead of doing an instance check if isinstance(batch_spec, RuntimeQueryBatchSpec): # query != None is already checked when RuntimeQueryBatchSpec is instantiated - query: str = batch_spec.query - + # re-compile the query to include any new parameters + compiled_query = selectable.compile( + dialect=self.engine.dialect, + compile_kwargs={"literal_binds": True}, + ) + query_str = str(compiled_query) batch_data = SqlAlchemyBatchData( execution_engine=self, - query=query, + query=query_str, temp_table_schema_name=temp_table_schema_name, create_temp_table=create_temp_table, ) elif isinstance(batch_spec, SqlAlchemyDatasourceBatchSpec): - selectable: Union[ - sqlalchemy.Selectable, str - ] = self._build_selectable_from_batch_spec(batch_spec=batch_spec) batch_data = SqlAlchemyBatchData( execution_engine=self, selectable=selectable, diff --git a/tests/datasource/fluent/integration/conftest.py b/tests/datasource/fluent/integration/conftest.py index 76d2c9b5e83b..a03ff55bb70b 100644 --- a/tests/datasource/fluent/integration/conftest.py +++ b/tests/datasource/fluent/integration/conftest.py @@ -106,8 +106,7 @@ def pandas_data( def sqlite_datasource( - context: AbstractDataContext, - db_filename: str, + context: AbstractDataContext, db_filename: str | pathlib.Path ) -> SqliteDatasource: relative_path = pathlib.Path( "..", diff --git a/tests/datasource/fluent/integration/test_integration_datasource.py b/tests/datasource/fluent/integration/test_integration_datasource.py index a4c249702fd9..65e957f19b3d 100644 --- a/tests/datasource/fluent/integration/test_integration_datasource.py +++ b/tests/datasource/fluent/integration/test_integration_datasource.py @@ -110,27 +110,47 @@ def test_batch_head( @pytest.mark.sqlite -def test_sql_query_data_asset(empty_data_context): - context = empty_data_context - datasource = sqlite_datasource(context, "yellow_tripdata.db") - passenger_count_value = 5 - asset = ( - datasource.add_query_asset( - name="query_asset", - query=f" SELECT * from yellow_tripdata_sample_2019_02 WHERE passenger_count = {passenger_count_value}", +class TestQueryAssets: + def test_success_with_splitters(self, empty_data_context): + context = empty_data_context + datasource = sqlite_datasource(context, "yellow_tripdata.db") + passenger_count_value = 5 + asset = ( + datasource.add_query_asset( + name="query_asset", + query=f" SELECT * from yellow_tripdata_sample_2019_02 WHERE passenger_count = {passenger_count_value}", + ) + .add_splitter_year_and_month(column_name="pickup_datetime") + .add_sorters(["year"]) ) - .add_splitter_year_and_month(column_name="pickup_datetime") - .add_sorters(["year"]) - ) - validator = context.get_validator( - batch_request=asset.build_batch_request({"year": 2019}) - ) - result = validator.expect_column_distinct_values_to_equal_set( - column="passenger_count", - value_set=[passenger_count_value], - result_format={"result_format": "BOOLEAN_ONLY"}, - ) - assert result.success + validator = context.get_validator( + batch_request=asset.build_batch_request({"year": 2019}) + ) + result = validator.expect_column_distinct_values_to_equal_set( + column="passenger_count", + value_set=[passenger_count_value], + result_format={"result_format": "BOOLEAN_ONLY"}, + ) + assert result.success + + def test_splitter_filtering(self, empty_data_context): + context = empty_data_context + datasource = sqlite_datasource( + context, "../../test_cases_for_sql_data_connector.db" + ) + + asset = datasource.add_query_asset( + name="trip_asset_split_by_event_type", + query="SELECT * FROM table_partitioned_by_date_column__A", + ).add_splitter_column_value("event_type") + batch_request = asset.build_batch_request({"event_type": "start"}) + validator = context.get_validator(batch_request=batch_request) + + # All rows returned by head have the start event_type. + result = validator.execution_engine.batch_manager.active_batch.head(n_rows=50) + unique_event_types = set(result.data["event_type"].unique()) + print(f"{unique_event_types=}") + assert unique_event_types == {"start"} @pytest.mark.filesystem