Skip to content

Commit

Permalink
SNOW-1183322: [Local Testing] Add support for registering sprocs (#1338)
Browse files Browse the repository at this point in the history
* SNOW-1183322: [Local Testing] Add support for registering sprocs

* CHANGELOG

* Review Feedback

* Enable more tests.

* Update license

* update test

* update CHANGELOG
  • Loading branch information
sfc-gh-jrose authored Apr 15, 2024
1 parent f0e04c4 commit 12d8620
Show file tree
Hide file tree
Showing 8 changed files with 403 additions and 57 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### New Features

- Added support for registering udfs and stored procedure to local testing.
- Added support for the following local testing APIs:
- snowflake.snowpark.Session:
- file.put
Expand All @@ -25,9 +26,10 @@
- current_database
- current_session
- date_trunc
- udf
- object_construct
- object_construct_keep_null
- pow
- sqrt
- Added the function `DataFrame.write.csv` to unload data from a ``DataFrame`` into one or more CSV files in a stage.
- Added telemetry to calculate query plan height and number of duplicate nodes during collect operations.
- Added the functions below to unload data from a `DataFrame` into one or more files in a stage:
Expand Down
14 changes: 14 additions & 0 deletions src/snowflake/snowpark/mock/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,20 @@ def mock_listagg(column: ColumnEmulator, delimiter: str, is_distinct: bool):
)


@patch("sqrt")
def mock_sqrt(column: ColumnEmulator):
result = column.apply(math.sqrt)
result.sf_type = ColumnType(FloatType(), column.sf_type.nullable)
return result


@patch("pow")
def mock_pow(left: ColumnEmulator, right: ColumnEmulator):
result = left.combine(right, lambda l, r: l**r)
result.sf_type = ColumnType(FloatType(), left.sf_type.nullable)
return result


