diff --git a/CHANGELOG.md b/CHANGELOG.md index d8e9937bcfb..6a3b351227e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ ### Improvements - Added cleanup logic at interpreter shutdown to close all active sessions. +- Closing sessions within stored procedures now is a no-op logging a warning instead of raising an error. ### Bug Fixes diff --git a/src/snowflake/snowpark/_internal/error_message.py b/src/snowflake/snowpark/_internal/error_message.py index c943fde9864..614a6ad73d8 100644 --- a/src/snowflake/snowpark/_internal/error_message.py +++ b/src/snowflake/snowpark/_internal/error_message.py @@ -423,13 +423,6 @@ def DONT_CREATE_SESSION_IN_SP() -> SnowparkSessionException: error_code="1410", ) - @staticmethod - def DONT_CLOSE_SESSION_IN_SP() -> SnowparkSessionException: - return SnowparkSessionException( - "In a stored procedure, you shouldn't close a session. The stored procedure manages the lifecycle of the provided session.", - error_code="1411", - ) - # General Error codes 15XX @staticmethod diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 88d2401d426..3f7fe37dc32 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -215,6 +215,8 @@ def _close_session_atexit(): This is the helper function to close all active sessions at interpreter shutdown. For example, when a jupyter notebook is shutting down, this will also close all active sessions and make sure send all telemetry to the server. """ + if is_in_stored_procedure(): + return with _session_management_lock: for session in _active_sessions.copy(): try: @@ -477,7 +479,8 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): - self.close() + if not is_in_stored_procedure(): + self.close() def __str__(self): return ( @@ -493,7 +496,8 @@ def _generate_new_action_id(self) -> int: def close(self) -> None: """Close this session.""" if is_in_stored_procedure(): - raise SnowparkClientExceptionMessages.DONT_CLOSE_SESSION_IN_SP() + _logger.warning("Closing a session in a stored procedure is a no-op.") + return try: if self._conn.is_closed(): _logger.debug( diff --git a/tests/integ/test_session.py b/tests/integ/test_session.py index 8ad3bc1aee6..d54a72fbe50 100644 --- a/tests/integ/test_session.py +++ b/tests/integ/test_session.py @@ -199,9 +199,8 @@ def test_close_session_in_sp(session): original_platform = internal_utils.PLATFORM internal_utils.PLATFORM = "XP" try: - with pytest.raises(SnowparkSessionException) as exec_info: - session.close() - assert exec_info.value.error_code == "1411" + session.close() + assert not session.connection.is_closed() finally: internal_utils.PLATFORM = original_platform @@ -570,7 +569,7 @@ def test_use_secondary_roles(session): session.use_secondary_roles(current_role[1:-1]) -@pytest.mark.skipif(IS_IN_STORED_PROC, reason="SP doesn't allow to close a session.") +@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Can't create a session in SP") def test_close_session_twice(db_parameters): new_session = Session.builder.configs(db_parameters).create() new_session.close() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 890cd7b2d27..8f8408cd43a 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import json +import logging import os from typing import Optional from unittest import mock @@ -122,6 +123,47 @@ def test_close_exception(): session.close() +def test_close_session_in_stored_procedure_no_op(): + fake_connection = mock.create_autospec(ServerConnection) + fake_connection._conn = mock.Mock() + fake_connection.is_closed = MagicMock(return_value=False) + session = Session(fake_connection) + with mock.patch.object( + snowflake.snowpark.session, "is_in_stored_procedure" + ) as mock_fn, mock.patch.object( + session._conn, "close" + ) as mock_close, mock.patch.object( + session, "cancel_all" + ) as mock_cancel_all, mock.patch.object( + snowflake.snowpark.session, "_remove_session" + ) as mock_remove: + mock_fn.return_value = True + session.close() + mock_cancel_all.assert_not_called() + mock_close.assert_not_called() + mock_remove.assert_not_called() + + +@pytest.mark.parametrize( + "warning_level, expected", + [(logging.WARNING, True), (logging.INFO, True), (logging.ERROR, False)], +) +def test_close_session_in_stored_procedure_log_level(caplog, warning_level, expected): + caplog.clear() + caplog.set_level(warning_level) + fake_connection = mock.create_autospec(ServerConnection) + fake_connection._conn = mock.Mock() + fake_connection.is_closed = MagicMock(return_value=False) + session = Session(fake_connection) + with mock.patch.object( + snowflake.snowpark.session, "is_in_stored_procedure" + ) as mock_fn: + mock_fn.return_value = True + session.close() + result = "Closing a session in a stored procedure is a no-op." in caplog.text + assert result == expected + + def test_resolve_import_path_ignore_import_path(tmp_path_factory): fake_connection = mock.create_autospec(ServerConnection) fake_connection._conn = mock.Mock()