Skip to content

Commit

Permalink
fix error
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-yzou committed Sep 16, 2024
1 parent 378cf18 commit 1cb4eee
Showing 1 changed file with 23 additions and 21 deletions.
44 changes: 23 additions & 21 deletions tests/integ/modin/io/test_read_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,12 @@
paramList = [False, True]


@pytest.fixture(params=paramList, autouse=True)
def setup_uo(request, session):
is_cte_optimization_enabled = session._cte_optimization_enabled
is_query_compilation_enabled = session._query_compilation_stage_enabled
session._query_compilation_stage_enabled = request.param
session._cte_optimization_enabled = True
@pytest.fixture(params=paramList)
def setup_use_scoped_object(request, session):
use_scoped_objects = session._use_scoped_temp_objects
session._use_scoped_temp_objects = request.param
yield
session._cte_optimization_enabled = is_cte_optimization_enabled
session._query_compilation_stage_enabled = is_query_compilation_enabled
session._use_scoped_temp_objects = use_scoped_objects


def read_snowflake_and_verify_snapshot_creation(
Expand Down Expand Up @@ -102,7 +99,7 @@ def read_snowflake_and_verify_snapshot_creation(
@pytest.mark.parametrize(
"as_query", [True, False], ids=["read_with_select_*", "read_with_table_name"]
)
def test_read_snowflake_basic(session, as_query):
def test_read_snowflake_basic(setup_use_scoped_object, session, as_query):
# create table
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
fully_qualified_name = [
Expand Down Expand Up @@ -130,7 +127,9 @@ def test_read_snowflake_basic(session, as_query):
@pytest.mark.parametrize(
"as_query", [True, False], ids=["read_with_select_*", "read_with_table_name"]
)
def test_read_snowflake_semi_structured_types(session, as_query):
def test_read_snowflake_semi_structured_types(
setup_use_scoped_object, session, as_query
):
# create table
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
session.create_dataframe([SEMI_STRUCTURED_TYPE_DATA]).write.save_as_table(
Expand Down Expand Up @@ -160,7 +159,7 @@ def test_read_snowflake_none_nan(session, as_query):

# create snowpark pandas dataframe
df = read_snowflake_and_verify_snapshot_creation(
session, table_name, as_query, verify_materialization=False
session, table_name, as_query, False
)

pdf = df.to_pandas()
Expand Down Expand Up @@ -376,7 +375,7 @@ def test_read_snowflake_column_not_list_raises(session) -> None:
"as_query", [True, False], ids=["read_with_select_*", "read_with_table_name"]
)
def test_read_snowflake_with_views(
session, test_table_name, table_type, caplog, as_query
setup_use_scoped_object, session, test_table_name, table_type, caplog, as_query
) -> None:
# create a temporary test table
expected_query_count = 6
Expand Down Expand Up @@ -424,9 +423,14 @@ def test_read_snowflake_with_views(


@pytest.mark.modin_sp_precommit
@pytest.mark.parametrize(
"as_query", [True, False], ids=["read_with_select_*", "read_with_table_name"]
)
def test_read_snowflake_row_access_policy_table(
setup_use_scoped_object,
session,
test_table_name,
as_query,
) -> None:
Utils.create_table(session, test_table_name, "col1 int, s text", is_temporary=True)
session.sql(f"insert into {test_table_name} values (1, 'ok')").collect()
Expand All @@ -440,13 +444,9 @@ def test_read_snowflake_row_access_policy_table(
).collect()

with SqlCounter(query_count=3):
df = pd.read_snowflake(test_table_name)

assert df.columns.tolist() == ["COL1", "S"]
assert len(df) == 0

with SqlCounter(query_count=3):
df = pd.read_snowflake(f"SELECT * FROM {test_table_name}")
df = read_snowflake_and_verify_snapshot_creation(
session, test_table_name, as_query, True
)

assert df.columns.tolist() == ["COL1", "S"]
assert len(df) == 0
Expand Down Expand Up @@ -510,7 +510,9 @@ def test_decimal(
@pytest.mark.parametrize(
"as_query", [True, False], ids=["read_with_select_*", "read_with_table_name"]
)
def test_read_snowflake_with_table_in_different_db(session, caplog, as_query) -> None:
def test_read_snowflake_with_table_in_different_db(
setup_use_scoped_object, session, caplog, as_query
) -> None:
db_name = f"testdb_snowpandas_{Utils.random_alphanumeric_str(4)}"
schema_name = f"testschema_snowpandas_{Utils.random_alphanumeric_str(4)}"
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
Expand All @@ -533,7 +535,7 @@ def test_read_snowflake_with_table_in_different_db(session, caplog, as_query) ->
caplog.clear()
with caplog.at_level(logging.DEBUG):
df = read_snowflake_and_verify_snapshot_creation(
session, table_name, as_query, True
session, table_name, as_query, False
)
# verify no temporary table is materialized for regular table
assert not ("Materialize temporary table" in caplog.text)
Expand Down

0 comments on commit 1cb4eee

Please sign in to comment.