Skip to content

Commit

Permalink
[Snowpark] Add arg for volatile/immutable UDF
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-tbao committed Sep 5, 2023
1 parent 78b37d8 commit 6eb6915
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### New Features

- Added support for VOLATILE/IMMUTABLE keyword when registering UDFs.
- Added support for specifying clustering keys when saving dataframes using `DataFrame.save_as_table`.
- Accept `Iterable` objects input for `schema` when creating dataframes using `Session.create_dataframe`.

Expand Down
5 changes: 4 additions & 1 deletion src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,6 +971,7 @@ def create_python_udf_or_sp(
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
) -> None:
runtime_version = (
f"{sys.version_info[0]}.{sys.version_info[1]}"
Expand Down Expand Up @@ -1010,8 +1011,9 @@ def create_python_udf_or_sp(
if inline_python_code
else ""
)

mutability = "IMMUTABLE" if immutable else "VOLATILE"
strict_as_sql = "\nSTRICT" if strict else ""

external_access_integrations_in_sql = (
f"\nEXTERNAL_ACCESS_INTEGRATIONS=({','.join(external_access_integrations)})"
if external_access_integrations
Expand All @@ -1028,6 +1030,7 @@ def create_python_udf_or_sp(
{"" if is_permanent else "TEMPORARY"} {"SECURE" if secure else ""} {object_type.value.replace("_", " ")} {"IF NOT EXISTS" if if_not_exists else ""} {object_name}({sql_func_args})
{return_sql}
LANGUAGE PYTHON {strict_as_sql}
{mutability}
RUNTIME_VERSION={runtime_version}
{imports_in_sql}
{packages_in_sql}
Expand Down
18 changes: 18 additions & 0 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6532,6 +6532,7 @@ def udf(
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
) -> Union[UserDefinedFunction, functools.partial]:
"""Registers a Python function as a Snowflake Python UDF and returns the UDF.
Expand Down Expand Up @@ -6615,6 +6616,7 @@ def udf(
The secrets can be accessed from handler code. The secrets specified as values must
also be specified in the external access integration and the keys are strings used to
retrieve the secrets using secret API.
immutable: Whether the UDF result is deterministic or not for the same input.
Returns:
A UDF function that can be called with :class:`~snowflake.snowpark.Column` expressions.
Expand Down Expand Up @@ -6703,6 +6705,7 @@ def udf(
secure=secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
)
else:
return session.udf.register(
Expand All @@ -6724,6 +6727,7 @@ def udf(
secure=secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
)


Expand All @@ -6746,6 +6750,7 @@ def udtf(
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
) -> Union[UserDefinedTableFunction, functools.partial]:
"""Registers a Python class as a Snowflake Python UDTF and returns the UDTF.
Expand Down Expand Up @@ -6815,6 +6820,7 @@ def udtf(
The secrets can be accessed from handler code. The secrets specified as values must
also be specified in the external access integration and the keys are strings used to
retrieve the secrets using secret API.
immutable: Whether the UDTF result is deterministic or not for the same input.
Returns:
A UDTF function that can be called with :class:`~snowflake.snowpark.Column` expressions.
Expand Down Expand Up @@ -6913,6 +6919,7 @@ def udtf(
secure=secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
)
else:
return session.udtf.register(
Expand All @@ -6932,6 +6939,7 @@ def udtf(
secure=secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
)


Expand All @@ -6951,6 +6959,7 @@ def udaf(
session: Optional["snowflake.snowpark.session.Session"] = None,
parallel: int = 4,
statement_params: Optional[Dict[str, str]] = None,
immutable: bool = False,
) -> Union[UserDefinedAggregateFunction, functools.partial]:
"""Registers a Python class as a Snowflake Python UDAF and returns the UDAF.
Expand Down Expand Up @@ -7013,6 +7022,7 @@ def udaf(
Increasing the number of threads can improve performance when uploading
large UDAF files.
statement_params: Dictionary of statement level parameters to be set while executing this action.
immutable: Whether the UDAF result is deterministic or not for the same input.
Returns:
A UDAF function that can be called with :class:`~snowflake.snowpark.Column` expressions.
Expand Down Expand Up @@ -7115,6 +7125,7 @@ def udaf(
if_not_exists=if_not_exists,
parallel=parallel,
statement_params=statement_params,
immutable=immutable,
)
else:
return session.udaf.register(
Expand All @@ -7130,6 +7141,7 @@ def udaf(
if_not_exists=if_not_exists,
parallel=parallel,
statement_params=statement_params,
immutable=immutable,
)


Expand All @@ -7154,6 +7166,7 @@ def pandas_udf(
source_code_display: bool = True,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
) -> Union[UserDefinedFunction, functools.partial]:
"""
Registers a Python function as a vectorized UDF and returns the UDF.
Expand Down Expand Up @@ -7227,6 +7240,7 @@ def pandas_udf(
source_code_display=source_code_display,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
)
else:
return session.udf.register(
Expand All @@ -7249,6 +7263,7 @@ def pandas_udf(
source_code_display=source_code_display,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
)


Expand All @@ -7271,6 +7286,7 @@ def pandas_udtf(
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
) -> Union[UserDefinedTableFunction, functools.partial]:
"""Registers a Python class as a vectorized Python UDTF and returns the UDTF.
Expand Down Expand Up @@ -7370,6 +7386,7 @@ def pandas_udtf(
secure=secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
)
else:
return session.udtf.register(
Expand All @@ -7389,6 +7406,7 @@ def pandas_udtf(
secure=secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
)


Expand Down
8 changes: 8 additions & 0 deletions src/snowflake/snowpark/udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def register(
*,
statement_params: Optional[Dict[str, str]] = None,
source_code_display: bool = True,
immutable: bool = False,
**kwargs,
) -> UserDefinedAggregateFunction:
"""
Expand Down Expand Up @@ -391,6 +392,7 @@ def register(
The source code is dynamically generated therefore it may not be identical to how the
`func` is originally defined. The default is ``True``.
If it is ``False``, source code will not be generated or displayed.
immutable: Whether the UDAF result is deterministic or not for the same input.
See Also:
- :func:`~snowflake.snowpark.functions.udaf`
Expand Down Expand Up @@ -425,6 +427,7 @@ def register(
source_code_display=source_code_display,
api_call_source="UDAFRegistration.register",
is_permanent=is_permanent,
immutable=immutable,
)

def register_from_file(
Expand All @@ -445,6 +448,7 @@ def register_from_file(
statement_params: Optional[Dict[str, str]] = None,
source_code_display: bool = True,
skip_upload_on_content_match: bool = False,
immutable: bool = False,
) -> UserDefinedAggregateFunction:
"""
Registers a Python class as a Snowflake Python UDAF from a Python or zip file,
Expand Down Expand Up @@ -518,6 +522,7 @@ def register_from_file(
skip_upload_on_content_match: When set to ``True`` and a version of source file already exists on stage, the given source
file will be uploaded to stage only if the contents of the current file differ from the remote file on stage. Defaults
to ``False``.
immutable: Whether the UDAF result is deterministic or not for the same input.
Note::
The type hints can still be extracted from the local source Python file if they
Expand Down Expand Up @@ -555,6 +560,7 @@ def register_from_file(
api_call_source="UDAFRegistration.register_from_file",
skip_upload_on_content_match=skip_upload_on_content_match,
is_permanent=is_permanent,
immutable=immutable,
)

def _do_register_udaf(
Expand All @@ -575,6 +581,7 @@ def _do_register_udaf(
api_call_source: str,
skip_upload_on_content_match: bool = False,
is_permanent: bool = False,
immutable: bool = False,
) -> UserDefinedAggregateFunction:
# get the udaf name, return and input types
(udaf_name, _, _, return_type, input_types,) = process_registration_inputs(
Expand Down Expand Up @@ -635,6 +642,7 @@ def _do_register_udaf(
if_not_exists=if_not_exists,
inline_python_code=code,
api_call_source=api_call_source,
immutable=immutable,
)
# an exception might happen during registering a udaf
# (e.g., a dependency might not be found on the stage),
Expand Down
8 changes: 8 additions & 0 deletions src/snowflake/snowpark/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ def register(
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
*,
statement_params: Optional[Dict[str, str]] = None,
source_code_display: bool = True,
Expand Down Expand Up @@ -581,6 +582,7 @@ def register(
The secrets can be accessed from handler code. The secrets specified as values must
also be specified in the external access integration and the keys are strings used to
retrieve the secrets using secret API.
immutable: Whether the UDF result is deterministic or not for the same input.
See Also:
- :func:`~snowflake.snowpark.functions.udf`
- :meth:`register_from_file`
Expand Down Expand Up @@ -615,6 +617,7 @@ def register(
secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
statement_params=statement_params,
source_code_display=source_code_display,
api_call_source="UDFRegistration.register"
Expand All @@ -640,6 +643,7 @@ def register_from_file(
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
*,
statement_params: Optional[Dict[str, str]] = None,
source_code_display: bool = True,
Expand Down Expand Up @@ -728,6 +732,7 @@ def register_from_file(
The secrets can be accessed from handler code. The secrets specified as values must
also be specified in the external access integration and the keys are strings used to
retrieve the secrets using secret API.
immutable: Whether the UDF result is deterministic or not for the same input.
Note::
The type hints can still be extracted from the local source Python file if they
Expand Down Expand Up @@ -760,6 +765,7 @@ def register_from_file(
secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
statement_params=statement_params,
source_code_display=source_code_display,
api_call_source="UDFRegistration.register_from_file",
Expand All @@ -785,6 +791,7 @@ def _do_register_udf(
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
*,
statement_params: Optional[Dict[str, str]] = None,
source_code_display: bool = True,
Expand Down Expand Up @@ -867,6 +874,7 @@ def _do_register_udf(
secure=secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
)
# an exception might happen during registering a udf
# (e.g., a dependency might not be found on the stage),
Expand Down
8 changes: 8 additions & 0 deletions src/snowflake/snowpark/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def register(
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
*,
statement_params: Optional[Dict[str, str]] = None,
) -> UserDefinedTableFunction:
Expand Down Expand Up @@ -493,6 +494,7 @@ def register(
The secrets can be accessed from handler code. The secrets specified as values must
also be specified in the external access integration and the keys are strings used to
retrieve the secrets using secret API.
immutable: Whether the UDTF result is deterministic or not for the same input.
See Also:
- :func:`~snowflake.snowpark.functions.udtf`
Expand Down Expand Up @@ -524,6 +526,7 @@ def register(
secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
statement_params=statement_params,
api_call_source="UDTFRegistration.register",
is_permanent=is_permanent,
Expand All @@ -549,6 +552,7 @@ def register_from_file(
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
*,
statement_params: Optional[Dict[str, str]] = None,
skip_upload_on_content_match: bool = False,
Expand Down Expand Up @@ -628,6 +632,7 @@ def register_from_file(
The secrets can be accessed from handler code. The secrets specified as values must
also be specified in the external access integration and the keys are strings used to
retrieve the secrets using secret API.
immutable: Whether the UDTF result is deterministic or not for the same input.
Note::
The type hints can still be extracted from the local source Python file if they
Expand Down Expand Up @@ -660,6 +665,7 @@ def register_from_file(
secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
statement_params=statement_params,
api_call_source="UDTFRegistration.register_from_file",
skip_upload_on_content_match=skip_upload_on_content_match,
Expand All @@ -682,6 +688,7 @@ def _do_register_udtf(
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
*,
statement_params: Optional[Dict[str, str]] = None,
api_call_source: str,
Expand Down Expand Up @@ -778,6 +785,7 @@ def _do_register_udtf(
secure=secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
)
# an exception might happen during registering a udtf
# (e.g., a dependency might not be found on the stage),
Expand Down
2 changes: 2 additions & 0 deletions tests/integ/test_udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def finish(self):
PythonSumUDAFHandler,
return_type=IntegerType(),
input_types=[IntegerType()],
immutable=True,
)
df = session.create_dataframe([[1, 3], [1, 4], [2, 5], [2, 6]]).to_df("a", "b")
Utils.check_answer(df.agg(sum_udaf("a")), [Row(6)])
Expand Down Expand Up @@ -394,6 +395,7 @@ def test_register_udaf_from_file_without_type_hints(session, resources_path):
"MyUDAFWithoutTypeHints",
return_type=IntegerType(),
input_types=[IntegerType()],
immutable=True,
)
df = session.create_dataframe([[1, 3], [1, 4], [2, 5], [2, 6]]).to_df("a", "b")
Utils.check_answer(df.agg(sum_udaf("a")), [Row(6)])
Expand Down
Loading

0 comments on commit 6eb6915

Please sign in to comment.