Skip to content

Commit

Permalink
SNOW-948486 Support specifying input column names for vectorized UDTF (
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-stan authored Nov 3, 2023
1 parent 65024cd commit df85e76
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 29 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### New Features
- Added support for managing case sensitivity in `DataFrame.to_local_iterator()`.
- Added support for specifying vectorized UDTF's input column names by using the optional parameter `input_names` in `UDTFRegistration.register/register_file` and `functions.pandas_udtf`. By default, `RelationalGroupedDataFrame.applyInPandas` will infer the column names from current dataframe schema.

### Bug Fixes

Expand Down
14 changes: 10 additions & 4 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7310,6 +7310,7 @@ def pandas_udtf(
*,
output_schema: Union[StructType, List[str], "PandasDataFrameType"],
input_types: Optional[List[DataType]] = None,
input_names: Optional[List[str]] = None,
name: Optional[Union[str, Iterable[str]]] = None,
is_permanent: bool = False,
stage_location: Optional[str] = None,
Expand Down Expand Up @@ -7366,14 +7367,14 @@ def pandas_udtf(
... def __init__(self):
... self.multiplier = 10
... def end_partition(self, df):
... df.columns = ['id', 'col1', 'col2']
... df.col1 = df.col1*self.multiplier
... df.col2 = df.col2*self.multiplier
... yield df
>>> multiply_udtf = pandas_udtf(
... multiply,
... output_schema=PandasDataFrameType([StringType(), IntegerType(), FloatType()], ["id_", "col1_", "col2_"]),
... input_types=[PandasDataFrameType([StringType(), IntegerType(), FloatType()])]
... input_types=[PandasDataFrameType([StringType(), IntegerType(), FloatType()])],
... input_names=['"id"', '"col1"', '"col2"']
... )
>>> df = session.create_dataframe([['x', 3, 35.9],['x', 9, 20.5]], schema=["id", "col1", "col2"])
>>> df.select(multiply_udtf("id", "col1", "col2").over(partition_by=["id"])).sort("col1_").show()
Expand All @@ -7387,12 +7388,15 @@ def pandas_udtf(
Example::
>>> @pandas_udtf(output_schema=PandasDataFrameType([StringType(), IntegerType(), FloatType()], ["id_", "col1_", "col2_"]), input_types=[PandasDataFrameType([StringType(), IntegerType(), FloatType()])])
>>> @pandas_udtf(
... output_schema=PandasDataFrameType([StringType(), IntegerType(), FloatType()], ["id_", "col1_", "col2_"]),
... input_types=[PandasDataFrameType([StringType(), IntegerType(), FloatType()])],
... input_names=['"id"', '"col1"', '"col2"']
... )
... class _multiply:
... def __init__(self):
... self.multiplier = 10
... def end_partition(self, df):
... df.columns = ['id', 'col1', 'col2']
... df.col1 = df.col1*self.multiplier
... df.col2 = df.col2*self.multiplier
... yield df
Expand All @@ -7411,6 +7415,7 @@ def pandas_udtf(
session.udtf.register,
output_schema=output_schema,
input_types=input_types,
input_names=input_names,
name=name,
is_permanent=is_permanent,
stage_location=stage_location,
Expand All @@ -7431,6 +7436,7 @@ def pandas_udtf(
handler,
output_schema=output_schema,
input_types=input_types,
input_names=input_names,
name=name,
is_permanent=is_permanent,
stage_location=stage_location,
Expand Down
25 changes: 16 additions & 9 deletions src/snowflake/snowpark/relational_grouped_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,23 +256,28 @@ def is_valid_tuple_for_agg(e: Union[list, tuple]) -> bool:
return self._to_df(agg_exprs)

def apply_in_pandas(
self, func: Callable, output_schema: StructType, **kwargs
self,
func: Callable,
output_schema: StructType,
**kwargs,
) -> DataFrame:
"""Maps each grouped dataframe in to a pandas.DataFrame, applies the given function on
data of each grouped dataframe, and returns a pandas.DataFrame. Internally, a vectorized
UDTF with input ``func`` argument as the ``end_partition`` is registered and called. Additional
``kwargs`` are accepted to specify arguments to register the UDTF. Group by clause used must be
column reference, not a general expression.
Depends on ``pandas`` being installed in the environment and declared as a dependency using
:meth:`~snowflake.snowpark.Session.add_packages` or via ``kwargs["packages"]``.
Requires ``pandas`` to be installed in the execution environment and declared as a dependency by either
specifying the keyword argument `packages=["pandas]` in this call or calling :meth:`~snowflake.snowpark.Session.add_packages` beforehand.
Args:
func: A Python native function that accepts a single input argument - a ``pandas.DataFrame``
object and returns a ``pandas.Dataframe``. It is used as input to ``end_partition`` in
a vectorized UDTF.
output_schema: A :class:`~snowflake.snowpark.types.StructType` instance that represents the
table function's output columns.
input_names: A list of strings that represents the table function's input column names. Optional,
if unspecified, default column names will be ARG1, ARG2, etc.
kwargs: Additional arguments to register the vectorized UDTF. See
:meth:`~snowflake.snowpark.udtf.UDTFRegistration.register` for all options.
Expand All @@ -282,9 +287,7 @@ def apply_in_pandas(
>>> import pandas as pd
>>> from snowflake.snowpark.types import StructType, StructField, StringType, FloatType
>>> def convert(pandas_df):
... pandas_df.columns = ['location', 'temp_c']
... return pandas_df.assign(temp_f = lambda x: x.temp_c * 9 / 5 + 32)
...
... return pandas_df.assign(TEMP_F = lambda x: x.TEMP_C * 9 / 5 + 32)
>>> df = session.createDataFrame([('SF', 21.0), ('SF', 17.5), ('SF', 24.0), ('NY', 30.9), ('NY', 33.6)],
... schema=['location', 'temp_c'])
>>> df.group_by("location").apply_in_pandas(convert,
Expand All @@ -307,13 +310,13 @@ def apply_in_pandas(
>>> from snowflake.snowpark.types import IntegerType, DoubleType
>>> _ = session.sql("create or replace temp stage mystage").collect()
>>> def group_sum(pdf):
... pdf.columns = ['grade', 'division', 'value']
... return pd.DataFrame([(pdf.grade.iloc[0], pdf.division.iloc[0], pdf.value.sum(), )])
... return pd.DataFrame([(pdf.GRADE.iloc[0], pdf.DIVISION.iloc[0], pdf.VALUE.sum(), )])
...
>>> df = session.createDataFrame([('A', 2, 11.0), ('A', 2, 13.9), ('B', 5, 5.0), ('B', 2, 12.1)],
... schema=["grade", "division", "value"])
>>> df.group_by([df.grade, df.division] ).applyInPandas(
... group_sum, output_schema=StructType([StructField("grade", StringType()),
... group_sum,
... output_schema=StructType([StructField("grade", StringType()),
... StructField("division", IntegerType()),
... StructField("sum", DoubleType())]),
... is_permanent=True, stage_location="@mystage", name="group_sum_in_pandas", replace=True
Expand Down Expand Up @@ -345,6 +348,10 @@ def end_partition(self, pdf: pandas.DataFrame) -> pandas.DataFrame:
"input_types", [field.datatype for field in self._df.schema.fields]
)

kwargs["input_names"] = kwargs.get(
"input_names", [field.name for field in self._df.schema.fields]
)

_apply_in_pandas_udtf = self._df._session.udtf.register(
_ApplyInPandas,
output_schema=output_schema,
Expand Down
49 changes: 43 additions & 6 deletions src/snowflake/snowpark/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,14 +321,14 @@ class UDTFRegistration:
... def __init__(self):
... self.multiplier = 10
... def end_partition(self, df):
... df.columns = ['id', 'col1', 'col2']
... df.col1 = df.col1*self.multiplier
... df.col2 = df.col2*self.multiplier
... yield df
>>> multiply_udtf = session.udtf.register(
... multiply,
... output_schema=PandasDataFrameType([StringType(), IntegerType(), FloatType()], ["id_", "col1_", "col2_"]),
... input_types=[PandasDataFrameType([StringType(), IntegerType(), FloatType()])]
... input_types=[PandasDataFrameType([StringType(), IntegerType(), FloatType()])],
... input_names = ['"id"', '"col1"', '"col2"'],
... )
>>> df = session.create_dataframe([['x', 3, 35.9],['x', 9, 20.5]], schema=["id", "col1", "col2"])
>>> df.select(multiply_udtf("id", "col1", "col2").over(partition_by=["id"])).sort("col1_").show()
Expand All @@ -348,13 +348,13 @@ class UDTFRegistration:
... def __init__(self):
... self.multiplier = 10
... def end_partition(self, df: PandasDataFrame[str, int, float]) -> PandasDataFrame[str, int, float]:
... df.columns = ['id', 'col1', 'col2']
... df.col1 = df.col1*self.multiplier
... df.col2 = df.col2*self.multiplier
... yield df
>>> multiply_udtf = session.udtf.register(
... multiply,
... output_schema=["id_", "col1_", "col2_"],
... input_names = ['"id"', '"col1"', '"col2"'],
... )
>>> df = session.create_dataframe([['x', 3, 35.9],['x', 9, 20.5]], schema=["id", "col1", "col2"])
>>> df.select(multiply_udtf("id", "col1", "col2").over(partition_by=["id"])).sort("col1_").show()
Expand All @@ -375,14 +375,42 @@ class UDTFRegistration:
... def __init__(self):
... self.multiplier = 10
... def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
... df.columns = ['id', 'col1', 'col2']
... df.col1 = df.col1*self.multiplier
... df.col2 = df.col2*self.multiplier
... yield df
>>> multiply_udtf = session.udtf.register(
... multiply,
... output_schema=StructType([StructField("id_", StringType()), StructField("col1_", IntegerType()), StructField("col2_", FloatType())]),
... input_types=[StringType(), IntegerType(), FloatType()]
... input_types=[StringType(), IntegerType(), FloatType()],
... input_names = ['"id"', '"col1"', '"col2"'],
... )
>>> df = session.create_dataframe([['x', 3, 35.9],['x', 9, 20.5]], schema=["id", "col1", "col2"])
>>> df.select(multiply_udtf("id", "col1", "col2").over(partition_by=["id"])).sort("col1_").show()
-----------------------------
|"ID_" |"COL1_" |"COL2_" |
-----------------------------
|x |30 |359.0 |
|x |90 |205.0 |
-----------------------------
<BLANKLINE>
Example 14
Same as Example 12, but does not specify `input_names` and instead set the column names in `end_partition`.
>>> from snowflake.snowpark.types import PandasDataFrameType, IntegerType, StringType, FloatType
>>> class multiply:
... def __init__(self):
... self.multiplier = 10
... def end_partition(self, df):
... df.columns = ["id", "col1", "col2"]
... df.col1 = df.col1*self.multiplier
... df.col2 = df.col2*self.multiplier
... yield df
>>> multiply_udtf = session.udtf.register(
... multiply,
... output_schema=PandasDataFrameType([StringType(), IntegerType(), FloatType()], ["id_", "col1_", "col2_"]),
... input_types=[PandasDataFrameType([StringType(), IntegerType(), FloatType()])],
... input_names = ['"id"', '"col1"', '"col2"'],
... )
>>> df = session.create_dataframe([['x', 3, 35.9],['x', 9, 20.5]], schema=["id", "col1", "col2"])
>>> df.select(multiply_udtf("id", "col1", "col2").over(partition_by=["id"])).sort("col1_").show()
Expand Down Expand Up @@ -412,6 +440,7 @@ def register(
handler: Type,
output_schema: Union[StructType, Iterable[str], "PandasDataFrameType"],
input_types: Optional[List[DataType]] = None,
input_names: Optional[List[str]] = None,
name: Optional[Union[str, Iterable[str]]] = None,
is_permanent: bool = False,
stage_location: Optional[str] = None,
Expand Down Expand Up @@ -442,6 +471,8 @@ def register(
input_types: A list of :class:`~snowflake.snowpark.types.DataType`
representing the input data types of the UDTF. Optional if
type hints are provided.
input_names: A list of `str` representing the input column names of the UDTF, this only applies to vectorized UDTF and is essentially a noop for regular UDTFs. If unspecified, default column names will be
ARG1, ARG2, etc.
name: A string or list of strings that specify the name or fully-qualified
object identifier (database name, schema name, and function name) for
the UDTF in Snowflake.
Expand Down Expand Up @@ -515,6 +546,7 @@ def register(
handler,
output_schema,
input_types,
input_names,
name,
stage_location,
imports,
Expand All @@ -538,6 +570,7 @@ def register_from_file(
handler_name: str,
output_schema: Union[StructType, Iterable[str], "PandasDataFrameType"],
input_types: Optional[List[DataType]] = None,
input_names: Optional[List[str]] = None,
name: Optional[Union[str, Iterable[str]]] = None,
is_permanent: bool = False,
stage_location: Optional[str] = None,
Expand Down Expand Up @@ -574,6 +607,8 @@ def register_from_file(
input_types: A list of :class:`~snowflake.snowpark.types.DataType`
representing the input data types of the UDTF. Optional if
type hints are provided.
input_names: A list of `str` representing the input column names of the UDTF, this only applies to vectorized UDTF and is essentially a noop for regular UDTFs. If unspecified, default column names will be
ARG1, ARG2, etc.
name: A string or list of strings that specify the name or fully-qualified
object identifier (database name, schema name, and function name) for
the UDTF in Snowflake, which allows you to call this UDTF in a SQL
Expand Down Expand Up @@ -652,6 +687,7 @@ def register_from_file(
(file_path, handler_name),
output_schema,
input_types,
input_names,
name,
stage_location,
imports,
Expand All @@ -675,6 +711,7 @@ def _do_register_udtf(
handler: Union[Callable, Tuple[str, str]],
output_schema: Union[StructType, Iterable[str], "PandasDataFrameType"],
input_types: Optional[List[DataType]],
input_names: Optional[List[str]],
name: Optional[str],
stage_location: Optional[str] = None,
imports: Optional[List[Union[str, Tuple[str, str]]]] = None,
Expand Down Expand Up @@ -730,7 +767,7 @@ def _do_register_udtf(
output_schema=output_schema,
)

arg_names = [f"arg{i + 1}" for i in range(len(input_types))]
arg_names = input_names or [f"arg{i + 1}" for i in range(len(input_types))]
input_args = [
UDFColumn(dt, arg_name) for dt, arg_name in zip(input_types, arg_names)
]
Expand Down
17 changes: 7 additions & 10 deletions tests/integ/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,10 +284,9 @@ def process(

@pytest.mark.skipif(not is_pandas_available, reason="pandas is required")
def test_apply_in_pandas(session):
# test with element wise opeartion
# test with element wise operation
def convert(pdf):
pdf.columns = ["location", "temp_c"]
return pdf.assign(temp_f=lambda x: x.temp_c * 9 / 5 + 32)
return pdf.assign(TEMP_F=lambda x: x.TEMP_C * 9 / 5 + 32)

df = session.createDataFrame(
[("SF", 21.0), ("SF", 17.5), ("SF", 24.0), ("NY", 30.9), ("NY", 33.6)],
Expand Down Expand Up @@ -321,9 +320,8 @@ def convert(pdf):
)

def normalize(pdf):
pdf.columns = ["id", "v"]
v = pdf.v
return pdf.assign(v=(v - v.mean()) / v.std())
V = pdf.V
return pdf.assign(V=(V - V.mean()) / V.std())

df = df.group_by("id").applyInPandas(
normalize,
Expand All @@ -350,13 +348,12 @@ def normalize(pdf):
)

def group_sum(pdf):
pdf.columns = ["grade", "division", "value"]
return pd.DataFrame(
[
(
pdf.grade.iloc[0],
pdf.division.iloc[0],
pdf.value.sum(),
pdf.GRADE.iloc[0],
pdf.DIVISION.iloc[0],
pdf.VALUE.sum(),
)
]
)
Expand Down

0 comments on commit df85e76

Please sign in to comment.