From 6eb69151bdacee42c3878a6af209dd11d5282172 Mon Sep 17 00:00:00 2001 From: Tianshu Bao Date: Tue, 5 Sep 2023 11:12:30 -0700 Subject: [PATCH] [Snowpark] Add arg for volatile/immutable UDF --- CHANGELOG.md | 1 + src/snowflake/snowpark/_internal/udf_utils.py | 5 ++++- src/snowflake/snowpark/functions.py | 18 ++++++++++++++++++ src/snowflake/snowpark/udaf.py | 8 ++++++++ src/snowflake/snowpark/udf.py | 8 ++++++++ src/snowflake/snowpark/udtf.py | 8 ++++++++ tests/integ/test_udaf.py | 2 ++ tests/integ/test_udf.py | 11 +++++++++-- tests/integ/test_udtf.py | 2 ++ 9 files changed, 60 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fb03cda5d8c..ee89c4820c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### New Features +- Added support for VOLATILE/IMMUTABLE keyword when registering UDFs. - Added support for specifying clustering keys when saving dataframes using `DataFrame.save_as_table`. - Accept `Iterable` objects input for `schema` when creating dataframes using `Session.create_dataframe`. diff --git a/src/snowflake/snowpark/_internal/udf_utils.py b/src/snowflake/snowpark/_internal/udf_utils.py index 2ab73b46e84..f18f29f8605 100644 --- a/src/snowflake/snowpark/_internal/udf_utils.py +++ b/src/snowflake/snowpark/_internal/udf_utils.py @@ -971,6 +971,7 @@ def create_python_udf_or_sp( secure: bool = False, external_access_integrations: Optional[List[str]] = None, secrets: Optional[Dict[str, str]] = None, + immutable: bool = False, ) -> None: runtime_version = ( f"{sys.version_info[0]}.{sys.version_info[1]}" @@ -1010,8 +1011,9 @@ def create_python_udf_or_sp( if inline_python_code else "" ) - + mutability = "IMMUTABLE" if immutable else "VOLATILE" strict_as_sql = "\nSTRICT" if strict else "" + external_access_integrations_in_sql = ( f"\nEXTERNAL_ACCESS_INTEGRATIONS=({','.join(external_access_integrations)})" if external_access_integrations @@ -1028,6 +1030,7 @@ def create_python_udf_or_sp( {"" if is_permanent else "TEMPORARY"} {"SECURE" if secure else ""} {object_type.value.replace("_", " ")} {"IF NOT EXISTS" if if_not_exists else ""} {object_name}({sql_func_args}) {return_sql} LANGUAGE PYTHON {strict_as_sql} +{mutability} RUNTIME_VERSION={runtime_version} {imports_in_sql} {packages_in_sql} diff --git a/src/snowflake/snowpark/functions.py b/src/snowflake/snowpark/functions.py index ee1098975d6..65190fd9739 100644 --- a/src/snowflake/snowpark/functions.py +++ b/src/snowflake/snowpark/functions.py @@ -6532,6 +6532,7 @@ def udf( secure: bool = False, external_access_integrations: Optional[List[str]] = None, secrets: Optional[Dict[str, str]] = None, + immutable: bool = False, ) -> Union[UserDefinedFunction, functools.partial]: """Registers a Python function as a Snowflake Python UDF and returns the UDF. @@ -6615,6 +6616,7 @@ def udf( The secrets can be accessed from handler code. The secrets specified as values must also be specified in the external access integration and the keys are strings used to retrieve the secrets using secret API. + immutable: Whether the UDF result is deterministic or not for the same input. Returns: A UDF function that can be called with :class:`~snowflake.snowpark.Column` expressions. @@ -6703,6 +6705,7 @@ def udf( secure=secure, external_access_integrations=external_access_integrations, secrets=secrets, + immutable=immutable, ) else: return session.udf.register( @@ -6724,6 +6727,7 @@ def udf( secure=secure, external_access_integrations=external_access_integrations, secrets=secrets, + immutable=immutable, ) @@ -6746,6 +6750,7 @@ def udtf( secure: bool = False, external_access_integrations: Optional[List[str]] = None, secrets: Optional[Dict[str, str]] = None, + immutable: bool = False, ) -> Union[UserDefinedTableFunction, functools.partial]: """Registers a Python class as a Snowflake Python UDTF and returns the UDTF. @@ -6815,6 +6820,7 @@ def udtf( The secrets can be accessed from handler code. The secrets specified as values must also be specified in the external access integration and the keys are strings used to retrieve the secrets using secret API. + immutable: Whether the UDTF result is deterministic or not for the same input. Returns: A UDTF function that can be called with :class:`~snowflake.snowpark.Column` expressions. @@ -6913,6 +6919,7 @@ def udtf( secure=secure, external_access_integrations=external_access_integrations, secrets=secrets, + immutable=immutable, ) else: return session.udtf.register( @@ -6932,6 +6939,7 @@ def udtf( secure=secure, external_access_integrations=external_access_integrations, secrets=secrets, + immutable=immutable, ) @@ -6951,6 +6959,7 @@ def udaf( session: Optional["snowflake.snowpark.session.Session"] = None, parallel: int = 4, statement_params: Optional[Dict[str, str]] = None, + immutable: bool = False, ) -> Union[UserDefinedAggregateFunction, functools.partial]: """Registers a Python class as a Snowflake Python UDAF and returns the UDAF. @@ -7013,6 +7022,7 @@ def udaf( Increasing the number of threads can improve performance when uploading large UDAF files. statement_params: Dictionary of statement level parameters to be set while executing this action. + immutable: Whether the UDAF result is deterministic or not for the same input. Returns: A UDAF function that can be called with :class:`~snowflake.snowpark.Column` expressions. @@ -7115,6 +7125,7 @@ def udaf( if_not_exists=if_not_exists, parallel=parallel, statement_params=statement_params, + immutable=immutable, ) else: return session.udaf.register( @@ -7130,6 +7141,7 @@ def udaf( if_not_exists=if_not_exists, parallel=parallel, statement_params=statement_params, + immutable=immutable, ) @@ -7154,6 +7166,7 @@ def pandas_udf( source_code_display: bool = True, external_access_integrations: Optional[List[str]] = None, secrets: Optional[Dict[str, str]] = None, + immutable: bool = False, ) -> Union[UserDefinedFunction, functools.partial]: """ Registers a Python function as a vectorized UDF and returns the UDF. @@ -7227,6 +7240,7 @@ def pandas_udf( source_code_display=source_code_display, external_access_integrations=external_access_integrations, secrets=secrets, + immutable=immutable, ) else: return session.udf.register( @@ -7249,6 +7263,7 @@ def pandas_udf( source_code_display=source_code_display, external_access_integrations=external_access_integrations, secrets=secrets, + immutable=immutable, ) @@ -7271,6 +7286,7 @@ def pandas_udtf( secure: bool = False, external_access_integrations: Optional[List[str]] = None, secrets: Optional[Dict[str, str]] = None, + immutable: bool = False, ) -> Union[UserDefinedTableFunction, functools.partial]: """Registers a Python class as a vectorized Python UDTF and returns the UDTF. @@ -7370,6 +7386,7 @@ def pandas_udtf( secure=secure, external_access_integrations=external_access_integrations, secrets=secrets, + immutable=immutable, ) else: return session.udtf.register( @@ -7389,6 +7406,7 @@ def pandas_udtf( secure=secure, external_access_integrations=external_access_integrations, secrets=secrets, + immutable=immutable, ) diff --git a/src/snowflake/snowpark/udaf.py b/src/snowflake/snowpark/udaf.py index a59dff93f54..8b2d34f227c 100644 --- a/src/snowflake/snowpark/udaf.py +++ b/src/snowflake/snowpark/udaf.py @@ -326,6 +326,7 @@ def register( *, statement_params: Optional[Dict[str, str]] = None, source_code_display: bool = True, + immutable: bool = False, **kwargs, ) -> UserDefinedAggregateFunction: """ @@ -391,6 +392,7 @@ def register( The source code is dynamically generated therefore it may not be identical to how the `func` is originally defined. The default is ``True``. If it is ``False``, source code will not be generated or displayed. + immutable: Whether the UDAF result is deterministic or not for the same input. See Also: - :func:`~snowflake.snowpark.functions.udaf` @@ -425,6 +427,7 @@ def register( source_code_display=source_code_display, api_call_source="UDAFRegistration.register", is_permanent=is_permanent, + immutable=immutable, ) def register_from_file( @@ -445,6 +448,7 @@ def register_from_file( statement_params: Optional[Dict[str, str]] = None, source_code_display: bool = True, skip_upload_on_content_match: bool = False, + immutable: bool = False, ) -> UserDefinedAggregateFunction: """ Registers a Python class as a Snowflake Python UDAF from a Python or zip file, @@ -518,6 +522,7 @@ def register_from_file( skip_upload_on_content_match: When set to ``True`` and a version of source file already exists on stage, the given source file will be uploaded to stage only if the contents of the current file differ from the remote file on stage. Defaults to ``False``. + immutable: Whether the UDAF result is deterministic or not for the same input. Note:: The type hints can still be extracted from the local source Python file if they @@ -555,6 +560,7 @@ def register_from_file( api_call_source="UDAFRegistration.register_from_file", skip_upload_on_content_match=skip_upload_on_content_match, is_permanent=is_permanent, + immutable=immutable, ) def _do_register_udaf( @@ -575,6 +581,7 @@ def _do_register_udaf( api_call_source: str, skip_upload_on_content_match: bool = False, is_permanent: bool = False, + immutable: bool = False, ) -> UserDefinedAggregateFunction: # get the udaf name, return and input types (udaf_name, _, _, return_type, input_types,) = process_registration_inputs( @@ -635,6 +642,7 @@ def _do_register_udaf( if_not_exists=if_not_exists, inline_python_code=code, api_call_source=api_call_source, + immutable=immutable, ) # an exception might happen during registering a udaf # (e.g., a dependency might not be found on the stage), diff --git a/src/snowflake/snowpark/udf.py b/src/snowflake/snowpark/udf.py index 8ae5395101d..49d57814d34 100644 --- a/src/snowflake/snowpark/udf.py +++ b/src/snowflake/snowpark/udf.py @@ -496,6 +496,7 @@ def register( secure: bool = False, external_access_integrations: Optional[List[str]] = None, secrets: Optional[Dict[str, str]] = None, + immutable: bool = False, *, statement_params: Optional[Dict[str, str]] = None, source_code_display: bool = True, @@ -581,6 +582,7 @@ def register( The secrets can be accessed from handler code. The secrets specified as values must also be specified in the external access integration and the keys are strings used to retrieve the secrets using secret API. + immutable: Whether the UDF result is deterministic or not for the same input. See Also: - :func:`~snowflake.snowpark.functions.udf` - :meth:`register_from_file` @@ -615,6 +617,7 @@ def register( secure, external_access_integrations=external_access_integrations, secrets=secrets, + immutable=immutable, statement_params=statement_params, source_code_display=source_code_display, api_call_source="UDFRegistration.register" @@ -640,6 +643,7 @@ def register_from_file( secure: bool = False, external_access_integrations: Optional[List[str]] = None, secrets: Optional[Dict[str, str]] = None, + immutable: bool = False, *, statement_params: Optional[Dict[str, str]] = None, source_code_display: bool = True, @@ -728,6 +732,7 @@ def register_from_file( The secrets can be accessed from handler code. The secrets specified as values must also be specified in the external access integration and the keys are strings used to retrieve the secrets using secret API. + immutable: Whether the UDF result is deterministic or not for the same input. Note:: The type hints can still be extracted from the local source Python file if they @@ -760,6 +765,7 @@ def register_from_file( secure, external_access_integrations=external_access_integrations, secrets=secrets, + immutable=immutable, statement_params=statement_params, source_code_display=source_code_display, api_call_source="UDFRegistration.register_from_file", @@ -785,6 +791,7 @@ def _do_register_udf( secure: bool = False, external_access_integrations: Optional[List[str]] = None, secrets: Optional[Dict[str, str]] = None, + immutable: bool = False, *, statement_params: Optional[Dict[str, str]] = None, source_code_display: bool = True, @@ -867,6 +874,7 @@ def _do_register_udf( secure=secure, external_access_integrations=external_access_integrations, secrets=secrets, + immutable=immutable, ) # an exception might happen during registering a udf # (e.g., a dependency might not be found on the stage), diff --git a/src/snowflake/snowpark/udtf.py b/src/snowflake/snowpark/udtf.py index cda852a6f80..327c029df32 100644 --- a/src/snowflake/snowpark/udtf.py +++ b/src/snowflake/snowpark/udtf.py @@ -424,6 +424,7 @@ def register( secure: bool = False, external_access_integrations: Optional[List[str]] = None, secrets: Optional[Dict[str, str]] = None, + immutable: bool = False, *, statement_params: Optional[Dict[str, str]] = None, ) -> UserDefinedTableFunction: @@ -493,6 +494,7 @@ def register( The secrets can be accessed from handler code. The secrets specified as values must also be specified in the external access integration and the keys are strings used to retrieve the secrets using secret API. + immutable: Whether the UDTF result is deterministic or not for the same input. See Also: - :func:`~snowflake.snowpark.functions.udtf` @@ -524,6 +526,7 @@ def register( secure, external_access_integrations=external_access_integrations, secrets=secrets, + immutable=immutable, statement_params=statement_params, api_call_source="UDTFRegistration.register", is_permanent=is_permanent, @@ -549,6 +552,7 @@ def register_from_file( secure: bool = False, external_access_integrations: Optional[List[str]] = None, secrets: Optional[Dict[str, str]] = None, + immutable: bool = False, *, statement_params: Optional[Dict[str, str]] = None, skip_upload_on_content_match: bool = False, @@ -628,6 +632,7 @@ def register_from_file( The secrets can be accessed from handler code. The secrets specified as values must also be specified in the external access integration and the keys are strings used to retrieve the secrets using secret API. + immutable: Whether the UDTF result is deterministic or not for the same input. Note:: The type hints can still be extracted from the local source Python file if they @@ -660,6 +665,7 @@ def register_from_file( secure, external_access_integrations=external_access_integrations, secrets=secrets, + immutable=immutable, statement_params=statement_params, api_call_source="UDTFRegistration.register_from_file", skip_upload_on_content_match=skip_upload_on_content_match, @@ -682,6 +688,7 @@ def _do_register_udtf( secure: bool = False, external_access_integrations: Optional[List[str]] = None, secrets: Optional[Dict[str, str]] = None, + immutable: bool = False, *, statement_params: Optional[Dict[str, str]] = None, api_call_source: str, @@ -778,6 +785,7 @@ def _do_register_udtf( secure=secure, external_access_integrations=external_access_integrations, secrets=secrets, + immutable=immutable, ) # an exception might happen during registering a udtf # (e.g., a dependency might not be found on the stage), diff --git a/tests/integ/test_udaf.py b/tests/integ/test_udaf.py index bdcf540d905..e6bb82d0bb4 100644 --- a/tests/integ/test_udaf.py +++ b/tests/integ/test_udaf.py @@ -42,6 +42,7 @@ def finish(self): PythonSumUDAFHandler, return_type=IntegerType(), input_types=[IntegerType()], + immutable=True, ) df = session.create_dataframe([[1, 3], [1, 4], [2, 5], [2, 6]]).to_df("a", "b") Utils.check_answer(df.agg(sum_udaf("a")), [Row(6)]) @@ -394,6 +395,7 @@ def test_register_udaf_from_file_without_type_hints(session, resources_path): "MyUDAFWithoutTypeHints", return_type=IntegerType(), input_types=[IntegerType()], + immutable=True, ) df = session.create_dataframe([[1, 3], [1, 4], [2, 5], [2, 6]]).to_df("a", "b") Utils.check_answer(df.agg(sum_udaf("a")), [Row(6)]) diff --git a/tests/integ/test_udf.py b/tests/integ/test_udf.py index 0ea837c586e..8b2f1b10bc5 100644 --- a/tests/integ/test_udf.py +++ b/tests/integ/test_udf.py @@ -125,9 +125,14 @@ def int2str(x): return str(x) return1_udf = udf(return1, return_type=StringType()) - plus1_udf = udf(plus1, return_type=IntegerType(), input_types=[IntegerType()]) + plus1_udf = udf( + plus1, return_type=IntegerType(), input_types=[IntegerType()], immutable=True + ) add_udf = udf( - add, return_type=IntegerType(), input_types=[IntegerType(), IntegerType()] + add, + return_type=IntegerType(), + input_types=[IntegerType(), IntegerType()], + immutable=True, ) int2str_udf = udf(int2str, return_type=StringType(), input_types=[IntegerType()]) pow_udf = udf( @@ -412,6 +417,7 @@ def test_register_udf_from_file(session, resources_path, tmpdir): "mod5", return_type=IntegerType(), input_types=[IntegerType()], + immutable=True, ) assert isinstance(mod5_udf.func, tuple) Utils.check_answer( @@ -1841,6 +1847,7 @@ def test_pandas_udf_return_types(session, _type, data, expected_types, expected_ lambda x: x, return_type=PandasSeriesType(_type()), input_types=[PandasSeriesType(_type())], + immutable=True, ) result_df = df.select(series_udf("a")).to_pandas() result_val = result_df.iloc[0][0] diff --git a/tests/integ/test_udtf.py b/tests/integ/test_udtf.py index 766c95c2f7d..db301f08e0b 100644 --- a/tests/integ/test_udtf.py +++ b/tests/integ/test_udtf.py @@ -103,6 +103,7 @@ def test_register_udtf_from_file_no_type_hints(session, resources_path): BinaryType(), BinaryType(), ], + immutable=True, ) assert isinstance(my_udtf.handler, tuple) df = session.table_function( @@ -638,6 +639,7 @@ def end_partition( "q3", "max", ], + immutable=True, ) assert_vectorized_udtf_result(session.table(vectorized_udtf_test_table), my_udtf)