From f0e04c436f0d9a90aa823ed0009804cf77427287 Mon Sep 17 00:00:00 2001 From: Eric Vandenberg Date: Mon, 15 Apr 2024 13:50:46 -0700 Subject: [PATCH] SNOW-1271612 Merge dynamic pivot into snowpark python (#1375) * SNOW-838815 Extend snowpark for dynamic pivot (#117) See https://docs.google.com/document/d/12usSBp73G-CPwRfOoeMuQfser4OfCArgG5Lmm-OIQ_0 Please answer these questions before submitting your pull requests. Thanks! 1. What GitHub issue is this PR addressing? Make sure that there is an accompanying issue to your PR. SNOW-838815 2. Fill out the following pre-review checklist: - [x] I am adding a new automated test(s) to verify correctness of my new code - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency 3. Please describe how your code solves the related issue. This extends the pivot API for dynamic pivot (specify None for ANY or subquery via dataframe) as well as default value to fill if there are empty result values. * SNOW-916205 Use cache_result for pivot when creates temp table for large inlined data (#345) SNOW-916205 Use cache_result when snowpark creates temp table for large inlined data Please answer these questions before submitting your pull requests. Thanks! 1. What GitHub issue is this PR addressing? Make sure that there is an accompanying issue to your PR. SNOW-916205 Use cache_result when snowpark creates temp table for large inlined data 2. Fill out the following pre-review checklist: - [x] I am adding a new automated test(s) to verify correctness of my new code - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency 3. Please describe how your code solves the related issue. If the size of underlying data exceeds ARRAY_BIND_THRESHOLD (512), then snowpark will automatically offload the data into a temp file. Each snowpark action results in the creation, insertion, query and dropping of this temp table. This causes a problem for dynamic pivot schema query which can occur of of band of a snowpark action, say to fetch the schema only, and will fail with table not found. To workaround this until [SNOW-916744](https://snowflakecomputing.atlassian.net/browse/SNOW-916744) is fixed we do a cache_result locally which puts the inlined data into a temp table that is not cleaned up until the session ends. [SNOW-916744]: https://snowflakecomputing.atlassian.net/browse/SNOW-916744?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ * SNOW-1271612 Prep dynamic pivot in snowpark for merge back into main * Update change log and fix lint issue * Update * Update doc tests * Update doc test * Update type hints * Update --- CHANGELOG.md | 1 + .../snowpark/_internal/analyzer/analyzer.py | 43 +++- .../_internal/analyzer/analyzer_utils.py | 27 ++- .../_internal/analyzer/snowflake_plan.py | 7 +- .../_internal/analyzer/unary_plan_node.py | 6 +- src/snowflake/snowpark/_internal/utils.py | 64 ++++++ src/snowflake/snowpark/dataframe.py | 46 ++++- .../snowpark/relational_grouped_dataframe.py | 69 +++++-- tests/integ/scala/test_column_suite.py | 17 ++ .../scala/test_dataframe_aggregate_suite.py | 194 ++++++++++++++++++ tests/unit/scala/test_df_suite.py | 2 +- 11 files changed, 438 insertions(+), 38 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 13f4fdad7c2..4a87cd123e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,7 @@ - save_as_table - Added support for snow:// URLs to `snowflake.snowpark.Session.file.get` and `snowflake.snowpark.Session.file.get_stream` - UDAF client support is ready for public preview. Please stay tuned for the Snowflake announcement of UDAF public preview. +- Added support for dynamic pivot. This feature is currently in private preview. ### Bug Fixes diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index c73c4d1ae53..df876762d53 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -4,7 +4,7 @@ # import uuid from collections import Counter, defaultdict -from typing import TYPE_CHECKING, DefaultDict, Dict, Optional, Union +from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Union import snowflake.snowpark from snowflake.snowpark._internal.analyzer.analyzer_utils import ( @@ -988,21 +988,52 @@ def do_resolve_with_resolved_children( else: child = resolved_children[logical_plan.child] - return self.plan_builder.pivot( + # We retrieve the pivot_values for generating SQL using types: + # List[str] => explicit list of pivot values + # ScalarSubquery => dynamic pivot subquery + # None => dynamic pivot ANY subquery + + if isinstance(logical_plan.pivot_values, List): + pivot_values = [ + self.analyze(pv, df_aliased_col_name_to_real_col_name) + for pv in logical_plan.pivot_values + ] + elif isinstance(logical_plan.pivot_values, ScalarSubquery): + pivot_values = self.analyze( + logical_plan.pivot_values, df_aliased_col_name_to_real_col_name + ) + else: + pivot_values = None + + pivot_plan = self.plan_builder.pivot( self.analyze( logical_plan.pivot_column, df_aliased_col_name_to_real_col_name ), - [ - self.analyze(pv, df_aliased_col_name_to_real_col_name) - for pv in logical_plan.pivot_values - ], + pivot_values, self.analyze( logical_plan.aggregates[0], df_aliased_col_name_to_real_col_name ), + self.analyze( + logical_plan.default_on_null, df_aliased_col_name_to_real_col_name + ) + if logical_plan.default_on_null + else None, child, logical_plan, ) + # If this is a dynamic pivot, then we can't use child.schema_query which is used in the schema_query + # sql generator by default because it is simplified and won't fetch the output columns from the underlying + # source. So in this case we use the actual pivot query as the schema query. + if logical_plan.pivot_values is None or isinstance( + logical_plan.pivot_values, ScalarSubquery + ): + # TODO (SNOW-916744): Using the original query here does not work if the query depends on a temp + # table as it may not exist at later point in time when dataframe.schema is called. + pivot_plan.schema_query = pivot_plan.queries[-1].sql + + return pivot_plan + if isinstance(logical_plan, Unpivot): return self.plan_builder.unpivot( logical_plan.value_column, diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py index 6f30658b616..65f16bd486b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py @@ -161,6 +161,8 @@ EXCEPT = f" {Except.sql} " NOT_NULL = " NOT NULL " WITH = "WITH " +DEFAULT_ON_NULL = " DEFAULT ON NULL " +ANY = " ANY " TEMPORARY_STRING_SET = frozenset(["temporary", "temp"]) @@ -1033,8 +1035,22 @@ def create_or_replace_dynamic_table_statement( def pivot_statement( - pivot_column: str, pivot_values: List[str], aggregate: str, child: str + pivot_column: str, + pivot_values: Optional[Union[str, List[str]]], + aggregate: str, + default_on_null: Optional[str], + child: str, ) -> str: + if isinstance(pivot_values, str): + # The subexpression in this case already includes parenthesis. + values_str = pivot_values + else: + values_str = ( + LEFT_PARENTHESIS + + (ANY if pivot_values is None else COMMA.join(pivot_values)) + + RIGHT_PARENTHESIS + ) + return ( SELECT + STAR @@ -1048,9 +1064,12 @@ def pivot_statement( + FOR + pivot_column + IN - + LEFT_PARENTHESIS - + COMMA.join(pivot_values) - + RIGHT_PARENTHESIS + + values_str + + ( + (DEFAULT_ON_NULL + LEFT_PARENTHESIS + default_on_null + RIGHT_PARENTHESIS) + if default_on_null + else EMPTY_STRING + ) + RIGHT_PARENTHESIS ) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 637a9462288..df4ec602e6e 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -817,13 +817,16 @@ def limit( def pivot( self, pivot_column: str, - pivot_values: List[str], + pivot_values: Optional[Union[str, List[str]]], aggregate: str, + default_on_null: Optional[str], child: SnowflakePlan, source_plan: Optional[LogicalPlan], ) -> SnowflakePlan: return self.build( - lambda x: pivot_statement(pivot_column, pivot_values, aggregate, x), + lambda x: pivot_statement( + pivot_column, pivot_values, aggregate, default_on_null, x + ), child, source_plan, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index adf507703c9..076aa80a74c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -2,7 +2,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from snowflake.snowpark._internal.analyzer.expression import Expression, NamedExpression from snowflake.snowpark._internal.analyzer.snowflake_plan import LogicalPlan @@ -53,8 +53,9 @@ def __init__( self, grouping_columns: List[Expression], pivot_column: Expression, - pivot_values: List[Expression], + pivot_values: Optional[Union[List[Expression], LogicalPlan]], aggregates: List[Expression], + default_on_null: Optional[Expression], child: LogicalPlan, ) -> None: super().__init__(child) @@ -62,6 +63,7 @@ def __init__( self.pivot_column = pivot_column self.pivot_values = pivot_values self.aggregates = aggregates + self.default_on_null = default_on_null class Unpivot(UnaryNode): diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index 211f2268065..161eeeb9b86 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -28,6 +28,7 @@ Any, Callable, Dict, + Iterable, Iterator, List, Literal, @@ -870,3 +871,66 @@ def validate_quoted_name(name: str) -> str: def escape_quotes(unescaped: str) -> str: return unescaped.replace(DOUBLE_QUOTE, DOUBLE_QUOTE + DOUBLE_QUOTE) + + +def prepare_pivot_arguments( + df: "snowflake.snowpark.DataFrame", + df_name: str, + pivot_col: "snowflake.snowpark._internal.type_utils.ColumnOrName", + values: Optional[ + Union[ + Iterable["snowflake.snowpark._internal.type_utils.LiteralType"], + "snowflake.snowpark.DataFrame", + ] + ], + default_on_null: Optional["snowflake.snowpark._internal.type_utils.LiteralType"], +): + """ + Prepare dataframe pivot arguments to use in the underlying pivot call. This includes issuing any applicable + warnings, ensuring column types and valid arguments. + Returns: + DateFrame, pivot column, pivot_values and default_on_null value. + """ + from snowflake.snowpark.dataframe import DataFrame + + if values is None or isinstance(values, DataFrame): + warning( + df_name, + "Parameter values is Optional or DataFrame is in private preview since v1.15.0. Do not use it in production.", + ) + + if default_on_null is not None: + warning( + df_name, + "Parameter default_on_null is not None is in private preview since v1.15.0. Do not use it in production.", + ) + + if values is not None and not values: + raise ValueError("values cannot be empty") + + pc = df._convert_cols_to_exprs(f"{df_name}()", pivot_col) + + from snowflake.snowpark._internal.analyzer.expression import Literal, ScalarSubquery + from snowflake.snowpark.column import Column + + if isinstance(values, Iterable): + pivot_values = [ + v._expression if isinstance(v, Column) else Literal(v) for v in values + ] + else: + if isinstance(values, DataFrame): + pivot_values = ScalarSubquery(values._plan) + else: + pivot_values = None + + if len(df.queries.get("post_actions", [])) > 0: + df = df.cache_result() + + if default_on_null is not None: + default_on_null = ( + default_on_null._expression + if isinstance(default_on_null, Column) + else Literal(default_on_null) + ) + + return df, pc, pivot_values, default_on_null diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index a992b43c1f8..8caa7c4d178 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -116,6 +116,7 @@ is_sql_select_statement, parse_positional_args_to_list, parse_table_name, + prepare_pivot_arguments, private_preview, quote_name, random_name_for_temp_object, @@ -1620,7 +1621,10 @@ def drop_duplicates(self, *subset: Union[str, Iterable[str]]) -> "DataFrame": def pivot( self, pivot_col: ColumnOrName, - values: Iterable[LiteralType], + values: Optional[ + Union[Iterable[LiteralType], "snowflake.snowpark.DataFrame"] + ] = None, + default_on_null: Optional[LiteralType] = None, ) -> "snowflake.snowpark.RelationalGroupedDataFrame": """Rotates this DataFrame by turning the unique values from one column in the input expression into multiple columns and aggregating results where required on any @@ -1649,21 +1653,43 @@ def pivot( ------------------------------- + >>> df = session.table("monthly_sales") + >>> df.pivot("month").sum("amount").sort("empid").show() + ------------------------------- + |"EMPID" |"'FEB'" |"'JAN'" | + ------------------------------- + |1 |8000 |10400 | + |2 |200 |39500 | + ------------------------------- + + + >>> subquery_df = session.table("monthly_sales").select(col("month")).filter(col("month") == "JAN") + >>> df = session.table("monthly_sales") + >>> df.pivot("month", values=subquery_df).sum("amount").sort("empid").show() + --------------------- + |"EMPID" |"'JAN'" | + --------------------- + |1 |10400 | + |2 |39500 | + --------------------- + + Args: pivot_col: The column or name of the column to use. - values: A list of values in the column. + values: A list of values in the column, + or dynamic based on the DataFrame query, + or None (default) will use all values of the pivot column. + default_on_null: Expression to replace empty result values. """ - if not values: - raise ValueError("values cannot be empty") - pc = self._convert_cols_to_exprs("pivot()", pivot_col) - value_exprs = [ - v._expression if isinstance(v, Column) else Literal(v) for v in values - ] + target_df, pc, pivot_values, default_on_null = prepare_pivot_arguments( + self, "DataFrame.pivot", pivot_col, values, default_on_null + ) + return snowflake.snowpark.RelationalGroupedDataFrame( - self, + target_df, [], snowflake.snowpark.relational_grouped_dataframe._PivotType( - pc[0], value_exprs + pc[0], pivot_values, default_on_null ), ) diff --git a/src/snowflake/snowpark/relational_grouped_dataframe.py b/src/snowflake/snowpark/relational_grouped_dataframe.py index c38e70625af..7b525ee9cd5 100644 --- a/src/snowflake/snowpark/relational_grouped_dataframe.py +++ b/src/snowflake/snowpark/relational_grouped_dataframe.py @@ -2,7 +2,7 @@ # # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import Callable, Dict, Iterable, List, Tuple, Union +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union from snowflake.connector.options import pandas from snowflake.snowpark import functions @@ -10,6 +10,7 @@ Expression, Literal, NamedExpression, + ScalarSubquery, SnowflakeUDF, UnresolvedAttribute, ) @@ -26,7 +27,10 @@ from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages from snowflake.snowpark._internal.telemetry import relational_group_df_api_usage from snowflake.snowpark._internal.type_utils import ColumnOrName, LiteralType -from snowflake.snowpark._internal.utils import parse_positional_args_to_list +from snowflake.snowpark._internal.utils import ( + parse_positional_args_to_list, + prepare_pivot_arguments, +) from snowflake.snowpark.column import Column from snowflake.snowpark.dataframe import DataFrame from snowflake.snowpark.types import StructType @@ -75,9 +79,15 @@ class _RollupType(_GroupType): class _PivotType(_GroupType): - def __init__(self, pivot_col: Expression, values: List[Expression]) -> None: + def __init__( + self, + pivot_col: Expression, + values: Optional[Union[List[Expression], ScalarSubquery]], + default_on_null: Optional[Expression], + ) -> None: self.pivot_col = pivot_col self.values = values + self.default_on_null = default_on_null class GroupingSets: @@ -173,6 +183,7 @@ def _to_df(self, agg_exprs: List[Expression]) -> DataFrame: self._group_type.pivot_col, self._group_type.values, agg_exprs, + self._group_type.default_on_null, self._df._select_statement or self._df._plan, ) else: # pragma: no cover @@ -363,7 +374,10 @@ def end_partition(self, pdf: pandas.DataFrame) -> pandas.DataFrame: applyInPandas = apply_in_pandas def pivot( - self, pivot_col: ColumnOrName, values: Iterable[LiteralType] + self, + pivot_col: ColumnOrName, + values: Optional[Union[Iterable[LiteralType], DataFrame]] = None, + default_on_null: Optional[LiteralType] = None, ) -> "RelationalGroupedDataFrame": """Rotates this DataFrame by turning unique values from one column in the input expression into multiple columns and aggregating results where required on any @@ -373,7 +387,10 @@ def pivot( Args: pivot_col: The column or name of the column to use. - values: A list of values in the column. + values: A list of values in the column, + or dynamic based on the DataFrame query, + or None (default) will use all values of the pivot column. + default_on_null: Expression to replace empty result values. Example:: @@ -406,16 +423,42 @@ def pivot( |2 |B |NULL |200 | ---------------------------------------- + + >>> df = session.table("monthly_sales") + >>> df.group_by(["empid", "team"]).pivot("month").sum("amount").sort("empid", "team").show() + ---------------------------------------- + |"EMPID" |"TEAM" |"'FEB'" |"'JAN'" | + ---------------------------------------- + |1 |A |3000 |10000 | + |1 |B |5000 |400 | + |2 |A |NULL |39500 | + |2 |B |200 |NULL | + ---------------------------------------- + + + >>> from snowflake.snowpark.functions import col + >>> subquery_df = session.table("monthly_sales").select("month").filter(col("month") == "JAN") + >>> df = session.table("monthly_sales") + >>> df.group_by(["empid", "team"]).pivot("month", values=subquery_df, default_on_null=999).sum("amount").sort("empid", "team").show() + ------------------------------ + |"EMPID" |"TEAM" |"'JAN'" | + ------------------------------ + |1 |A |10000 | + |1 |B |400 | + |2 |A |39500 | + |2 |B |999 | + ------------------------------ + """ - if not values: - raise ValueError("values cannot be empty") - pc = self._df._convert_cols_to_exprs( - "RelationalGroupedDataFrame.pivot()", pivot_col + self._df, pc, pivot_values, default_on_null = prepare_pivot_arguments( + self._df, + "RelationalGroupedDataFrame.pivot", + pivot_col, + values, + default_on_null, ) - value_exprs = [ - v._expression if isinstance(v, Column) else Literal(v) for v in values - ] - self._group_type = _PivotType(pc[0], value_exprs) + + self._group_type = _PivotType(pc[0], pivot_values, default_on_null) return self @relational_group_df_api_usage diff --git a/tests/integ/scala/test_column_suite.py b/tests/integ/scala/test_column_suite.py index 5e547a95678..64d613bdac5 100644 --- a/tests/integ/scala/test_column_suite.py +++ b/tests/integ/scala/test_column_suite.py @@ -979,3 +979,20 @@ def test_in_expression_with_multiple_queries(session): Utils.check_answer( df2.select(col("a").in_(df1.select("a"))), [Row(True), Row(False)] ) + + +def test_pivot_with_multiple_queries(session): + from snowflake.snowpark._internal.analyzer import analyzer + + original_value = analyzer.ARRAY_BIND_THRESHOLD + try: + analyzer.ARRAY_BIND_THRESHOLD = 2 + df1 = session.create_dataframe([[1, "one"], [2, "two"]], schema=["a", "b"]) + finally: + analyzer.ARRAY_BIND_THRESHOLD = original_value + df2 = session.create_dataframe( + [[1, "one"], [11, "one"], [3, "three"]], schema=["a", "b"] + ) + Utils.check_answer( + df2.pivot(col("b"), df1.select(col("b"))).agg(avg(col("a"))), [Row(6, None)] + ) diff --git a/tests/integ/scala/test_dataframe_aggregate_suite.py b/tests/integ/scala/test_dataframe_aggregate_suite.py index 766a00cb474..55cc05914c6 100644 --- a/tests/integ/scala/test_dataframe_aggregate_suite.py +++ b/tests/integ/scala/test_dataframe_aggregate_suite.py @@ -5,11 +5,13 @@ from decimal import Decimal from math import sqrt +from typing import NamedTuple import pytest from snowflake.snowpark import GroupingSets, Row from snowflake.snowpark._internal.utils import TempObjectType +from snowflake.snowpark.column import Column from snowflake.snowpark.exceptions import ( SnowparkDataframeException, SnowparkSQLException, @@ -95,6 +97,61 @@ def test_group_by_pivot(session): ).agg([sum(col("amount")), avg(col("amount"))]) +def test_group_by_pivot_dynamic_any(session): + Utils.check_answer( + TestData.monthly_sales_with_team(session) + .group_by("empid") + .pivot("month") + .agg(sum(col("amount"))) + .sort(col("empid")), + [ + Row(1, 18000, 8000, 10400, 11000), + Row(2, 5300, 90700, 39500, 12000), + ], + sort=False, + ) + + Utils.check_answer( + TestData.monthly_sales_with_team(session) + .group_by(["empid", "team"]) + .pivot("month") + .agg(sum(col("amount"))) + .sort(col("empid"), col("team")), + [ + Row(1, "A", 10000, None, 10400, 5000), + Row(1, "B", 8000, 8000, None, 6000), + Row(2, "A", 5300, 90700, 4500, None), + Row(2, "B", None, None, 35000, 12000), + ], + sort=False, + ) + + +def test_group_by_pivot_dynamic_subquery(session): + src = TestData.monthly_sales(session) + subquery_df = src.select(col("month")).filter(col("month") == "JAN") + + Utils.check_answer( + TestData.monthly_sales_with_team(session) + .group_by("empid") + .pivot("month", subquery_df) + .agg(sum(col("amount"))) + .sort(col("empid")), + [Row(1, 10400), Row(2, 39500)], + sort=False, + ) + + Utils.check_answer( + TestData.monthly_sales_with_team(session) + .group_by(["empid", "team"]) + .pivot("month", subquery_df, 999) + .agg(sum(col("amount"))) + .sort(col("empid"), col("team")), + [Row(1, "A", 10400), Row(1, "B", 999), Row(2, "A", 4500), Row(2, "B", 35000)], + sort=False, + ) + + def test_join_on_pivot(session): df1 = ( TestData.monthly_sales(session) @@ -132,6 +189,143 @@ def test_pivot_on_join(session): ) +# TODO (SNOW-916206) If the source is a temp table with inlined data, then we need to validate that +# pivot will materialize the data before executing pivot, otherwise would fail with not finding the +# data when doing a later schema call. +def test_pivot_dynamic_any_with_temp_table_inlined_data(session): + original_df = session.create_dataframe( + [tuple(range(26)) for r in range(20)], schema=list("ABCDEFGHIJKLMNOPQRSTUVWXYZ") + ) + + # Validate the data is backed by a temporary table + assert len(original_df.queries.get("post_actions", [])) > 0 + + pivot_op_df = original_df.pivot("a").agg(sum(col("b"))).sort(col("c")) + + # Query and ensure the schema matches as expected, this would fail with an exception if the data is not + # materialized (happens internally) first. + assert {f.column_identifier.name for f in pivot_op_df.schema.fields} == set( + list("CDEFGHIJKLMNOPQRSTUVWXYZ") + ['"0"'] + ) + + assert pivot_op_df.count() == 1 + + +def test_pivot_dynamic_any(session): + Utils.check_answer( + TestData.monthly_sales(session) + .pivot("month") + .agg(sum(col("amount"))) + .sort(col("empid")), + [ + Row(1, 18000, 8000, 10400, 11000), + Row(2, 5300, 90700, 39500, 12000), + ], + sort=False, + ) + + +def test_pivot_dynamic_subquery(session): + src = TestData.monthly_sales(session) + subquery_df = src.select(col("month")).filter(col("month") == "JAN") + + Utils.check_answer( + TestData.monthly_sales(session) + .pivot("month", subquery_df) + .agg(sum(col("amount"))) + .sort(col("empid")), + [Row(1, 10400), Row(2, 39500)], + sort=False, + ) + + +@pytest.mark.skip( + "SNOW-847500: Currently fails because of snowpark is not using DISTINCT keyword expected by server" +) +@pytest.mark.parametrize("is_ascending", [True, False]) +def test_pivot_dynamic_subquery_with_sort(session, is_ascending): + src = TestData.monthly_sales(session) + subquery_df = src.select(col("month")).filter( + (col("month") == "JAN") | (col("month") == "APR") + ) + + Utils.check_answer( + TestData.monthly_sales(session) + .pivot( + "month", + subquery_df.select("month") + .distinct() + .sort("month", ascending=is_ascending), + ) + .agg(sum(col("amount"))) + .sort(col("empid")), + [ + Row(1, 18000, 10400) if is_ascending else Row(1, 10400, 18000), + Row(2, 5300, 39500) if is_ascending else Row(2, 39500, 5300), + ], + sort=False, + ) + + +@pytest.mark.skip( + "SNOW-848987: Requires server changes in 7.22 so can unskip once sfctest0 is on >= 7.22" +) +def test_pivot_dynamic_subquery_with_bad_subquery(session): + src = TestData.monthly_sales(session) + subquery_df = src.select(col("month")).filter( + (col("month") == "JAN") | (col("month") == "APR") + ) + + with pytest.raises(SnowparkSQLException) as ex_info: + TestData.monthly_sales(session).pivot( + "month", subquery_df.select("month").sort("month") + ).agg(sum(col("amount"))).collect() + + assert "Invalid subquery pivot order by must be distinct query" in str(ex_info) + + with pytest.raises(SnowparkSQLException) as ex_info: + TestData.monthly_sales(session).pivot( + "month", subquery_df.select(["month", "empid"]) + ).agg(sum(col("amount"))).collect() + + assert "Pivot subquery must select single column" in str(ex_info.value) + + +def test_pivot_default_on_none(session): + class MonthlySales(NamedTuple): + empid: int + amount: int + month: str + + src = session.create_dataframe( + [ + MonthlySales(1, 10000, "JAN"), + MonthlySales(1, 400, "JAN"), + MonthlySales(1, None, "FEB"), + MonthlySales(1, 6000, "MAR"), + MonthlySales(2, 9000, "MAR"), + MonthlySales(2, None, "MAR"), + ] + ) + + for default_on_null in [Decimal(1.5), lit(9999), 9999, 0, None]: + default_value = ( + default_on_null._expression.value + if isinstance(default_on_null, Column) + else default_on_null + ) + Utils.check_answer( + src.pivot("month", ["JAN", "FEB", "MAR"], default_on_null=default_on_null) + .agg(sum(col("amount"))) + .sort(col("empid")), + [ + Row(1, 10400, default_value, 6000), + Row(2, default_value, default_value, 9000), + ], + sort=False, + ) + + @pytest.mark.localtest def test_rel_grouped_dataframe_agg(session): df = ( diff --git a/tests/unit/scala/test_df_suite.py b/tests/unit/scala/test_df_suite.py index 2707465bfec..fb1175bff38 100644 --- a/tests/unit/scala/test_df_suite.py +++ b/tests/unit/scala/test_df_suite.py @@ -15,4 +15,4 @@ def test_to_string_of_relational_grouped_dataframe(): assert _GroupByType().to_string() == "GroupBy" assert _CubeType().to_string() == "Cube" assert _RollupType().to_string() == "Rollup" - assert _PivotType(None, []).to_string() == "Pivot" + assert _PivotType(None, [], None).to_string() == "Pivot"