From 0c1b6d9bfb1ee5fe480b4fd83f2a8534caadacf5 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 6 Sep 2024 17:07:31 -0700 Subject: [PATCH 01/62] add profiler --- src/snowflake/snowpark/profiler.py | 104 +++++++++++++++++++++++++++++ src/snowflake/snowpark/session.py | 11 +++ 2 files changed, 115 insertions(+) create mode 100644 src/snowflake/snowpark/profiler.py diff --git a/src/snowflake/snowpark/profiler.py b/src/snowflake/snowpark/profiler.py new file mode 100644 index 00000000000..190c91f44f8 --- /dev/null +++ b/src/snowflake/snowpark/profiler.py @@ -0,0 +1,104 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# +from contextlib import contextmanager +from typing import List, Optional + +import snowflake.snowpark +from snowflake.snowpark._internal.utils import validate_object_name + + +class Profiler: + def __init__( + self, + stage: str, + active_profiler: str = "LINE", + session: Optional["snowflake.snowpark.Session"] = None, + ) -> None: + self.stage = stage + self.active_profiler = active_profiler + self.modules_to_register = [] + self.register_modules_sql = "" + self.set_targeted_stage_sql = "" + self.enable_profiler_sql = "" + self.disable_profiler_sql = "" + self.set_active_profiler_sql = "" + self.session = session + self.stage_and_profiler_name_validation() + self.prepare_sql() + + def stage_and_profiler_name_validation(self): + if self.active_profiler not in ["LINE", "MEMORY"]: + raise ValueError( + f"active_profiler expect 'LINE' or 'MEMORY', got {self.active_profiler} instead" + ) + validate_object_name(self.stage) + + def prepare_sql(self): + self.register_modules_sql = f"alter session set python_profiler_modules='{','.join(self.modules_to_register)}'" + self.set_targeted_stage_sql = ( + f'alter session set PYTHON_PROFILER_TARGET_STAGE ="{self.stage}"' + ) + self.enable_profiler_sql = "alter session set ENABLE_PYTHON_PROFILER = true" + self.disable_profiler_sql = "alter session set ENABLE_PYTHON_PROFILER = false" + self.set_active_profiler_sql = f"alter session set ACTIVE_PYTHON_PROFILER = '{self.active_profiler.upper()}'" + + def register_modules(self, modules: List[str]): + self.modules_to_register = modules + self.prepare_sql() + + def set_targeted_stage(self, stage: str): + self.stage = stage + self.prepare_sql() + + def set_active_profiler(self, active_profiler: str): + self.active_profiler = active_profiler + self.prepare_sql() + + def _register_modules(self): + self.session.sql(self.register_modules_sql).collect() + + def _set_targeted_stage(self): + self.session.sql(self.set_targeted_stage_sql).collect() + + def _set_active_profiler(self): + self.session.sql(self.set_active_profiler_sql).collect() + + def enable_profiler(self): + self.session.sql(self.enable_profiler_sql).collect() + + def disable_profiler(self): + self.session.sql(self.disable_profiler_sql).collect() + + def show_profiles(self, query_id: str): + sql = f"select snowflake.core.get_python_profiler_output({query_id});" + res = self.session.sql(sql).collect() + return res + + def dump_profiles(self, query_id: str, dst_file: str): + sql = f"select snowflake.core.get_python_profiler_output({query_id});" + res = self.session.sql(sql).collect() + with open(dst_file, "w") as f: + for row in res: + f.write(str(row)) + + +@contextmanager +def profiler( + stage: str, + active_profiler: str, + session: "snowflake.snowpark.Session", + modules: Optional[List[str]] = None, +): + internal_profiler = Profiler(stage, active_profiler, session) + modules = [] if modules is None else modules + try: + # set up phase + internal_profiler._set_targeted_stage() + internal_profiler._set_active_profiler() + + internal_profiler.register_modules(modules) + internal_profiler._register_modules() + internal_profiler.enable_profiler() + finally: + internal_profiler.disable_profiler() diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index c6f430cc980..b8a97e48d6d 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -147,6 +147,7 @@ from snowflake.snowpark.mock._plan_builder import MockSnowflakePlanBuilder from snowflake.snowpark.mock._stored_procedure import MockStoredProcedureRegistration from snowflake.snowpark.mock._udf import MockUDFRegistration +from snowflake.snowpark.profiler import Profiler from snowflake.snowpark.query_history import QueryHistory from snowflake.snowpark.row import Row from snowflake.snowpark.stored_procedure import StoredProcedureRegistration @@ -570,6 +571,7 @@ def __init__( self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self) if self._auto_clean_up_temp_table_enabled: self._temp_table_auto_cleaner.start() + self.profiler = None _logger.info("Snowpark Session information: %s", self._session_info) @@ -3385,6 +3387,15 @@ def flatten( set_api_call_source(df, "Session.flatten") return df + def register_profiler(self, profiler: Profiler): + self.profiler = profiler + self.profiler.session = self + if len(self.sql(f"show stages like '{profiler.stage}'").collect()) == 0: + self.sql(f"create or replace temp stage if not exists {profiler.stage}") + self.profiler._register_modules() + self.profiler._set_targeted_stage() + self.profiler._set_active_profiler() + def query_history(self) -> QueryHistory: """Create an instance of :class:`QueryHistory` as a context manager to record queries that are pushed down to the Snowflake database. From 9d911df5ad8cd02fb654626e4215ed49ddec26c0 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Sun, 8 Sep 2024 21:43:38 -0700 Subject: [PATCH 02/62] profiler finish --- src/snowflake/snowpark/profiler.py | 20 ++++++++++++++++++-- src/snowflake/snowpark/session.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/profiler.py b/src/snowflake/snowpark/profiler.py index 190c91f44f8..561c9eea72d 100644 --- a/src/snowflake/snowpark/profiler.py +++ b/src/snowflake/snowpark/profiler.py @@ -26,6 +26,7 @@ def __init__( self.session = session self.stage_and_profiler_name_validation() self.prepare_sql() + self.query_history = None def stage_and_profiler_name_validation(self): if self.active_profiler not in ["LINE", "MEMORY"]: @@ -70,12 +71,24 @@ def enable_profiler(self): def disable_profiler(self): self.session.sql(self.disable_profiler_sql).collect() - def show_profiles(self, query_id: str): + def _get_last_query_id(self): + sps = self.session.sql("show procedures").collect() + names = [r.name for r in sps] + for query in self.query_history.queries[::-1]: + if query.sql_text.startswith("CALL"): + sp_name = query.sql_text.split(" ")[1].split("(")[0] + if sp_name.upper() in names: + return query.query_id + return None + + def show_profiles(self): + query_id = self._get_last_query_id() sql = f"select snowflake.core.get_python_profiler_output({query_id});" res = self.session.sql(sql).collect() return res - def dump_profiles(self, query_id: str, dst_file: str): + def dump_profiles(self, dst_file: str): + query_id = self._get_last_query_id() sql = f"select snowflake.core.get_python_profiler_output({query_id});" res = self.session.sql(sql).collect() with open(dst_file, "w") as f: @@ -101,4 +114,7 @@ def profiler( internal_profiler._register_modules() internal_profiler.enable_profiler() finally: + yield + internal_profiler.register_modules([]) + internal_profiler._register_modules() internal_profiler.disable_profiler() diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index b8a97e48d6d..c4ea463930d 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -3388,6 +3388,7 @@ def flatten( return df def register_profiler(self, profiler: Profiler): + """Register a profiler to current session, all action are actually executed during this function""" self.profiler = profiler self.profiler.session = self if len(self.sql(f"show stages like '{profiler.stage}'").collect()) == 0: @@ -3395,6 +3396,35 @@ def register_profiler(self, profiler: Profiler): self.profiler._register_modules() self.profiler._set_targeted_stage() self.profiler._set_active_profiler() + self.profiler.query_history = self.query_history() + + def show_profiles(self): + """Gather and return result of profiler, results are also print to console""" + if self.profiler is not None and isinstance(self.profiler, Profiler): + self.profiler.show_profiles() + else: + raise ValueError( + "profiler is not set, use session.register_profiler or profiler context manager" + ) + + def dump_profiles(self, dst_file: str): + """Gather result of a profiler and redirect it to a file""" + if self.profiler is not None and isinstance(self.profiler, Profiler): + self.profiler.dump_profiles(dst_file=dst_file) + else: + raise ValueError( + "profiler is not set, use session.register_profiler or profiler context manager" + ) + + def register_profiler_modules(self, modules: List[str]): + """Register modules want to create profile""" + if self.profiler is not None and isinstance(self.profiler, Profiler): + self.profiler.register_modules(modules) + else: + sql_statement = ( + f"alter session set python_profiler_modules='{','.join(modules)}'" + ) + self.sql(sql_statement).collect() def query_history(self) -> QueryHistory: """Create an instance of :class:`QueryHistory` as a context manager to record queries that are pushed down to the Snowflake database. From 968ac2cb93ae024eaa2541dae8d4667ce032f079 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Sun, 8 Sep 2024 21:54:35 -0700 Subject: [PATCH 03/62] t --- src/snowflake/snowpark/profiler.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/snowflake/snowpark/profiler.py b/src/snowflake/snowpark/profiler.py index 561c9eea72d..8cb67d6bf55 100644 --- a/src/snowflake/snowpark/profiler.py +++ b/src/snowflake/snowpark/profiler.py @@ -47,14 +47,20 @@ def prepare_sql(self): def register_modules(self, modules: List[str]): self.modules_to_register = modules self.prepare_sql() + if self.session is not None: + self._register_modules() def set_targeted_stage(self, stage: str): self.stage = stage self.prepare_sql() + if self.session is not None: + self._set_targeted_stage() def set_active_profiler(self, active_profiler: str): self.active_profiler = active_profiler self.prepare_sql() + if self.session is not None: + self._set_active_profiler() def _register_modules(self): self.session.sql(self.register_modules_sql).collect() From 71967c6144111aad5dcd8b881026a139c1c58f75 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Sun, 8 Sep 2024 22:21:19 -0700 Subject: [PATCH 04/62] t --- src/snowflake/snowpark/profiler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/profiler.py b/src/snowflake/snowpark/profiler.py index 8cb67d6bf55..27810465594 100644 --- a/src/snowflake/snowpark/profiler.py +++ b/src/snowflake/snowpark/profiler.py @@ -11,8 +11,8 @@ class Profiler: def __init__( self, - stage: str, - active_profiler: str = "LINE", + stage: Optional[str] = "", + active_profiler: Optional[str] = "LINE", session: Optional["snowflake.snowpark.Session"] = None, ) -> None: self.stage = stage From 3222459708ccf260eab8cb92662236e1daab2ef9 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 9 Sep 2024 22:14:12 -0700 Subject: [PATCH 05/62] add test --- src/snowflake/snowpark/profiler.py | 31 ++++++------ src/snowflake/snowpark/session.py | 4 +- tests/integ/test_profiler.py | 81 ++++++++++++++++++++++++++++++ 3 files changed, 98 insertions(+), 18 deletions(-) create mode 100644 tests/integ/test_profiler.py diff --git a/src/snowflake/snowpark/profiler.py b/src/snowflake/snowpark/profiler.py index 27810465594..256a5a9baf5 100644 --- a/src/snowflake/snowpark/profiler.py +++ b/src/snowflake/snowpark/profiler.py @@ -24,17 +24,9 @@ def __init__( self.disable_profiler_sql = "" self.set_active_profiler_sql = "" self.session = session - self.stage_and_profiler_name_validation() self.prepare_sql() self.query_history = None - def stage_and_profiler_name_validation(self): - if self.active_profiler not in ["LINE", "MEMORY"]: - raise ValueError( - f"active_profiler expect 'LINE' or 'MEMORY', got {self.active_profiler} instead" - ) - validate_object_name(self.stage) - def prepare_sql(self): self.register_modules_sql = f"alter session set python_profiler_modules='{','.join(self.modules_to_register)}'" self.set_targeted_stage_sql = ( @@ -44,19 +36,24 @@ def prepare_sql(self): self.disable_profiler_sql = "alter session set ENABLE_PYTHON_PROFILER = false" self.set_active_profiler_sql = f"alter session set ACTIVE_PYTHON_PROFILER = '{self.active_profiler.upper()}'" - def register_modules(self, modules: List[str]): + def register_profiler_modules(self, modules: List[str]): self.modules_to_register = modules self.prepare_sql() if self.session is not None: self._register_modules() def set_targeted_stage(self, stage: str): + validate_object_name(stage) self.stage = stage self.prepare_sql() if self.session is not None: self._set_targeted_stage() def set_active_profiler(self, active_profiler: str): + if self.active_profiler not in ["LINE", "MEMORY"]: + raise ValueError( + f"active_profiler expect 'LINE' or 'MEMORY', got {self.active_profiler} instead" + ) self.active_profiler = active_profiler self.prepare_sql() if self.session is not None: @@ -89,17 +86,17 @@ def _get_last_query_id(self): def show_profiles(self): query_id = self._get_last_query_id() - sql = f"select snowflake.core.get_python_profiler_output({query_id});" + sql = f"select snowflake.core.get_python_profiler_output('{query_id}');" res = self.session.sql(sql).collect() - return res + print(res[0][0]) # noqa: T201: we need to print here. + return res[0][0] def dump_profiles(self, dst_file: str): query_id = self._get_last_query_id() - sql = f"select snowflake.core.get_python_profiler_output({query_id});" + sql = f"select snowflake.core.get_python_profiler_output('{query_id}');" res = self.session.sql(sql).collect() with open(dst_file, "w") as f: - for row in res: - f.write(str(row)) + f.write(str(res[0][0])) @contextmanager @@ -110,17 +107,19 @@ def profiler( modules: Optional[List[str]] = None, ): internal_profiler = Profiler(stage, active_profiler, session) + session.profiler = internal_profiler + internal_profiler.query_history = session.query_history() modules = [] if modules is None else modules try: # set up phase internal_profiler._set_targeted_stage() internal_profiler._set_active_profiler() - internal_profiler.register_modules(modules) + internal_profiler.register_profiler_modules(modules) internal_profiler._register_modules() internal_profiler.enable_profiler() finally: yield - internal_profiler.register_modules([]) + internal_profiler.register_profiler_modules([]) internal_profiler._register_modules() internal_profiler.disable_profiler() diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index c4ea463930d..83b07e55c4b 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -3401,7 +3401,7 @@ def register_profiler(self, profiler: Profiler): def show_profiles(self): """Gather and return result of profiler, results are also print to console""" if self.profiler is not None and isinstance(self.profiler, Profiler): - self.profiler.show_profiles() + return self.profiler.show_profiles() else: raise ValueError( "profiler is not set, use session.register_profiler or profiler context manager" @@ -3419,7 +3419,7 @@ def dump_profiles(self, dst_file: str): def register_profiler_modules(self, modules: List[str]): """Register modules want to create profile""" if self.profiler is not None and isinstance(self.profiler, Profiler): - self.profiler.register_modules(modules) + self.profiler.register_profiler_modules(modules) else: sql_statement = ( f"alter session set python_profiler_modules='{','.join(modules)}'" diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py new file mode 100644 index 00000000000..b8af2458ef5 --- /dev/null +++ b/tests/integ/test_profiler.py @@ -0,0 +1,81 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +import pytest + +import snowflake.snowpark +from snowflake.snowpark import DataFrame +from snowflake.snowpark.functions import sproc +from snowflake.snowpark.profiler import Profiler, profiler +from tests.utils import Utils + +tmp_stage_name = Utils.random_stage_name() + + +@pytest.fixture(scope="module", autouse=True) +def setup(session, resources_path, local_testing_mode): + if not local_testing_mode: + Utils.create_stage(session, tmp_stage_name, is_temporary=True) + session.add_packages("snowflake-snowpark-python") + + +def test_profiler_with_context_manager(session, db_parameters): + @sproc(name="table_sp", replace=True) + def table_sp(session: snowflake.snowpark.Session) -> DataFrame: + return session.sql("select 1") + + session.register_profiler_modules(["table_sp"]) + with profiler( + stage=f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}", + active_profiler="LINE", + session=session, + ): + session.call("table_sp").collect() + res = session.show_profiles() + session.register_profiler_modules([]) + assert res is not None + assert "Modules Profiled" in res + + +def test_profiler_with_profiler_class(session, db_parameters): + @sproc(name="table_sp", replace=True) + def table_sp(session: snowflake.snowpark.Session) -> DataFrame: + return session.sql("select 1") + + profiler = Profiler() + profiler.register_profiler_modules(["table_sp"]) + profiler.set_active_profiler("LINE") + profiler.set_targeted_stage( + f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" + ) + session.register_profiler(profiler) + + profiler.enable_profiler() + + session.call("table_sp").collect() + res = session.show_profiles() + + profiler.disable_profiler() + + profiler.register_profiler_modules([]) + assert res is not None + assert "Modules Profiled" in res + + +def test_single_return_value_of_sp(session, db_parameters): + @sproc(name="single_value_sp", replace=True) + def single_value_sp(session: snowflake.snowpark.Session) -> str: + return "success" + + session.register_profiler_modules(["table_sp"]) + with profiler( + stage=f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}", + active_profiler="LINE", + session=session, + ): + session.call("single_value_sp") + res = session.show_profiles() + session.register_profiler_modules([]) + assert res is not None + assert "Modules Profiled" in res From e1d314a468e55b5a01678074958c98d9763bde1f Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 10 Sep 2024 20:49:00 -0700 Subject: [PATCH 06/62] address comment --- CHANGELOG.md | 1 + src/snowflake/snowpark/profiler.py | 2 +- src/snowflake/snowpark/session.py | 4 +++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fea42391259..789778ca578 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ - Added following new functions in `snowflake.snowpark.functions`: - `array_remove` - `ln` +- Added snowpark python API for profiler. #### Improvements diff --git a/src/snowflake/snowpark/profiler.py b/src/snowflake/snowpark/profiler.py index 256a5a9baf5..c5117776468 100644 --- a/src/snowflake/snowpark/profiler.py +++ b/src/snowflake/snowpark/profiler.py @@ -109,7 +109,7 @@ def profiler( internal_profiler = Profiler(stage, active_profiler, session) session.profiler = internal_profiler internal_profiler.query_history = session.query_history() - modules = [] if modules is None else modules + modules = modules or [] try: # set up phase internal_profiler._set_targeted_stage() diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 9fad1ebb68f..589caa10a40 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -3406,7 +3406,9 @@ def register_profiler(self, profiler: Profiler): self.profiler = profiler self.profiler.session = self if len(self.sql(f"show stages like '{profiler.stage}'").collect()) == 0: - self.sql(f"create or replace temp stage if not exists {profiler.stage}") + self.sql( + f"create or replace temp stage {profiler.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" + ).collect() self.profiler._register_modules() self.profiler._set_targeted_stage() self.profiler._set_active_profiler() From bf71cc2e34abd01b5222d9bfb92e5dba743bd344 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 10 Sep 2024 21:14:19 -0700 Subject: [PATCH 07/62] fix get last query id --- src/snowflake/snowpark/profiler.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/snowflake/snowpark/profiler.py b/src/snowflake/snowpark/profiler.py index c5117776468..eda37199741 100644 --- a/src/snowflake/snowpark/profiler.py +++ b/src/snowflake/snowpark/profiler.py @@ -75,13 +75,9 @@ def disable_profiler(self): self.session.sql(self.disable_profiler_sql).collect() def _get_last_query_id(self): - sps = self.session.sql("show procedures").collect() - names = [r.name for r in sps] for query in self.query_history.queries[::-1]: if query.sql_text.startswith("CALL"): - sp_name = query.sql_text.split(" ")[1].split("(")[0] - if sp_name.upper() in names: - return query.query_id + return query.query_id return None def show_profiles(self): From 33208147d06e26c90eed131fdfdff3fc7621cfd0 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 16 Sep 2024 10:28:16 -0700 Subject: [PATCH 08/62] Update session.py --- src/snowflake/snowpark/session.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 81bfbd922a0..857809e31a8 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -605,13 +605,8 @@ def __init__( self._tmpdir_handler: Optional[tempfile.TemporaryDirectory] = None self._runtime_version_from_requirement: str = None self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self) -<<<<<<< SNOW-1527717 - if self._auto_clean_up_temp_table_enabled: - self._temp_table_auto_cleaner.start() self.profiler = None -======= ->>>>>>> main _logger.info("Snowpark Session information: %s", self._session_info) def __enter__(self): From 791cf2ae206710680341fce1c860a2eddc6c04d0 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 17 Sep 2024 13:28:20 -0700 Subject: [PATCH 09/62] add docstring --- src/snowflake/snowpark/profiler.py | 61 +++++++++++++++++++++++++++--- src/snowflake/snowpark/session.py | 29 ++++++++++++-- tests/integ/test_profiler.py | 1 + 3 files changed, 81 insertions(+), 10 deletions(-) diff --git a/src/snowflake/snowpark/profiler.py b/src/snowflake/snowpark/profiler.py index eda37199741..5aba9357f70 100644 --- a/src/snowflake/snowpark/profiler.py +++ b/src/snowflake/snowpark/profiler.py @@ -24,10 +24,10 @@ def __init__( self.disable_profiler_sql = "" self.set_active_profiler_sql = "" self.session = session - self.prepare_sql() + self._prepare_sql() self.query_history = None - def prepare_sql(self): + def _prepare_sql(self): self.register_modules_sql = f"alter session set python_profiler_modules='{','.join(self.modules_to_register)}'" self.set_targeted_stage_sql = ( f'alter session set PYTHON_PROFILER_TARGET_STAGE ="{self.stage}"' @@ -37,25 +37,53 @@ def prepare_sql(self): self.set_active_profiler_sql = f"alter session set ACTIVE_PYTHON_PROFILER = '{self.active_profiler.upper()}'" def register_profiler_modules(self, modules: List[str]): + """ + Register stored procedures to generate profiles for them. + + Note: + Registered nodules will be overwritten by this function, + use this function with an empty string will remove registered modules. + Args: + modules: List of names of stored procedures. + """ self.modules_to_register = modules - self.prepare_sql() + self._prepare_sql() if self.session is not None: self._register_modules() def set_targeted_stage(self, stage: str): + """ + Set targeted stage for profiler output. + + Note: + The stage name must be a fully qualified name. + + Args: + stage: String of fully qualified name of targeted stage + """ validate_object_name(stage) self.stage = stage - self.prepare_sql() + self._prepare_sql() if self.session is not None: self._set_targeted_stage() def set_active_profiler(self, active_profiler: str): + """ + Set active profiler. + + Note: + Active profiler must be either 'LINE' or 'MEMORY' (case-sensitive), + active profiler is set to 'LINE' by default. + Args: + active_profiler: String that represent active_profiler, must be either 'LINE' or 'MEMORY' (case-sensitive). + + """ if self.active_profiler not in ["LINE", "MEMORY"]: raise ValueError( f"active_profiler expect 'LINE' or 'MEMORY', got {self.active_profiler} instead" ) self.active_profiler = active_profiler - self.prepare_sql() + self._prepare_sql() if self.session is not None: self._set_active_profiler() @@ -69,9 +97,15 @@ def _set_active_profiler(self): self.session.sql(self.set_active_profiler_sql).collect() def enable_profiler(self): + """ + Enable profiler. Profiles will be generated until profiler is disabled. + """ self.session.sql(self.enable_profiler_sql).collect() def disable_profiler(self): + """ + Disable profiler. + """ self.session.sql(self.disable_profiler_sql).collect() def _get_last_query_id(self): @@ -80,7 +114,13 @@ def _get_last_query_id(self): return query.query_id return None - def show_profiles(self): + def show_profiles(self) -> str: + """ + Return and show the profiles of last executed stored procedure. + + Note: + This function must be called right after the execution of stored procedure you want to profile. + """ query_id = self._get_last_query_id() sql = f"select snowflake.core.get_python_profiler_output('{query_id}');" res = self.session.sql(sql).collect() @@ -88,6 +128,15 @@ def show_profiles(self): return res[0][0] def dump_profiles(self, dst_file: str): + """ + Write the profiles of last executed stored procedure to given file. + + Note: + This function must be called right after the execution of stored procedure you want to profile. + + Args: + dst_file: String of file name that you want to store the profiles. + """ query_id = self._get_last_query_id() sql = f"select snowflake.core.get_python_profiler_output('{query_id}');" res = self.session.sql(sql).collect() diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 857809e31a8..119c5f9b59e 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -3476,8 +3476,13 @@ def register_profiler(self, profiler: Profiler): self.profiler._set_active_profiler() self.profiler.query_history = self.query_history() - def show_profiles(self): - """Gather and return result of profiler, results are also print to console""" + def show_profiles(self) -> str: + """ + Return and show the profiles of last executed stored procedure. + + Note: + This function must be called right after the execution of stored procedure you want to profile. + """ if self.profiler is not None and isinstance(self.profiler, Profiler): return self.profiler.show_profiles() else: @@ -3486,7 +3491,15 @@ def show_profiles(self): ) def dump_profiles(self, dst_file: str): - """Gather result of a profiler and redirect it to a file""" + """ + Write the profiles of last executed stored procedure to given file. + + Note: + This function must be called right after the execution of stored procedure you want to profile. + + Args: + dst_file: String of file name that you want to store the profiles. + """ if self.profiler is not None and isinstance(self.profiler, Profiler): self.profiler.dump_profiles(dst_file=dst_file) else: @@ -3495,7 +3508,15 @@ def dump_profiles(self, dst_file: str): ) def register_profiler_modules(self, modules: List[str]): - """Register modules want to create profile""" + """ + Register stored procedures to generate profiles for them. + + Note: + Registered nodules will be overwritten by this function, + use this function with an empty string will remove registered modules. + Args: + modules: List of names of stored procedures. + """ if self.profiler is not None and isinstance(self.profiler, Profiler): self.profiler.register_profiler_modules(modules) else: diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index b8af2458ef5..24eca44b170 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -35,6 +35,7 @@ def table_sp(session: snowflake.snowpark.Session) -> DataFrame: res = session.show_profiles() session.register_profiler_modules([]) assert res is not None + print(type(res)) assert "Modules Profiled" in res From b458a852473c657b3efa3ddcf81775fb5dd002f3 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 17 Sep 2024 14:26:35 -0700 Subject: [PATCH 10/62] add regx for anonymous procedure --- src/snowflake/snowpark/profiler.py | 6 +++++- tests/integ/test_profiler.py | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/profiler.py b/src/snowflake/snowpark/profiler.py index 5aba9357f70..4e2c36f135d 100644 --- a/src/snowflake/snowpark/profiler.py +++ b/src/snowflake/snowpark/profiler.py @@ -1,6 +1,7 @@ # # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import re from contextlib import contextmanager from typing import List, Optional @@ -109,8 +110,11 @@ def disable_profiler(self): self.session.sql(self.disable_profiler_sql).collect() def _get_last_query_id(self): + pattern = r"WITH\s+.*?\s+AS\s+PROCEDURE\s+.*?\s+CALL\s+.*" for query in self.query_history.queries[::-1]: - if query.sql_text.startswith("CALL"): + if query.sql_text.startswith("CALL") or re.match( + pattern, query.sql_text, re.DOTALL + ): return query.query_id return None diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index 24eca44b170..bcc5a46d803 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -80,3 +80,21 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: session.register_profiler_modules([]) assert res is not None assert "Modules Profiled" in res + + +def test_anonymous_procedure(session, db_parameters): + def single_value_sp(session: snowflake.snowpark.Session) -> str: + return "success" + + single_value_sp = session.sproc.register(single_value_sp, anonymous=True) + session.register_profiler_modules(["table_sp"]) + with profiler( + stage=f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}", + active_profiler="LINE", + session=session, + ): + single_value_sp() + res = session.show_profiles() + session.register_profiler_modules([]) + assert res is not None + assert "Modules Profiled" in res From ba726ff7c2acedb6531bdb7b0b23ee8b77b1e42f Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 17 Sep 2024 14:33:45 -0700 Subject: [PATCH 11/62] add docstring --- src/snowflake/snowpark/profiler.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/snowflake/snowpark/profiler.py b/src/snowflake/snowpark/profiler.py index 4e2c36f135d..b9bc5e0b2fe 100644 --- a/src/snowflake/snowpark/profiler.py +++ b/src/snowflake/snowpark/profiler.py @@ -10,6 +10,13 @@ class Profiler: + """ + Setup profiler to receive profiles of stored procedures. + + Note: + This feature cannot be used in owner's right SP because owner's right SP will not be able to set session-level parameters. + """ + def __init__( self, stage: Optional[str] = "", From 242e4ded1953a540b38e9d9b9ef8aa53175a96cf Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 17 Sep 2024 14:38:45 -0700 Subject: [PATCH 12/62] skip in localtesting --- tests/integ/test_profiler.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index bcc5a46d803..91193448a7e 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -20,6 +20,10 @@ def setup(session, resources_path, local_testing_mode): session.add_packages("snowflake-snowpark-python") +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="session.sql is not supported in localtesting", +) def test_profiler_with_context_manager(session, db_parameters): @sproc(name="table_sp", replace=True) def table_sp(session: snowflake.snowpark.Session) -> DataFrame: @@ -39,6 +43,10 @@ def table_sp(session: snowflake.snowpark.Session) -> DataFrame: assert "Modules Profiled" in res +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="session.sql is not supported in localtesting", +) def test_profiler_with_profiler_class(session, db_parameters): @sproc(name="table_sp", replace=True) def table_sp(session: snowflake.snowpark.Session) -> DataFrame: @@ -64,6 +72,10 @@ def table_sp(session: snowflake.snowpark.Session) -> DataFrame: assert "Modules Profiled" in res +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="session.sql is not supported in localtesting", +) def test_single_return_value_of_sp(session, db_parameters): @sproc(name="single_value_sp", replace=True) def single_value_sp(session: snowflake.snowpark.Session) -> str: @@ -82,6 +94,10 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: assert "Modules Profiled" in res +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="session.sql is not supported in localtesting", +) def test_anonymous_procedure(session, db_parameters): def single_value_sp(session: snowflake.snowpark.Session) -> str: return "success" From bf4d16952c94d4e3cee09754f9e120f1c423ec9e Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 17 Sep 2024 15:36:03 -0700 Subject: [PATCH 13/62] coverage test --- src/snowflake/snowpark/profiler.py | 4 +- tests/integ/test_profiler.py | 73 ++++++++++++++++++++++++++---- 2 files changed, 67 insertions(+), 10 deletions(-) diff --git a/src/snowflake/snowpark/profiler.py b/src/snowflake/snowpark/profiler.py index b9bc5e0b2fe..b4138d3471e 100644 --- a/src/snowflake/snowpark/profiler.py +++ b/src/snowflake/snowpark/profiler.py @@ -86,9 +86,9 @@ def set_active_profiler(self, active_profiler: str): active_profiler: String that represent active_profiler, must be either 'LINE' or 'MEMORY' (case-sensitive). """ - if self.active_profiler not in ["LINE", "MEMORY"]: + if active_profiler not in ["LINE", "MEMORY"]: raise ValueError( - f"active_profiler expect 'LINE' or 'MEMORY', got {self.active_profiler} instead" + f"active_profiler expect 'LINE' or 'MEMORY', got {active_profiler} instead" ) self.active_profiler = active_profiler self._prepare_sql() diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index 91193448a7e..0ad3c4b6a7b 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -48,26 +48,32 @@ def table_sp(session: snowflake.snowpark.Session) -> DataFrame: reason="session.sql is not supported in localtesting", ) def test_profiler_with_profiler_class(session, db_parameters): + another_tmp_stage_name = Utils.random_stage_name() + @sproc(name="table_sp", replace=True) def table_sp(session: snowflake.snowpark.Session) -> DataFrame: return session.sql("select 1") - profiler = Profiler() - profiler.register_profiler_modules(["table_sp"]) - profiler.set_active_profiler("LINE") - profiler.set_targeted_stage( + pro = Profiler() + pro.register_profiler_modules(["table_sp"]) + pro.set_active_profiler("LINE") + pro.set_targeted_stage( f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" ) - session.register_profiler(profiler) + session.register_profiler(pro) + + pro.set_targeted_stage( + f"{db_parameters['database']}.{db_parameters['schema']}.{another_tmp_stage_name}" + ) - profiler.enable_profiler() + pro.enable_profiler() session.call("table_sp").collect() res = session.show_profiles() - profiler.disable_profiler() + pro.disable_profiler() - profiler.register_profiler_modules([]) + pro.register_profiler_modules([]) assert res is not None assert "Modules Profiled" in res @@ -114,3 +120,54 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: session.register_profiler_modules([]) assert res is not None assert "Modules Profiled" in res + + +def test_not_set_profiler_error(session, tmpdir): + with pytest.raises(ValueError) as e: + session.show_profiles() + assert "profiler is not set, use session.register_profiler or profiler context manager" in str(e) + + with pytest.raises(ValueError) as e: + session.dump_profiles(tmpdir.join("file.txt")) + assert "profiler is not set, use session.register_profiler or profiler context manager" in str(e) + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="session.sql is not supported in localtesting", +) +def test_register_module_without_profiler(session, db_parameters): + session.register_profiler_modules(["fake_module"]) + res = session.sql("show parameters like 'python_profiler_modules'").collect() + assert res[0].value == "fake_module" + session.register_profiler_modules([]) + + +def test_set_incorrect_active_profiler(): + pro = Profiler() + with pytest.raises(ValueError) as e: + pro.set_active_profiler("wrong_active_profiler") + assert "active_profiler expect 'LINE' or 'MEMORY', got wrong_active_profiler instead" in str(e) + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="session.sql is not supported in localtesting", +) +def test_dump_profile_to_file(session, db_parameters, tmpdir): + file = tmpdir.join("profile.lprof") + def single_value_sp(session: snowflake.snowpark.Session) -> str: + return "success" + + single_value_sp = session.sproc.register(single_value_sp, anonymous=True) + session.register_profiler_modules(["table_sp"]) + with profiler( + stage=f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}", + active_profiler="LINE", + session=session, + ): + single_value_sp() + session.dump_profiles(file) + session.register_profiler_modules([]) + with open(file, "r") as f: + assert "Modules Profiled" in f.read() From e788125a55abc364a811b61cbbae608b669e5734 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 17 Sep 2024 15:42:20 -0700 Subject: [PATCH 14/62] fix lint From c45ca1ebc92b41f0dcee2129d4518a5aded88ed9 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 17 Sep 2024 15:42:53 -0700 Subject: [PATCH 15/62] lint fix --- tests/integ/test_profiler.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index 0ad3c4b6a7b..17fe81eaa94 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -125,11 +125,17 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: def test_not_set_profiler_error(session, tmpdir): with pytest.raises(ValueError) as e: session.show_profiles() - assert "profiler is not set, use session.register_profiler or profiler context manager" in str(e) + assert ( + "profiler is not set, use session.register_profiler or profiler context manager" + in str(e) + ) with pytest.raises(ValueError) as e: session.dump_profiles(tmpdir.join("file.txt")) - assert "profiler is not set, use session.register_profiler or profiler context manager" in str(e) + assert ( + "profiler is not set, use session.register_profiler or profiler context manager" + in str(e) + ) @pytest.mark.skipif( @@ -147,7 +153,10 @@ def test_set_incorrect_active_profiler(): pro = Profiler() with pytest.raises(ValueError) as e: pro.set_active_profiler("wrong_active_profiler") - assert "active_profiler expect 'LINE' or 'MEMORY', got wrong_active_profiler instead" in str(e) + assert ( + "active_profiler expect 'LINE' or 'MEMORY', got wrong_active_profiler instead" + in str(e) + ) @pytest.mark.skipif( @@ -156,6 +165,7 @@ def test_set_incorrect_active_profiler(): ) def test_dump_profile_to_file(session, db_parameters, tmpdir): file = tmpdir.join("profile.lprof") + def single_value_sp(session: snowflake.snowpark.Session) -> str: return "success" @@ -169,5 +179,5 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: single_value_sp() session.dump_profiles(file) session.register_profiler_modules([]) - with open(file, "r") as f: + with open(file) as f: assert "Modules Profiled" in f.read() From 72e0162f321ad6a2c702779cfcd9ffcce275e296 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 18 Sep 2024 09:56:43 -0700 Subject: [PATCH 16/62] fix test --- tests/integ/conftest.py | 33 +++++++++++++ tests/integ/test_profiler.py | 96 ++++++++++++++++++------------------ 2 files changed, 82 insertions(+), 47 deletions(-) diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index ec619605e66..8c0d855d0da 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -219,6 +219,39 @@ def session( session.close() +@pytest.fixture(scope="function") +def profiler_session( + db_parameters, + resources_path, + sql_simplifier_enabled, + local_testing_mode, + cte_optimization_enabled, +): + rule1 = f"rule1{Utils.random_alphanumeric_str(10)}" + rule2 = f"rule2{Utils.random_alphanumeric_str(10)}" + key1 = f"key1{Utils.random_alphanumeric_str(10)}" + key2 = f"key2{Utils.random_alphanumeric_str(10)}" + integration1 = f"integration1{Utils.random_alphanumeric_str(10)}" + integration2 = f"integration2{Utils.random_alphanumeric_str(10)}" + session = ( + Session.builder.configs(db_parameters) + .config("local_testing", local_testing_mode) + .create() + ) + session.sql_simplifier_enabled = sql_simplifier_enabled + session._cte_optimization_enabled = cte_optimization_enabled + if os.getenv("GITHUB_ACTIONS") == "true" and not local_testing_mode: + set_up_external_access_integration_resources( + session, rule1, rule2, key1, key2, integration1, integration2 + ) + yield session + if os.getenv("GITHUB_ACTIONS") == "true" and not local_testing_mode: + clean_up_external_access_integration_resources( + session, rule1, rule2, key1, key2, integration1, integration2 + ) + session.close() + + @pytest.fixture(scope="function") def temp_schema(connection, session, local_testing_mode) -> None: """Set up and tear down a temp schema for cross-schema test. diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index 17fe81eaa94..e29d781316d 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -14,30 +14,30 @@ @pytest.fixture(scope="module", autouse=True) -def setup(session, resources_path, local_testing_mode): +def setup(profiler_session, resources_path, local_testing_mode): if not local_testing_mode: - Utils.create_stage(session, tmp_stage_name, is_temporary=True) - session.add_packages("snowflake-snowpark-python") + Utils.create_stage(profiler_session, tmp_stage_name, is_temporary=True) + profiler_session.add_packages("snowflake-snowpark-python") @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) -def test_profiler_with_context_manager(session, db_parameters): +def test_profiler_with_context_manager(profiler_session, db_parameters): @sproc(name="table_sp", replace=True) - def table_sp(session: snowflake.snowpark.Session) -> DataFrame: - return session.sql("select 1") + def table_sp(profiler_session: snowflake.snowpark.profiler_session) -> DataFrame: + return profiler_session.sql("select 1") - session.register_profiler_modules(["table_sp"]) + profiler_session.register_profiler_modules(["table_sp"]) with profiler( stage=f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}", active_profiler="LINE", - session=session, + session=profiler_session, ): - session.call("table_sp").collect() - res = session.show_profiles() - session.register_profiler_modules([]) + profiler_session.call("table_sp").collect() + res = profiler_session.show_profiles() + profiler_session.register_profiler_modules([]) assert res is not None print(type(res)) assert "Modules Profiled" in res @@ -47,12 +47,12 @@ def table_sp(session: snowflake.snowpark.Session) -> DataFrame: "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) -def test_profiler_with_profiler_class(session, db_parameters): +def test_profiler_with_profiler_class(profiler_session, db_parameters): another_tmp_stage_name = Utils.random_stage_name() @sproc(name="table_sp", replace=True) - def table_sp(session: snowflake.snowpark.Session) -> DataFrame: - return session.sql("select 1") + def table_sp(profiler_session: snowflake.snowpark.profiler_session) -> DataFrame: + return profiler_session.sql("select 1") pro = Profiler() pro.register_profiler_modules(["table_sp"]) @@ -60,7 +60,7 @@ def table_sp(session: snowflake.snowpark.Session) -> DataFrame: pro.set_targeted_stage( f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" ) - session.register_profiler(pro) + profiler_session.register_profiler(pro) pro.set_targeted_stage( f"{db_parameters['database']}.{db_parameters['schema']}.{another_tmp_stage_name}" @@ -68,8 +68,8 @@ def table_sp(session: snowflake.snowpark.Session) -> DataFrame: pro.enable_profiler() - session.call("table_sp").collect() - res = session.show_profiles() + profiler_session.call("table_sp").collect() + res = profiler_session.show_profiles() pro.disable_profiler() @@ -82,20 +82,20 @@ def table_sp(session: snowflake.snowpark.Session) -> DataFrame: "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) -def test_single_return_value_of_sp(session, db_parameters): +def test_single_return_value_of_sp(profiler_session, db_parameters): @sproc(name="single_value_sp", replace=True) - def single_value_sp(session: snowflake.snowpark.Session) -> str: + def single_value_sp(profiler_session: snowflake.snowpark.profiler_session) -> str: return "success" - session.register_profiler_modules(["table_sp"]) + profiler_session.register_profiler_modules(["table_sp"]) with profiler( stage=f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}", active_profiler="LINE", - session=session, + session=profiler_session, ): - session.call("single_value_sp") - res = session.show_profiles() - session.register_profiler_modules([]) + profiler_session.call("single_value_sp") + res = profiler_session.show_profiles() + profiler_session.register_profiler_modules([]) assert res is not None assert "Modules Profiled" in res @@ -104,36 +104,36 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) -def test_anonymous_procedure(session, db_parameters): - def single_value_sp(session: snowflake.snowpark.Session) -> str: +def test_anonymous_procedure(profiler_session, db_parameters): + def single_value_sp(profiler_session: snowflake.snowpark.profiler_session) -> str: return "success" - single_value_sp = session.sproc.register(single_value_sp, anonymous=True) - session.register_profiler_modules(["table_sp"]) + single_value_sp = profiler_session.sproc.register(single_value_sp, anonymous=True) + profiler_session.register_profiler_modules(["table_sp"]) with profiler( stage=f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}", active_profiler="LINE", - session=session, + session=profiler_session, ): single_value_sp() - res = session.show_profiles() - session.register_profiler_modules([]) + res = profiler_session.show_profiles() + profiler_session.register_profiler_modules([]) assert res is not None assert "Modules Profiled" in res -def test_not_set_profiler_error(session, tmpdir): +def test_not_set_profiler_error(profiler_session, tmpdir): with pytest.raises(ValueError) as e: - session.show_profiles() + profiler_session.show_profiles() assert ( - "profiler is not set, use session.register_profiler or profiler context manager" + "profiler is not set, use profiler_session.register_profiler or profiler context manager" in str(e) ) with pytest.raises(ValueError) as e: - session.dump_profiles(tmpdir.join("file.txt")) + profiler_session.dump_profiles(tmpdir.join("file.txt")) assert ( - "profiler is not set, use session.register_profiler or profiler context manager" + "profiler is not set, use profiler_session.register_profiler or profiler context manager" in str(e) ) @@ -142,11 +142,13 @@ def test_not_set_profiler_error(session, tmpdir): "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) -def test_register_module_without_profiler(session, db_parameters): - session.register_profiler_modules(["fake_module"]) - res = session.sql("show parameters like 'python_profiler_modules'").collect() +def test_register_module_without_profiler(profiler_session, db_parameters): + profiler_session.register_profiler_modules(["fake_module"]) + res = profiler_session.sql( + "show parameters like 'python_profiler_modules'" + ).collect() assert res[0].value == "fake_module" - session.register_profiler_modules([]) + profiler_session.register_profiler_modules([]) def test_set_incorrect_active_profiler(): @@ -163,21 +165,21 @@ def test_set_incorrect_active_profiler(): "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) -def test_dump_profile_to_file(session, db_parameters, tmpdir): +def test_dump_profile_to_file(profiler_session, db_parameters, tmpdir): file = tmpdir.join("profile.lprof") - def single_value_sp(session: snowflake.snowpark.Session) -> str: + def single_value_sp(profiler_session: snowflake.snowpark.profiler_session) -> str: return "success" - single_value_sp = session.sproc.register(single_value_sp, anonymous=True) - session.register_profiler_modules(["table_sp"]) + single_value_sp = profiler_session.sproc.register(single_value_sp, anonymous=True) + profiler_session.register_profiler_modules(["table_sp"]) with profiler( stage=f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}", active_profiler="LINE", - session=session, + session=profiler_session, ): single_value_sp() - session.dump_profiles(file) - session.register_profiler_modules([]) + profiler_session.dump_profiles(file) + profiler_session.register_profiler_modules([]) with open(file) as f: assert "Modules Profiled" in f.read() From 1d992f2e4f9684dd3825a90773617a1b98506056 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 18 Sep 2024 09:59:04 -0700 Subject: [PATCH 17/62] fix test --- tests/integ/test_profiler.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index e29d781316d..202283a7153 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -26,8 +26,8 @@ def setup(profiler_session, resources_path, local_testing_mode): ) def test_profiler_with_context_manager(profiler_session, db_parameters): @sproc(name="table_sp", replace=True) - def table_sp(profiler_session: snowflake.snowpark.profiler_session) -> DataFrame: - return profiler_session.sql("select 1") + def table_sp(session: snowflake.snowpark.session) -> DataFrame: + return session.sql("select 1") profiler_session.register_profiler_modules(["table_sp"]) with profiler( @@ -51,8 +51,8 @@ def test_profiler_with_profiler_class(profiler_session, db_parameters): another_tmp_stage_name = Utils.random_stage_name() @sproc(name="table_sp", replace=True) - def table_sp(profiler_session: snowflake.snowpark.profiler_session) -> DataFrame: - return profiler_session.sql("select 1") + def table_sp(session: snowflake.snowpark.session) -> DataFrame: + return session.sql("select 1") pro = Profiler() pro.register_profiler_modules(["table_sp"]) @@ -84,7 +84,7 @@ def table_sp(profiler_session: snowflake.snowpark.profiler_session) -> DataFrame ) def test_single_return_value_of_sp(profiler_session, db_parameters): @sproc(name="single_value_sp", replace=True) - def single_value_sp(profiler_session: snowflake.snowpark.profiler_session) -> str: + def single_value_sp(session: snowflake.snowpark.profiler_session) -> str: return "success" profiler_session.register_profiler_modules(["table_sp"]) @@ -105,7 +105,7 @@ def single_value_sp(profiler_session: snowflake.snowpark.profiler_session) -> st reason="session.sql is not supported in localtesting", ) def test_anonymous_procedure(profiler_session, db_parameters): - def single_value_sp(profiler_session: snowflake.snowpark.profiler_session) -> str: + def single_value_sp(session: snowflake.snowpark.profiler_session) -> str: return "success" single_value_sp = profiler_session.sproc.register(single_value_sp, anonymous=True) @@ -126,14 +126,14 @@ def test_not_set_profiler_error(profiler_session, tmpdir): with pytest.raises(ValueError) as e: profiler_session.show_profiles() assert ( - "profiler is not set, use profiler_session.register_profiler or profiler context manager" + "profiler is not set, use session.register_profiler or profiler context manager" in str(e) ) with pytest.raises(ValueError) as e: profiler_session.dump_profiles(tmpdir.join("file.txt")) assert ( - "profiler is not set, use profiler_session.register_profiler or profiler context manager" + "profiler is not set, use session.register_profiler or profiler context manager" in str(e) ) @@ -168,7 +168,7 @@ def test_set_incorrect_active_profiler(): def test_dump_profile_to_file(profiler_session, db_parameters, tmpdir): file = tmpdir.join("profile.lprof") - def single_value_sp(profiler_session: snowflake.snowpark.profiler_session) -> str: + def single_value_sp(session: snowflake.snowpark.profiler_session) -> str: return "success" single_value_sp = profiler_session.sproc.register(single_value_sp, anonymous=True) From 60c4aba0c42273707b61ae30de2780d0ea5286f9 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 18 Sep 2024 11:43:17 -0700 Subject: [PATCH 18/62] fix test --- src/snowflake/snowpark/profiler.py | 30 ++++++++++++++++++++++++++++-- src/snowflake/snowpark/session.py | 14 ++++++++++++-- tests/integ/test_profiler.py | 30 ++++++++++++++++-------------- 3 files changed, 56 insertions(+), 18 deletions(-) diff --git a/src/snowflake/snowpark/profiler.py b/src/snowflake/snowpark/profiler.py index b4138d3471e..9eb85994723 100644 --- a/src/snowflake/snowpark/profiler.py +++ b/src/snowflake/snowpark/profiler.py @@ -73,6 +73,18 @@ def set_targeted_stage(self, stage: str): self.stage = stage self._prepare_sql() if self.session is not None: + if ( + len(self.session.sql(f"show stages like '{self.stage}'").collect()) == 0 + and len( + self.session.sql( + f"show stages like '{self.stage.split('.')[-1]}'" + ).collect() + ) + == 0 + ): + self.session.sql( + f"create temp stage {self.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" + ).collect() self._set_targeted_stage() def set_active_profiler(self, active_profiler: str): @@ -133,7 +145,7 @@ def show_profiles(self) -> str: This function must be called right after the execution of stored procedure you want to profile. """ query_id = self._get_last_query_id() - sql = f"select snowflake.core.get_python_profiler_output('{query_id}');" + sql = f"select snowflake.core.get_python_profiler_output('{query_id}')" res = self.session.sql(sql).collect() print(res[0][0]) # noqa: T201: we need to print here. return res[0][0] @@ -149,7 +161,7 @@ def dump_profiles(self, dst_file: str): dst_file: String of file name that you want to store the profiles. """ query_id = self._get_last_query_id() - sql = f"select snowflake.core.get_python_profiler_output('{query_id}');" + sql = f"select snowflake.core.get_python_profiler_output('{query_id}')" res = self.session.sql(sql).collect() with open(dst_file, "w") as f: f.write(str(res[0][0])) @@ -167,6 +179,20 @@ def profiler( internal_profiler.query_history = session.query_history() modules = modules or [] try: + # create stage if not exist + if ( + len(session.sql(f"show stages like '{internal_profiler.stage}'").collect()) + == 0 + and len( + session.sql( + f"show stages like '{internal_profiler.stage.split('.')[-1]}'" + ).collect() + ) + == 0 + ): + session.sql( + f"create temp stage {internal_profiler.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" + ).collect() # set up phase internal_profiler._set_targeted_stage() internal_profiler._set_active_profiler() diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 63f5e3846bd..0e877619ea4 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -3463,9 +3463,19 @@ def register_profiler(self, profiler: Profiler): """Register a profiler to current session, all action are actually executed during this function""" self.profiler = profiler self.profiler.session = self - if len(self.sql(f"show stages like '{profiler.stage}'").collect()) == 0: + self.sql(f"show stages like '{profiler.stage}'").show() + self.sql(f"show stages like '{profiler.stage.split('.')[-1]}'").show() + if ( + len(self.sql(f"show stages like '{profiler.stage}'").collect()) == 0 + and len( + self.sql( + f"show stages like '{profiler.stage.split('.')[-1]}'" + ).collect() + ) + == 0 + ): self.sql( - f"create or replace temp stage {profiler.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" + f"create temp stage {profiler.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" ).collect() self.profiler._register_modules() self.profiler._set_targeted_stage() diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index 202283a7153..544ae0ff60c 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -10,13 +10,16 @@ from snowflake.snowpark.profiler import Profiler, profiler from tests.utils import Utils -tmp_stage_name = Utils.random_stage_name() +@pytest.fixture(scope="function") +def tmp_stage_name(): + tmp_stage_name = Utils.random_stage_name() + yield tmp_stage_name -@pytest.fixture(scope="module", autouse=True) + +@pytest.fixture(scope="function", autouse=True) def setup(profiler_session, resources_path, local_testing_mode): if not local_testing_mode: - Utils.create_stage(profiler_session, tmp_stage_name, is_temporary=True) profiler_session.add_packages("snowflake-snowpark-python") @@ -24,9 +27,9 @@ def setup(profiler_session, resources_path, local_testing_mode): "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) -def test_profiler_with_context_manager(profiler_session, db_parameters): +def test_profiler_with_context_manager(profiler_session, db_parameters, tmp_stage_name): @sproc(name="table_sp", replace=True) - def table_sp(session: snowflake.snowpark.session) -> DataFrame: + def table_sp(session: snowflake.snowpark.Session) -> DataFrame: return session.sql("select 1") profiler_session.register_profiler_modules(["table_sp"]) @@ -39,7 +42,6 @@ def table_sp(session: snowflake.snowpark.session) -> DataFrame: res = profiler_session.show_profiles() profiler_session.register_profiler_modules([]) assert res is not None - print(type(res)) assert "Modules Profiled" in res @@ -47,11 +49,11 @@ def table_sp(session: snowflake.snowpark.session) -> DataFrame: "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) -def test_profiler_with_profiler_class(profiler_session, db_parameters): +def test_profiler_with_profiler_class(profiler_session, db_parameters, tmp_stage_name): another_tmp_stage_name = Utils.random_stage_name() @sproc(name="table_sp", replace=True) - def table_sp(session: snowflake.snowpark.session) -> DataFrame: + def table_sp(session: snowflake.snowpark.Session) -> DataFrame: return session.sql("select 1") pro = Profiler() @@ -82,9 +84,9 @@ def table_sp(session: snowflake.snowpark.session) -> DataFrame: "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) -def test_single_return_value_of_sp(profiler_session, db_parameters): +def test_single_return_value_of_sp(profiler_session, db_parameters, tmp_stage_name): @sproc(name="single_value_sp", replace=True) - def single_value_sp(session: snowflake.snowpark.profiler_session) -> str: + def single_value_sp(session: snowflake.snowpark.Session) -> str: return "success" profiler_session.register_profiler_modules(["table_sp"]) @@ -104,8 +106,8 @@ def single_value_sp(session: snowflake.snowpark.profiler_session) -> str: "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) -def test_anonymous_procedure(profiler_session, db_parameters): - def single_value_sp(session: snowflake.snowpark.profiler_session) -> str: +def test_anonymous_procedure(profiler_session, db_parameters, tmp_stage_name): + def single_value_sp(session: snowflake.snowpark.Session) -> str: return "success" single_value_sp = profiler_session.sproc.register(single_value_sp, anonymous=True) @@ -165,10 +167,10 @@ def test_set_incorrect_active_profiler(): "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) -def test_dump_profile_to_file(profiler_session, db_parameters, tmpdir): +def test_dump_profile_to_file(profiler_session, db_parameters, tmpdir, tmp_stage_name): file = tmpdir.join("profile.lprof") - def single_value_sp(session: snowflake.snowpark.profiler_session) -> str: + def single_value_sp(session: snowflake.snowpark.Session) -> str: return "success" single_value_sp = profiler_session.sproc.register(single_value_sp, anonymous=True) From eda9559e3a2a255cb518eed89ad10d4f85dca59d Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 18 Sep 2024 16:40:11 -0700 Subject: [PATCH 19/62] add test --- tests/integ/test_profiler.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index 544ae0ff60c..9ab705737a0 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -23,6 +23,17 @@ def setup(profiler_session, resources_path, local_testing_mode): profiler_session.add_packages("snowflake-snowpark-python") +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="session.sql is not supported in localtesting", +) +def test_profiler_function_exist(profiler_session): + res = profiler_session.sql( + "show functions like 'GET_PYTHON_PROFILER_OUTPUT' in snowflake.core" + ).collect() + assert len(res) != 0 + + @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", From d196ddd0041900a523db2b75429900bf14332e4f Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 18 Sep 2024 16:56:38 -0700 Subject: [PATCH 20/62] add unit test --- src/snowflake/snowpark/profiler.py | 13 ++++++++----- src/snowflake/snowpark/session.py | 2 +- tests/unit/test_profiler.py | 21 +++++++++++++++++++++ 3 files changed, 30 insertions(+), 6 deletions(-) create mode 100644 tests/unit/test_profiler.py diff --git a/src/snowflake/snowpark/profiler.py b/src/snowflake/snowpark/profiler.py index 9eb85994723..9c21976ea4d 100644 --- a/src/snowflake/snowpark/profiler.py +++ b/src/snowflake/snowpark/profiler.py @@ -31,6 +31,7 @@ def __init__( self.enable_profiler_sql = "" self.disable_profiler_sql = "" self.set_active_profiler_sql = "" + self.pattern = r"WITH\s+.*?\s+AS\s+PROCEDURE\s+.*?\s+CALL\s+.*" self.session = session self._prepare_sql() self.query_history = None @@ -83,7 +84,7 @@ def set_targeted_stage(self, stage: str): == 0 ): self.session.sql( - f"create temp stage {self.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" + f"create temp stage if not exist {self.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" ).collect() self._set_targeted_stage() @@ -128,11 +129,13 @@ def disable_profiler(self): """ self.session.sql(self.disable_profiler_sql).collect() + def _is_sp_call(self, query): + return re.match(self.pattern, query, re.DOTALL) is not None + def _get_last_query_id(self): - pattern = r"WITH\s+.*?\s+AS\s+PROCEDURE\s+.*?\s+CALL\s+.*" for query in self.query_history.queries[::-1]: - if query.sql_text.startswith("CALL") or re.match( - pattern, query.sql_text, re.DOTALL + if query.sql_text.upper().startswith("CALL") or self._is_sp_call( + query.sql_text ): return query.query_id return None @@ -191,7 +194,7 @@ def profiler( == 0 ): session.sql( - f"create temp stage {internal_profiler.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" + f"create temp stage if not exist {internal_profiler.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" ).collect() # set up phase internal_profiler._set_targeted_stage() diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index e17219417cb..73eb387bf84 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -3475,7 +3475,7 @@ def register_profiler(self, profiler: Profiler): == 0 ): self.sql( - f"create temp stage {profiler.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" + f"create temp stage if not exist {profiler.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" ).collect() self.profiler._register_modules() self.profiler._set_targeted_stage() diff --git a/tests/unit/test_profiler.py b/tests/unit/test_profiler.py new file mode 100644 index 00000000000..74cccb3f2aa --- /dev/null +++ b/tests/unit/test_profiler.py @@ -0,0 +1,21 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +from snowflake.snowpark.profiler import Profiler + + +def test_sp_call_match(): + pro = Profiler() + sp_call_sql = """WITH myProcedure AS PROCEDURE () + RETURNS TABLE ( ) + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ( 'snowflake-snowpark-python==1.2.0', 'pandas==1.3.3' ) + IMPORTS = ( '@my_stage/file1.py', '@my_stage/file2.py' ) + HANDLER = 'my_function' + RETURNS NULL ON NULL INPUT +AS 'fake' +CALL myProcedure()INTO :result + """ + assert pro._is_sp_call(sp_call_sql) From 6a7046d3a450e0e412f021f8e03b0805eaef2b60 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 18 Sep 2024 17:42:01 -0700 Subject: [PATCH 21/62] fix test --- src/snowflake/snowpark/profiler.py | 4 ++-- src/snowflake/snowpark/session.py | 2 +- tests/integ/test_profiler.py | 3 +++ 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/snowflake/snowpark/profiler.py b/src/snowflake/snowpark/profiler.py index 9c21976ea4d..ad0c2fc8169 100644 --- a/src/snowflake/snowpark/profiler.py +++ b/src/snowflake/snowpark/profiler.py @@ -84,7 +84,7 @@ def set_targeted_stage(self, stage: str): == 0 ): self.session.sql( - f"create temp stage if not exist {self.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" + f"create temp stage if not exists {self.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" ).collect() self._set_targeted_stage() @@ -194,7 +194,7 @@ def profiler( == 0 ): session.sql( - f"create temp stage if not exist {internal_profiler.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" + f"create temp stage if not exists {internal_profiler.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" ).collect() # set up phase internal_profiler._set_targeted_stage() diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 73eb387bf84..60a4dc1c0b3 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -3475,7 +3475,7 @@ def register_profiler(self, profiler: Profiler): == 0 ): self.sql( - f"create temp stage if not exist {profiler.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" + f"create temp stage if not exists {profiler.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" ).collect() self.profiler._register_modules() self.profiler._set_targeted_stage() diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index 9ab705737a0..067f52e9d0a 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -33,6 +33,9 @@ def test_profiler_function_exist(profiler_session): ).collect() assert len(res) != 0 + res = profiler_session.sql("select current_role()").collect() + assert res[0][0] == "TESTROLE_SNOWPARK_PYTHON" + @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", From a27e64c41e43c76fe45a971c5fa020cc3e8ecf4e Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 19 Sep 2024 11:29:52 -0700 Subject: [PATCH 22/62] make test robust --- tests/integ/test_profiler.py | 39 ++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index 067f52e9d0a..875f6ec59dc 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -11,6 +11,17 @@ from tests.utils import Utils +def is_profiler_function_exist(profiler_session, local_testing_mode): + if local_testing_mode: + return False + functions = profiler_session.sql( + "show functions like 'GET_PYTHON_PROFILER_OUTPUT' in snowflake.core" + ).collect() + if len(functions) == 0: + return False + return True + + @pytest.fixture(scope="function") def tmp_stage_name(): tmp_stage_name = Utils.random_stage_name() @@ -27,6 +38,10 @@ def setup(profiler_session, resources_path, local_testing_mode): "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) +@pytest.mark.skipif( + not is_profiler_function_exist, + reason="profiler function does not exist or in local testing mode", +) def test_profiler_function_exist(profiler_session): res = profiler_session.sql( "show functions like 'GET_PYTHON_PROFILER_OUTPUT' in snowflake.core" @@ -41,6 +56,10 @@ def test_profiler_function_exist(profiler_session): "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) +@pytest.mark.skipif( + not is_profiler_function_exist, + reason="profiler function does not exist or in local testing mode", +) def test_profiler_with_context_manager(profiler_session, db_parameters, tmp_stage_name): @sproc(name="table_sp", replace=True) def table_sp(session: snowflake.snowpark.Session) -> DataFrame: @@ -63,6 +82,10 @@ def table_sp(session: snowflake.snowpark.Session) -> DataFrame: "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) +@pytest.mark.skipif( + not is_profiler_function_exist, + reason="profiler function does not exist or in local testing mode", +) def test_profiler_with_profiler_class(profiler_session, db_parameters, tmp_stage_name): another_tmp_stage_name = Utils.random_stage_name() @@ -98,6 +121,10 @@ def table_sp(session: snowflake.snowpark.Session) -> DataFrame: "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) +@pytest.mark.skipif( + not is_profiler_function_exist, + reason="profiler function does not exist or in local testing mode", +) def test_single_return_value_of_sp(profiler_session, db_parameters, tmp_stage_name): @sproc(name="single_value_sp", replace=True) def single_value_sp(session: snowflake.snowpark.Session) -> str: @@ -120,6 +147,10 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) +@pytest.mark.skipif( + not is_profiler_function_exist, + reason="profiler function does not exist or in local testing mode", +) def test_anonymous_procedure(profiler_session, db_parameters, tmp_stage_name): def single_value_sp(session: snowflake.snowpark.Session) -> str: return "success" @@ -158,6 +189,10 @@ def test_not_set_profiler_error(profiler_session, tmpdir): "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) +@pytest.mark.skipif( + not is_profiler_function_exist, + reason="profiler function does not exist or in local testing mode", +) def test_register_module_without_profiler(profiler_session, db_parameters): profiler_session.register_profiler_modules(["fake_module"]) res = profiler_session.sql( @@ -181,6 +216,10 @@ def test_set_incorrect_active_profiler(): "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) +@pytest.mark.skipif( + not is_profiler_function_exist, + reason="profiler function does not exist or in local testing mode", +) def test_dump_profile_to_file(profiler_session, db_parameters, tmpdir, tmp_stage_name): file = tmpdir.join("profile.lprof") From 3e6a77172ddb0168ba65c2272269d5fc79a33a1f Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 20 Sep 2024 10:59:10 -0700 Subject: [PATCH 23/62] address comments --- src/snowflake/snowpark/profiler.py | 18 ++++++++++-------- src/snowflake/snowpark/session.py | 9 ++++++--- tests/integ/test_profiler.py | 4 ++-- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/snowflake/snowpark/profiler.py b/src/snowflake/snowpark/profiler.py index ad0c2fc8169..662450431d8 100644 --- a/src/snowflake/snowpark/profiler.py +++ b/src/snowflake/snowpark/profiler.py @@ -41,8 +41,7 @@ def _prepare_sql(self): self.set_targeted_stage_sql = ( f'alter session set PYTHON_PROFILER_TARGET_STAGE ="{self.stage}"' ) - self.enable_profiler_sql = "alter session set ENABLE_PYTHON_PROFILER = true" - self.disable_profiler_sql = "alter session set ENABLE_PYTHON_PROFILER = false" + self.disable_profiler_sql = "alter session set ACTIVE_PYTHON_PROFILER = ''" self.set_active_profiler_sql = f"alter session set ACTIVE_PYTHON_PROFILER = '{self.active_profiler.upper()}'" def register_profiler_modules(self, modules: List[str]): @@ -94,7 +93,7 @@ def set_active_profiler(self, active_profiler: str): Note: Active profiler must be either 'LINE' or 'MEMORY' (case-sensitive), - active profiler is set to 'LINE' by default. + active profiler is 'LINE' by default. Args: active_profiler: String that represent active_profiler, must be either 'LINE' or 'MEMORY' (case-sensitive). @@ -117,13 +116,16 @@ def _set_targeted_stage(self): def _set_active_profiler(self): self.session.sql(self.set_active_profiler_sql).collect() - def enable_profiler(self): + def enable(self): """ Enable profiler. Profiles will be generated until profiler is disabled. """ - self.session.sql(self.enable_profiler_sql).collect() + if self.active_profiler == "": + self.active_profiler = "LINE" + self._prepare_sql() + self._set_active_profiler() - def disable_profiler(self): + def disable(self): """ Disable profiler. """ @@ -202,9 +204,9 @@ def profiler( internal_profiler.register_profiler_modules(modules) internal_profiler._register_modules() - internal_profiler.enable_profiler() + internal_profiler.enable() finally: yield internal_profiler.register_profiler_modules([]) internal_profiler._register_modules() - internal_profiler.disable_profiler() + internal_profiler.disable() diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 60a4dc1c0b3..860cef96834 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -3460,11 +3460,14 @@ def flatten( return df def register_profiler(self, profiler: Profiler): - """Register a profiler to current session, all action are actually executed during this function""" + """Register a profiler to a session, all action are actually executed during this function""" + if ( + profiler.session is not None + and profiler.session._session_id != self._session_id + ): + raise ValueError("A profiler can only be registered to one session.") self.profiler = profiler self.profiler.session = self - self.sql(f"show stages like '{profiler.stage}'").show() - self.sql(f"show stages like '{profiler.stage.split('.')[-1]}'").show() if ( len(self.sql(f"show stages like '{profiler.stage}'").collect()) == 0 and len( diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index 875f6ec59dc..d7f09b805bd 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -105,12 +105,12 @@ def table_sp(session: snowflake.snowpark.Session) -> DataFrame: f"{db_parameters['database']}.{db_parameters['schema']}.{another_tmp_stage_name}" ) - pro.enable_profiler() + pro.enable() profiler_session.call("table_sp").collect() res = profiler_session.show_profiles() - pro.disable_profiler() + pro.disable() pro.register_profiler_modules([]) assert res is not None From 0d75b841c003e81624bc50bc6c009cd5bc7a9864 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 20 Sep 2024 11:50:38 -0700 Subject: [PATCH 24/62] fix test --- tests/integ/test_profiler.py | 65 ++++++++++++------------------------ 1 file changed, 22 insertions(+), 43 deletions(-) diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index d7f09b805bd..9050d741b7e 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -11,15 +11,13 @@ from tests.utils import Utils -def is_profiler_function_exist(profiler_session, local_testing_mode): - if local_testing_mode: - return False +@pytest.fixture(scope="function") +def is_profiler_function_exist(profiler_session): functions = profiler_session.sql( "show functions like 'GET_PYTHON_PROFILER_OUTPUT' in snowflake.core" ).collect() if len(functions) == 0: - return False - return True + pytest.skip("profiler function does not exist") @pytest.fixture(scope="function") @@ -38,29 +36,20 @@ def setup(profiler_session, resources_path, local_testing_mode): "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) -@pytest.mark.skipif( - not is_profiler_function_exist, - reason="profiler function does not exist or in local testing mode", -) -def test_profiler_function_exist(profiler_session): +def test_profiler_function_exist(is_profiler_function_exist, profiler_session): res = profiler_session.sql( "show functions like 'GET_PYTHON_PROFILER_OUTPUT' in snowflake.core" ).collect() assert len(res) != 0 - res = profiler_session.sql("select current_role()").collect() - assert res[0][0] == "TESTROLE_SNOWPARK_PYTHON" - @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) -@pytest.mark.skipif( - not is_profiler_function_exist, - reason="profiler function does not exist or in local testing mode", -) -def test_profiler_with_context_manager(profiler_session, db_parameters, tmp_stage_name): +def test_profiler_with_context_manager( + is_profiler_function_exist, profiler_session, db_parameters, tmp_stage_name +): @sproc(name="table_sp", replace=True) def table_sp(session: snowflake.snowpark.Session) -> DataFrame: return session.sql("select 1") @@ -82,11 +71,9 @@ def table_sp(session: snowflake.snowpark.Session) -> DataFrame: "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) -@pytest.mark.skipif( - not is_profiler_function_exist, - reason="profiler function does not exist or in local testing mode", -) -def test_profiler_with_profiler_class(profiler_session, db_parameters, tmp_stage_name): +def test_profiler_with_profiler_class( + is_profiler_function_exist, profiler_session, db_parameters, tmp_stage_name +): another_tmp_stage_name = Utils.random_stage_name() @sproc(name="table_sp", replace=True) @@ -121,11 +108,9 @@ def table_sp(session: snowflake.snowpark.Session) -> DataFrame: "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) -@pytest.mark.skipif( - not is_profiler_function_exist, - reason="profiler function does not exist or in local testing mode", -) -def test_single_return_value_of_sp(profiler_session, db_parameters, tmp_stage_name): +def test_single_return_value_of_sp( + is_profiler_function_exist, profiler_session, db_parameters, tmp_stage_name +): @sproc(name="single_value_sp", replace=True) def single_value_sp(session: snowflake.snowpark.Session) -> str: return "success" @@ -147,11 +132,9 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) -@pytest.mark.skipif( - not is_profiler_function_exist, - reason="profiler function does not exist or in local testing mode", -) -def test_anonymous_procedure(profiler_session, db_parameters, tmp_stage_name): +def test_anonymous_procedure( + is_profiler_function_exist, profiler_session, db_parameters, tmp_stage_name +): def single_value_sp(session: snowflake.snowpark.Session) -> str: return "success" @@ -189,11 +172,9 @@ def test_not_set_profiler_error(profiler_session, tmpdir): "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) -@pytest.mark.skipif( - not is_profiler_function_exist, - reason="profiler function does not exist or in local testing mode", -) -def test_register_module_without_profiler(profiler_session, db_parameters): +def test_register_module_without_profiler( + is_profiler_function_exist, profiler_session, db_parameters +): profiler_session.register_profiler_modules(["fake_module"]) res = profiler_session.sql( "show parameters like 'python_profiler_modules'" @@ -216,11 +197,9 @@ def test_set_incorrect_active_profiler(): "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) -@pytest.mark.skipif( - not is_profiler_function_exist, - reason="profiler function does not exist or in local testing mode", -) -def test_dump_profile_to_file(profiler_session, db_parameters, tmpdir, tmp_stage_name): +def test_dump_profile_to_file( + is_profiler_function_exist, profiler_session, db_parameters, tmpdir, tmp_stage_name +): file = tmpdir.join("profile.lprof") def single_value_sp(session: snowflake.snowpark.Session) -> str: From f9df3cdf34dccc99af2a85a162e273bdc027c110 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 20 Sep 2024 11:53:58 -0700 Subject: [PATCH 25/62] address comments --- src/snowflake/snowpark/profiler.py | 32 +++++++++++++++--------------- src/snowflake/snowpark/session.py | 10 ++++------ 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/src/snowflake/snowpark/profiler.py b/src/snowflake/snowpark/profiler.py index 662450431d8..9a2b258a513 100644 --- a/src/snowflake/snowpark/profiler.py +++ b/src/snowflake/snowpark/profiler.py @@ -26,25 +26,25 @@ def __init__( self.stage = stage self.active_profiler = active_profiler self.modules_to_register = [] - self.register_modules_sql = "" - self.set_targeted_stage_sql = "" - self.enable_profiler_sql = "" - self.disable_profiler_sql = "" - self.set_active_profiler_sql = "" + self._register_modules_sql = "" + self._set_targeted_stage_sql = "" + self._enable_profiler_sql = "" + self._disable_profiler_sql = "" + self._set_active_profiler_sql = "" self.pattern = r"WITH\s+.*?\s+AS\s+PROCEDURE\s+.*?\s+CALL\s+.*" self.session = session self._prepare_sql() self.query_history = None def _prepare_sql(self): - self.register_modules_sql = f"alter session set python_profiler_modules='{','.join(self.modules_to_register)}'" - self.set_targeted_stage_sql = ( + self._register_modules_sql = f"alter session set python_profiler_modules='{','.join(self.modules_to_register)}'" + self._set_targeted_stage_sql = ( f'alter session set PYTHON_PROFILER_TARGET_STAGE ="{self.stage}"' ) - self.disable_profiler_sql = "alter session set ACTIVE_PYTHON_PROFILER = ''" - self.set_active_profiler_sql = f"alter session set ACTIVE_PYTHON_PROFILER = '{self.active_profiler.upper()}'" + self._disable_profiler_sql = "alter session set ACTIVE_PYTHON_PROFILER = ''" + self._set_active_profiler_sql = f"alter session set ACTIVE_PYTHON_PROFILER = '{self.active_profiler.upper()}'" - def register_profiler_modules(self, modules: List[str]): + def register_profiler_modules(self, stored_procedures: List[str]): """ Register stored procedures to generate profiles for them. @@ -52,9 +52,9 @@ def register_profiler_modules(self, modules: List[str]): Registered nodules will be overwritten by this function, use this function with an empty string will remove registered modules. Args: - modules: List of names of stored procedures. + stored_procedures: List of names of stored procedures. """ - self.modules_to_register = modules + self.modules_to_register = stored_procedures self._prepare_sql() if self.session is not None: self._register_modules() @@ -108,13 +108,13 @@ def set_active_profiler(self, active_profiler: str): self._set_active_profiler() def _register_modules(self): - self.session.sql(self.register_modules_sql).collect() + self.session.sql(self._register_modules_sql).collect() def _set_targeted_stage(self): - self.session.sql(self.set_targeted_stage_sql).collect() + self.session.sql(self._set_targeted_stage_sql).collect() def _set_active_profiler(self): - self.session.sql(self.set_active_profiler_sql).collect() + self.session.sql(self._set_active_profiler_sql).collect() def enable(self): """ @@ -129,7 +129,7 @@ def disable(self): """ Disable profiler. """ - self.session.sql(self.disable_profiler_sql).collect() + self.session.sql(self._disable_profiler_sql).collect() def _is_sp_call(self, query): return re.match(self.pattern, query, re.DOTALL) is not None diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 860cef96834..41269970a2c 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -3516,7 +3516,7 @@ def dump_profiles(self, dst_file: str): "profiler is not set, use session.register_profiler or profiler context manager" ) - def register_profiler_modules(self, modules: List[str]): + def register_profiler_modules(self, stored_procedures: List[str]): """ Register stored procedures to generate profiles for them. @@ -3524,14 +3524,12 @@ def register_profiler_modules(self, modules: List[str]): Registered nodules will be overwritten by this function, use this function with an empty string will remove registered modules. Args: - modules: List of names of stored procedures. + stored_procedures: List of names of stored procedures. """ if self.profiler is not None and isinstance(self.profiler, Profiler): - self.profiler.register_profiler_modules(modules) + self.profiler.register_profiler_modules(stored_procedures) else: - sql_statement = ( - f"alter session set python_profiler_modules='{','.join(modules)}'" - ) + sql_statement = f"alter session set python_profiler_modules='{','.join(stored_procedures)}'" self.sql(sql_statement).collect() def query_history(self, include_describe: bool = False) -> QueryHistory: From de0d5c0b35b4fb68f76b8a4cf74fd1c0e4093b62 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 26 Sep 2024 17:00:12 -0700 Subject: [PATCH 26/62] align with doc --- src/snowflake/snowpark/profiler.py | 154 +++++++++-------------------- src/snowflake/snowpark/session.py | 75 +------------- tests/integ/test_profiler.py | 138 ++++++++------------------ 3 files changed, 88 insertions(+), 279 deletions(-) diff --git a/src/snowflake/snowpark/profiler.py b/src/snowflake/snowpark/profiler.py index 9a2b258a513..93c7e632271 100644 --- a/src/snowflake/snowpark/profiler.py +++ b/src/snowflake/snowpark/profiler.py @@ -2,8 +2,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # import re -from contextlib import contextmanager -from typing import List, Optional +from typing import List import snowflake.snowpark from snowflake.snowpark._internal.utils import validate_object_name @@ -19,32 +18,16 @@ class Profiler: def __init__( self, - stage: Optional[str] = "", - active_profiler: Optional[str] = "LINE", - session: Optional["snowflake.snowpark.Session"] = None, + session: "snowflake.snowpark.Session" = None, ) -> None: - self.stage = stage - self.active_profiler = active_profiler - self.modules_to_register = [] - self._register_modules_sql = "" - self._set_targeted_stage_sql = "" - self._enable_profiler_sql = "" - self._disable_profiler_sql = "" - self._set_active_profiler_sql = "" + self.stage = "" + self.active_profiler = "" + self.registered_stored_procedures = [] self.pattern = r"WITH\s+.*?\s+AS\s+PROCEDURE\s+.*?\s+CALL\s+.*" self.session = session - self._prepare_sql() - self.query_history = None + self.query_history = session.query_history() - def _prepare_sql(self): - self._register_modules_sql = f"alter session set python_profiler_modules='{','.join(self.modules_to_register)}'" - self._set_targeted_stage_sql = ( - f'alter session set PYTHON_PROFILER_TARGET_STAGE ="{self.stage}"' - ) - self._disable_profiler_sql = "alter session set ACTIVE_PYTHON_PROFILER = ''" - self._set_active_profiler_sql = f"alter session set ACTIVE_PYTHON_PROFILER = '{self.active_profiler.upper()}'" - - def register_profiler_modules(self, stored_procedures: List[str]): + def register_modules(self, stored_procedures: List[str]): """ Register stored procedures to generate profiles for them. @@ -54,10 +37,11 @@ def register_profiler_modules(self, stored_procedures: List[str]): Args: stored_procedures: List of names of stored procedures. """ - self.modules_to_register = stored_procedures - self._prepare_sql() - if self.session is not None: - self._register_modules() + self.registered_stored_procedures = stored_procedures + sql_statement = ( + f"alter session set python_profiler_modules='{','.join(stored_procedures)}'" + ) + self.session.sql(sql_statement).collect() def set_targeted_stage(self, stage: str): """ @@ -71,21 +55,20 @@ def set_targeted_stage(self, stage: str): """ validate_object_name(stage) self.stage = stage - self._prepare_sql() - if self.session is not None: - if ( - len(self.session.sql(f"show stages like '{self.stage}'").collect()) == 0 - and len( - self.session.sql( - f"show stages like '{self.stage.split('.')[-1]}'" - ).collect() - ) - == 0 - ): + if ( + len(self.session.sql(f"show stages like '{self.stage}'").collect()) == 0 + and len( self.session.sql( - f"create temp stage if not exists {self.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" + f"show stages like '{self.stage.split('.')[-1]}'" ).collect() - self._set_targeted_stage() + ) + == 0 + ): + self.session.sql( + f"create temp stage if not exists {self.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" + ).collect() + sql_statement = f'alter session set PYTHON_PROFILER_TARGET_STAGE ="{stage}"' + self.session.sql(sql_statement).collect() def set_active_profiler(self, active_profiler: str): """ @@ -98,38 +81,21 @@ def set_active_profiler(self, active_profiler: str): active_profiler: String that represent active_profiler, must be either 'LINE' or 'MEMORY' (case-sensitive). """ - if active_profiler not in ["LINE", "MEMORY"]: + if active_profiler not in ["LINE", "MEMORY", ""]: raise ValueError( - f"active_profiler expect 'LINE' or 'MEMORY', got {active_profiler} instead" + f"active_profiler expect 'LINE', 'MEMORY' or empty string '', got {active_profiler} instead" ) self.active_profiler = active_profiler - self._prepare_sql() - if self.session is not None: - self._set_active_profiler() - - def _register_modules(self): - self.session.sql(self._register_modules_sql).collect() - - def _set_targeted_stage(self): - self.session.sql(self._set_targeted_stage_sql).collect() - - def _set_active_profiler(self): - self.session.sql(self._set_active_profiler_sql).collect() - - def enable(self): - """ - Enable profiler. Profiles will be generated until profiler is disabled. - """ - if self.active_profiler == "": - self.active_profiler = "LINE" - self._prepare_sql() - self._set_active_profiler() + sql_statement = f"alter session set ACTIVE_PYTHON_PROFILER = '{self.active_profiler.upper()}'" + self.session.sql(sql_statement).collect() def disable(self): """ Disable profiler. """ - self.session.sql(self._disable_profiler_sql).collect() + self.active_profiler = "" + sql_statement = "alter session set ACTIVE_PYTHON_PROFILER = ''" + self.session.sql(sql_statement).collect() def _is_sp_call(self, query): return re.match(self.pattern, query, re.DOTALL) is not None @@ -142,9 +108,9 @@ def _get_last_query_id(self): return query.query_id return None - def show_profiles(self) -> str: + def show(self) -> None: """ - Return and show the profiles of last executed stored procedure. + Show the profiles of last executed stored procedure. Note: This function must be called right after the execution of stored procedure you want to profile. @@ -153,9 +119,19 @@ def show_profiles(self) -> str: sql = f"select snowflake.core.get_python_profiler_output('{query_id}')" res = self.session.sql(sql).collect() print(res[0][0]) # noqa: T201: we need to print here. - return res[0][0] - def dump_profiles(self, dst_file: str): + def collect(self) -> str: + """ + Return the profiles of last executed stored procedure. + + Note: + This function must be called right after the execution of stored procedure you want to profile. + """ + query_id = self._get_last_query_id() + sql = f"select snowflake.core.get_python_profiler_output('{query_id}')" + return self.session.sql(sql).collect()[0][0] + + def dump(self, dst_file: str): """ Write the profiles of last executed stored procedure to given file. @@ -170,43 +146,3 @@ def dump_profiles(self, dst_file: str): res = self.session.sql(sql).collect() with open(dst_file, "w") as f: f.write(str(res[0][0])) - - -@contextmanager -def profiler( - stage: str, - active_profiler: str, - session: "snowflake.snowpark.Session", - modules: Optional[List[str]] = None, -): - internal_profiler = Profiler(stage, active_profiler, session) - session.profiler = internal_profiler - internal_profiler.query_history = session.query_history() - modules = modules or [] - try: - # create stage if not exist - if ( - len(session.sql(f"show stages like '{internal_profiler.stage}'").collect()) - == 0 - and len( - session.sql( - f"show stages like '{internal_profiler.stage.split('.')[-1]}'" - ).collect() - ) - == 0 - ): - session.sql( - f"create temp stage if not exists {internal_profiler.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" - ).collect() - # set up phase - internal_profiler._set_targeted_stage() - internal_profiler._set_active_profiler() - - internal_profiler.register_profiler_modules(modules) - internal_profiler._register_modules() - internal_profiler.enable() - finally: - yield - internal_profiler.register_profiler_modules([]) - internal_profiler._register_modules() - internal_profiler.disable() diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 41269970a2c..ebe7610f554 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -604,7 +604,7 @@ def __init__( self._conf = self.RuntimeConfig(self, options or {}) self._runtime_version_from_requirement: str = None self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self) - self.profiler = None + self.profiler = Profiler(session=self) _logger.info("Snowpark Session information: %s", self._session_info) @@ -3459,79 +3459,6 @@ def flatten( set_api_call_source(df, "Session.flatten") return df - def register_profiler(self, profiler: Profiler): - """Register a profiler to a session, all action are actually executed during this function""" - if ( - profiler.session is not None - and profiler.session._session_id != self._session_id - ): - raise ValueError("A profiler can only be registered to one session.") - self.profiler = profiler - self.profiler.session = self - if ( - len(self.sql(f"show stages like '{profiler.stage}'").collect()) == 0 - and len( - self.sql( - f"show stages like '{profiler.stage.split('.')[-1]}'" - ).collect() - ) - == 0 - ): - self.sql( - f"create temp stage if not exists {profiler.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" - ).collect() - self.profiler._register_modules() - self.profiler._set_targeted_stage() - self.profiler._set_active_profiler() - self.profiler.query_history = self.query_history() - - def show_profiles(self) -> str: - """ - Return and show the profiles of last executed stored procedure. - - Note: - This function must be called right after the execution of stored procedure you want to profile. - """ - if self.profiler is not None and isinstance(self.profiler, Profiler): - return self.profiler.show_profiles() - else: - raise ValueError( - "profiler is not set, use session.register_profiler or profiler context manager" - ) - - def dump_profiles(self, dst_file: str): - """ - Write the profiles of last executed stored procedure to given file. - - Note: - This function must be called right after the execution of stored procedure you want to profile. - - Args: - dst_file: String of file name that you want to store the profiles. - """ - if self.profiler is not None and isinstance(self.profiler, Profiler): - self.profiler.dump_profiles(dst_file=dst_file) - else: - raise ValueError( - "profiler is not set, use session.register_profiler or profiler context manager" - ) - - def register_profiler_modules(self, stored_procedures: List[str]): - """ - Register stored procedures to generate profiles for them. - - Note: - Registered nodules will be overwritten by this function, - use this function with an empty string will remove registered modules. - Args: - stored_procedures: List of names of stored procedures. - """ - if self.profiler is not None and isinstance(self.profiler, Profiler): - self.profiler.register_profiler_modules(stored_procedures) - else: - sql_statement = f"alter session set python_profiler_modules='{','.join(stored_procedures)}'" - self.sql(sql_statement).collect() - def query_history(self, include_describe: bool = False) -> QueryHistory: """Create an instance of :class:`QueryHistory` as a context manager to record queries that are pushed down to the Snowflake database. diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index 9050d741b7e..54d74348d3f 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -7,7 +7,6 @@ import snowflake.snowpark from snowflake.snowpark import DataFrame from snowflake.snowpark.functions import sproc -from snowflake.snowpark.profiler import Profiler, profiler from tests.utils import Utils @@ -43,30 +42,6 @@ def test_profiler_function_exist(is_profiler_function_exist, profiler_session): assert len(res) != 0 -@pytest.mark.skipif( - "config.getoption('local_testing_mode', default=False)", - reason="session.sql is not supported in localtesting", -) -def test_profiler_with_context_manager( - is_profiler_function_exist, profiler_session, db_parameters, tmp_stage_name -): - @sproc(name="table_sp", replace=True) - def table_sp(session: snowflake.snowpark.Session) -> DataFrame: - return session.sql("select 1") - - profiler_session.register_profiler_modules(["table_sp"]) - with profiler( - stage=f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}", - active_profiler="LINE", - session=profiler_session, - ): - profiler_session.call("table_sp").collect() - res = profiler_session.show_profiles() - profiler_session.register_profiler_modules([]) - assert res is not None - assert "Modules Profiled" in res - - @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", @@ -74,32 +49,24 @@ def table_sp(session: snowflake.snowpark.Session) -> DataFrame: def test_profiler_with_profiler_class( is_profiler_function_exist, profiler_session, db_parameters, tmp_stage_name ): - another_tmp_stage_name = Utils.random_stage_name() - @sproc(name="table_sp", replace=True) def table_sp(session: snowflake.snowpark.Session) -> DataFrame: return session.sql("select 1") - pro = Profiler() - pro.register_profiler_modules(["table_sp"]) - pro.set_active_profiler("LINE") + pro = profiler_session.profiler + pro.register_modules(["table_sp"]) pro.set_targeted_stage( f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" ) - profiler_session.register_profiler(pro) - pro.set_targeted_stage( - f"{db_parameters['database']}.{db_parameters['schema']}.{another_tmp_stage_name}" - ) - - pro.enable() + pro.set_active_profiler("LINE") profiler_session.call("table_sp").collect() - res = profiler_session.show_profiles() + res = pro.collect() pro.disable() - pro.register_profiler_modules([]) + pro.register_modules([]) assert res is not None assert "Modules Profiled" in res @@ -115,15 +82,19 @@ def test_single_return_value_of_sp( def single_value_sp(session: snowflake.snowpark.Session) -> str: return "success" - profiler_session.register_profiler_modules(["table_sp"]) - with profiler( - stage=f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}", - active_profiler="LINE", - session=profiler_session, - ): - profiler_session.call("single_value_sp") - res = profiler_session.show_profiles() - profiler_session.register_profiler_modules([]) + profiler_session.profiler.register_modules(["single_value_sp"]) + profiler_session.profiler.pro.set_targeted_stage( + f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" + ) + + profiler_session.profiler.set_active_profiler("LINE") + + profiler_session.call("single_value_sp").collect() + res = profiler_session.profiler.collect() + + profiler_session.profiler.disable() + + profiler_session.profiler.register_modules([]) assert res is not None assert "Modules Profiled" in res @@ -139,54 +110,26 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: return "success" single_value_sp = profiler_session.sproc.register(single_value_sp, anonymous=True) - profiler_session.register_profiler_modules(["table_sp"]) - with profiler( - stage=f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}", - active_profiler="LINE", - session=profiler_session, - ): - single_value_sp() - res = profiler_session.show_profiles() - profiler_session.register_profiler_modules([]) - assert res is not None - assert "Modules Profiled" in res - -def test_not_set_profiler_error(profiler_session, tmpdir): - with pytest.raises(ValueError) as e: - profiler_session.show_profiles() - assert ( - "profiler is not set, use session.register_profiler or profiler context manager" - in str(e) + profiler_session.profiler.pro.set_targeted_stage( + f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" ) - with pytest.raises(ValueError) as e: - profiler_session.dump_profiles(tmpdir.join("file.txt")) - assert ( - "profiler is not set, use session.register_profiler or profiler context manager" - in str(e) - ) + profiler_session.profiler.set_active_profiler("LINE") + single_value_sp() + res = profiler_session.profiler.collect() -@pytest.mark.skipif( - "config.getoption('local_testing_mode', default=False)", - reason="session.sql is not supported in localtesting", -) -def test_register_module_without_profiler( - is_profiler_function_exist, profiler_session, db_parameters -): - profiler_session.register_profiler_modules(["fake_module"]) - res = profiler_session.sql( - "show parameters like 'python_profiler_modules'" - ).collect() - assert res[0].value == "fake_module" - profiler_session.register_profiler_modules([]) + profiler_session.profiler.disable() + + profiler_session.profiler.register_modules([]) + assert res is not None + assert "Modules Profiled" in res -def test_set_incorrect_active_profiler(): - pro = Profiler() +def test_set_incorrect_active_profiler(profiler_session): with pytest.raises(ValueError) as e: - pro.set_active_profiler("wrong_active_profiler") + profiler_session.profiler.set_active_profiler("wrong_active_profiler") assert ( "active_profiler expect 'LINE' or 'MEMORY', got wrong_active_profiler instead" in str(e) @@ -206,14 +149,17 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: return "success" single_value_sp = profiler_session.sproc.register(single_value_sp, anonymous=True) - profiler_session.register_profiler_modules(["table_sp"]) - with profiler( - stage=f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}", - active_profiler="LINE", - session=profiler_session, - ): - single_value_sp() - profiler_session.dump_profiles(file) - profiler_session.register_profiler_modules([]) + profiler_session.profiler.pro.set_targeted_stage( + f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" + ) + + profiler_session.profiler.set_active_profiler("LINE") + + single_value_sp() + profiler_session.profiler.dump(file) + + profiler_session.profiler.disable() + + profiler_session.profiler.register_modules([]) with open(file) as f: assert "Modules Profiled" in f.read() From 8ee830e64a659267c40c01d858d2927a14cab5e3 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 30 Sep 2024 09:50:12 -0700 Subject: [PATCH 27/62] fix test --- tests/integ/test_profiler.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index 54d74348d3f..b08922fb872 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -61,7 +61,7 @@ def table_sp(session: snowflake.snowpark.Session) -> DataFrame: pro.set_active_profiler("LINE") - profiler_session.call("table_sp").collect() + profiler_session.call("table_sp") res = pro.collect() pro.disable() @@ -83,13 +83,13 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: return "success" profiler_session.profiler.register_modules(["single_value_sp"]) - profiler_session.profiler.pro.set_targeted_stage( + profiler_session.profiler.set_targeted_stage( f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" ) profiler_session.profiler.set_active_profiler("LINE") - profiler_session.call("single_value_sp").collect() + profiler_session.call("single_value_sp") res = profiler_session.profiler.collect() profiler_session.profiler.disable() @@ -111,7 +111,7 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: single_value_sp = profiler_session.sproc.register(single_value_sp, anonymous=True) - profiler_session.profiler.pro.set_targeted_stage( + profiler_session.profiler.set_targeted_stage( f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" ) @@ -130,10 +130,7 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: def test_set_incorrect_active_profiler(profiler_session): with pytest.raises(ValueError) as e: profiler_session.profiler.set_active_profiler("wrong_active_profiler") - assert ( - "active_profiler expect 'LINE' or 'MEMORY', got wrong_active_profiler instead" - in str(e) - ) + assert "active_profiler expect 'LINE', 'MEMORY' or empty string ''" in str(e) @pytest.mark.skipif( @@ -149,7 +146,7 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: return "success" single_value_sp = profiler_session.sproc.register(single_value_sp, anonymous=True) - profiler_session.profiler.pro.set_targeted_stage( + profiler_session.profiler.set_targeted_stage( f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" ) From 1e2d86c383910d0aef94881ca9eac7cb7a04a63c Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 30 Sep 2024 09:54:37 -0700 Subject: [PATCH 28/62] fix doc --- CHANGELOG.md | 2 +- src/snowflake/snowpark/profiler.py | 12 +++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b86eee6b355..e38a776f4a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,7 +47,7 @@ This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for det - Added the following new functions in `snowflake.snowpark.functions`: - `array_remove` - `ln` -- Added snowpark python API for profiler. +- Added Snowpark Python API for profiler. #### Improvements diff --git a/src/snowflake/snowpark/profiler.py b/src/snowflake/snowpark/profiler.py index 93c7e632271..d46f6ddd2de 100644 --- a/src/snowflake/snowpark/profiler.py +++ b/src/snowflake/snowpark/profiler.py @@ -10,10 +10,10 @@ class Profiler: """ - Setup profiler to receive profiles of stored procedures. + Set up profiler to receive profiles of stored procedures. Note: - This feature cannot be used in owner's right SP because owner's right SP will not be able to set session-level parameters. + This feature cannot be used in owner's right stored procedure because owner's right stored procedure will not be able to set session-level parameters. """ def __init__( @@ -32,8 +32,7 @@ def register_modules(self, stored_procedures: List[str]): Register stored procedures to generate profiles for them. Note: - Registered nodules will be overwritten by this function, - use this function with an empty string will remove registered modules. + Registered modules will be overwritten by this function. Use this function with an empty string will remove registered modules. Args: stored_procedures: List of names of stored procedures. """ @@ -75,8 +74,7 @@ def set_active_profiler(self, active_profiler: str): Set active profiler. Note: - Active profiler must be either 'LINE' or 'MEMORY' (case-sensitive), - active profiler is 'LINE' by default. + Active profiler must be either 'LINE' or 'MEMORY' (case-sensitive). Active profiler is 'LINE' by default. Args: active_profiler: String that represent active_profiler, must be either 'LINE' or 'MEMORY' (case-sensitive). @@ -110,7 +108,7 @@ def _get_last_query_id(self): def show(self) -> None: """ - Show the profiles of last executed stored procedure. + Show the profiles of the last executed stored procedure. Note: This function must be called right after the execution of stored procedure you want to profile. From 4629d8952f212fc29b7c7fbacdb01cef4c0d28b3 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 30 Sep 2024 11:37:12 -0700 Subject: [PATCH 29/62] fix test --- tests/integ/test_profiler.py | 16 ++++++++++++++++ tests/unit/test_profiler.py | 21 --------------------- 2 files changed, 16 insertions(+), 21 deletions(-) delete mode 100644 tests/unit/test_profiler.py diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index b08922fb872..d871397da7f 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -160,3 +160,19 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: profiler_session.profiler.register_modules([]) with open(file) as f: assert "Modules Profiled" in f.read() + + +def test_sp_call_match(profiler_session): + pro = profiler_session.profiler + sp_call_sql = """WITH myProcedure AS PROCEDURE () + RETURNS TABLE ( ) + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ( 'snowflake-snowpark-python==1.2.0', 'pandas==1.3.3' ) + IMPORTS = ( '@my_stage/file1.py', '@my_stage/file2.py' ) + HANDLER = 'my_function' + RETURNS NULL ON NULL INPUT +AS 'fake' +CALL myProcedure()INTO :result + """ + assert pro._is_sp_call(sp_call_sql) diff --git a/tests/unit/test_profiler.py b/tests/unit/test_profiler.py deleted file mode 100644 index 74cccb3f2aa..00000000000 --- a/tests/unit/test_profiler.py +++ /dev/null @@ -1,21 +0,0 @@ -# -# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. -# - -from snowflake.snowpark.profiler import Profiler - - -def test_sp_call_match(): - pro = Profiler() - sp_call_sql = """WITH myProcedure AS PROCEDURE () - RETURNS TABLE ( ) - LANGUAGE PYTHON - RUNTIME_VERSION = '3.8' - PACKAGES = ( 'snowflake-snowpark-python==1.2.0', 'pandas==1.3.3' ) - IMPORTS = ( '@my_stage/file1.py', '@my_stage/file2.py' ) - HANDLER = 'my_function' - RETURNS NULL ON NULL INPUT -AS 'fake' -CALL myProcedure()INTO :result - """ - assert pro._is_sp_call(sp_call_sql) From 43c0f5ae343ceeadfadfd7a9e115f5581c30e758 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Mon, 30 Sep 2024 16:43:39 -0700 Subject: [PATCH 30/62] fix test --- tests/integ/test_profiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integ/test_profiler.py b/tests/integ/test_profiler.py index d871397da7f..d2bed399b98 100644 --- a/tests/integ/test_profiler.py +++ b/tests/integ/test_profiler.py @@ -63,6 +63,7 @@ def table_sp(session: snowflake.snowpark.Session) -> DataFrame: profiler_session.call("table_sp") res = pro.collect() + pro.show() pro.disable() From 76f5eaa659849ec789511b7c8c0a2ce297c7a032 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 1 Oct 2024 10:12:22 -0700 Subject: [PATCH 31/62] rename to stored procedure profiler --- CHANGELOG.md | 2 +- src/snowflake/snowpark/session.py | 4 ++-- .../snowpark/{profiler.py => stored_procedure_profiler.py} | 2 +- .../{test_profiler.py => test_stored_procedure_profiler.py} | 0 4 files changed, 4 insertions(+), 4 deletions(-) rename src/snowflake/snowpark/{profiler.py => stored_procedure_profiler.py} (99%) rename tests/integ/{test_profiler.py => test_stored_procedure_profiler.py} (100%) diff --git a/CHANGELOG.md b/CHANGELOG.md index e38a776f4a7..66911e7f79e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,7 +47,7 @@ This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for det - Added the following new functions in `snowflake.snowpark.functions`: - `array_remove` - `ln` -- Added Snowpark Python API for profiler. +- Added Snowpark Python API for stored procedure profiler. #### Improvements diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index ebe7610f554..3ddaacafaf8 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -158,10 +158,10 @@ from snowflake.snowpark.mock._plan_builder import MockSnowflakePlanBuilder from snowflake.snowpark.mock._stored_procedure import MockStoredProcedureRegistration from snowflake.snowpark.mock._udf import MockUDFRegistration -from snowflake.snowpark.profiler import Profiler from snowflake.snowpark.query_history import QueryHistory from snowflake.snowpark.row import Row from snowflake.snowpark.stored_procedure import StoredProcedureRegistration +from snowflake.snowpark.stored_procedure_profiler import StoredProcedureProfiler from snowflake.snowpark.table import Table from snowflake.snowpark.table_function import ( TableFunctionCall, @@ -604,7 +604,7 @@ def __init__( self._conf = self.RuntimeConfig(self, options or {}) self._runtime_version_from_requirement: str = None self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self) - self.profiler = Profiler(session=self) + self.profiler = StoredProcedureProfiler(session=self) _logger.info("Snowpark Session information: %s", self._session_info) diff --git a/src/snowflake/snowpark/profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py similarity index 99% rename from src/snowflake/snowpark/profiler.py rename to src/snowflake/snowpark/stored_procedure_profiler.py index d46f6ddd2de..ce0072adbe7 100644 --- a/src/snowflake/snowpark/profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -8,7 +8,7 @@ from snowflake.snowpark._internal.utils import validate_object_name -class Profiler: +class StoredProcedureProfiler: """ Set up profiler to receive profiles of stored procedures. diff --git a/tests/integ/test_profiler.py b/tests/integ/test_stored_procedure_profiler.py similarity index 100% rename from tests/integ/test_profiler.py rename to tests/integ/test_stored_procedure_profiler.py From 1a626f1f28e2d2671749b701cfb310b161ffb185 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 1 Oct 2024 10:15:39 -0700 Subject: [PATCH 32/62] rename to stored procedure profiler --- src/snowflake/snowpark/session.py | 2 +- tests/integ/test_stored_procedure_profiler.py | 40 ++++++++++--------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 3ddaacafaf8..9e4383bde1c 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -604,7 +604,7 @@ def __init__( self._conf = self.RuntimeConfig(self, options or {}) self._runtime_version_from_requirement: str = None self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self) - self.profiler = StoredProcedureProfiler(session=self) + self.stored_procedure_profiler = StoredProcedureProfiler(session=self) _logger.info("Snowpark Session information: %s", self._session_info) diff --git a/tests/integ/test_stored_procedure_profiler.py b/tests/integ/test_stored_procedure_profiler.py index d2bed399b98..46bf872a30c 100644 --- a/tests/integ/test_stored_procedure_profiler.py +++ b/tests/integ/test_stored_procedure_profiler.py @@ -53,7 +53,7 @@ def test_profiler_with_profiler_class( def table_sp(session: snowflake.snowpark.Session) -> DataFrame: return session.sql("select 1") - pro = profiler_session.profiler + pro = profiler_session.stored_procedure_profiler pro.register_modules(["table_sp"]) pro.set_targeted_stage( f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" @@ -83,19 +83,19 @@ def test_single_return_value_of_sp( def single_value_sp(session: snowflake.snowpark.Session) -> str: return "success" - profiler_session.profiler.register_modules(["single_value_sp"]) - profiler_session.profiler.set_targeted_stage( + profiler_session.stored_procedure_profiler.register_modules(["single_value_sp"]) + profiler_session.stored_procedure_profiler.set_targeted_stage( f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" ) - profiler_session.profiler.set_active_profiler("LINE") + profiler_session.stored_procedure_profiler.set_active_profiler("LINE") profiler_session.call("single_value_sp") - res = profiler_session.profiler.collect() + res = profiler_session.stored_procedure_profiler.collect() - profiler_session.profiler.disable() + profiler_session.stored_procedure_profiler.disable() - profiler_session.profiler.register_modules([]) + profiler_session.stored_procedure_profiler.register_modules([]) assert res is not None assert "Modules Profiled" in res @@ -112,25 +112,27 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: single_value_sp = profiler_session.sproc.register(single_value_sp, anonymous=True) - profiler_session.profiler.set_targeted_stage( + profiler_session.stored_procedure_profiler.set_targeted_stage( f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" ) - profiler_session.profiler.set_active_profiler("LINE") + profiler_session.stored_procedure_profiler.set_active_profiler("LINE") single_value_sp() - res = profiler_session.profiler.collect() + res = profiler_session.stored_procedure_profiler.collect() - profiler_session.profiler.disable() + profiler_session.stored_procedure_profiler.disable() - profiler_session.profiler.register_modules([]) + profiler_session.stored_procedure_profiler.register_modules([]) assert res is not None assert "Modules Profiled" in res def test_set_incorrect_active_profiler(profiler_session): with pytest.raises(ValueError) as e: - profiler_session.profiler.set_active_profiler("wrong_active_profiler") + profiler_session.stored_procedure_profiler.set_active_profiler( + "wrong_active_profiler" + ) assert "active_profiler expect 'LINE', 'MEMORY' or empty string ''" in str(e) @@ -147,24 +149,24 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: return "success" single_value_sp = profiler_session.sproc.register(single_value_sp, anonymous=True) - profiler_session.profiler.set_targeted_stage( + profiler_session.stored_procedure_profiler.set_targeted_stage( f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" ) - profiler_session.profiler.set_active_profiler("LINE") + profiler_session.stored_procedure_profiler.set_active_profiler("LINE") single_value_sp() - profiler_session.profiler.dump(file) + profiler_session.stored_procedure_profiler.dump(file) - profiler_session.profiler.disable() + profiler_session.stored_procedure_profiler.disable() - profiler_session.profiler.register_modules([]) + profiler_session.stored_procedure_profiler.register_modules([]) with open(file) as f: assert "Modules Profiled" in f.read() def test_sp_call_match(profiler_session): - pro = profiler_session.profiler + pro = profiler_session.stored_procedure_profiler sp_call_sql = """WITH myProcedure AS PROCEDURE () RETURNS TABLE ( ) LANGUAGE PYTHON From f5c246a512fb53e99e06c0aaea641b4bea908ecd Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 1 Oct 2024 15:14:22 -0700 Subject: [PATCH 33/62] multi thread compatiable --- src/snowflake/snowpark/stored_procedure_profiler.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/stored_procedure_profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py index ce0072adbe7..346a770bcfe 100644 --- a/src/snowflake/snowpark/stored_procedure_profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # import re +import threading from typing import List import snowflake.snowpark @@ -99,11 +100,14 @@ def _is_sp_call(self, query): return re.match(self.pattern, query, re.DOTALL) is not None def _get_last_query_id(self): + current_thread = threading.get_ident() for query in self.query_history.queries[::-1]: - if query.sql_text.upper().startswith("CALL") or self._is_sp_call( - query.sql_text - ): - return query.query_id + query_thread = getattr(query, "thread_id", None) + if query_thread is None or query_thread == current_thread: + if query.sql_text.upper().startswith("CALL") or self._is_sp_call( + query.sql_text + ): + return query.query_id return None def show(self) -> None: From e67cca4be4b06c6c30a0c234c4d403e070d99793 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Tue, 1 Oct 2024 17:01:57 -0700 Subject: [PATCH 34/62] multi thread support --- src/snowflake/snowpark/stored_procedure_profiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/stored_procedure_profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py index 346a770bcfe..3b25a3bd01a 100644 --- a/src/snowflake/snowpark/stored_procedure_profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -26,7 +26,7 @@ def __init__( self.registered_stored_procedures = [] self.pattern = r"WITH\s+.*?\s+AS\s+PROCEDURE\s+.*?\s+CALL\s+.*" self.session = session - self.query_history = session.query_history() + self.query_history = session.query_history(include_thread_id=True) def register_modules(self, stored_procedures: List[str]): """ From 233a9f48f5d6279001fa1c683925dad6e3167d7d Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 2 Oct 2024 10:58:59 -0700 Subject: [PATCH 35/62] remove show and dump --- .../snowpark/stored_procedure_profiler.py | 30 +-------------- tests/integ/test_stored_procedure_profiler.py | 37 ++----------------- 2 files changed, 4 insertions(+), 63 deletions(-) diff --git a/src/snowflake/snowpark/stored_procedure_profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py index 3b25a3bd01a..534832009d2 100644 --- a/src/snowflake/snowpark/stored_procedure_profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -110,19 +110,7 @@ def _get_last_query_id(self): return query.query_id return None - def show(self) -> None: - """ - Show the profiles of the last executed stored procedure. - - Note: - This function must be called right after the execution of stored procedure you want to profile. - """ - query_id = self._get_last_query_id() - sql = f"select snowflake.core.get_python_profiler_output('{query_id}')" - res = self.session.sql(sql).collect() - print(res[0][0]) # noqa: T201: we need to print here. - - def collect(self) -> str: + def get_output(self) -> str: """ Return the profiles of last executed stored procedure. @@ -132,19 +120,3 @@ def collect(self) -> str: query_id = self._get_last_query_id() sql = f"select snowflake.core.get_python_profiler_output('{query_id}')" return self.session.sql(sql).collect()[0][0] - - def dump(self, dst_file: str): - """ - Write the profiles of last executed stored procedure to given file. - - Note: - This function must be called right after the execution of stored procedure you want to profile. - - Args: - dst_file: String of file name that you want to store the profiles. - """ - query_id = self._get_last_query_id() - sql = f"select snowflake.core.get_python_profiler_output('{query_id}')" - res = self.session.sql(sql).collect() - with open(dst_file, "w") as f: - f.write(str(res[0][0])) diff --git a/tests/integ/test_stored_procedure_profiler.py b/tests/integ/test_stored_procedure_profiler.py index 46bf872a30c..87d1ce09991 100644 --- a/tests/integ/test_stored_procedure_profiler.py +++ b/tests/integ/test_stored_procedure_profiler.py @@ -62,9 +62,7 @@ def table_sp(session: snowflake.snowpark.Session) -> DataFrame: pro.set_active_profiler("LINE") profiler_session.call("table_sp") - res = pro.collect() - pro.show() - + res = pro.get_output() pro.disable() pro.register_modules([]) @@ -91,7 +89,7 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: profiler_session.stored_procedure_profiler.set_active_profiler("LINE") profiler_session.call("single_value_sp") - res = profiler_session.stored_procedure_profiler.collect() + res = profiler_session.stored_procedure_profiler.get_output() profiler_session.stored_procedure_profiler.disable() @@ -119,7 +117,7 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: profiler_session.stored_procedure_profiler.set_active_profiler("LINE") single_value_sp() - res = profiler_session.stored_procedure_profiler.collect() + res = profiler_session.stored_procedure_profiler.get_output() profiler_session.stored_procedure_profiler.disable() @@ -136,35 +134,6 @@ def test_set_incorrect_active_profiler(profiler_session): assert "active_profiler expect 'LINE', 'MEMORY' or empty string ''" in str(e) -@pytest.mark.skipif( - "config.getoption('local_testing_mode', default=False)", - reason="session.sql is not supported in localtesting", -) -def test_dump_profile_to_file( - is_profiler_function_exist, profiler_session, db_parameters, tmpdir, tmp_stage_name -): - file = tmpdir.join("profile.lprof") - - def single_value_sp(session: snowflake.snowpark.Session) -> str: - return "success" - - single_value_sp = profiler_session.sproc.register(single_value_sp, anonymous=True) - profiler_session.stored_procedure_profiler.set_targeted_stage( - f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" - ) - - profiler_session.stored_procedure_profiler.set_active_profiler("LINE") - - single_value_sp() - profiler_session.stored_procedure_profiler.dump(file) - - profiler_session.stored_procedure_profiler.disable() - - profiler_session.stored_procedure_profiler.register_modules([]) - with open(file) as f: - assert "Modules Profiled" in f.read() - - def test_sp_call_match(profiler_session): pro = profiler_session.stored_procedure_profiler sp_call_sql = """WITH myProcedure AS PROCEDURE () From aa90a19bb021df3f8a0b65ff6c8974086d31d132 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 2 Oct 2024 11:01:08 -0700 Subject: [PATCH 36/62] fix docstring --- src/snowflake/snowpark/stored_procedure_profiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/stored_procedure_profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py index 534832009d2..0e955a49ec5 100644 --- a/src/snowflake/snowpark/stored_procedure_profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -112,7 +112,7 @@ def _get_last_query_id(self): def get_output(self) -> str: """ - Return the profiles of last executed stored procedure. + Return the profiles of last executed stored procedure in current thread. Note: This function must be called right after the execution of stored procedure you want to profile. From bb4220d6bed29fa17ff12e2236385973343c0769 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 2 Oct 2024 14:39:13 -0700 Subject: [PATCH 37/62] address comments --- .../snowpark/stored_procedure_profiler.py | 29 ++++++++++--------- tests/integ/test_stored_procedure_profiler.py | 2 +- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/snowflake/snowpark/stored_procedure_profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py index 0e955a49ec5..55940d64f95 100644 --- a/src/snowflake/snowpark/stored_procedure_profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -3,7 +3,7 @@ # import re import threading -from typing import List +from typing import List, Optional import snowflake.snowpark from snowflake.snowpark._internal.utils import validate_object_name @@ -22,11 +22,11 @@ def __init__( session: "snowflake.snowpark.Session" = None, ) -> None: self.stage = "" - self.active_profiler = "" + self.active_profiler_type = "" self.registered_stored_procedures = [] - self.pattern = r"WITH\s+.*?\s+AS\s+PROCEDURE\s+.*?\s+CALL\s+.*" self.session = session - self.query_history = session.query_history(include_thread_id=True) + self._pattern = r"WITH\s+.*?\s+AS\s+PROCEDURE\s+.*?\s+CALL\s+.*" + self._query_history = session.query_history(include_thread_id=True) def register_modules(self, stored_procedures: List[str]): """ @@ -70,38 +70,41 @@ def set_targeted_stage(self, stage: str): sql_statement = f'alter session set PYTHON_PROFILER_TARGET_STAGE ="{stage}"' self.session.sql(sql_statement).collect() - def set_active_profiler(self, active_profiler: str): + def set_active_profiler(self, active_profiler_type: Optional[str] = None): """ Set active profiler. Note: Active profiler must be either 'LINE' or 'MEMORY' (case-sensitive). Active profiler is 'LINE' by default. Args: - active_profiler: String that represent active_profiler, must be either 'LINE' or 'MEMORY' (case-sensitive). + active_profiler_type: String that represent active_profiler, must be either 'LINE' or 'MEMORY' (case-sensitive). """ - if active_profiler not in ["LINE", "MEMORY", ""]: + if active_profiler_type is None: + active_profiler_type = "" + + if active_profiler_type not in ["LINE", "MEMORY", ""]: raise ValueError( - f"active_profiler expect 'LINE', 'MEMORY' or empty string '', got {active_profiler} instead" + f"active_profiler expect 'LINE', 'MEMORY' or empty string '', got {active_profiler_type} instead" ) - self.active_profiler = active_profiler - sql_statement = f"alter session set ACTIVE_PYTHON_PROFILER = '{self.active_profiler.upper()}'" + self.active_profiler_type = active_profiler_type + sql_statement = f"alter session set ACTIVE_PYTHON_PROFILER = '{self.active_profiler_type.upper()}'" self.session.sql(sql_statement).collect() def disable(self): """ Disable profiler. """ - self.active_profiler = "" + self.active_profiler_type = "" sql_statement = "alter session set ACTIVE_PYTHON_PROFILER = ''" self.session.sql(sql_statement).collect() def _is_sp_call(self, query): - return re.match(self.pattern, query, re.DOTALL) is not None + return re.match(self._pattern, query, re.DOTALL) is not None def _get_last_query_id(self): current_thread = threading.get_ident() - for query in self.query_history.queries[::-1]: + for query in self._query_history.queries[::-1]: query_thread = getattr(query, "thread_id", None) if query_thread is None or query_thread == current_thread: if query.sql_text.upper().startswith("CALL") or self._is_sp_call( diff --git a/tests/integ/test_stored_procedure_profiler.py b/tests/integ/test_stored_procedure_profiler.py index 87d1ce09991..1e98776ec95 100644 --- a/tests/integ/test_stored_procedure_profiler.py +++ b/tests/integ/test_stored_procedure_profiler.py @@ -119,7 +119,7 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: single_value_sp() res = profiler_session.stored_procedure_profiler.get_output() - profiler_session.stored_procedure_profiler.disable() + profiler_session.stored_procedure_profiler.set_active_profiler() profiler_session.stored_procedure_profiler.register_modules([]) assert res is not None From d6ba864c60b40361f26cd3ad11da68656fa3af9a Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 2 Oct 2024 15:46:33 -0700 Subject: [PATCH 38/62] make prifler thread safe --- .../snowpark/stored_procedure_profiler.py | 59 +++++++++---------- tests/integ/test_stored_procedure_profiler.py | 4 +- 2 files changed, 29 insertions(+), 34 deletions(-) diff --git a/src/snowflake/snowpark/stored_procedure_profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py index 55940d64f95..68fe4d6dc5b 100644 --- a/src/snowflake/snowpark/stored_procedure_profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -3,29 +3,28 @@ # import re import threading -from typing import List, Optional +from typing import List, Literal import snowflake.snowpark from snowflake.snowpark._internal.utils import validate_object_name +STORED_PROCEDURE_CALL_PATTERN = r"WITH\s+.*?\s+AS\s+PROCEDURE\s+.*?\s+CALL\s+.*" + class StoredProcedureProfiler: """ Set up profiler to receive profiles of stored procedures. Note: - This feature cannot be used in owner's right stored procedure because owner's right stored procedure will not be able to set session-level parameters. + This feature cannot be used in owner's right stored procedure because owner's right stored procedure will not be + able to set session-level parameters. """ def __init__( self, session: "snowflake.snowpark.Session" = None, ) -> None: - self.stage = "" - self.active_profiler_type = "" - self.registered_stored_procedures = [] - self.session = session - self._pattern = r"WITH\s+.*?\s+AS\s+PROCEDURE\s+.*?\s+CALL\s+.*" + self._session = session self._query_history = session.query_history(include_thread_id=True) def register_modules(self, stored_procedures: List[str]): @@ -33,15 +32,15 @@ def register_modules(self, stored_procedures: List[str]): Register stored procedures to generate profiles for them. Note: - Registered modules will be overwritten by this function. Use this function with an empty string will remove registered modules. + Registered modules will be overwritten by this function. Use this function with an empty string will remove + registered modules. Args: stored_procedures: List of names of stored procedures. """ - self.registered_stored_procedures = stored_procedures sql_statement = ( f"alter session set python_profiler_modules='{','.join(stored_procedures)}'" ) - self.session.sql(sql_statement).collect() + self._session.sql(sql_statement).collect() def set_targeted_stage(self, stage: str): """ @@ -54,53 +53,49 @@ def set_targeted_stage(self, stage: str): stage: String of fully qualified name of targeted stage """ validate_object_name(stage) - self.stage = stage if ( - len(self.session.sql(f"show stages like '{self.stage}'").collect()) == 0 + len(self._session.sql(f"show stages like '{stage}'").collect()) == 0 and len( - self.session.sql( - f"show stages like '{self.stage.split('.')[-1]}'" + self._session.sql( + f"show stages like '{stage.split('.')[-1]}'" ).collect() ) == 0 ): - self.session.sql( - f"create temp stage if not exists {self.stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" + self._session.sql( + f"create temp stage if not exists {stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" ).collect() sql_statement = f'alter session set PYTHON_PROFILER_TARGET_STAGE ="{stage}"' - self.session.sql(sql_statement).collect() + self._session.sql(sql_statement).collect() - def set_active_profiler(self, active_profiler_type: Optional[str] = None): + def set_active_profiler(self, active_profiler_type: Literal["LINE", "MEMORY"]): """ Set active profiler. Note: Active profiler must be either 'LINE' or 'MEMORY' (case-sensitive). Active profiler is 'LINE' by default. Args: - active_profiler_type: String that represent active_profiler, must be either 'LINE' or 'MEMORY' (case-sensitive). + active_profiler_type: String that represent active_profiler, must be either 'LINE' or 'MEMORY' + (case-sensitive). """ - if active_profiler_type is None: - active_profiler_type = "" - - if active_profiler_type not in ["LINE", "MEMORY", ""]: + if active_profiler_type not in ["LINE", "MEMORY"]: raise ValueError( - f"active_profiler expect 'LINE', 'MEMORY' or empty string '', got {active_profiler_type} instead" + f"active_profiler expect 'LINE', 'MEMORY', got {active_profiler_type} instead" ) - self.active_profiler_type = active_profiler_type - sql_statement = f"alter session set ACTIVE_PYTHON_PROFILER = '{self.active_profiler_type.upper()}'" - self.session.sql(sql_statement).collect() + sql_statement = f"alter session set ACTIVE_PYTHON_PROFILER = '{active_profiler_type.upper()}'" + self._session.sql(sql_statement).collect() def disable(self): """ Disable profiler. """ - self.active_profiler_type = "" sql_statement = "alter session set ACTIVE_PYTHON_PROFILER = ''" - self.session.sql(sql_statement).collect() + self._session.sql(sql_statement).collect() - def _is_sp_call(self, query): - return re.match(self._pattern, query, re.DOTALL) is not None + @staticmethod + def _is_sp_call(query): + return re.match(STORED_PROCEDURE_CALL_PATTERN, query, re.DOTALL) is not None def _get_last_query_id(self): current_thread = threading.get_ident() @@ -122,4 +117,4 @@ def get_output(self) -> str: """ query_id = self._get_last_query_id() sql = f"select snowflake.core.get_python_profiler_output('{query_id}')" - return self.session.sql(sql).collect()[0][0] + return self._session.sql(sql).collect()[0][0] diff --git a/tests/integ/test_stored_procedure_profiler.py b/tests/integ/test_stored_procedure_profiler.py index 1e98776ec95..e4218afa9e2 100644 --- a/tests/integ/test_stored_procedure_profiler.py +++ b/tests/integ/test_stored_procedure_profiler.py @@ -119,7 +119,7 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: single_value_sp() res = profiler_session.stored_procedure_profiler.get_output() - profiler_session.stored_procedure_profiler.set_active_profiler() + profiler_session.stored_procedure_profiler.disable() profiler_session.stored_procedure_profiler.register_modules([]) assert res is not None @@ -131,7 +131,7 @@ def test_set_incorrect_active_profiler(profiler_session): profiler_session.stored_procedure_profiler.set_active_profiler( "wrong_active_profiler" ) - assert "active_profiler expect 'LINE', 'MEMORY' or empty string ''" in str(e) + assert "active_profiler expect 'LINE', 'MEMORY'" in str(e) def test_sp_call_match(profiler_session): From 6b265dc27f073c4b88b5f55fbbc3eca561df2430 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 2 Oct 2024 16:25:52 -0700 Subject: [PATCH 39/62] destroy query history when not in use --- .../snowpark/stored_procedure_profiler.py | 4 ++- tests/integ/test_stored_procedure_profiler.py | 26 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/stored_procedure_profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py index 68fe4d6dc5b..c0d6e56b83f 100644 --- a/src/snowflake/snowpark/stored_procedure_profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -25,7 +25,7 @@ def __init__( session: "snowflake.snowpark.Session" = None, ) -> None: self._session = session - self._query_history = session.query_history(include_thread_id=True) + self._query_history = None def register_modules(self, stored_procedures: List[str]): """ @@ -85,11 +85,13 @@ def set_active_profiler(self, active_profiler_type: Literal["LINE", "MEMORY"]): ) sql_statement = f"alter session set ACTIVE_PYTHON_PROFILER = '{active_profiler_type.upper()}'" self._session.sql(sql_statement).collect() + self._query_history = self._session.query_history(include_thread_id=True) def disable(self): """ Disable profiler. """ + self._session._conn.remove_query_listener(self._query_history) sql_statement = "alter session set ACTIVE_PYTHON_PROFILER = ''" self._session.sql(sql_statement).collect() diff --git a/tests/integ/test_stored_procedure_profiler.py b/tests/integ/test_stored_procedure_profiler.py index e4218afa9e2..370260f1ba5 100644 --- a/tests/integ/test_stored_procedure_profiler.py +++ b/tests/integ/test_stored_procedure_profiler.py @@ -148,3 +148,29 @@ def test_sp_call_match(profiler_session): CALL myProcedure()INTO :result """ assert pro._is_sp_call(sp_call_sql) + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="session.sql is not supported in localtesting", +) +def test_query_history_destroyed_after_finish_profiling( + profiler_session, db_parameters, tmp_stage_name +): + profiler_session.stored_procedure_profiler.set_targeted_stage( + f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" + ) + + profiler_session.stored_procedure_profiler.set_active_profiler("LINE") + assert ( + profiler_session.stored_procedure_profiler._query_history + in profiler_session._conn._query_listener + ) + + profiler_session.stored_procedure_profiler.disable() + assert ( + profiler_session.stored_procedure_profiler._query_history + not in profiler_session._conn._query_listener + ) + + profiler_session.stored_procedure_profiler.register_modules([]) From 63908d141fdd147bdf84ab75b6f3433889ec02bd Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 2 Oct 2024 16:48:49 -0700 Subject: [PATCH 40/62] address comments --- .../snowpark/stored_procedure_profiler.py | 10 +++++----- tests/integ/test_stored_procedure_profiler.py | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/src/snowflake/snowpark/stored_procedure_profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py index c0d6e56b83f..79aef8b21db 100644 --- a/src/snowflake/snowpark/stored_procedure_profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -96,17 +96,17 @@ def disable(self): self._session.sql(sql_statement).collect() @staticmethod - def _is_sp_call(query): - return re.match(STORED_PROCEDURE_CALL_PATTERN, query, re.DOTALL) is not None + def _is_sp_call(query: str): + return re.match( + STORED_PROCEDURE_CALL_PATTERN, query.strip(" "), re.DOTALL + ) is not None or query.upper().strip(" ").startswith("CALL") def _get_last_query_id(self): current_thread = threading.get_ident() for query in self._query_history.queries[::-1]: query_thread = getattr(query, "thread_id", None) if query_thread is None or query_thread == current_thread: - if query.sql_text.upper().startswith("CALL") or self._is_sp_call( - query.sql_text - ): + if self._is_sp_call(query.sql_text): return query.query_id return None diff --git a/tests/integ/test_stored_procedure_profiler.py b/tests/integ/test_stored_procedure_profiler.py index 370260f1ba5..1d0dea48b8e 100644 --- a/tests/integ/test_stored_procedure_profiler.py +++ b/tests/integ/test_stored_procedure_profiler.py @@ -149,6 +149,21 @@ def test_sp_call_match(profiler_session): """ assert pro._is_sp_call(sp_call_sql) + sp_call_sql = """CALL MY_SPROC()""" + assert pro._is_sp_call(sp_call_sql) + + sp_call_sql = """ CALL MY_SPROC()""" + assert pro._is_sp_call(sp_call_sql) + + sp_call_sql = """WITH myProcedure AS PROCEDURE () CALL myProcedure""" + assert pro._is_sp_call(sp_call_sql) + + sp_call_sql = """ WITH myProcedure AS PROCEDURE ... CALL myProcedure""" + assert pro._is_sp_call(sp_call_sql) + + sp_call_sql = """WITH SNOWPARK_TEMP_CTE_1234 AS (SELECT 1 as A) SELECT * FROM SNOWPARK_TEMP_CTE_1234 AS PROCEDURE () CALL myprocedure""" + assert pro._is_sp_call(sp_call_sql) + @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", From 37fa336c5f3f904ff357e25ea57125503bbbc408 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 2 Oct 2024 17:09:21 -0700 Subject: [PATCH 41/62] address comments --- .../snowpark/stored_procedure_profiler.py | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/snowflake/snowpark/stored_procedure_profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py index 79aef8b21db..4ce45cef0d4 100644 --- a/src/snowflake/snowpark/stored_procedure_profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -3,7 +3,7 @@ # import re import threading -from typing import List, Literal +from typing import List, Literal, Union import snowflake.snowpark from snowflake.snowpark._internal.utils import validate_object_name @@ -40,7 +40,7 @@ def register_modules(self, stored_procedures: List[str]): sql_statement = ( f"alter session set python_profiler_modules='{','.join(stored_procedures)}'" ) - self._session.sql(sql_statement).collect() + self._session.sql(sql_statement)._internal_collect_with_tag_no_telemetry() def set_targeted_stage(self, stage: str): """ @@ -54,19 +54,24 @@ def set_targeted_stage(self, stage: str): """ validate_object_name(stage) if ( - len(self._session.sql(f"show stages like '{stage}'").collect()) == 0 + len( + self._session.sql( + f"show stages like '{stage}'" + )._internal_collect_with_tag_no_telemetry() + ) + == 0 and len( self._session.sql( f"show stages like '{stage.split('.')[-1]}'" - ).collect() + )._internal_collect_with_tag_no_telemetry() ) == 0 ): self._session.sql( f"create temp stage if not exists {stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" - ).collect() + )._internal_collect_with_tag_no_telemetry() sql_statement = f'alter session set PYTHON_PROFILER_TARGET_STAGE ="{stage}"' - self._session.sql(sql_statement).collect() + self._session.sql(sql_statement)._internal_collect_with_tag_no_telemetry() def set_active_profiler(self, active_profiler_type: Literal["LINE", "MEMORY"]): """ @@ -84,24 +89,24 @@ def set_active_profiler(self, active_profiler_type: Literal["LINE", "MEMORY"]): f"active_profiler expect 'LINE', 'MEMORY', got {active_profiler_type} instead" ) sql_statement = f"alter session set ACTIVE_PYTHON_PROFILER = '{active_profiler_type.upper()}'" - self._session.sql(sql_statement).collect() + self._session.sql(sql_statement)._internal_collect_with_tag_no_telemetry() self._query_history = self._session.query_history(include_thread_id=True) - def disable(self): + def disable(self) -> None: """ Disable profiler. """ self._session._conn.remove_query_listener(self._query_history) sql_statement = "alter session set ACTIVE_PYTHON_PROFILER = ''" - self._session.sql(sql_statement).collect() + self._session.sql(sql_statement)._internal_collect_with_tag_no_telemetry() @staticmethod - def _is_sp_call(query: str): + def _is_sp_call(query: str) -> bool: return re.match( STORED_PROCEDURE_CALL_PATTERN, query.strip(" "), re.DOTALL ) is not None or query.upper().strip(" ").startswith("CALL") - def _get_last_query_id(self): + def _get_last_query_id(self) -> Union[str, None]: current_thread = threading.get_ident() for query in self._query_history.queries[::-1]: query_thread = getattr(query, "thread_id", None) @@ -119,4 +124,4 @@ def get_output(self) -> str: """ query_id = self._get_last_query_id() sql = f"select snowflake.core.get_python_profiler_output('{query_id}')" - return self._session.sql(sql).collect()[0][0] + return self._session.sql(sql)._internal_collect_with_tag_no_telemetry()[0][0] From 393e98e5e451cf145fa7d5e78bdfcaf997ff6042 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 2 Oct 2024 17:52:08 -0700 Subject: [PATCH 42/62] address comment --- src/snowflake/snowpark/stored_procedure_profiler.py | 2 ++ tests/integ/test_stored_procedure_profiler.py | 10 +++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/stored_procedure_profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py index 4ce45cef0d4..05aa68abffc 100644 --- a/src/snowflake/snowpark/stored_procedure_profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -123,5 +123,7 @@ def get_output(self) -> str: This function must be called right after the execution of stored procedure you want to profile. """ query_id = self._get_last_query_id() + if query_id is None: + raise ValueError("Last executed stored procedure does not exist") sql = f"select snowflake.core.get_python_profiler_output('{query_id}')" return self._session.sql(sql)._internal_collect_with_tag_no_telemetry()[0][0] diff --git a/tests/integ/test_stored_procedure_profiler.py b/tests/integ/test_stored_procedure_profiler.py index 1d0dea48b8e..9e728e45e50 100644 --- a/tests/integ/test_stored_procedure_profiler.py +++ b/tests/integ/test_stored_procedure_profiler.py @@ -126,13 +126,21 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: assert "Modules Profiled" in res -def test_set_incorrect_active_profiler(profiler_session): +def test_set_incorrect_active_profiler(profiler_session, db_parameters, tmp_stage_name): with pytest.raises(ValueError) as e: profiler_session.stored_procedure_profiler.set_active_profiler( "wrong_active_profiler" ) assert "active_profiler expect 'LINE', 'MEMORY'" in str(e) + with pytest.raises(ValueError) as e: + profiler_session.stored_procedure_profiler.set_targeted_stage( + f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" + ) + profiler_session.stored_procedure_profiler.set_active_profiler("LINE") + profiler_session.stored_procedure_profiler.get_output() + assert "Last executed stored procedure does not exist" in str(e) + def test_sp_call_match(profiler_session): pro = profiler_session.stored_procedure_profiler From e53866932f06d1dffd0d599b561297adde50d09e Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 2 Oct 2024 17:54:47 -0700 Subject: [PATCH 43/62] address comments --- docs/source/snowpark/session.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/snowpark/session.rst b/docs/source/snowpark/session.rst index b0a5ef8d4f6..a19f752114d 100644 --- a/docs/source/snowpark/session.rst +++ b/docs/source/snowpark/session.rst @@ -58,6 +58,7 @@ Snowpark Session Session.get_packages Session.get_session_stage Session.query_history + Session.stored_procedure_profiler Session.range Session.remove_import Session.remove_package From 468872d92ba9b9939e2e537a2ff5a8223896b7f1 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 2 Oct 2024 17:55:39 -0700 Subject: [PATCH 44/62] address comments --- src/snowflake/snowpark/stored_procedure_profiler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/stored_procedure_profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py index 05aa68abffc..57d0e7a3ffd 100644 --- a/src/snowflake/snowpark/stored_procedure_profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -88,7 +88,9 @@ def set_active_profiler(self, active_profiler_type: Literal["LINE", "MEMORY"]): raise ValueError( f"active_profiler expect 'LINE', 'MEMORY', got {active_profiler_type} instead" ) - sql_statement = f"alter session set ACTIVE_PYTHON_PROFILER = '{active_profiler_type.upper()}'" + sql_statement = ( + f"alter session set ACTIVE_PYTHON_PROFILER = '{active_profiler_type}'" + ) self._session.sql(sql_statement)._internal_collect_with_tag_no_telemetry() self._query_history = self._session.query_history(include_thread_id=True) From 51da343204614651d45786bd10e30b65c112a67b Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 2 Oct 2024 18:48:57 -0700 Subject: [PATCH 45/62] fix doc test --- src/snowflake/snowpark/session.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index a9cd39a82e7..ae20b856de0 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -602,7 +602,7 @@ def __init__( self._conf = self.RuntimeConfig(self, options or {}) self._runtime_version_from_requirement: str = None self._temp_table_auto_cleaner: TempTableAutoCleaner = TempTableAutoCleaner(self) - self.stored_procedure_profiler = StoredProcedureProfiler(session=self) + self._sp_profiler = StoredProcedureProfiler(session=self) _logger.info("Snowpark Session information: %s", self._session_info) @@ -3245,6 +3245,14 @@ def sproc(self) -> StoredProcedureRegistration: """ return self._sp_registration + @property + def stored_procedure_profiler(self) -> StoredProcedureProfiler: + """ + Returns a :class:`stored_procedure_profiler.StoredProcedureProfiler` object that you can use to profile stored procedures. + See details of how to use this object in :class:`stored_procedure_profiler.StoredProcedureProfiler`. + """ + return self._sp_profiler + def _infer_is_return_table( self, sproc_name: str, *args: Any, log_on_exception: bool = False ) -> bool: From 60885bcb72c0eabe1b65b9cfd3ecf9d867bf3c9d Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Wed, 2 Oct 2024 19:00:44 -0700 Subject: [PATCH 46/62] fixtest --- tests/integ/test_stored_procedure_profiler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/integ/test_stored_procedure_profiler.py b/tests/integ/test_stored_procedure_profiler.py index 9e728e45e50..32cd0716537 100644 --- a/tests/integ/test_stored_procedure_profiler.py +++ b/tests/integ/test_stored_procedure_profiler.py @@ -126,6 +126,10 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: assert "Modules Profiled" in res +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="session.sql is not supported in localtesting", +) def test_set_incorrect_active_profiler(profiler_session, db_parameters, tmp_stage_name): with pytest.raises(ValueError) as e: profiler_session.stored_procedure_profiler.set_active_profiler( From 77b94576bf8b895b5c59b05a2a38321b035330f3 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 3 Oct 2024 09:26:26 -0700 Subject: [PATCH 47/62] add type check --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index 701d450c660..6b61f32315b 100644 --- a/tox.ini +++ b/tox.ini @@ -169,6 +169,7 @@ deps = pyright==1.1.338 commands = pyright src/snowflake/snowpark/_internal/analyzer pyright src/snowflake/snowpark/_internal/compiler + pyright src/snowflake/snowpark/stored_procedure_profiler [testenv:dev] description = create dev environment From 5d299df326ecae06ad87baa154442ee64bca8084 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 3 Oct 2024 09:50:48 -0700 Subject: [PATCH 48/62] fix type check --- src/snowflake/snowpark/stored_procedure_profiler.py | 12 ++++++------ tox.ini | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/snowflake/snowpark/stored_procedure_profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py index 57d0e7a3ffd..1593e1700f3 100644 --- a/src/snowflake/snowpark/stored_procedure_profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -22,7 +22,7 @@ class StoredProcedureProfiler: def __init__( self, - session: "snowflake.snowpark.Session" = None, + session: "snowflake.snowpark.Session", ) -> None: self._session = session self._query_history = None @@ -55,13 +55,13 @@ def set_targeted_stage(self, stage: str): validate_object_name(stage) if ( len( - self._session.sql( + self._session.sql( # type: ignore f"show stages like '{stage}'" )._internal_collect_with_tag_no_telemetry() ) == 0 and len( - self._session.sql( + self._session.sql( # type: ignore f"show stages like '{stage.split('.')[-1]}'" )._internal_collect_with_tag_no_telemetry() ) @@ -98,7 +98,7 @@ def disable(self) -> None: """ Disable profiler. """ - self._session._conn.remove_query_listener(self._query_history) + self._session._conn.remove_query_listener(self._query_history) # type: ignore sql_statement = "alter session set ACTIVE_PYTHON_PROFILER = ''" self._session.sql(sql_statement)._internal_collect_with_tag_no_telemetry() @@ -110,7 +110,7 @@ def _is_sp_call(query: str) -> bool: def _get_last_query_id(self) -> Union[str, None]: current_thread = threading.get_ident() - for query in self._query_history.queries[::-1]: + for query in self._query_history.queries[::-1]: # type: ignore query_thread = getattr(query, "thread_id", None) if query_thread is None or query_thread == current_thread: if self._is_sp_call(query.sql_text): @@ -128,4 +128,4 @@ def get_output(self) -> str: if query_id is None: raise ValueError("Last executed stored procedure does not exist") sql = f"select snowflake.core.get_python_profiler_output('{query_id}')" - return self._session.sql(sql)._internal_collect_with_tag_no_telemetry()[0][0] + return self._session.sql(sql)._internal_collect_with_tag_no_telemetry()[0][0] # type: ignore diff --git a/tox.ini b/tox.ini index 6b61f32315b..cf9d98547c8 100644 --- a/tox.ini +++ b/tox.ini @@ -169,7 +169,7 @@ deps = pyright==1.1.338 commands = pyright src/snowflake/snowpark/_internal/analyzer pyright src/snowflake/snowpark/_internal/compiler - pyright src/snowflake/snowpark/stored_procedure_profiler + pyright src/snowflake/snowpark/stored_procedure_profiler.py [testenv:dev] description = create dev environment From b335b77469a3fe029b9af6ef5049885372ed97db Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 3 Oct 2024 12:13:47 -0700 Subject: [PATCH 49/62] address comments --- .../snowpark/stored_procedure_profiler.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/snowflake/snowpark/stored_procedure_profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py index 1593e1700f3..f41f1f5031a 100644 --- a/src/snowflake/snowpark/stored_procedure_profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -3,7 +3,7 @@ # import re import threading -from typing import List, Literal, Union +from typing import List, Literal, Optional import snowflake.snowpark from snowflake.snowpark._internal.utils import validate_object_name @@ -27,7 +27,7 @@ def __init__( self._session = session self._query_history = None - def register_modules(self, stored_procedures: List[str]): + def register_modules(self, stored_procedures: List[str]) -> None: """ Register stored procedures to generate profiles for them. @@ -42,7 +42,7 @@ def register_modules(self, stored_procedures: List[str]): ) self._session.sql(sql_statement)._internal_collect_with_tag_no_telemetry() - def set_targeted_stage(self, stage: str): + def set_targeted_stage(self, stage: str) -> None: """ Set targeted stage for profiler output. @@ -73,7 +73,9 @@ def set_targeted_stage(self, stage: str): sql_statement = f'alter session set PYTHON_PROFILER_TARGET_STAGE ="{stage}"' self._session.sql(sql_statement)._internal_collect_with_tag_no_telemetry() - def set_active_profiler(self, active_profiler_type: Literal["LINE", "MEMORY"]): + def set_active_profiler( + self, active_profiler_type: Literal["LINE", "MEMORY"] + ) -> None: """ Set active profiler. @@ -104,17 +106,17 @@ def disable(self) -> None: @staticmethod def _is_sp_call(query: str) -> bool: + query = query.upper().strip(" ") return re.match( - STORED_PROCEDURE_CALL_PATTERN, query.strip(" "), re.DOTALL - ) is not None or query.upper().strip(" ").startswith("CALL") + STORED_PROCEDURE_CALL_PATTERN, query, re.DOTALL + ) is not None or query.startswith("CALL") - def _get_last_query_id(self) -> Union[str, None]: + def _get_last_query_id(self) -> Optional[str]: current_thread = threading.get_ident() for query in self._query_history.queries[::-1]: # type: ignore query_thread = getattr(query, "thread_id", None) - if query_thread is None or query_thread == current_thread: - if self._is_sp_call(query.sql_text): - return query.query_id + if query_thread == current_thread and self._is_sp_call(query.sql_text): + return query.query_id return None def get_output(self) -> str: From 1188594892a1a5814d312da9514e4622fe2199b5 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 3 Oct 2024 13:15:43 -0700 Subject: [PATCH 50/62] change logic of finding stage --- .../snowpark/stored_procedure_profiler.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/src/snowflake/snowpark/stored_procedure_profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py index f41f1f5031a..d8fadf2aa82 100644 --- a/src/snowflake/snowpark/stored_procedure_profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -53,19 +53,14 @@ def set_targeted_stage(self, stage: str) -> None: stage: String of fully qualified name of targeted stage """ validate_object_name(stage) + [db_name, schema_name, stage_name] = stage.split(".") + existing_stages = self._session.sql( + f"show stages like '{stage_name}' in database {db_name}" + )._internal_collect_with_tag_no_telemetry() + existing_stages_schema_list = [row.schema_name for row in existing_stages] # type: ignore if ( - len( - self._session.sql( # type: ignore - f"show stages like '{stage}'" - )._internal_collect_with_tag_no_telemetry() - ) - == 0 - and len( - self._session.sql( # type: ignore - f"show stages like '{stage.split('.')[-1]}'" - )._internal_collect_with_tag_no_telemetry() - ) - == 0 + existing_stages_schema_list == [] + or schema_name not in existing_stages_schema_list ): self._session.sql( f"create temp stage if not exists {stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" From 157bc3db0dbdbae2c2a2e9dac644b609454a47b9 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 3 Oct 2024 13:18:26 -0700 Subject: [PATCH 51/62] remove bracket --- src/snowflake/snowpark/stored_procedure_profiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/stored_procedure_profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py index d8fadf2aa82..ec5b2bd6815 100644 --- a/src/snowflake/snowpark/stored_procedure_profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -53,7 +53,7 @@ def set_targeted_stage(self, stage: str) -> None: stage: String of fully qualified name of targeted stage """ validate_object_name(stage) - [db_name, schema_name, stage_name] = stage.split(".") + db_name, schema_name, stage_name = stage.split(".") existing_stages = self._session.sql( f"show stages like '{stage_name}' in database {db_name}" )._internal_collect_with_tag_no_telemetry() From 9fb57ac53af495b1b7bd85ffe25e7fcbb19c85f5 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 3 Oct 2024 14:01:33 -0700 Subject: [PATCH 52/62] use existing pattern --- src/snowflake/snowpark/stored_procedure_profiler.py | 12 ++++++------ tests/integ/test_stored_procedure_profiler.py | 3 --- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/snowflake/snowpark/stored_procedure_profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py index ec5b2bd6815..7ac81ca7df0 100644 --- a/src/snowflake/snowpark/stored_procedure_profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -1,14 +1,14 @@ # # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -import re import threading from typing import List, Literal, Optional import snowflake.snowpark -from snowflake.snowpark._internal.utils import validate_object_name - -STORED_PROCEDURE_CALL_PATTERN = r"WITH\s+.*?\s+AS\s+PROCEDURE\s+.*?\s+CALL\s+.*" +from snowflake.snowpark._internal.utils import ( + SNOWFLAKE_ANONYMOUS_CALL_WITH_PATTERN, + validate_object_name, +) class StoredProcedureProfiler: @@ -102,8 +102,8 @@ def disable(self) -> None: @staticmethod def _is_sp_call(query: str) -> bool: query = query.upper().strip(" ") - return re.match( - STORED_PROCEDURE_CALL_PATTERN, query, re.DOTALL + return SNOWFLAKE_ANONYMOUS_CALL_WITH_PATTERN.match( + query ) is not None or query.startswith("CALL") def _get_last_query_id(self) -> Optional[str]: diff --git a/tests/integ/test_stored_procedure_profiler.py b/tests/integ/test_stored_procedure_profiler.py index 32cd0716537..13a6a72a580 100644 --- a/tests/integ/test_stored_procedure_profiler.py +++ b/tests/integ/test_stored_procedure_profiler.py @@ -173,9 +173,6 @@ def test_sp_call_match(profiler_session): sp_call_sql = """ WITH myProcedure AS PROCEDURE ... CALL myProcedure""" assert pro._is_sp_call(sp_call_sql) - sp_call_sql = """WITH SNOWPARK_TEMP_CTE_1234 AS (SELECT 1 as A) SELECT * FROM SNOWPARK_TEMP_CTE_1234 AS PROCEDURE () CALL myprocedure""" - assert pro._is_sp_call(sp_call_sql) - @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", From 773e69e417426bc6f1761e3d39e0564e4b562d37 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 3 Oct 2024 14:05:04 -0700 Subject: [PATCH 53/62] change test --- tests/integ/test_stored_procedure_profiler.py | 44 +++++++++---------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/tests/integ/test_stored_procedure_profiler.py b/tests/integ/test_stored_procedure_profiler.py index 13a6a72a580..f0df88c6f77 100644 --- a/tests/integ/test_stored_procedure_profiler.py +++ b/tests/integ/test_stored_procedure_profiler.py @@ -146,31 +146,29 @@ def test_set_incorrect_active_profiler(profiler_session, db_parameters, tmp_stag assert "Last executed stored procedure does not exist" in str(e) -def test_sp_call_match(profiler_session): +@pytest.mark.parametrize( + "sp_call_sql", + [ + """WITH myProcedure AS PROCEDURE () + RETURNS TABLE ( ) + LANGUAGE PYTHON + RUNTIME_VERSION = '3.8' + PACKAGES = ( 'snowflake-snowpark-python==1.2.0', 'pandas==1.3.3' ) + IMPORTS = ( '@my_stage/file1.py', '@my_stage/file2.py' ) + HANDLER = 'my_function' + RETURNS NULL ON NULL INPUT + AS 'fake' + CALL myProcedure()INTO :result + """, + """CALL MY_SPROC()""", + """ CALL MY_SPROC()""", + """WITH myProcedure AS PROCEDURE () CALL myProcedure""", + """ WITH myProcedure AS PROCEDURE ... CALL myProcedure""", + ], +) +def test_sp_call_match(profiler_session, sp_call_sql): pro = profiler_session.stored_procedure_profiler - sp_call_sql = """WITH myProcedure AS PROCEDURE () - RETURNS TABLE ( ) - LANGUAGE PYTHON - RUNTIME_VERSION = '3.8' - PACKAGES = ( 'snowflake-snowpark-python==1.2.0', 'pandas==1.3.3' ) - IMPORTS = ( '@my_stage/file1.py', '@my_stage/file2.py' ) - HANDLER = 'my_function' - RETURNS NULL ON NULL INPUT -AS 'fake' -CALL myProcedure()INTO :result - """ - assert pro._is_sp_call(sp_call_sql) - - sp_call_sql = """CALL MY_SPROC()""" - assert pro._is_sp_call(sp_call_sql) - - sp_call_sql = """ CALL MY_SPROC()""" - assert pro._is_sp_call(sp_call_sql) - - sp_call_sql = """WITH myProcedure AS PROCEDURE () CALL myProcedure""" - assert pro._is_sp_call(sp_call_sql) - sp_call_sql = """ WITH myProcedure AS PROCEDURE ... CALL myProcedure""" assert pro._is_sp_call(sp_call_sql) From 292834c9fd93bf5e9b95a9cd66a29f906384e2ca Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Thu, 3 Oct 2024 17:41:19 -0700 Subject: [PATCH 54/62] thread safe change --- .../snowpark/stored_procedure_profiler.py | 16 ++++++++++-- tests/integ/test_stored_procedure_profiler.py | 26 +++++++++++++++++++ 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/stored_procedure_profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py index 7ac81ca7df0..ebdf6d84aa8 100644 --- a/src/snowflake/snowpark/stored_procedure_profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -26,6 +26,8 @@ def __init__( ) -> None: self._session = session self._query_history = None + self._lock = threading.RLock() + self._active_profiler_number = 0 def register_modules(self, stored_procedures: List[str]) -> None: """ @@ -81,6 +83,8 @@ def set_active_profiler( (case-sensitive). """ + with self._lock: + self._active_profiler_number += 1 if active_profiler_type not in ["LINE", "MEMORY"]: raise ValueError( f"active_profiler expect 'LINE', 'MEMORY', got {active_profiler_type} instead" @@ -89,13 +93,21 @@ def set_active_profiler( f"alter session set ACTIVE_PYTHON_PROFILER = '{active_profiler_type}'" ) self._session.sql(sql_statement)._internal_collect_with_tag_no_telemetry() - self._query_history = self._session.query_history(include_thread_id=True) + with self._lock: + if self._query_history is None: + self._query_history = self._session.query_history( + include_thread_id=True + ) def disable(self) -> None: """ Disable profiler. """ - self._session._conn.remove_query_listener(self._query_history) # type: ignore + with self._lock: + self._active_profiler_number -= 1 + if self._active_profiler_number == 0: + self._session._conn.remove_query_listener(self._query_history) # type: ignore + self._query_history = None sql_statement = "alter session set ACTIVE_PYTHON_PROFILER = ''" self._session.sql(sql_statement)._internal_collect_with_tag_no_telemetry() diff --git a/tests/integ/test_stored_procedure_profiler.py b/tests/integ/test_stored_procedure_profiler.py index f0df88c6f77..282ede9682d 100644 --- a/tests/integ/test_stored_procedure_profiler.py +++ b/tests/integ/test_stored_procedure_profiler.py @@ -1,15 +1,22 @@ # # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from concurrent.futures import ThreadPoolExecutor import pytest import snowflake.snowpark from snowflake.snowpark import DataFrame from snowflake.snowpark.functions import sproc +from snowflake.snowpark.stored_procedure_profiler import StoredProcedureProfiler from tests.utils import Utils +def multi_thread_helper_function(pro: StoredProcedureProfiler): + pro.set_active_profiler("LINE") + pro.disable() + + @pytest.fixture(scope="function") def is_profiler_function_exist(profiler_session): functions = profiler_session.sql( @@ -196,3 +203,22 @@ def test_query_history_destroyed_after_finish_profiling( ) profiler_session.stored_procedure_profiler.register_modules([]) + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="session.sql is not supported in localtesting", +) +def test_thread_safe_on_activate_and_disable( + profiler_session, db_parameters, tmp_stage_name +): + pro = profiler_session.stored_procedure_profiler + pro.register_modules(["table_sp"]) + pro.set_targeted_stage( + f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" + ) + with ThreadPoolExecutor(max_workers=2) as tpe: + for _ in range(6): + tpe.submit(multi_thread_helper_function, pro) + assert pro._query_history is None + pro.register_modules([]) From 95cb15d6068913a5df08fbc1180f83f6b074495f Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 4 Oct 2024 14:38:13 -0700 Subject: [PATCH 55/62] address comments --- CHANGELOG.md | 4 +- src/snowflake/snowpark/session.py | 3 + .../snowpark/stored_procedure_profiler.py | 55 ++++++++----------- tests/integ/test_stored_procedure_profiler.py | 24 ++++---- 4 files changed, 42 insertions(+), 44 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dcaf297bda3..e685b966cf0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,8 @@ - Added support for constructing `Series` and `DataFrame` objects with the lazy `Index` object as `data`, `index`, and `columns` arguments. - Added support for constructing `Series` and `DataFrame` objects with `index` and `column` values not present in `DataFrame`/`Series` `data`. - Added `thread_id` to `QueryRecord` to track the thread id submitting the query history. -- +- Added support for `Session.stored_procedure_profiler`. + #### Improvements #### Bug Fixes @@ -70,7 +71,6 @@ This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for det - Added the following new functions in `snowflake.snowpark.functions`: - `array_remove` - `ln` -- Added Snowpark Python API for stored procedure profiler. #### Improvements diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index ae20b856de0..399806cd326 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -3250,6 +3250,9 @@ def stored_procedure_profiler(self) -> StoredProcedureProfiler: """ Returns a :class:`stored_procedure_profiler.StoredProcedureProfiler` object that you can use to profile stored procedures. See details of how to use this object in :class:`stored_procedure_profiler.StoredProcedureProfiler`. + + See more details about stored procedure profiler at: + https://docs.snowflake.com/LIMITEDACCESS/stored-procedures-python-profiler """ return self._sp_profiler diff --git a/src/snowflake/snowpark/stored_procedure_profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py index ebdf6d84aa8..b68aae0f428 100644 --- a/src/snowflake/snowpark/stored_procedure_profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -7,17 +7,17 @@ import snowflake.snowpark from snowflake.snowpark._internal.utils import ( SNOWFLAKE_ANONYMOUS_CALL_WITH_PATTERN, - validate_object_name, + parse_table_name, ) class StoredProcedureProfiler: """ - Set up profiler to receive profiles of stored procedures. + Set up profiler to receive profiles of stored procedures. This feature cannot be used in owner's right stored + procedure because owner's right stored procedure will not be able to set session-level parameters. - Note: - This feature cannot be used in owner's right stored procedure because owner's right stored procedure will not be - able to set session-level parameters. + See more details about stored procedure profiler at: + https://docs.snowflake.com/LIMITEDACCESS/stored-procedures-python-profiler """ def __init__( @@ -29,41 +29,34 @@ def __init__( self._lock = threading.RLock() self._active_profiler_number = 0 - def register_modules(self, stored_procedures: List[str]) -> None: + def register_modules(self, stored_procedures: Optional[List[str]] = None) -> None: """ Register stored procedures to generate profiles for them. - Note: - Registered modules will be overwritten by this function. Use this function with an empty string will remove - registered modules. Args: - stored_procedures: List of names of stored procedures. + stored_procedures: List of names of stored procedures. Registered modules will be overwritten by this input. + Input None or an empty list will remove registered modules. """ - sql_statement = ( - f"alter session set python_profiler_modules='{','.join(stored_procedures)}'" - ) + sp_string = ",".join(stored_procedures) if stored_procedures is not None else "" + sql_statement = f"alter session set python_profiler_modules='{sp_string}'" self._session.sql(sql_statement)._internal_collect_with_tag_no_telemetry() - def set_targeted_stage(self, stage: str) -> None: + def set_target_stage(self, stage: str) -> None: """ Set targeted stage for profiler output. - Note: - The stage name must be a fully qualified name. - Args: stage: String of fully qualified name of targeted stage """ - validate_object_name(stage) - db_name, schema_name, stage_name = stage.split(".") + names = parse_table_name(stage) + if len(names) != 3: + raise ValueError( + f"stage name must be fully qualified name, got {stage} instead" + ) existing_stages = self._session.sql( - f"show stages like '{stage_name}' in database {db_name}" + f"show stages like '{names[2]}' in schema {names[0]}.{names[1]}" )._internal_collect_with_tag_no_telemetry() - existing_stages_schema_list = [row.schema_name for row in existing_stages] # type: ignore - if ( - existing_stages_schema_list == [] - or schema_name not in existing_stages_schema_list - ): + if len(existing_stages) == 0: self._session.sql( f"create temp stage if not exists {stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" )._internal_collect_with_tag_no_telemetry() @@ -71,16 +64,14 @@ def set_targeted_stage(self, stage: str) -> None: self._session.sql(sql_statement)._internal_collect_with_tag_no_telemetry() def set_active_profiler( - self, active_profiler_type: Literal["LINE", "MEMORY"] + self, active_profiler_type: Literal["LINE", "MEMORY"] = "LINE" ) -> None: """ Set active profiler. - Note: - Active profiler must be either 'LINE' or 'MEMORY' (case-sensitive). Active profiler is 'LINE' by default. Args: active_profiler_type: String that represent active_profiler, must be either 'LINE' or 'MEMORY' - (case-sensitive). + (case-sensitive). Active profiler is 'LINE' by default. """ with self._lock: @@ -128,10 +119,10 @@ def _get_last_query_id(self) -> Optional[str]: def get_output(self) -> str: """ - Return the profiles of last executed stored procedure in current thread. + Return the profiles of last executed stored procedure in current thread. If there is no previous + stored procedure call, an error will be raised. + Please call this function right after the stored procedure you want to profile to avoid any error. - Note: - This function must be called right after the execution of stored procedure you want to profile. """ query_id = self._get_last_query_id() if query_id is None: diff --git a/tests/integ/test_stored_procedure_profiler.py b/tests/integ/test_stored_procedure_profiler.py index 282ede9682d..ba0a8cd179a 100644 --- a/tests/integ/test_stored_procedure_profiler.py +++ b/tests/integ/test_stored_procedure_profiler.py @@ -62,7 +62,7 @@ def table_sp(session: snowflake.snowpark.Session) -> DataFrame: pro = profiler_session.stored_procedure_profiler pro.register_modules(["table_sp"]) - pro.set_targeted_stage( + pro.set_target_stage( f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" ) @@ -89,7 +89,7 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: return "success" profiler_session.stored_procedure_profiler.register_modules(["single_value_sp"]) - profiler_session.stored_procedure_profiler.set_targeted_stage( + profiler_session.stored_procedure_profiler.set_target_stage( f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" ) @@ -100,7 +100,7 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: profiler_session.stored_procedure_profiler.disable() - profiler_session.stored_procedure_profiler.register_modules([]) + profiler_session.stored_procedure_profiler.register_modules() assert res is not None assert "Modules Profiled" in res @@ -117,7 +117,7 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: single_value_sp = profiler_session.sproc.register(single_value_sp, anonymous=True) - profiler_session.stored_procedure_profiler.set_targeted_stage( + profiler_session.stored_procedure_profiler.set_target_stage( f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" ) @@ -128,7 +128,7 @@ def single_value_sp(session: snowflake.snowpark.Session) -> str: profiler_session.stored_procedure_profiler.disable() - profiler_session.stored_procedure_profiler.register_modules([]) + profiler_session.stored_procedure_profiler.register_modules() assert res is not None assert "Modules Profiled" in res @@ -145,7 +145,11 @@ def test_set_incorrect_active_profiler(profiler_session, db_parameters, tmp_stag assert "active_profiler expect 'LINE', 'MEMORY'" in str(e) with pytest.raises(ValueError) as e: - profiler_session.stored_procedure_profiler.set_targeted_stage( + profiler_session.stored_procedure_profiler.set_target_stage(f"{tmp_stage_name}") + assert "stage name must be fully qualified name" in str(e) + + with pytest.raises(ValueError) as e: + profiler_session.stored_procedure_profiler.set_target_stage( f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" ) profiler_session.stored_procedure_profiler.set_active_profiler("LINE") @@ -186,7 +190,7 @@ def test_sp_call_match(profiler_session, sp_call_sql): def test_query_history_destroyed_after_finish_profiling( profiler_session, db_parameters, tmp_stage_name ): - profiler_session.stored_procedure_profiler.set_targeted_stage( + profiler_session.stored_procedure_profiler.set_target_stage( f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" ) @@ -202,7 +206,7 @@ def test_query_history_destroyed_after_finish_profiling( not in profiler_session._conn._query_listener ) - profiler_session.stored_procedure_profiler.register_modules([]) + profiler_session.stored_procedure_profiler.register_modules() @pytest.mark.skipif( @@ -214,11 +218,11 @@ def test_thread_safe_on_activate_and_disable( ): pro = profiler_session.stored_procedure_profiler pro.register_modules(["table_sp"]) - pro.set_targeted_stage( + pro.set_target_stage( f"{db_parameters['database']}.{db_parameters['schema']}.{tmp_stage_name}" ) with ThreadPoolExecutor(max_workers=2) as tpe: for _ in range(6): tpe.submit(multi_thread_helper_function, pro) assert pro._query_history is None - pro.register_modules([]) + pro.register_modules() From 766d77ea8a6b2de9d5d8dcf57894afe27dee0032 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 4 Oct 2024 14:56:04 -0700 Subject: [PATCH 56/62] remove link --- src/snowflake/snowpark/session.py | 3 --- src/snowflake/snowpark/stored_procedure_profiler.py | 6 ++---- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 399806cd326..ae20b856de0 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -3250,9 +3250,6 @@ def stored_procedure_profiler(self) -> StoredProcedureProfiler: """ Returns a :class:`stored_procedure_profiler.StoredProcedureProfiler` object that you can use to profile stored procedures. See details of how to use this object in :class:`stored_procedure_profiler.StoredProcedureProfiler`. - - See more details about stored procedure profiler at: - https://docs.snowflake.com/LIMITEDACCESS/stored-procedures-python-profiler """ return self._sp_profiler diff --git a/src/snowflake/snowpark/stored_procedure_profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py index b68aae0f428..d9e867e17bd 100644 --- a/src/snowflake/snowpark/stored_procedure_profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -8,6 +8,7 @@ from snowflake.snowpark._internal.utils import ( SNOWFLAKE_ANONYMOUS_CALL_WITH_PATTERN, parse_table_name, + strip_double_quotes_in_like_statement_in_table_name, ) @@ -15,9 +16,6 @@ class StoredProcedureProfiler: """ Set up profiler to receive profiles of stored procedures. This feature cannot be used in owner's right stored procedure because owner's right stored procedure will not be able to set session-level parameters. - - See more details about stored procedure profiler at: - https://docs.snowflake.com/LIMITEDACCESS/stored-procedures-python-profiler """ def __init__( @@ -54,7 +52,7 @@ def set_target_stage(self, stage: str) -> None: f"stage name must be fully qualified name, got {stage} instead" ) existing_stages = self._session.sql( - f"show stages like '{names[2]}' in schema {names[0]}.{names[1]}" + f"show stages like '{strip_double_quotes_in_like_statement_in_table_name(names[2])}' in schema {names[0]}.{names[1]}" )._internal_collect_with_tag_no_telemetry() if len(existing_stages) == 0: self._session.sql( From 776ba6854f560497277218eabfdefe5f8331eab9 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 4 Oct 2024 14:59:12 -0700 Subject: [PATCH 57/62] add prpr tag --- src/snowflake/snowpark/session.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index ae20b856de0..814acaf4f78 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -107,6 +107,7 @@ normalize_local_file, normalize_remote_file_or_dir, parse_positional_args_to_list, + private_preview, quote_name, random_name_for_temp_object, strip_double_quotes_in_like_statement_in_table_name, @@ -3245,6 +3246,7 @@ def sproc(self) -> StoredProcedureRegistration: """ return self._sp_registration + @private_preview(version="1.23.0") @property def stored_procedure_profiler(self) -> StoredProcedureProfiler: """ From 1324b1d26e72b588254d7edcd531d3ca80c5c1e1 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 4 Oct 2024 15:19:00 -0700 Subject: [PATCH 58/62] add prpr tag --- src/snowflake/snowpark/session.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 814acaf4f78..779b849e927 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -3246,8 +3246,8 @@ def sproc(self) -> StoredProcedureRegistration: """ return self._sp_registration - @private_preview(version="1.23.0") @property + @private_preview(version="1.23.0") def stored_procedure_profiler(self) -> StoredProcedureProfiler: """ Returns a :class:`stored_procedure_profiler.StoredProcedureProfiler` object that you can use to profile stored procedures. From e89f6b9bfa225d3e807d225ca66f58a5c51391cd Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 4 Oct 2024 15:26:16 -0700 Subject: [PATCH 59/62] type check --- src/snowflake/snowpark/stored_procedure_profiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/stored_procedure_profiler.py b/src/snowflake/snowpark/stored_procedure_profiler.py index d9e867e17bd..04a2545598a 100644 --- a/src/snowflake/snowpark/stored_procedure_profiler.py +++ b/src/snowflake/snowpark/stored_procedure_profiler.py @@ -54,7 +54,7 @@ def set_target_stage(self, stage: str) -> None: existing_stages = self._session.sql( f"show stages like '{strip_double_quotes_in_like_statement_in_table_name(names[2])}' in schema {names[0]}.{names[1]}" )._internal_collect_with_tag_no_telemetry() - if len(existing_stages) == 0: + if len(existing_stages) == 0: # type: ignore self._session.sql( f"create temp stage if not exists {stage} FILE_FORMAT = (RECORD_DELIMITER = NONE FIELD_DELIMITER = NONE )" )._internal_collect_with_tag_no_telemetry() From 375f11b32abdc14e28f5f181369e337ca7b707d5 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 4 Oct 2024 15:49:52 -0700 Subject: [PATCH 60/62] add test for coverage --- tests/integ/test_stored_procedure_profiler.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/integ/test_stored_procedure_profiler.py b/tests/integ/test_stored_procedure_profiler.py index ba0a8cd179a..0b67173a385 100644 --- a/tests/integ/test_stored_procedure_profiler.py +++ b/tests/integ/test_stored_procedure_profiler.py @@ -226,3 +226,24 @@ def test_thread_safe_on_activate_and_disable( tpe.submit(multi_thread_helper_function, pro) assert pro._query_history is None pro.register_modules() + + +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="session.sql is not supported in localtesting", +) +def test_create_temp_stage(profiler_session, tmp_stage_name): + pro = profiler_session.stored_procedure_profiler + db_name = Utils.random_temp_database() + schema_name = Utils.random_temp_schema() + try: + profiler_session.sql(f"create database {db_name}").collect() + profiler_session.sql(f"create schema {schema_name}").collect() + pro.set_target_stage(f"{db_name}.{schema_name}.{tmp_stage_name}") + + res = profiler_session.sql( + f"show stages like '{tmp_stage_name}' in schema {db_name}.{schema_name}" + ).collect() + assert len(res) != 0 + finally: + profiler_session.sql(f"drop database if exists {db_name}").collect() From 2906d0bab311dd8e562683016d660b0646eb0f5d Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 4 Oct 2024 16:06:47 -0700 Subject: [PATCH 61/62] fix test --- tests/integ/test_stored_procedure_profiler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/integ/test_stored_procedure_profiler.py b/tests/integ/test_stored_procedure_profiler.py index 0b67173a385..32b76339a66 100644 --- a/tests/integ/test_stored_procedure_profiler.py +++ b/tests/integ/test_stored_procedure_profiler.py @@ -236,6 +236,7 @@ def test_create_temp_stage(profiler_session, tmp_stage_name): pro = profiler_session.stored_procedure_profiler db_name = Utils.random_temp_database() schema_name = Utils.random_temp_schema() + current_db = profiler_session.sql("select current_database()").collect()[0][0] try: profiler_session.sql(f"create database {db_name}").collect() profiler_session.sql(f"create schema {schema_name}").collect() @@ -247,3 +248,4 @@ def test_create_temp_stage(profiler_session, tmp_stage_name): assert len(res) != 0 finally: profiler_session.sql(f"drop database if exists {db_name}").collect() + profiler_session.sql(f"use database {current_db}").collect() From dfe2de8b8c123ca0a4c45f6ee86e328ed9cc7896 Mon Sep 17 00:00:00 2001 From: Yuyang Wang Date: Fri, 4 Oct 2024 16:59:27 -0700 Subject: [PATCH 62/62] fix test --- tests/integ/test_stored_procedure_profiler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/integ/test_stored_procedure_profiler.py b/tests/integ/test_stored_procedure_profiler.py index 32b76339a66..6de17057b0d 100644 --- a/tests/integ/test_stored_procedure_profiler.py +++ b/tests/integ/test_stored_procedure_profiler.py @@ -232,18 +232,19 @@ def test_thread_safe_on_activate_and_disable( "config.getoption('local_testing_mode', default=False)", reason="session.sql is not supported in localtesting", ) -def test_create_temp_stage(profiler_session, tmp_stage_name): +def test_create_temp_stage(profiler_session): pro = profiler_session.stored_procedure_profiler db_name = Utils.random_temp_database() schema_name = Utils.random_temp_schema() + temp_stage = Utils.random_stage_name() current_db = profiler_session.sql("select current_database()").collect()[0][0] try: profiler_session.sql(f"create database {db_name}").collect() profiler_session.sql(f"create schema {schema_name}").collect() - pro.set_target_stage(f"{db_name}.{schema_name}.{tmp_stage_name}") + pro.set_target_stage(f"{db_name}.{schema_name}.{temp_stage}") res = profiler_session.sql( - f"show stages like '{tmp_stage_name}' in schema {db_name}.{schema_name}" + f"show stages like '{temp_stage}' in schema {db_name}.{schema_name}" ).collect() assert len(res) != 0 finally: