Skip to content

Commit

Permalink
SNOW-1032398: Add SYSTEM$REFERENCE support. (#2057)
Browse files Browse the repository at this point in the history
sfc-gh-jrose authored Aug 20, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 03a1b57 commit 15e7e42
Showing 4 changed files with 75 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -53,6 +53,7 @@
- Added support for passing `parameters` parameter to `Column.rlike` and `Column.regexp`.
- Added support for automatically cleaning up temporary tables created by `df.cache_result()` in the current session, when the DataFrame is no longer referenced (i.e., gets garbage collected). It is still an experimental feature not enabled by default, and can be enabled by setting `session.auto_clean_up_temp_table_enabled` to `True`.
- Added support for string literals to the `fmt` parameter of `snowflake.snowpark.functions.to_date`.
- Added support for system$reference function.

#### Bug Fixes

23 changes: 23 additions & 0 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
@@ -327,6 +327,29 @@ def sql_expr(sql: str) -> Column:
return Column._expr(sql)


def system_reference(
object_type: str,
object_identifier: str,
scope: str = "CALL",
privileges: Optional[List[str]] = None,
):
"""
Returns a reference to an object (a table, view, or function). When you execute SQL actions on a
reference to an object, the actions are performed using the role of the user who created the
reference.
Example::
>>> df = session.create_dataframe([(1,)], schema=["A"])
>>> df.write.save_as_table("my_table", mode="overwrite", table_type="temporary")
>>> df.select(substr(system_reference("table", "my_table"), 1, 14).alias("identifier")).collect()
[Row(IDENTIFIER='ENT_REF_TABLE_')]
"""
privileges = privileges or []
return builtin("system$reference")(
object_type, object_identifier, scope, *privileges
)


def current_session() -> Column:
"""
Returns a unique system identifier for the Snowflake session corresponding to the present connection.
18 changes: 18 additions & 0 deletions tests/integ/scala/test_function_suite.py
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@
import pytz

from snowflake.snowpark import Row
from snowflake.snowpark._internal.utils import TempObjectType
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.functions import (
_columns_from_timestamp_parts,
@@ -178,6 +179,7 @@
substring,
sum,
sum_distinct,
system_reference,
tan,
tanh,
time_from_parts,
@@ -249,6 +251,22 @@ def test_lit(session):
assert res == [Row(1), Row(1)]


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="system functions not supported by local testing",
)
def test_system_reference(session):
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
df = session.create_dataframe([(1,)]).to_df(["a"])
df.write.save_as_table(table_name)

try:
data = df.select(system_reference("TABLE", table_name)).collect()
assert data[0][0].startswith("ENT_REF_TABLE")
finally:
session.table(table_name).drop_table()


def test_avg(session):
res = TestData.duplicated_numbers(session).select(avg(col("A"))).collect()
assert res == [Row(Decimal("2.2"))]
33 changes: 33 additions & 0 deletions tests/integ/test_stored_procedure.py
Original file line number Diff line number Diff line change
@@ -40,6 +40,7 @@
pow,
sproc,
sqrt,
system_reference,
)
from snowflake.snowpark.row import Row
from snowflake.snowpark.types import (
@@ -421,6 +422,38 @@ def test_sproc(_session: Session) -> DataFrame:
assert df.dtypes == expected_dtypes


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="system functions not supported by local testing",
)
def test_sproc_pass_system_reference(session):
table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
df = session.create_dataframe([(1,)]).to_df(["a"])
df.write.save_as_table(table_name)

def insert_and_return_count(session_: Session, table_name_: str) -> int:
session_.sql(f"INSERT INTO {table_name_} VALUES (2)").collect()
return session_.table(table_name_).count()

insert_sproc = sproc(insert_and_return_count, return_type=IntegerType())

try:
assert (
insert_sproc(
system_reference(
"TABLE",
table_name,
"SESSION",
["SELECT", "INSERT", "UPDATE", "TRUNCATE"],
)
)
== 2
)
Utils.check_answer(session.table(table_name), [Row(1), Row(2)])
finally:
session.table(table_name).drop_table()


@pytest.mark.parametrize("anonymous", [True, False])
def test_call_table_sproc_triggers_action(session, anonymous):
"""Here we create a table sproc which creates a table. we call the table sproc using

0 comments on commit 15e7e42

Please sign in to comment.