Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Snowpark] Add arg for volatile/immutable UDF #1038

Merged
merged 1 commit into from
Sep 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@

### New Features

- Added support for VOLATILE/IMMUTABLE keyword when registering UDFs.
- Added support for specifying clustering keys when saving dataframes using `DataFrame.save_as_table`.
- Accept `Iterable` objects input for `schema` when creating dataframes using `Session.create_dataframe`.

5 changes: 4 additions & 1 deletion src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
@@ -971,6 +971,7 @@ def create_python_udf_or_sp(
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
sfc-gh-aling marked this conversation as resolved.
Show resolved Hide resolved
sfc-gh-aling marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
runtime_version = (
f"{sys.version_info[0]}.{sys.version_info[1]}"
@@ -1010,8 +1011,9 @@ def create_python_udf_or_sp(
if inline_python_code
else ""
)

mutability = "IMMUTABLE" if immutable else "VOLATILE"
sfc-gh-tbao marked this conversation as resolved.
Show resolved Hide resolved
strict_as_sql = "\nSTRICT" if strict else ""

external_access_integrations_in_sql = (
f"\nEXTERNAL_ACCESS_INTEGRATIONS=({','.join(external_access_integrations)})"
if external_access_integrations
@@ -1028,6 +1030,7 @@ def create_python_udf_or_sp(
{"" if is_permanent else "TEMPORARY"} {"SECURE" if secure else ""} {object_type.value.replace("_", " ")} {"IF NOT EXISTS" if if_not_exists else ""} {object_name}({sql_func_args})
{return_sql}
LANGUAGE PYTHON {strict_as_sql}
{mutability}
RUNTIME_VERSION={runtime_version}
{imports_in_sql}
{packages_in_sql}
18 changes: 18 additions & 0 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
@@ -6532,6 +6532,7 @@ def udf(
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
sfc-gh-aling marked this conversation as resolved.
Show resolved Hide resolved
) -> Union[UserDefinedFunction, functools.partial]:
"""Registers a Python function as a Snowflake Python UDF and returns the UDF.

@@ -6615,6 +6616,7 @@ def udf(
The secrets can be accessed from handler code. The secrets specified as values must
also be specified in the external access integration and the keys are strings used to
retrieve the secrets using secret API.
immutable: Whether the UDF result is deterministic or not for the same input.

Returns:
A UDF function that can be called with :class:`~snowflake.snowpark.Column` expressions.
@@ -6703,6 +6705,7 @@ def udf(
secure=secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
)
else:
return session.udf.register(
@@ -6724,6 +6727,7 @@ def udf(
secure=secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
)


@@ -6746,6 +6750,7 @@ def udtf(
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
) -> Union[UserDefinedTableFunction, functools.partial]:
"""Registers a Python class as a Snowflake Python UDTF and returns the UDTF.

@@ -6815,6 +6820,7 @@ def udtf(
The secrets can be accessed from handler code. The secrets specified as values must
also be specified in the external access integration and the keys are strings used to
retrieve the secrets using secret API.
immutable: Whether the UDTF result is deterministic or not for the same input.

Returns:
A UDTF function that can be called with :class:`~snowflake.snowpark.Column` expressions.
@@ -6913,6 +6919,7 @@ def udtf(
secure=secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
)
else:
return session.udtf.register(
@@ -6932,6 +6939,7 @@ def udtf(
secure=secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
)


@@ -6951,6 +6959,7 @@ def udaf(
session: Optional["snowflake.snowpark.session.Session"] = None,
parallel: int = 4,
statement_params: Optional[Dict[str, str]] = None,
immutable: bool = False,
) -> Union[UserDefinedAggregateFunction, functools.partial]:
"""Registers a Python class as a Snowflake Python UDAF and returns the UDAF.

@@ -7013,6 +7022,7 @@ def udaf(
Increasing the number of threads can improve performance when uploading
large UDAF files.
statement_params: Dictionary of statement level parameters to be set while executing this action.
immutable: Whether the UDAF result is deterministic or not for the same input.

Returns:
A UDAF function that can be called with :class:`~snowflake.snowpark.Column` expressions.
@@ -7115,6 +7125,7 @@ def udaf(
if_not_exists=if_not_exists,
parallel=parallel,
statement_params=statement_params,
immutable=immutable,
)
else:
return session.udaf.register(
@@ -7130,6 +7141,7 @@ def udaf(
if_not_exists=if_not_exists,
parallel=parallel,
statement_params=statement_params,
immutable=immutable,
)


@@ -7154,6 +7166,7 @@ def pandas_udf(
source_code_display: bool = True,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
) -> Union[UserDefinedFunction, functools.partial]:
"""
Registers a Python function as a vectorized UDF and returns the UDF.
@@ -7227,6 +7240,7 @@ def pandas_udf(
source_code_display=source_code_display,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
)
else:
return session.udf.register(
@@ -7249,6 +7263,7 @@ def pandas_udf(
source_code_display=source_code_display,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
)


@@ -7271,6 +7286,7 @@ def pandas_udtf(
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
) -> Union[UserDefinedTableFunction, functools.partial]:
"""Registers a Python class as a vectorized Python UDTF and returns the UDTF.

@@ -7370,6 +7386,7 @@ def pandas_udtf(
secure=secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
)
else:
return session.udtf.register(
@@ -7389,6 +7406,7 @@ def pandas_udtf(
secure=secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
)


8 changes: 8 additions & 0 deletions src/snowflake/snowpark/udaf.py
Original file line number Diff line number Diff line change
@@ -326,6 +326,7 @@ def register(
*,
statement_params: Optional[Dict[str, str]] = None,
source_code_display: bool = True,
immutable: bool = False,
**kwargs,
) -> UserDefinedAggregateFunction:
"""
@@ -391,6 +392,7 @@ def register(
The source code is dynamically generated therefore it may not be identical to how the
`func` is originally defined. The default is ``True``.
If it is ``False``, source code will not be generated or displayed.
immutable: Whether the UDAF result is deterministic or not for the same input.

See Also:
- :func:`~snowflake.snowpark.functions.udaf`
@@ -425,6 +427,7 @@ def register(
source_code_display=source_code_display,
api_call_source="UDAFRegistration.register",
is_permanent=is_permanent,
immutable=immutable,
)

def register_from_file(
@@ -445,6 +448,7 @@ def register_from_file(
statement_params: Optional[Dict[str, str]] = None,
source_code_display: bool = True,
skip_upload_on_content_match: bool = False,
immutable: bool = False,
) -> UserDefinedAggregateFunction:
"""
Registers a Python class as a Snowflake Python UDAF from a Python or zip file,
@@ -518,6 +522,7 @@ def register_from_file(
skip_upload_on_content_match: When set to ``True`` and a version of source file already exists on stage, the given source
file will be uploaded to stage only if the contents of the current file differ from the remote file on stage. Defaults
to ``False``.
immutable: Whether the UDAF result is deterministic or not for the same input.

Note::
The type hints can still be extracted from the local source Python file if they
@@ -555,6 +560,7 @@ def register_from_file(
api_call_source="UDAFRegistration.register_from_file",
skip_upload_on_content_match=skip_upload_on_content_match,
is_permanent=is_permanent,
immutable=immutable,
)

def _do_register_udaf(
@@ -575,6 +581,7 @@ def _do_register_udaf(
api_call_source: str,
skip_upload_on_content_match: bool = False,
is_permanent: bool = False,
immutable: bool = False,
) -> UserDefinedAggregateFunction:
# get the udaf name, return and input types
(udaf_name, _, _, return_type, input_types,) = process_registration_inputs(
@@ -635,6 +642,7 @@ def _do_register_udaf(
if_not_exists=if_not_exists,
inline_python_code=code,
api_call_source=api_call_source,
immutable=immutable,
)
# an exception might happen during registering a udaf
# (e.g., a dependency might not be found on the stage),
8 changes: 8 additions & 0 deletions src/snowflake/snowpark/udf.py
Original file line number Diff line number Diff line change
@@ -496,6 +496,7 @@ def register(
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
*,
statement_params: Optional[Dict[str, str]] = None,
source_code_display: bool = True,
@@ -581,6 +582,7 @@ def register(
The secrets can be accessed from handler code. The secrets specified as values must
also be specified in the external access integration and the keys are strings used to
retrieve the secrets using secret API.
immutable: Whether the UDF result is deterministic or not for the same input.
See Also:
- :func:`~snowflake.snowpark.functions.udf`
- :meth:`register_from_file`
@@ -615,6 +617,7 @@ def register(
secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
statement_params=statement_params,
source_code_display=source_code_display,
api_call_source="UDFRegistration.register"
@@ -640,6 +643,7 @@ def register_from_file(
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
*,
statement_params: Optional[Dict[str, str]] = None,
source_code_display: bool = True,
@@ -728,6 +732,7 @@ def register_from_file(
The secrets can be accessed from handler code. The secrets specified as values must
also be specified in the external access integration and the keys are strings used to
retrieve the secrets using secret API.
immutable: Whether the UDF result is deterministic or not for the same input.

Note::
The type hints can still be extracted from the local source Python file if they
@@ -760,6 +765,7 @@ def register_from_file(
secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
statement_params=statement_params,
source_code_display=source_code_display,
api_call_source="UDFRegistration.register_from_file",
@@ -785,6 +791,7 @@ def _do_register_udf(
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
*,
statement_params: Optional[Dict[str, str]] = None,
source_code_display: bool = True,
@@ -867,6 +874,7 @@ def _do_register_udf(
secure=secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
)
# an exception might happen during registering a udf
# (e.g., a dependency might not be found on the stage),
8 changes: 8 additions & 0 deletions src/snowflake/snowpark/udtf.py
Original file line number Diff line number Diff line change
@@ -424,6 +424,7 @@ def register(
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
*,
statement_params: Optional[Dict[str, str]] = None,
) -> UserDefinedTableFunction:
@@ -493,6 +494,7 @@ def register(
The secrets can be accessed from handler code. The secrets specified as values must
also be specified in the external access integration and the keys are strings used to
retrieve the secrets using secret API.
immutable: Whether the UDTF result is deterministic or not for the same input.

See Also:
- :func:`~snowflake.snowpark.functions.udtf`
@@ -524,6 +526,7 @@ def register(
secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
statement_params=statement_params,
api_call_source="UDTFRegistration.register",
is_permanent=is_permanent,
@@ -549,6 +552,7 @@ def register_from_file(
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
*,
statement_params: Optional[Dict[str, str]] = None,
skip_upload_on_content_match: bool = False,
@@ -628,6 +632,7 @@ def register_from_file(
The secrets can be accessed from handler code. The secrets specified as values must
also be specified in the external access integration and the keys are strings used to
retrieve the secrets using secret API.
immutable: Whether the UDTF result is deterministic or not for the same input.

Note::
The type hints can still be extracted from the local source Python file if they
@@ -660,6 +665,7 @@ def register_from_file(
secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
statement_params=statement_params,
api_call_source="UDTFRegistration.register_from_file",
skip_upload_on_content_match=skip_upload_on_content_match,
@@ -682,6 +688,7 @@ def _do_register_udtf(
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
*,
statement_params: Optional[Dict[str, str]] = None,
api_call_source: str,
@@ -778,6 +785,7 @@ def _do_register_udtf(
secure=secure,
external_access_integrations=external_access_integrations,
secrets=secrets,
immutable=immutable,
)
# an exception might happen during registering a udtf
# (e.g., a dependency might not be found on the stage),
2 changes: 2 additions & 0 deletions tests/integ/test_udaf.py
Original file line number Diff line number Diff line change
@@ -42,6 +42,7 @@ def finish(self):
PythonSumUDAFHandler,
return_type=IntegerType(),
input_types=[IntegerType()],
immutable=True,
)
df = session.create_dataframe([[1, 3], [1, 4], [2, 5], [2, 6]]).to_df("a", "b")
Utils.check_answer(df.agg(sum_udaf("a")), [Row(6)])
@@ -394,6 +395,7 @@ def test_register_udaf_from_file_without_type_hints(session, resources_path):
"MyUDAFWithoutTypeHints",
return_type=IntegerType(),
input_types=[IntegerType()],
immutable=True,
)
df = session.create_dataframe([[1, 3], [1, 4], [2, 5], [2, 6]]).to_df("a", "b")
Utils.check_answer(df.agg(sum_udaf("a")), [Row(6)])
Loading