Skip to content

Commit

Permalink
[BUGFIX] 0.18.x cherrypick create_temp_tables fixes from develop (
Browse files Browse the repository at this point in the history
#9124)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Kilo59 and pre-commit-ci[bot] authored Dec 18, 2023
1 parent 4b78ada commit eb36875
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
},
"create_temp_table": {
"title": "Create Temp Table",
"default": true,
"default": false,
"type": "boolean"
},
"kwargs": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
},
"create_temp_table": {
"title": "Create Temp Table",
"default": true,
"default": false,
"type": "boolean"
},
"kwargs": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
},
"create_temp_table": {
"title": "Create Temp Table",
"default": true,
"default": false,
"type": "boolean"
},
"kwargs": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
},
"create_temp_table": {
"title": "Create Temp Table",
"default": true,
"default": false,
"type": "boolean"
},
"kwargs": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
},
"create_temp_table": {
"title": "Create Temp Table",
"default": true,
"default": false,
"type": "boolean"
},
"kwargs": {
Expand Down
5 changes: 4 additions & 1 deletion great_expectations/datasource/fluent/sql_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,7 +1025,7 @@ class SQLDatasource(Datasource):
# left side enforces the names on instance creation
type: Literal["sql"] = "sql"
connection_string: Union[ConfigStr, str]
create_temp_table: bool = True
create_temp_table: bool = False
kwargs: Dict[str, Union[ConfigStr, Any]] = pydantic.Field(
default={},
description="Optional dictionary of `kwargs` will be passed to the SQLAlchemy Engine"
Expand Down Expand Up @@ -1088,6 +1088,9 @@ def get_execution_engine(self) -> SqlAlchemyExecutionEngine:
current_execution_engine_kwargs = self.dict(
exclude=self._get_exec_engine_excludes(),
config_provider=self._config_provider,
# by default we exclude unset values to prevent lots of extra values in the yaml files
# but we want to include them here
exclude_unset=False,
)
if (
current_execution_engine_kwargs != self._cached_execution_engine_kwargs
Expand Down
1 change: 1 addition & 0 deletions tests/datasource/fluent/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def sqlite_datasource(
datasource = context.sources.add_sqlite(
name="test_datasource",
connection_string=f"sqlite:///{db_file}",
create_temp_table=True,
)
return datasource

Expand Down
133 changes: 110 additions & 23 deletions tests/datasource/fluent/test_sql_datasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@

import logging
import warnings
from typing import TYPE_CHECKING, Generator
from pprint import pformat as pf
from typing import TYPE_CHECKING, Any, Generator
from unittest import mock

import pytest
from pytest import param

from great_expectations.compatibility import sqlalchemy
from great_expectations.compatibility.sqlalchemy import sqlalchemy as sa
from great_expectations.datasource.fluent import GxDatasourceWarning, SQLDatasource
from great_expectations.datasource.fluent.sql_datasource import TableAsset
from great_expectations.execution_engine import SqlAlchemyExecutionEngine

if TYPE_CHECKING:
from pytest_mock import MockerFixture
Expand All @@ -29,6 +32,23 @@ def create_engine_spy(mocker: MockerFixture) -> Generator[mock.MagicMock, None,
LOGGER.warning("SQLAlchemy create_engine was not called")


@pytest.fixture
def gx_sqlalchemy_execution_engine_spy(
mocker: MockerFixture, monkeypatch: pytest.MonkeyPatch
) -> Generator[mock.MagicMock, None, None]:
"""
Mock the SQLDatasource.execution_engine_type property to return a spy so that what would be passed to
the GX SqlAlchemyExecutionEngine constructor can be inspected.
NOTE: This is not exactly what gets passed to the sqlalchemy.engine.create_engine() function, but it is close.
"""
spy = mocker.Mock(spec=SqlAlchemyExecutionEngine)
monkeypatch.setattr(SQLDatasource, "execution_engine_type", spy)
yield spy
if not spy.call_count:
LOGGER.warning("SqlAlchemyExecutionEngine.__init__() was not called")


@pytest.fixture
def create_engine_fake(monkeypatch: pytest.MonkeyPatch) -> None:
"""Monkeypatch sqlalchemy.create_engine to always return a in-memory sqlite engine."""
Expand All @@ -45,33 +65,100 @@ def _fake_create_engine(*args, **kwargs) -> sa.engine.Engine:
@pytest.mark.parametrize(
"ds_kwargs",
[
dict(
connection_string="sqlite:///",
kwargs={"isolation_level": "SERIALIZABLE"},
param(
dict(
connection_string="sqlite:///",
),
id="connection_string only",
),
param(
dict(
connection_string="sqlite:///",
kwargs={"isolation_level": "SERIALIZABLE"},
),
id="no subs + kwargs",
),
param(
dict(
connection_string="${MY_CONN_STR}",
kwargs={"isolation_level": "SERIALIZABLE"},
),
id="subs + kwargs",
),
dict(
connection_string="${MY_CONN_STR}",
kwargs={"isolation_level": "SERIALIZABLE"},
param(
dict(
connection_string="sqlite:///",
create_temp_table=True,
),
id="create_temp_table=True",
),
param(
dict(
connection_string="sqlite:///",
create_temp_table=False,
),
id="create_temp_table=False",
),
],
)
def test_kwargs_are_passed_to_create_engine(
create_engine_spy: mock.MagicMock,
monkeypatch: pytest.MonkeyPatch,
ephemeral_context_with_defaults: EphemeralDataContext,
ds_kwargs: dict,
filter_gx_datasource_warnings: None,
):
monkeypatch.setenv("MY_CONN_STR", "sqlite:///")

context = ephemeral_context_with_defaults
ds = context.sources.add_or_update_sql(name="my_datasource", **ds_kwargs)
print(ds)
ds.test_connection()
class TestConfigPasstrough:
def test_kwargs_passed_to_create_engine(
self,
create_engine_spy: mock.MagicMock,
monkeypatch: pytest.MonkeyPatch,
ephemeral_context_with_defaults: EphemeralDataContext,
ds_kwargs: dict,
filter_gx_datasource_warnings: None,
):
monkeypatch.setenv("MY_CONN_STR", "sqlite:///")

context = ephemeral_context_with_defaults
ds = context.sources.add_or_update_sql(name="my_datasource", **ds_kwargs)
print(ds)
ds.test_connection()

create_engine_spy.assert_called_once_with(
"sqlite:///",
**{
**ds.dict(include={"kwargs"}, exclude_unset=False)["kwargs"],
**ds_kwargs.get("kwargs", {}),
},
)

create_engine_spy.assert_called_once_with(
"sqlite:///", **{"isolation_level": "SERIALIZABLE"}
)
def test_ds_config_passed_to_gx_sqlalchemy_execution_engine(
self,
gx_sqlalchemy_execution_engine_spy: mock.MagicMock,
monkeypatch: pytest.MonkeyPatch,
ephemeral_context_with_defaults: EphemeralDataContext,
ds_kwargs: dict,
filter_gx_datasource_warnings: None,
):
monkeypatch.setenv("MY_CONN_STR", "sqlite:///")

context = ephemeral_context_with_defaults
ds = context.sources.add_or_update_sql(name="my_datasource", **ds_kwargs)
print(ds)
gx_execution_engine: SqlAlchemyExecutionEngine = ds.get_execution_engine()
print(f"{gx_execution_engine=}")

expected_args: dict[str, Any] = {
# kwargs that we expect are passed to SqlAlchemyExecutionEngine
# including datasource field default values
**ds.dict(
exclude_unset=False,
exclude={"kwargs", *ds_kwargs.keys(), *ds._get_exec_engine_excludes()},
),
**{k: v for k, v in ds_kwargs.items() if k not in ["kwargs"]},
**ds_kwargs.get("kwargs", {}),
# config substitution should have been performed
**ds.dict(
include={"connection_string"}, config_provider=ds._config_provider
),
}
assert "create_temp_table" in expected_args

print(f"\nExpected SqlAlchemyExecutionEngine arguments:\n{pf(expected_args)}")
gx_sqlalchemy_execution_engine_spy.assert_called_once_with(**expected_args)


@pytest.mark.unit
Expand Down
2 changes: 1 addition & 1 deletion tests/execution_engine/test_sqlalchemy_execution_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,7 @@ def test_resolve_metric_bundle_with_compute_domain_kwargs_json_serialization(sa)
@pytest.mark.sqlite
def test_get_batch_data_and_markers_using_query(sqlite_view_engine, test_df):
my_execution_engine: SqlAlchemyExecutionEngine = SqlAlchemyExecutionEngine(
engine=sqlite_view_engine
engine=sqlite_view_engine,
)
add_dataframe_to_db(df=test_df, name="test_table_0", con=my_execution_engine.engine)

Expand Down

0 comments on commit eb36875

Please sign in to comment.