@patch("to_date")
def mock_to_date(
column: ColumnEmulator,
Expand Down
238 changes: 238 additions & 0 deletions src/snowflake/snowpark/mock/_stored_procedure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

import json
import sys
import typing
from types import ModuleType
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import snowflake.snowpark
from snowflake.snowpark._internal.udf_utils import (
check_python_runtime_version,
process_registration_inputs,
)
from snowflake.snowpark._internal.utils import TempObjectType
from snowflake.snowpark.column import Column
from snowflake.snowpark.dataframe import DataFrame
from snowflake.snowpark.exceptions import SnowparkSQLException
from snowflake.snowpark.mock import CUSTOM_JSON_ENCODER
from snowflake.snowpark.mock._plan import calculate_expression
from snowflake.snowpark.mock._snowflake_data_type import ColumnEmulator
from snowflake.snowpark.stored_procedure import (
StoredProcedure,
StoredProcedureRegistration,
)
from snowflake.snowpark.types import (
ArrayType,
DataType,
MapType,
StructType,
_FractionalType,
_IntegralType,
)

from ._telemetry import LocalTestOOBTelemetryService

if sys.version_info <= (3, 9):
from typing import Iterable
else:
from collections.abc import Iterable


def sproc_types_are_compatible(x, y):
if (
isinstance(x, type(y))
or isinstance(x, _IntegralType)
and isinstance(y, _IntegralType)
or isinstance(x, _FractionalType)
and isinstance(y, _FractionalType)
):
return True
return False


class MockStoredProcedure(StoredProcedure):
def __call__(
self,
*args: Any,
session: Optional["snowflake.snowpark.session.Session"] = None,
statement_params: Optional[Dict[str, str]] = None,
) -> Any:
args, session = self._validate_call(args, session)

# Unpack columns if passed
parsed_args = []
for arg, expected_type in zip(args, self._input_types):
if isinstance(arg, Column):
expr = arg._expression

# If expression does not define its datatype we cannot verify it's compatibale.
# This is potentially unsafe.
if expr.datatype and not sproc_types_are_compatible(
expr.datatype, expected_type
):
raise ValueError(
f"Unexpected type {expr.datatype} for sproc argument of type {expected_type}"
)

# Expression may be a nested expression. Expression should not need any input data
# and should only return one value so that it can be passed as a literal value.
# We pass in a single None value so that the expression evaluator has some data to
# pass to the expressions.
resolved_expr = calculate_expression(
expr,
ColumnEmulator(data=[None]),
session._analyzer,
{},
)

# If the length of the resolved expression is not a single value we cannot pass it as a literal.
if len(resolved_expr) != 1:
raise ValueError(
"[Local Testing] Unexpected argument type {expr.__class__.__name__} for call to sproc"
)
parsed_args.append(resolved_expr[0])
else:
parsed_args.append(arg)

result = self.func(session, *parsed_args)

# Semi-structured types are serialized in json
if isinstance(
self._return_type,
(
ArrayType,
MapType,
StructType,
),
) and not isinstance(result, DataFrame):
result = json.dumps(result, indent=2, cls=CUSTOM_JSON_ENCODER)

return result


class MockStoredProcedureRegistration(StoredProcedureRegistration):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._registry: Dict[str, Callable] = dict()

def register_from_file(
self,
file_path: str,
func_name: str,
return_type: Optional[DataType] = None,
input_types: Optional[List[DataType]] = None,
name: Optional[Union[str, Iterable[str]]] = None,
is_permanent: bool = False,
stage_location: Optional[str] = None,
imports: Optional[List[Union[str, Tuple[str, str]]]] = None,
packages: Optional[List[Union[str, ModuleType]]] = None,
replace: bool = False,
if_not_exists: bool = False,
parallel: int = 4,
execute_as: typing.Literal["caller", "owner"] = "owner",
strict: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
*,
statement_params: Optional[Dict[str, str]] = None,
source_code_display: bool = True,
skip_upload_on_content_match: bool = False,
) -> StoredProcedure:
LocalTestOOBTelemetryService.get_instance().log_not_supported_error(
external_feature_name="register sproc from file",
internal_feature_name="MockStoredProcedureRegistration.register_from_file",
parameters_info={},
raise_error=NotImplementedError,
)

def _do_register_sp(
self,
func: Union[Callable, Tuple[str, str]],
return_type: DataType,
input_types: List[DataType],
sp_name: str,
stage_location: Optional[str],
imports: Optional[List[Union[str, Tuple[str, str]]]],
packages: Optional[List[Union[str, ModuleType]]],
replace: bool,
if_not_exists: bool,
parallel: int,
strict: bool,
*,
source_code_display: bool = False,
statement_params: Optional[Dict[str, str]] = None,
execute_as: typing.Literal["caller", "owner"] = "owner",
anonymous: bool = False,
api_call_source: str,
skip_upload_on_content_match: bool = False,
is_permanent: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
force_inline_code: bool = False,
) -> StoredProcedure:
(
udf_name,
is_pandas_udf,
is_dataframe_input,
return_type,
input_types,
) = process_registration_inputs(
self._session,
TempObjectType.PROCEDURE,
func,
return_type,
input_types,
sp_name,
anonymous,
)

if is_pandas_udf:
raise TypeError("pandas stored procedure is not supported")

if packages or imports:
LocalTestOOBTelemetryService.get_instance().log_not_supported_error(
external_feature_name="uploading imports and packages for sprocs",
internal_feature_name="MockStoredProcedureRegistration._do_register_sp",
parameters_info={},
raise_error=NotImplementedError,
)

check_python_runtime_version(self._session._runtime_version_from_requirement)

if udf_name in self._registry and not replace:
raise SnowparkSQLException(
f"002002 (42710): SQL compilation error: \nObject '{udf_name}' already exists.",
error_code="1304",
)

sproc = MockStoredProcedure(
func,
return_type,
input_types,
udf_name,
execute_as=execute_as,
)

self._registry[udf_name] = sproc

return sproc

def call(
self,
sproc_name: str,
*args: Any,
session: Optional["snowflake.snowpark.session.Session"] = None,
statement_params: Optional[Dict[str, str]] = None,
):

if sproc_name not in self._registry:
raise SnowparkSQLException(
f"[Local Testing] sproc {sproc_name} does not exist."
)

return self._registry[sproc_name](
*args, session=session, statement_params=statement_params
)
15 changes: 9 additions & 6 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
_extract_schema_and_data_from_pandas_df,
)
from snowflake.snowpark.mock._plan_builder import MockSnowflakePlanBuilder
from snowflake.snowpark.mock._stored_procedure import MockStoredProcedureRegistration
from snowflake.snowpark.mock._udf import MockUDFRegistration
from snowflake.snowpark.query_history import QueryHistory
from snowflake.snowpark.row import Row
Expand Down Expand Up @@ -441,12 +442,14 @@ def __init__(

if isinstance(conn, MockServerConnection):
self._udf_registration = MockUDFRegistration(self)
self._sp_registration = MockStoredProcedureRegistration(self)
else:
self._udf_registration = UDFRegistration(self)
self._sp_registration = StoredProcedureRegistration(self)

self._udtf_registration = UDTFRegistration(self)
self._udaf_registration = UDAFRegistration(self)
self._sp_registration = StoredProcedureRegistration(self)

self._plan_builder = (
SnowflakePlanBuilder(self)
if isinstance(self._conn, ServerConnection)
Expand Down Expand Up @@ -2805,11 +2808,6 @@ def sproc(self) -> StoredProcedureRegistration:
Returns a :class:`stored_procedure.StoredProcedureRegistration` object that you can use to register stored procedures.
See details of how to use this object in :class:`stored_procedure.StoredProcedureRegistration`.
"""
if isinstance(self, MockServerConnection):
self._conn.log_not_supported_error(
external_feature_name="Session.sproc",
raise_error=NotImplementedError,
)
return self._sp_registration

def _infer_is_return_table(
Expand Down Expand Up @@ -2917,6 +2915,11 @@ def _call(
is_return_table: When set to a non-null value, it signifies whether the return type of sproc_name
is a table return type. This skips infer check and returns a dataframe with appropriate sql call.
"""
if isinstance(self._sp_registration, MockStoredProcedureRegistration):
return self._sp_registration.call(
sproc_name, *args, session=self, statement_params=statement_params
)

validate_object_name(sproc_name)
query = generate_call_python_sp_sql(self, sproc_name, *args)

Expand Down
17 changes: 13 additions & 4 deletions src/snowflake/snowpark/stored_procedure.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,11 @@ def __init__(
self._anonymous_sp_sql = anonymous_sp_sql
self._is_return_table = isinstance(return_type, StructType)

def __call__(
def _validate_call(
self,
*args: Any,
args: List[Any],
session: Optional["snowflake.snowpark.session.Session"] = None,
statement_params: Optional[Dict[str, str]] = None,
) -> Any:
):
if args and isinstance(args[0], snowflake.snowpark.session.Session):
if session:
raise ValueError(
Expand All @@ -98,6 +97,16 @@ def __call__(
f"Incorrect number of arguments passed to the stored procedure. Expected: {len(self._input_types)}, Found: {len(args)}"
)

return args, session

def __call__(
self,
*args: Any,
session: Optional["snowflake.snowpark.session.Session"] = None,
statement_params: Optional[Dict[str, str]] = None,
) -> Any:
args, session = self._validate_call(args, session)

session._conn._telemetry_client.send_function_usage_telemetry(
"StoredProcedure.__call__", TelemetryField.FUNC_CAT_USAGE.value
)
Expand Down
2 changes: 2 additions & 0 deletions tests/integ/scala/test_function_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ def test_random(session):
df.select(random()).collect()


@pytest.mark.localtest
def test_sqrt(session):
Utils.check_answer(
TestData.test_data1(session).select(sqrt(col("NUM"))),
Expand Down Expand Up @@ -552,6 +553,7 @@ def test_log(session):
)


@pytest.mark.localtest
def test_pow(session):
Utils.check_answer(
TestData.double2(session).select(pow(col("A"), col("B"))),
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ def test_strtok_to_array(session):
assert res[0] == "a" and res[1] == "b" and res[2] == "c"


@pytest.mark.local
@pytest.mark.localtest
@pytest.mark.parametrize("use_col", [True, False])
@pytest.mark.parametrize(
"values,expected",
Expand All @@ -431,7 +431,7 @@ def test_greatest(session, use_col, values, expected):
assert res[0][0] == expected


@pytest.mark.local
@pytest.mark.localtest
@pytest.mark.parametrize("use_col", [True, False])
@pytest.mark.parametrize(
"values,expected",
Expand Down
Loading

0 comments on commit 12d8620

Please sign in to comment.