Skip to content

Commit

Permalink
[BUGFIX] validator head query limit (#9036)
Browse files Browse the repository at this point in the history
  • Loading branch information
dctalbot committed Dec 8, 2023
1 parent a0d74c1 commit 33d09f9
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 204 deletions.
45 changes: 0 additions & 45 deletions great_expectations/compatibility/sqlalchemy_and_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,48 +82,3 @@ def pandas_read_sql(sql, con, **kwargs) -> pd.DataFrame | Iterator[pd.DataFrame]
sql = sa.select(sa.text("*")).select_from(sql)
return_value = pd.read_sql(sql=sql, con=con, **kwargs)
return return_value


def pandas_read_sql_query(
sql, con, execution_engine, chunksize=None, **kwargs
) -> pd.DataFrame:
"""Suppress deprecation warnings while executing the pandas read_sql_query function.
Note this only passes params straight to pandas read_sql_query method, please
see the pandas documentation
(currently https://pandas.pydata.org/docs/reference/api/pandas.read_sql_query.html)
for more information on this method.
If pandas version is below 2.0 and sqlalchemy installed then we suppress
the sqlalchemy 2.0 warning and raise our own warning. pandas does not
support sqlalchemy 2.0 until version 2.0 (see https://pandas.pydata.org/docs/dev/whatsnew/v2.0.0.html#other-enhancements)
Args:
sql: str or SQLAlchemy Selectable (select or text object)
con: SQLAlchemy connectable, str, or sqlite3 connection
chunksize: If specified, return an iterator where `chunksize` is the number of rows to include in each chunk.
**kwargs: Other keyword arguments, not enumerated here since they differ
between pandas versions.
Returns:
dataframe
"""
if (
sa
and is_version_greater_or_equal(sa.__version__, "2.0.0")
and is_version_less_than(pd.__version__, "2.0.0")
):
warn_pandas_less_than_2_0_and_sqlalchemy_greater_than_or_equal_2_0()

with warnings.catch_warnings():
# Note that RemovedIn20Warning is the warning class that we see from sqlalchemy
# but using the base class here since sqlalchemy is an optional dependency and this
# warning type only exists in sqlalchemy < 2.0.
warnings.filterwarnings(action="ignore", category=DeprecationWarning)
return_value = pd.read_sql_query(
sql=sql, con=con, chunksize=chunksize, **kwargs
)
else:
return_value = pd.read_sql_query(
sql=sql, con=con, chunksize=chunksize, **kwargs
)
return return_value
171 changes: 12 additions & 159 deletions great_expectations/expectations/metrics/table_metrics/table_head.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,19 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Iterator, Optional
from typing import TYPE_CHECKING, Any

import pandas as pd

from great_expectations.compatibility import sqlalchemy
from great_expectations.compatibility.not_imported import (
is_version_less_than,
)
from great_expectations.compatibility.sqlalchemy import sqlalchemy as sa
from great_expectations.compatibility.sqlalchemy_and_pandas import (
pandas_read_sql,
pandas_read_sql_query,
)
from great_expectations.compatibility.sqlalchemy_compatibility_wrappers import (
read_sql_table_as_df,
)
from great_expectations.core.metric_domain_types import MetricDomainTypes
from great_expectations.execution_engine import (
PandasExecutionEngine,
SparkDFExecutionEngine,
SqlAlchemyExecutionEngine,
)
from great_expectations.execution_engine.sqlalchemy_dialect import GXSqlDialect
from great_expectations.expectations.metrics.metric_provider import metric_value
from great_expectations.expectations.metrics.table_metric_provider import (
TableMetricProvider,
Expand Down Expand Up @@ -72,65 +63,26 @@ def _sqlalchemy( # noqa: PLR0913
selectable, _, _ = execution_engine.get_compute_domain(
metric_domain_kwargs, domain_type=MetricDomainTypes.TABLE
)
dialect = execution_engine.engine.dialect.name.lower()

if dialect not in GXSqlDialect.get_all_dialect_names():
dialect = GXSqlDialect.OTHER

table_name = getattr(selectable, "name", None)
n_rows: int = (
metric_value_kwargs.get("n_rows")
if metric_value_kwargs.get("n_rows") is not None
else cls.default_kwarg_values["n_rows"]
)

if is_version_less_than(pd.__version__, "1.4.0"):
df = TableHead._sqlalchemy_head_pandas_less_than14(
selectable=selectable,
execution_engine=execution_engine,
metric_value_kwargs=metric_value_kwargs,
n_rows=n_rows,
)
return df

# None means no limit
limit: int | None = n_rows
if metric_value_kwargs["fetch_all"]:
df = TableHead._return_full_sql_table_as_head(
table_name=table_name,
execution_engine=execution_engine,
selectable=selectable,
dialect=dialect,
)
return df
limit = None

selectable = sa.select("*").select_from(selectable).limit(limit).selectable

try:
df_chunk_iterator: Iterator[pd.DataFrame]
if table_name and not isinstance(table_name, sqlalchemy._anonymous_label):
with execution_engine.get_connection() as con:
# passing chunksize causes the Iterator to be returned
df_chunk_iterator = read_sql_table_as_df(
table_name=getattr(selectable, "name", None),
schema=getattr(selectable, "schema", None),
con=con,
chunksize=abs(n_rows),
dialect=dialect,
)
df = TableHead._get_head_df_from_df_iterator(
df_chunk_iterator=df_chunk_iterator, n_rows=n_rows
)
else:
# passing chunksize causes the Iterator to be returned
with execution_engine.get_connection() as con:
# convert subquery into query using select_from()
if not selectable.supports_execution:
selectable = sa.select(sa.text("*")).select_from(selectable)
df_chunk_iterator = pandas_read_sql_query(
sql=selectable,
con=con,
execution_engine=execution_engine,
chunksize=abs(n_rows),
)
df = TableHead._get_head_df_from_df_iterator(
df_chunk_iterator=df_chunk_iterator, n_rows=n_rows
)
with execution_engine.get_connection() as con:
df = pandas_read_sql(
sql=selectable,
con=con,
)
except StopIteration:
# empty table. At least try to get the column names
validator = Validator(execution_engine=execution_engine)
Expand All @@ -140,26 +92,6 @@ def _sqlalchemy( # noqa: PLR0913
df = pd.DataFrame(columns=columns)
return df

@staticmethod
def _get_head_df_from_df_iterator(
df_chunk_iterator: Iterator[pd.DataFrame], n_rows: int
) -> pd.DataFrame:
if n_rows > 0:
df = next(df_chunk_iterator)
else:
# if n_rows is zero or negative, remove the last chunk
df_chunk_list: list[pd.DataFrame]
df_last_chunk: pd.DataFrame
*df_chunk_list, df_last_chunk = df_chunk_iterator
if df_chunk_list:
df = pd.concat(objs=df_chunk_list, ignore_index=True)
else:
# if n_rows is zero, the last chunk is the entire dataframe,
# so we truncate it to preserve the header
df = df_last_chunk.head(0)

return df

@metric_value(engine=SparkDFExecutionEngine)
def _spark( # noqa: PLR0913
cls,
Expand Down Expand Up @@ -190,82 +122,3 @@ def _spark( # noqa: PLR0913
df = pd.DataFrame(data=rows)

return df

@staticmethod
def _return_full_sql_table_as_head(
table_name: Optional[Any],
execution_engine: SqlAlchemyExecutionEngine,
selectable: sa.sql.selectable.Selectable,
dialect: str,
) -> pd.DataFrame:
if table_name and not isinstance(table_name, sqlalchemy._anonymous_label):
with execution_engine.get_connection() as con:
# using named table
df = read_sql_table_as_df(
table_name=getattr(selectable, "name", None),
schema=getattr(selectable, "schema", None),
con=con,
dialect=dialect,
)
else:
# use selectable as query. If custom query is passed, it will be used
with execution_engine.get_connection() as con:
df = pandas_read_sql(
sql=selectable,
con=con,
)
return df

@staticmethod
def _sqlalchemy_head_pandas_less_than14(
selectable: sa.sql.selectable.Selectable,
execution_engine: SqlAlchemyExecutionEngine,
metric_value_kwargs: dict,
n_rows: int,
) -> pd.DataFrame:
"""
Helper function for _sqlalchemy_head_pandas.
MetaData that is used by pd.read_sql_table cannot work on a temp table with pandas < 1.4.0.
If it fails, we try to get the data using read_sql instead().
"""
stmt = sa.select("*").select_from(selectable)
fetch_all = metric_value_kwargs["fetch_all"]
if fetch_all:
sql = stmt.compile(
dialect=execution_engine.engine.dialect,
compile_kwargs={"literal_binds": True},
)
elif execution_engine.engine.dialect.name.lower() == GXSqlDialect.MSSQL:
# limit doesn't compile properly for mssql
sql = str(
stmt.compile(
dialect=execution_engine.engine.dialect,
compile_kwargs={"literal_binds": True},
)
)
if n_rows > 0:
sql = f"SELECT TOP {n_rows}{sql[6:]}"
else:
if n_rows > 0:
stmt = stmt.limit(n_rows)

sql = stmt.compile(
dialect=execution_engine.engine.dialect,
compile_kwargs={"literal_binds": True},
)

if n_rows <= 0 and not fetch_all:
with execution_engine.get_connection() as con:
df_chunk_iterator = pandas_read_sql(
sql=sql, con=con, chunksize=abs(n_rows)
)
df = TableHead._get_head_df_from_df_iterator(
df_chunk_iterator=df_chunk_iterator, n_rows=n_rows
)
else:
with execution_engine.get_connection() as con:
df = pandas_read_sql_query(
sql=sql, con=con, execution_engine=execution_engine
)
return df
45 changes: 45 additions & 0 deletions tests/expectations/metrics/table_metrics/test_table_head.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from unittest import mock

import pandas as pd
import pytest

Expand Down Expand Up @@ -160,3 +162,46 @@ def test_table_head_sqlite(
len(get_sqlite_temp_table_names_from_engine(engine.engine))
== expected_temp_tables
)


@pytest.mark.sqlite
@pytest.mark.parametrize(
"execution_engine",
[
"sqlite_batch_with_table_name",
"sqlite_batch_with_selectable_with_temp_table",
"sqlite_batch_with_selectable_without_temp_table",
],
)
@pytest.mark.parametrize(
"n_rows",
[None, 0, 1, 2],
)
@pytest.mark.parametrize(
"fetch_all",
[None, True, False],
)
def test_limit_included_in_head_query(
execution_engine,
n_rows,
fetch_all,
request,
):
engine = request.getfixturevalue(execution_engine)
table_head = TableHead()

with mock.patch(
"great_expectations.compatibility.sqlalchemy_and_pandas.pd.read_sql"
) as mock_node:
table_head._sqlalchemy(
execution_engine=engine,
metric_domain_kwargs={},
metric_value_kwargs={"n_rows": n_rows, "fetch_all": fetch_all},
metrics={},
runtime_configuration={},
)

args, kwargs = mock_node.call_args
mock_node.assert_called_once()

assert ("limit" in str(kwargs["sql"]).lower()) == (fetch_all is not True)

0 comments on commit 33d09f9

Please sign in to comment.