Skip to content

Commit

Permalink
SNOW-1271612 Merge dynamic pivot into snowpark python (#1375)
Browse files Browse the repository at this point in the history
* SNOW-838815 Extend snowpark for dynamic pivot (#117)

See
https://docs.google.com/document/d/12usSBp73G-CPwRfOoeMuQfser4OfCArgG5Lmm-OIQ_0

Please answer these questions before submitting your pull requests.
Thanks!

1. What GitHub issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   SNOW-838815

2. Fill out the following pre-review checklist:

- [x] I am adding a new automated test(s) to verify correctness of my
new code
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency

3. Please describe how your code solves the related issue.

This extends the pivot API for dynamic pivot (specify None for ANY or
subquery via dataframe) as well as default value to fill if there are
empty result values.

* SNOW-916205 Use cache_result for pivot when creates temp table for large inlined data (#345)

SNOW-916205 Use cache_result when snowpark creates temp table for large
inlined data

Please answer these questions before submitting your pull requests.
Thanks!

1. What GitHub issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

SNOW-916205 Use cache_result when snowpark creates temp table for large
inlined data

2. Fill out the following pre-review checklist:

- [x] I am adding a new automated test(s) to verify correctness of my
new code
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency

3. Please describe how your code solves the related issue.

If the size of underlying data exceeds ARRAY_BIND_THRESHOLD (512), then
snowpark will automatically
offload the data into a temp file. Each snowpark action results in the
creation, insertion, query and dropping
of this temp table. This causes a problem for dynamic pivot schema query
which can occur of of band
of a snowpark action, say to fetch the schema only, and will fail with
table not found. To workaround this
until
[SNOW-916744](https://snowflakecomputing.atlassian.net/browse/SNOW-916744)
is fixed we do
a cache_result locally which puts the inlined data into a temp table
that is not cleaned up until the session
ends.

[SNOW-916744]:
https://snowflakecomputing.atlassian.net/browse/SNOW-916744?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ

* SNOW-1271612 Prep dynamic pivot in snowpark for merge back into main

* Update change log and fix lint issue

* Update

* Update doc tests

* Update doc test

* Update type hints

* Update
  • Loading branch information
sfc-gh-evandenberg authored Apr 15, 2024
1 parent cdebbd9 commit f0e04c4
Show file tree
Hide file tree
Showing 11 changed files with 438 additions and 38 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
- save_as_table
- Added support for snow:// URLs to `snowflake.snowpark.Session.file.get` and `snowflake.snowpark.Session.file.get_stream`
- UDAF client support is ready for public preview. Please stay tuned for the Snowflake announcement of UDAF public preview.
- Added support for dynamic pivot. This feature is currently in private preview.

### Bug Fixes

Expand Down
43 changes: 37 additions & 6 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#
import uuid
from collections import Counter, defaultdict
from typing import TYPE_CHECKING, DefaultDict, Dict, Optional, Union
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Union

import snowflake.snowpark
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
Expand Down Expand Up @@ -988,21 +988,52 @@ def do_resolve_with_resolved_children(
else:
child = resolved_children[logical_plan.child]

return self.plan_builder.pivot(
# We retrieve the pivot_values for generating SQL using types:
# List[str] => explicit list of pivot values
# ScalarSubquery => dynamic pivot subquery
# None => dynamic pivot ANY subquery

if isinstance(logical_plan.pivot_values, List):
pivot_values = [
self.analyze(pv, df_aliased_col_name_to_real_col_name)
for pv in logical_plan.pivot_values
]
elif isinstance(logical_plan.pivot_values, ScalarSubquery):
pivot_values = self.analyze(
logical_plan.pivot_values, df_aliased_col_name_to_real_col_name
)
else:
pivot_values = None

pivot_plan = self.plan_builder.pivot(
self.analyze(
logical_plan.pivot_column, df_aliased_col_name_to_real_col_name
),
[
self.analyze(pv, df_aliased_col_name_to_real_col_name)
for pv in logical_plan.pivot_values
],
pivot_values,
self.analyze(
logical_plan.aggregates[0], df_aliased_col_name_to_real_col_name
),
self.analyze(
logical_plan.default_on_null, df_aliased_col_name_to_real_col_name
)
if logical_plan.default_on_null
else None,
child,
logical_plan,
)

# If this is a dynamic pivot, then we can't use child.schema_query which is used in the schema_query
# sql generator by default because it is simplified and won't fetch the output columns from the underlying
# source. So in this case we use the actual pivot query as the schema query.
if logical_plan.pivot_values is None or isinstance(
logical_plan.pivot_values, ScalarSubquery
):
# TODO (SNOW-916744): Using the original query here does not work if the query depends on a temp
# table as it may not exist at later point in time when dataframe.schema is called.
pivot_plan.schema_query = pivot_plan.queries[-1].sql

return pivot_plan

if isinstance(logical_plan, Unpivot):
return self.plan_builder.unpivot(
logical_plan.value_column,
Expand Down
27 changes: 23 additions & 4 deletions src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@
EXCEPT = f" {Except.sql} "
NOT_NULL = " NOT NULL "
WITH = "WITH "
DEFAULT_ON_NULL = " DEFAULT ON NULL "
ANY = " ANY "

TEMPORARY_STRING_SET = frozenset(["temporary", "temp"])

Expand Down Expand Up @@ -1033,8 +1035,22 @@ def create_or_replace_dynamic_table_statement(


def pivot_statement(
pivot_column: str, pivot_values: List[str], aggregate: str, child: str
pivot_column: str,
pivot_values: Optional[Union[str, List[str]]],
aggregate: str,
default_on_null: Optional[str],
child: str,
) -> str:
if isinstance(pivot_values, str):
# The subexpression in this case already includes parenthesis.
values_str = pivot_values
else:
values_str = (
LEFT_PARENTHESIS
+ (ANY if pivot_values is None else COMMA.join(pivot_values))
+ RIGHT_PARENTHESIS
)

return (
SELECT
+ STAR
Expand All @@ -1048,9 +1064,12 @@ def pivot_statement(
+ FOR
+ pivot_column
+ IN
+ LEFT_PARENTHESIS
+ COMMA.join(pivot_values)
+ RIGHT_PARENTHESIS
+ values_str
+ (
(DEFAULT_ON_NULL + LEFT_PARENTHESIS + default_on_null + RIGHT_PARENTHESIS)
if default_on_null
else EMPTY_STRING
)
+ RIGHT_PARENTHESIS
)

Expand Down
7 changes: 5 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,13 +817,16 @@ def limit(
def pivot(
self,
pivot_column: str,
pivot_values: List[str],
pivot_values: Optional[Union[str, List[str]]],
aggregate: str,
default_on_null: Optional[str],
child: SnowflakePlan,
source_plan: Optional[LogicalPlan],
) -> SnowflakePlan:
return self.build(
lambda x: pivot_statement(pivot_column, pivot_values, aggregate, x),
lambda x: pivot_statement(
pivot_column, pivot_values, aggregate, default_on_null, x
),
child,
source_plan,
)
Expand Down
6 changes: 4 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union

from snowflake.snowpark._internal.analyzer.expression import Expression, NamedExpression
from snowflake.snowpark._internal.analyzer.snowflake_plan import LogicalPlan
Expand Down Expand Up @@ -53,15 +53,17 @@ def __init__(
self,
grouping_columns: List[Expression],
pivot_column: Expression,
pivot_values: List[Expression],
pivot_values: Optional[Union[List[Expression], LogicalPlan]],
aggregates: List[Expression],
default_on_null: Optional[Expression],
child: LogicalPlan,
) -> None:
super().__init__(child)
self.grouping_columns = grouping_columns
self.pivot_column = pivot_column
self.pivot_values = pivot_values
self.aggregates = aggregates
self.default_on_null = default_on_null


class Unpivot(UnaryNode):
Expand Down
64 changes: 64 additions & 0 deletions src/snowflake/snowpark/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Literal,
Expand Down Expand Up @@ -870,3 +871,66 @@ def validate_quoted_name(name: str) -> str:

def escape_quotes(unescaped: str) -> str:
return unescaped.replace(DOUBLE_QUOTE, DOUBLE_QUOTE + DOUBLE_QUOTE)


def prepare_pivot_arguments(
df: "snowflake.snowpark.DataFrame",
df_name: str,
pivot_col: "snowflake.snowpark._internal.type_utils.ColumnOrName",
values: Optional[
Union[
Iterable["snowflake.snowpark._internal.type_utils.LiteralType"],
"snowflake.snowpark.DataFrame",
]
],
default_on_null: Optional["snowflake.snowpark._internal.type_utils.LiteralType"],
):
"""
Prepare dataframe pivot arguments to use in the underlying pivot call. This includes issuing any applicable
warnings, ensuring column types and valid arguments.
Returns:
DateFrame, pivot column, pivot_values and default_on_null value.
"""
from snowflake.snowpark.dataframe import DataFrame

if values is None or isinstance(values, DataFrame):
warning(
df_name,
"Parameter values is Optional or DataFrame is in private preview since v1.15.0. Do not use it in production.",
)

if default_on_null is not None:
warning(
df_name,
"Parameter default_on_null is not None is in private preview since v1.15.0. Do not use it in production.",
)

if values is not None and not values:
raise ValueError("values cannot be empty")

pc = df._convert_cols_to_exprs(f"{df_name}()", pivot_col)

from snowflake.snowpark._internal.analyzer.expression import Literal, ScalarSubquery
from snowflake.snowpark.column import Column

if isinstance(values, Iterable):
pivot_values = [
v._expression if isinstance(v, Column) else Literal(v) for v in values
]
else:
if isinstance(values, DataFrame):
pivot_values = ScalarSubquery(values._plan)
else:
pivot_values = None

if len(df.queries.get("post_actions", [])) > 0:
df = df.cache_result()

if default_on_null is not None:
default_on_null = (
default_on_null._expression
if isinstance(default_on_null, Column)
else Literal(default_on_null)
)

return df, pc, pivot_values, default_on_null
46 changes: 36 additions & 10 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
is_sql_select_statement,
parse_positional_args_to_list,
parse_table_name,
prepare_pivot_arguments,
private_preview,
quote_name,
random_name_for_temp_object,
Expand Down Expand Up @@ -1620,7 +1621,10 @@ def drop_duplicates(self, *subset: Union[str, Iterable[str]]) -> "DataFrame":
def pivot(
self,
pivot_col: ColumnOrName,
values: Iterable[LiteralType],
values: Optional[
Union[Iterable[LiteralType], "snowflake.snowpark.DataFrame"]
] = None,
default_on_null: Optional[LiteralType] = None,
) -> "snowflake.snowpark.RelationalGroupedDataFrame":
"""Rotates this DataFrame by turning the unique values from one column in the input
expression into multiple columns and aggregating results where required on any
Expand Down Expand Up @@ -1649,21 +1653,43 @@ def pivot(
-------------------------------
<BLANKLINE>
>>> df = session.table("monthly_sales")
>>> df.pivot("month").sum("amount").sort("empid").show()
-------------------------------
|"EMPID" |"'FEB'" |"'JAN'" |
-------------------------------
|1 |8000 |10400 |
|2 |200 |39500 |
-------------------------------
<BLANKLINE>
>>> subquery_df = session.table("monthly_sales").select(col("month")).filter(col("month") == "JAN")
>>> df = session.table("monthly_sales")
>>> df.pivot("month", values=subquery_df).sum("amount").sort("empid").show()
---------------------
|"EMPID" |"'JAN'" |
---------------------
|1 |10400 |
|2 |39500 |
---------------------
<BLANKLINE>
Args:
pivot_col: The column or name of the column to use.
values: A list of values in the column.
values: A list of values in the column,
or dynamic based on the DataFrame query,
or None (default) will use all values of the pivot column.
default_on_null: Expression to replace empty result values.
"""
if not values:
raise ValueError("values cannot be empty")
pc = self._convert_cols_to_exprs("pivot()", pivot_col)
value_exprs = [
v._expression if isinstance(v, Column) else Literal(v) for v in values
]
target_df, pc, pivot_values, default_on_null = prepare_pivot_arguments(
self, "DataFrame.pivot", pivot_col, values, default_on_null
)

return snowflake.snowpark.RelationalGroupedDataFrame(
self,
target_df,
[],
snowflake.snowpark.relational_grouped_dataframe._PivotType(
pc[0], value_exprs
pc[0], pivot_values, default_on_null
),
)

Expand Down
Loading

0 comments on commit f0e04c4

Please sign in to comment.