diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index 1c73bcdc84922..e7ecf267abc1c 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -50,7 +50,7 @@ LambdaFunction, UnresolvedNamedLambdaVariable, ) -from pyspark.sql.connect.udf import _create_udf +from pyspark.sql.connect.udf import _create_py_udf from pyspark.sql import functions as pysparkfuncs from pyspark.sql.types import _from_numpy_type, DataType, StructType, ArrayType, StringType @@ -2461,6 +2461,7 @@ def unwrap_udt(col: "ColumnOrName") -> Column: def udf( f: Optional[Union[Callable[..., Any], "DataTypeOrString"]] = None, returnType: "DataTypeOrString" = StringType(), + useArrow: Optional[bool] = None, ) -> Union["UserDefinedFunctionLike", Callable[[Callable[..., Any]], "UserDefinedFunctionLike"]]: from pyspark.rdd import PythonEvalType @@ -2469,10 +2470,15 @@ def udf( # for decorator use it as a returnType return_type = f or returnType return functools.partial( - _create_udf, returnType=return_type, evalType=PythonEvalType.SQL_BATCHED_UDF + _create_py_udf, + returnType=return_type, + evalType=PythonEvalType.SQL_BATCHED_UDF, + useArrow=useArrow, ) else: - return _create_udf(f=f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF) + return _create_py_udf( + f=f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, useArrow=useArrow + ) udf.__doc__ = pysparkfuncs.udf.__doc__ diff --git a/python/pyspark/sql/connect/udf.py b/python/pyspark/sql/connect/udf.py index 9afc6e0e626a5..aab7bb3c0d3f8 100644 --- a/python/pyspark/sql/connect/udf.py +++ b/python/pyspark/sql/connect/udf.py @@ -23,6 +23,8 @@ import sys import functools +import warnings +from inspect import getfullargspec from typing import cast, Callable, Any, TYPE_CHECKING, Optional, Union from pyspark.rdd import PythonEvalType @@ -33,7 +35,7 @@ ) from pyspark.sql.connect.column import Column from pyspark.sql.connect.types import UnparsedDataType -from pyspark.sql.types import DataType, StringType +from pyspark.sql.types import ArrayType, DataType, MapType, StringType, StructType from pyspark.sql.udf import UDFRegistration as PySparkUDFRegistration @@ -47,6 +49,48 @@ from pyspark.sql.types import StringType +def _create_py_udf( + f: Callable[..., Any], + returnType: "DataTypeOrString", + evalType: int, + useArrow: Optional[bool] = None, +) -> "UserDefinedFunctionLike": + from pyspark.sql.udf import _create_arrow_py_udf + from pyspark.sql.connect.session import _active_spark_session + + if _active_spark_session is None: + is_arrow_enabled = False + else: + is_arrow_enabled = ( + _active_spark_session.conf.get("spark.sql.execution.pythonUDF.arrow.enabled") == "true" + if useArrow is None + else useArrow + ) + + regular_udf = _create_udf(f, returnType, evalType) + return_type = regular_udf.returnType + try: + is_func_with_args = len(getfullargspec(f).args) > 0 + except TypeError: + is_func_with_args = False + is_output_atomic_type = ( + not isinstance(return_type, StructType) + and not isinstance(return_type, MapType) + and not isinstance(return_type, ArrayType) + ) + if is_arrow_enabled: + if is_output_atomic_type and is_func_with_args: + return _create_arrow_py_udf(regular_udf) + else: + warnings.warn( + "Arrow optimization for Python UDFs cannot be enabled.", + UserWarning, + ) + return regular_udf + else: + return regular_udf + + def _create_udf( f: Callable[..., Any], returnType: "DataTypeOrString", diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py b/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py new file mode 100644 index 0000000000000..e4a64a7d5913e --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_arrow_python_udf.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.sql.tests.connect.test_parity_udf import UDFParityTests +from pyspark.sql.tests.test_arrow_python_udf import PythonUDFArrowTestsMixin + + +class ArrowPythonUDFParityTests(UDFParityTests, PythonUDFArrowTestsMixin): + @classmethod + def setUpClass(cls): + super(ArrowPythonUDFParityTests, cls).setUpClass() + cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true") + + @classmethod + def tearDownClass(cls): + try: + cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled") + finally: + super(ArrowPythonUDFParityTests, cls).tearDownClass() + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.connect.test_parity_arrow_python_udf import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py b/python/pyspark/sql/tests/test_arrow_python_udf.py index 14d00633cc647..681c42c6a5cd8 100644 --- a/python/pyspark/sql/tests/test_arrow_python_udf.py +++ b/python/pyspark/sql/tests/test_arrow_python_udf.py @@ -31,12 +31,7 @@ @unittest.skipIf( not have_pandas or not have_pyarrow, pandas_requirement_message or pyarrow_requirement_message ) -class PythonUDFArrowTests(BaseUDFTestsMixin, ReusedSQLTestCase): - @classmethod - def setUpClass(cls): - super(PythonUDFArrowTests, cls).setUpClass() - cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true") - +class PythonUDFArrowTestsMixin(BaseUDFTestsMixin): @unittest.skip("Unrelated test, and it fails when it runs duplicatedly.") def test_broadcast_in_udf(self): super(PythonUDFArrowTests, self).test_broadcast_in_udf() @@ -118,6 +113,20 @@ def test_use_arrow(self): self.assertEquals(row_false[0], "[1, 2, 3]") +class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase): + @classmethod + def setUpClass(cls): + super(PythonUDFArrowTests, cls).setUpClass() + cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "true") + + @classmethod + def tearDownClass(cls): + try: + cls.spark.conf.unset("spark.sql.execution.pythonUDF.arrow.enabled") + finally: + super(PythonUDFArrowTests, cls).tearDownClass() + + if __name__ == "__main__": from pyspark.sql.tests.test_arrow_python_udf import * # noqa: F401 diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index 26fe735c9c375..d8a464b006f66 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -838,47 +838,6 @@ def setUpClass(cls): cls.spark.conf.set("spark.sql.execution.pythonUDF.arrow.enabled", "false") -def test_use_arrow(self): - # useArrow=True - row_true = ( - self.spark.range(1) - .selectExpr( - "array(1, 2, 3) as array", - ) - .select( - udf(lambda x: str(x), useArrow=True)("array"), - ) - .first() - ) - # The input is a NumPy array when the Arrow optimization is on. - self.assertEquals(row_true[0], "[1 2 3]") - - # useArrow=None - row_none = ( - self.spark.range(1) - .selectExpr( - "array(1, 2, 3) as array", - ) - .select( - udf(lambda x: str(x), useArrow=None)("array"), - ) - .first() - ) - - # useArrow=False - row_false = ( - self.spark.range(1) - .selectExpr( - "array(1, 2, 3) as array", - ) - .select( - udf(lambda x: str(x), useArrow=False)("array"), - ) - .first() - ) - self.assertEquals(row_false[0], row_none[0]) # "[1, 2, 3]" - - class UDFInitializationTests(unittest.TestCase): def tearDown(self): if SparkSession._instantiatedSession is not None: diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 52d02dc00c258..c486d869cba96 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -75,6 +75,7 @@ def _create_udf( name: Optional[str] = None, deterministic: bool = True, ) -> "UserDefinedFunctionLike": + """Create a regular(non-Arrow-optimized) Python UDF.""" # Set the name of the UserDefinedFunction object to be the name of function f udf_obj = UserDefinedFunction( f, returnType=returnType, name=name, evalType=evalType, deterministic=deterministic @@ -88,6 +89,7 @@ def _create_py_udf( evalType: int, useArrow: Optional[bool] = None, ) -> "UserDefinedFunctionLike": + """Create a regular/Arrow-optimized Python UDF.""" # The following table shows the results when the type coercion in Arrow is needed, that is, # when the user-specified return type(SQL Type) of the UDF and the actual instance(Python # Value(Type)) that the UDF returns are different. @@ -138,49 +140,62 @@ def _create_py_udf( and not isinstance(return_type, MapType) and not isinstance(return_type, ArrayType) ) - if is_arrow_enabled and is_output_atomic_type and is_func_with_args: - require_minimum_pandas_version() - require_minimum_pyarrow_version() - - import pandas as pd - from pyspark.sql.pandas.functions import _create_pandas_udf # type: ignore[attr-defined] - - # "result_func" ensures the result of a Python UDF to be consistent with/without Arrow - # optimization. - # Otherwise, an Arrow-optimized Python UDF raises "pyarrow.lib.ArrowTypeError: Expected a - # string or bytes dtype, got ..." whereas a non-Arrow-optimized Python UDF returns - # successfully. - result_func = lambda pdf: pdf # noqa: E731 - if type(return_type) == StringType: - result_func = lambda r: str(r) if r is not None else r # noqa: E731 - elif type(return_type) == BinaryType: - result_func = lambda r: bytes(r) if r is not None else r # noqa: E731 - - def vectorized_udf(*args: pd.Series) -> pd.Series: - if any(map(lambda arg: isinstance(arg, pd.DataFrame), args)): - raise NotImplementedError( - "Struct input type are not supported with Arrow optimization " - "enabled in Python UDFs. Disable " - "'spark.sql.execution.pythonUDF.arrow.enabled' to workaround." - ) - return pd.Series(result_func(f(*a)) for a in zip(*args)) - - # Regular UDFs can take callable instances too. - vectorized_udf.__name__ = f.__name__ if hasattr(f, "__name__") else f.__class__.__name__ - vectorized_udf.__module__ = ( - f.__module__ if hasattr(f, "__module__") else f.__class__.__module__ - ) - vectorized_udf.__doc__ = f.__doc__ - pudf = _create_pandas_udf(vectorized_udf, returnType, None) - # Keep the attributes as if this is a regular Python UDF. - pudf.func = f - pudf.returnType = return_type - pudf.evalType = regular_udf.evalType - return pudf + if is_arrow_enabled: + if is_output_atomic_type and is_func_with_args: + return _create_arrow_py_udf(regular_udf) + else: + warnings.warn( + "Arrow optimization for Python UDFs cannot be enabled.", + UserWarning, + ) + return regular_udf else: return regular_udf +def _create_arrow_py_udf(regular_udf): # type: ignore + """Create an Arrow-optimized Python UDF out of a regular Python UDF.""" + require_minimum_pandas_version() + require_minimum_pyarrow_version() + + import pandas as pd + from pyspark.sql.pandas.functions import _create_pandas_udf + + f = regular_udf.func + return_type = regular_udf.returnType + + # "result_func" ensures the result of a Python UDF to be consistent with/without Arrow + # optimization. + # Otherwise, an Arrow-optimized Python UDF raises "pyarrow.lib.ArrowTypeError: Expected a + # string or bytes dtype, got ..." whereas a non-Arrow-optimized Python UDF returns + # successfully. + result_func = lambda pdf: pdf # noqa: E731 + if type(return_type) == StringType: + result_func = lambda r: str(r) if r is not None else r # noqa: E731 + elif type(return_type) == BinaryType: + result_func = lambda r: bytes(r) if r is not None else r # noqa: E731 + + def vectorized_udf(*args: pd.Series) -> pd.Series: + if any(map(lambda arg: isinstance(arg, pd.DataFrame), args)): + raise NotImplementedError( + "Struct input type are not supported with Arrow optimization " + "enabled in Python UDFs. Disable " + "'spark.sql.execution.pythonUDF.arrow.enabled' to workaround." + ) + return pd.Series(result_func(f(*a)) for a in zip(*args)) + + # Regular UDFs can take callable instances too. + vectorized_udf.__name__ = f.__name__ if hasattr(f, "__name__") else f.__class__.__name__ + vectorized_udf.__module__ = f.__module__ if hasattr(f, "__module__") else f.__class__.__module__ + vectorized_udf.__doc__ = f.__doc__ + pudf = _create_pandas_udf(vectorized_udf, return_type, None) + # Keep the attributes as if this is a regular Python UDF. + pudf.func = f + pudf.returnType = return_type + pudf.evalType = regular_udf.evalType + return pudf + + class UserDefinedFunction: """ User defined function in Python