From 1cb4eee1286e73a8dce2d6ed8989c18e3fc01023 Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Mon, 16 Sep 2024 16:59:54 -0700 Subject: [PATCH] fix error --- tests/integ/modin/io/test_read_snowflake.py | 44 +++++++++++---------- 1 file changed, 23 insertions(+), 21 deletions(-) diff --git a/tests/integ/modin/io/test_read_snowflake.py b/tests/integ/modin/io/test_read_snowflake.py index fb645493205..3de147ce20b 100644 --- a/tests/integ/modin/io/test_read_snowflake.py +++ b/tests/integ/modin/io/test_read_snowflake.py @@ -37,15 +37,12 @@ paramList = [False, True] -@pytest.fixture(params=paramList, autouse=True) -def setup_uo(request, session): - is_cte_optimization_enabled = session._cte_optimization_enabled - is_query_compilation_enabled = session._query_compilation_stage_enabled - session._query_compilation_stage_enabled = request.param - session._cte_optimization_enabled = True +@pytest.fixture(params=paramList) +def setup_use_scoped_object(request, session): + use_scoped_objects = session._use_scoped_temp_objects + session._use_scoped_temp_objects = request.param yield - session._cte_optimization_enabled = is_cte_optimization_enabled - session._query_compilation_stage_enabled = is_query_compilation_enabled + session._use_scoped_temp_objects = use_scoped_objects def read_snowflake_and_verify_snapshot_creation( @@ -102,7 +99,7 @@ def read_snowflake_and_verify_snapshot_creation( @pytest.mark.parametrize( "as_query", [True, False], ids=["read_with_select_*", "read_with_table_name"] ) -def test_read_snowflake_basic(session, as_query): +def test_read_snowflake_basic(setup_use_scoped_object, session, as_query): # create table table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) fully_qualified_name = [ @@ -130,7 +127,9 @@ def test_read_snowflake_basic(session, as_query): @pytest.mark.parametrize( "as_query", [True, False], ids=["read_with_select_*", "read_with_table_name"] ) -def test_read_snowflake_semi_structured_types(session, as_query): +def test_read_snowflake_semi_structured_types( + setup_use_scoped_object, session, as_query +): # create table table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) session.create_dataframe([SEMI_STRUCTURED_TYPE_DATA]).write.save_as_table( @@ -160,7 +159,7 @@ def test_read_snowflake_none_nan(session, as_query): # create snowpark pandas dataframe df = read_snowflake_and_verify_snapshot_creation( - session, table_name, as_query, verify_materialization=False + session, table_name, as_query, False ) pdf = df.to_pandas() @@ -376,7 +375,7 @@ def test_read_snowflake_column_not_list_raises(session) -> None: "as_query", [True, False], ids=["read_with_select_*", "read_with_table_name"] ) def test_read_snowflake_with_views( - session, test_table_name, table_type, caplog, as_query + setup_use_scoped_object, session, test_table_name, table_type, caplog, as_query ) -> None: # create a temporary test table expected_query_count = 6 @@ -424,9 +423,14 @@ def test_read_snowflake_with_views( @pytest.mark.modin_sp_precommit +@pytest.mark.parametrize( + "as_query", [True, False], ids=["read_with_select_*", "read_with_table_name"] +) def test_read_snowflake_row_access_policy_table( + setup_use_scoped_object, session, test_table_name, + as_query, ) -> None: Utils.create_table(session, test_table_name, "col1 int, s text", is_temporary=True) session.sql(f"insert into {test_table_name} values (1, 'ok')").collect() @@ -440,13 +444,9 @@ def test_read_snowflake_row_access_policy_table( ).collect() with SqlCounter(query_count=3): - df = pd.read_snowflake(test_table_name) - - assert df.columns.tolist() == ["COL1", "S"] - assert len(df) == 0 - - with SqlCounter(query_count=3): - df = pd.read_snowflake(f"SELECT * FROM {test_table_name}") + df = read_snowflake_and_verify_snapshot_creation( + session, test_table_name, as_query, True + ) assert df.columns.tolist() == ["COL1", "S"] assert len(df) == 0 @@ -510,7 +510,9 @@ def test_decimal( @pytest.mark.parametrize( "as_query", [True, False], ids=["read_with_select_*", "read_with_table_name"] ) -def test_read_snowflake_with_table_in_different_db(session, caplog, as_query) -> None: +def test_read_snowflake_with_table_in_different_db( + setup_use_scoped_object, session, caplog, as_query +) -> None: db_name = f"testdb_snowpandas_{Utils.random_alphanumeric_str(4)}" schema_name = f"testschema_snowpandas_{Utils.random_alphanumeric_str(4)}" table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) @@ -533,7 +535,7 @@ def test_read_snowflake_with_table_in_different_db(session, caplog, as_query) -> caplog.clear() with caplog.at_level(logging.DEBUG): df = read_snowflake_and_verify_snapshot_creation( - session, table_name, as_query, True + session, table_name, as_query, False ) # verify no temporary table is materialized for regular table assert not ("Materialize temporary table" in caplog.text)