diff --git a/CHANGELOG.md b/CHANGELOG.md index 880cfef2af8..cd5a5259bb1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ - Added support for VOLATILE/IMMUTABLE keyword when registering UDFs. - Added support for specifying clustering keys when saving dataframes using `DataFrame.save_as_table`. - Accept `Iterable` objects input for `schema` when creating dataframes using `Session.create_dataframe`. +- Added the property `DataFrame.session` to return a `Session` object. +- Added the property `Session.session_id` to return an integer that represents session ID. +- Added the property `Session.connection` to return a `SnowflakeConnection` object . + - Added support for creating a Snowpark session from a configuration file or environment variables. ### Dependency updates diff --git a/docs/source/dataframe.rst b/docs/source/dataframe.rst index 93ba7afabc6..cab7c071b2c 100644 --- a/docs/source/dataframe.rst +++ b/docs/source/dataframe.rst @@ -125,3 +125,4 @@ DataFrame DataFrame.stat DataFrame.write DataFrame.is_cached + DataFrame.session diff --git a/docs/source/session.rst b/docs/source/session.rst index 0fe3a6d1b9b..468aa7a4a34 100644 --- a/docs/source/session.rst +++ b/docs/source/session.rst @@ -70,4 +70,6 @@ Snowpark Session Session.sql_simplifier_enabled Session.telemetry_enabled Session.udf - Session.udtf \ No newline at end of file + Session.udtf + Session.session_id + Session.connection \ No newline at end of file diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index ece5ffe7a02..dea1d5df060 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -3320,6 +3320,13 @@ def na(self) -> DataFrameNaFunctions: """ return self._na + @property + def session(self) -> "snowflake.snowpark.Session": + """ + Returns a :class:`snowflake.snowpark.Session` object that provides access to the session the current DataFrame is relying on. + """ + return self._session + def describe(self, *cols: Union[str, List[str]]) -> "DataFrame": """ Computes basic statistics for numeric columns, which includes diff --git a/src/snowflake/snowpark/session.py b/src/snowflake/snowpark/session.py index 394a01c5263..339e83af2e6 100644 --- a/src/snowflake/snowpark/session.py +++ b/src/snowflake/snowpark/session.py @@ -1695,6 +1695,18 @@ def read(self) -> "DataFrameReader": supported sources (e.g. a file in a stage) as a DataFrame.""" return DataFrameReader(self) + @property + def session_id(self) -> int: + """Returns an integer that represents the session ID of this session.""" + return self._session_id + + @property + def connection(self) -> "SnowflakeConnection": + """Returns a :class:`SnowflakeConnection` object that allows you to access the connection between the current session + and Snowflake server.""" + return self._conn._conn + + def _run_query( self, query: str, diff --git a/tests/unit/test_dataframe.py b/tests/unit/test_dataframe.py index a58f1c55956..be78b8745c1 100644 --- a/tests/unit/test_dataframe.py +++ b/tests/unit/test_dataframe.py @@ -8,6 +8,7 @@ import pytest import snowflake.snowpark.session +from snowflake.snowpark.session import Session from snowflake.snowpark import ( DataFrame, DataFrameNaFunctions, @@ -292,3 +293,14 @@ def test_dataFrame_printSchema(capfd): out == "root\n |-- A: IntegerType() (nullable = False)\n |-- B: StringType() (nullable = True)\n" ) + + +def test_session(): + fake_session = mock.create_autospec(Session, _session_id=123456) + fake_session._analyzer = mock.Mock() + df = DataFrame(fake_session) + + assert(df.session == fake_session) + assert(df.session._session_id == fake_session._session_id) + + diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 9a76d0604ee..032c926f1c5 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -3,7 +3,7 @@ # import json import os -from typing import Optional +from typing import Optional, Dict, Union from unittest import mock from unittest.mock import MagicMock @@ -368,3 +368,25 @@ def test_parse_table_name(): assert parse_table_name('*&^."abc".abc') # unsupported chars in unquoted ids with pytest.raises(SnowparkInvalidObjectNameException): assert parse_table_name('."abc".') # unsupported semantic + + +def test_session_id(): + fake_server_connection = mock.create_autospec(ServerConnection) + fake_server_connection.get_session_id = mock.Mock(return_value=123456) + session = Session(fake_server_connection) + + assert(session.session_id == 123456) + + +def test_connection(): + fake_snowflake_connection = mock.create_autospec(SnowflakeConnection) + fake_snowflake_connection._telemetry = mock.Mock() + fake_snowflake_connection._session_parameters = mock.Mock() + fake_snowflake_connection.is_closed = mock.Mock(return_value=False) + fake_options = {"": ""} + server_connection = ServerConnection(fake_options, fake_snowflake_connection) + session = Session(server_connection) + + assert(session.connection == session._conn._conn) + assert(session.connection == fake_snowflake_connection) +