diff --git a/CHANGELOG.md b/CHANGELOG.md index ecdbb23e379..b6d61ab5d1d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ ### New Features - Added support for managing case sensitivity in `DataFrame.to_local_iterator()`. +- Added support for specifying vectorized UDTF's input column names by using the optional parameter `input_names` in `UDTFRegistration.register/register_file` and `functions.pandas_udtf`. By default, `RelationalGroupedDataFrame.applyInPandas` will infer the column names from current dataframe schema. ### Bug Fixes diff --git a/src/snowflake/snowpark/functions.py b/src/snowflake/snowpark/functions.py index 6c5b81053dd..cb1a34c16c0 100644 --- a/src/snowflake/snowpark/functions.py +++ b/src/snowflake/snowpark/functions.py @@ -7310,6 +7310,7 @@ def pandas_udtf( *, output_schema: Union[StructType, List[str], "PandasDataFrameType"], input_types: Optional[List[DataType]] = None, + input_names: Optional[List[str]] = None, name: Optional[Union[str, Iterable[str]]] = None, is_permanent: bool = False, stage_location: Optional[str] = None, @@ -7366,14 +7367,14 @@ def pandas_udtf( ... def __init__(self): ... self.multiplier = 10 ... def end_partition(self, df): - ... df.columns = ['id', 'col1', 'col2'] ... df.col1 = df.col1*self.multiplier ... df.col2 = df.col2*self.multiplier ... yield df >>> multiply_udtf = pandas_udtf( ... multiply, ... output_schema=PandasDataFrameType([StringType(), IntegerType(), FloatType()], ["id_", "col1_", "col2_"]), - ... input_types=[PandasDataFrameType([StringType(), IntegerType(), FloatType()])] + ... input_types=[PandasDataFrameType([StringType(), IntegerType(), FloatType()])], + ... input_names=['"id"', '"col1"', '"col2"'] ... ) >>> df = session.create_dataframe([['x', 3, 35.9],['x', 9, 20.5]], schema=["id", "col1", "col2"]) >>> df.select(multiply_udtf("id", "col1", "col2").over(partition_by=["id"])).sort("col1_").show() @@ -7387,12 +7388,15 @@ def pandas_udtf( Example:: - >>> @pandas_udtf(output_schema=PandasDataFrameType([StringType(), IntegerType(), FloatType()], ["id_", "col1_", "col2_"]), input_types=[PandasDataFrameType([StringType(), IntegerType(), FloatType()])]) + >>> @pandas_udtf( + ... output_schema=PandasDataFrameType([StringType(), IntegerType(), FloatType()], ["id_", "col1_", "col2_"]), + ... input_types=[PandasDataFrameType([StringType(), IntegerType(), FloatType()])], + ... input_names=['"id"', '"col1"', '"col2"'] + ... ) ... class _multiply: ... def __init__(self): ... self.multiplier = 10 ... def end_partition(self, df): - ... df.columns = ['id', 'col1', 'col2'] ... df.col1 = df.col1*self.multiplier ... df.col2 = df.col2*self.multiplier ... yield df @@ -7411,6 +7415,7 @@ def pandas_udtf( session.udtf.register, output_schema=output_schema, input_types=input_types, + input_names=input_names, name=name, is_permanent=is_permanent, stage_location=stage_location, @@ -7431,6 +7436,7 @@ def pandas_udtf( handler, output_schema=output_schema, input_types=input_types, + input_names=input_names, name=name, is_permanent=is_permanent, stage_location=stage_location, diff --git a/src/snowflake/snowpark/relational_grouped_dataframe.py b/src/snowflake/snowpark/relational_grouped_dataframe.py index efd9b355e7a..a9b2d0ce533 100644 --- a/src/snowflake/snowpark/relational_grouped_dataframe.py +++ b/src/snowflake/snowpark/relational_grouped_dataframe.py @@ -256,7 +256,10 @@ def is_valid_tuple_for_agg(e: Union[list, tuple]) -> bool: return self._to_df(agg_exprs) def apply_in_pandas( - self, func: Callable, output_schema: StructType, **kwargs + self, + func: Callable, + output_schema: StructType, + **kwargs, ) -> DataFrame: """Maps each grouped dataframe in to a pandas.DataFrame, applies the given function on data of each grouped dataframe, and returns a pandas.DataFrame. Internally, a vectorized @@ -264,8 +267,8 @@ def apply_in_pandas( ``kwargs`` are accepted to specify arguments to register the UDTF. Group by clause used must be column reference, not a general expression. - Depends on ``pandas`` being installed in the environment and declared as a dependency using - :meth:`~snowflake.snowpark.Session.add_packages` or via ``kwargs["packages"]``. + Requires ``pandas`` to be installed in the execution environment and declared as a dependency by either + specifying the keyword argument `packages=["pandas]` in this call or calling :meth:`~snowflake.snowpark.Session.add_packages` beforehand. Args: func: A Python native function that accepts a single input argument - a ``pandas.DataFrame`` @@ -273,6 +276,8 @@ def apply_in_pandas( a vectorized UDTF. output_schema: A :class:`~snowflake.snowpark.types.StructType` instance that represents the table function's output columns. + input_names: A list of strings that represents the table function's input column names. Optional, + if unspecified, default column names will be ARG1, ARG2, etc. kwargs: Additional arguments to register the vectorized UDTF. See :meth:`~snowflake.snowpark.udtf.UDTFRegistration.register` for all options. @@ -282,9 +287,7 @@ def apply_in_pandas( >>> import pandas as pd >>> from snowflake.snowpark.types import StructType, StructField, StringType, FloatType >>> def convert(pandas_df): - ... pandas_df.columns = ['location', 'temp_c'] - ... return pandas_df.assign(temp_f = lambda x: x.temp_c * 9 / 5 + 32) - ... + ... return pandas_df.assign(TEMP_F = lambda x: x.TEMP_C * 9 / 5 + 32) >>> df = session.createDataFrame([('SF', 21.0), ('SF', 17.5), ('SF', 24.0), ('NY', 30.9), ('NY', 33.6)], ... schema=['location', 'temp_c']) >>> df.group_by("location").apply_in_pandas(convert, @@ -307,13 +310,13 @@ def apply_in_pandas( >>> from snowflake.snowpark.types import IntegerType, DoubleType >>> _ = session.sql("create or replace temp stage mystage").collect() >>> def group_sum(pdf): - ... pdf.columns = ['grade', 'division', 'value'] - ... return pd.DataFrame([(pdf.grade.iloc[0], pdf.division.iloc[0], pdf.value.sum(), )]) + ... return pd.DataFrame([(pdf.GRADE.iloc[0], pdf.DIVISION.iloc[0], pdf.VALUE.sum(), )]) ... >>> df = session.createDataFrame([('A', 2, 11.0), ('A', 2, 13.9), ('B', 5, 5.0), ('B', 2, 12.1)], ... schema=["grade", "division", "value"]) >>> df.group_by([df.grade, df.division] ).applyInPandas( - ... group_sum, output_schema=StructType([StructField("grade", StringType()), + ... group_sum, + ... output_schema=StructType([StructField("grade", StringType()), ... StructField("division", IntegerType()), ... StructField("sum", DoubleType())]), ... is_permanent=True, stage_location="@mystage", name="group_sum_in_pandas", replace=True @@ -345,6 +348,10 @@ def end_partition(self, pdf: pandas.DataFrame) -> pandas.DataFrame: "input_types", [field.datatype for field in self._df.schema.fields] ) + kwargs["input_names"] = kwargs.get( + "input_names", [field.name for field in self._df.schema.fields] + ) + _apply_in_pandas_udtf = self._df._session.udtf.register( _ApplyInPandas, output_schema=output_schema, diff --git a/src/snowflake/snowpark/udtf.py b/src/snowflake/snowpark/udtf.py index 80c3360a409..784bc72873a 100644 --- a/src/snowflake/snowpark/udtf.py +++ b/src/snowflake/snowpark/udtf.py @@ -321,14 +321,14 @@ class UDTFRegistration: ... def __init__(self): ... self.multiplier = 10 ... def end_partition(self, df): - ... df.columns = ['id', 'col1', 'col2'] ... df.col1 = df.col1*self.multiplier ... df.col2 = df.col2*self.multiplier ... yield df >>> multiply_udtf = session.udtf.register( ... multiply, ... output_schema=PandasDataFrameType([StringType(), IntegerType(), FloatType()], ["id_", "col1_", "col2_"]), - ... input_types=[PandasDataFrameType([StringType(), IntegerType(), FloatType()])] + ... input_types=[PandasDataFrameType([StringType(), IntegerType(), FloatType()])], + ... input_names = ['"id"', '"col1"', '"col2"'], ... ) >>> df = session.create_dataframe([['x', 3, 35.9],['x', 9, 20.5]], schema=["id", "col1", "col2"]) >>> df.select(multiply_udtf("id", "col1", "col2").over(partition_by=["id"])).sort("col1_").show() @@ -348,13 +348,13 @@ class UDTFRegistration: ... def __init__(self): ... self.multiplier = 10 ... def end_partition(self, df: PandasDataFrame[str, int, float]) -> PandasDataFrame[str, int, float]: - ... df.columns = ['id', 'col1', 'col2'] ... df.col1 = df.col1*self.multiplier ... df.col2 = df.col2*self.multiplier ... yield df >>> multiply_udtf = session.udtf.register( ... multiply, ... output_schema=["id_", "col1_", "col2_"], + ... input_names = ['"id"', '"col1"', '"col2"'], ... ) >>> df = session.create_dataframe([['x', 3, 35.9],['x', 9, 20.5]], schema=["id", "col1", "col2"]) >>> df.select(multiply_udtf("id", "col1", "col2").over(partition_by=["id"])).sort("col1_").show() @@ -375,14 +375,42 @@ class UDTFRegistration: ... def __init__(self): ... self.multiplier = 10 ... def end_partition(self, df: pd.DataFrame) -> pd.DataFrame: - ... df.columns = ['id', 'col1', 'col2'] ... df.col1 = df.col1*self.multiplier ... df.col2 = df.col2*self.multiplier ... yield df >>> multiply_udtf = session.udtf.register( ... multiply, ... output_schema=StructType([StructField("id_", StringType()), StructField("col1_", IntegerType()), StructField("col2_", FloatType())]), - ... input_types=[StringType(), IntegerType(), FloatType()] + ... input_types=[StringType(), IntegerType(), FloatType()], + ... input_names = ['"id"', '"col1"', '"col2"'], + ... ) + >>> df = session.create_dataframe([['x', 3, 35.9],['x', 9, 20.5]], schema=["id", "col1", "col2"]) + >>> df.select(multiply_udtf("id", "col1", "col2").over(partition_by=["id"])).sort("col1_").show() + ----------------------------- + |"ID_" |"COL1_" |"COL2_" | + ----------------------------- + |x |30 |359.0 | + |x |90 |205.0 | + ----------------------------- + + + Example 14 + Same as Example 12, but does not specify `input_names` and instead set the column names in `end_partition`. + + >>> from snowflake.snowpark.types import PandasDataFrameType, IntegerType, StringType, FloatType + >>> class multiply: + ... def __init__(self): + ... self.multiplier = 10 + ... def end_partition(self, df): + ... df.columns = ["id", "col1", "col2"] + ... df.col1 = df.col1*self.multiplier + ... df.col2 = df.col2*self.multiplier + ... yield df + >>> multiply_udtf = session.udtf.register( + ... multiply, + ... output_schema=PandasDataFrameType([StringType(), IntegerType(), FloatType()], ["id_", "col1_", "col2_"]), + ... input_types=[PandasDataFrameType([StringType(), IntegerType(), FloatType()])], + ... input_names = ['"id"', '"col1"', '"col2"'], ... ) >>> df = session.create_dataframe([['x', 3, 35.9],['x', 9, 20.5]], schema=["id", "col1", "col2"]) >>> df.select(multiply_udtf("id", "col1", "col2").over(partition_by=["id"])).sort("col1_").show() @@ -412,6 +440,7 @@ def register( handler: Type, output_schema: Union[StructType, Iterable[str], "PandasDataFrameType"], input_types: Optional[List[DataType]] = None, + input_names: Optional[List[str]] = None, name: Optional[Union[str, Iterable[str]]] = None, is_permanent: bool = False, stage_location: Optional[str] = None, @@ -442,6 +471,8 @@ def register( input_types: A list of :class:`~snowflake.snowpark.types.DataType` representing the input data types of the UDTF. Optional if type hints are provided. + input_names: A list of `str` representing the input column names of the UDTF, this only applies to vectorized UDTF and is essentially a noop for regular UDTFs. If unspecified, default column names will be + ARG1, ARG2, etc. name: A string or list of strings that specify the name or fully-qualified object identifier (database name, schema name, and function name) for the UDTF in Snowflake. @@ -515,6 +546,7 @@ def register( handler, output_schema, input_types, + input_names, name, stage_location, imports, @@ -538,6 +570,7 @@ def register_from_file( handler_name: str, output_schema: Union[StructType, Iterable[str], "PandasDataFrameType"], input_types: Optional[List[DataType]] = None, + input_names: Optional[List[str]] = None, name: Optional[Union[str, Iterable[str]]] = None, is_permanent: bool = False, stage_location: Optional[str] = None, @@ -574,6 +607,8 @@ def register_from_file( input_types: A list of :class:`~snowflake.snowpark.types.DataType` representing the input data types of the UDTF. Optional if type hints are provided. + input_names: A list of `str` representing the input column names of the UDTF, this only applies to vectorized UDTF and is essentially a noop for regular UDTFs. If unspecified, default column names will be + ARG1, ARG2, etc. name: A string or list of strings that specify the name or fully-qualified object identifier (database name, schema name, and function name) for the UDTF in Snowflake, which allows you to call this UDTF in a SQL @@ -652,6 +687,7 @@ def register_from_file( (file_path, handler_name), output_schema, input_types, + input_names, name, stage_location, imports, @@ -675,6 +711,7 @@ def _do_register_udtf( handler: Union[Callable, Tuple[str, str]], output_schema: Union[StructType, Iterable[str], "PandasDataFrameType"], input_types: Optional[List[DataType]], + input_names: Optional[List[str]], name: Optional[str], stage_location: Optional[str] = None, imports: Optional[List[Union[str, Tuple[str, str]]]] = None, @@ -730,7 +767,7 @@ def _do_register_udtf( output_schema=output_schema, ) - arg_names = [f"arg{i + 1}" for i in range(len(input_types))] + arg_names = input_names or [f"arg{i + 1}" for i in range(len(input_types))] input_args = [ UDFColumn(dt, arg_name) for dt, arg_name in zip(input_types, arg_names) ] diff --git a/tests/integ/test_udtf.py b/tests/integ/test_udtf.py index f235a0486d3..f12fb2e3587 100644 --- a/tests/integ/test_udtf.py +++ b/tests/integ/test_udtf.py @@ -284,10 +284,9 @@ def process( @pytest.mark.skipif(not is_pandas_available, reason="pandas is required") def test_apply_in_pandas(session): - # test with element wise opeartion + # test with element wise operation def convert(pdf): - pdf.columns = ["location", "temp_c"] - return pdf.assign(temp_f=lambda x: x.temp_c * 9 / 5 + 32) + return pdf.assign(TEMP_F=lambda x: x.TEMP_C * 9 / 5 + 32) df = session.createDataFrame( [("SF", 21.0), ("SF", 17.5), ("SF", 24.0), ("NY", 30.9), ("NY", 33.6)], @@ -321,9 +320,8 @@ def convert(pdf): ) def normalize(pdf): - pdf.columns = ["id", "v"] - v = pdf.v - return pdf.assign(v=(v - v.mean()) / v.std()) + V = pdf.V + return pdf.assign(V=(V - V.mean()) / V.std()) df = df.group_by("id").applyInPandas( normalize, @@ -350,13 +348,12 @@ def normalize(pdf): ) def group_sum(pdf): - pdf.columns = ["grade", "division", "value"] return pd.DataFrame( [ ( - pdf.grade.iloc[0], - pdf.division.iloc[0], - pdf.value.sum(), + pdf.GRADE.iloc[0], + pdf.DIVISION.iloc[0], + pdf.VALUE.sum(), ) ] )