diff --git a/CHANGELOG.md b/CHANGELOG.md index 13f4fdad7c2..4a87cd123e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,6 +45,7 @@ - save_as_table - Added support for snow:// URLs to `snowflake.snowpark.Session.file.get` and `snowflake.snowpark.Session.file.get_stream` - UDAF client support is ready for public preview. Please stay tuned for the Snowflake announcement of UDAF public preview. +- Added support for dynamic pivot. This feature is currently in private preview. ### Bug Fixes diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer.py b/src/snowflake/snowpark/_internal/analyzer/analyzer.py index c73c4d1ae53..df876762d53 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer.py @@ -4,7 +4,7 @@ # import uuid from collections import Counter, defaultdict -from typing import TYPE_CHECKING, DefaultDict, Dict, Optional, Union +from typing import TYPE_CHECKING, DefaultDict, Dict, List, Optional, Union import snowflake.snowpark from snowflake.snowpark._internal.analyzer.analyzer_utils import ( @@ -988,21 +988,52 @@ def do_resolve_with_resolved_children( else: child = resolved_children[logical_plan.child] - return self.plan_builder.pivot( + # We retrieve the pivot_values for generating SQL using types: + # List[str] => explicit list of pivot values + # ScalarSubquery => dynamic pivot subquery + # None => dynamic pivot ANY subquery + + if isinstance(logical_plan.pivot_values, List): + pivot_values = [ + self.analyze(pv, df_aliased_col_name_to_real_col_name) + for pv in logical_plan.pivot_values + ] + elif isinstance(logical_plan.pivot_values, ScalarSubquery): + pivot_values = self.analyze( + logical_plan.pivot_values, df_aliased_col_name_to_real_col_name + ) + else: + pivot_values = None + + pivot_plan = self.plan_builder.pivot( self.analyze( logical_plan.pivot_column, df_aliased_col_name_to_real_col_name ), - [ - self.analyze(pv, df_aliased_col_name_to_real_col_name) - for pv in logical_plan.pivot_values - ], + pivot_values, self.analyze( logical_plan.aggregates[0], df_aliased_col_name_to_real_col_name ), + self.analyze( + logical_plan.default_on_null, df_aliased_col_name_to_real_col_name + ) + if logical_plan.default_on_null + else None, child, logical_plan, ) + # If this is a dynamic pivot, then we can't use child.schema_query which is used in the schema_query + # sql generator by default because it is simplified and won't fetch the output columns from the underlying + # source. So in this case we use the actual pivot query as the schema query. + if logical_plan.pivot_values is None or isinstance( + logical_plan.pivot_values, ScalarSubquery + ): + # TODO (SNOW-916744): Using the original query here does not work if the query depends on a temp + # table as it may not exist at later point in time when dataframe.schema is called. + pivot_plan.schema_query = pivot_plan.queries[-1].sql + + return pivot_plan + if isinstance(logical_plan, Unpivot): return self.plan_builder.unpivot( logical_plan.value_column, diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py index 6f30658b616..65f16bd486b 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py @@ -161,6 +161,8 @@ EXCEPT = f" {Except.sql} " NOT_NULL = " NOT NULL " WITH = "WITH " +DEFAULT_ON_NULL = " DEFAULT ON NULL " +ANY = " ANY " TEMPORARY_STRING_SET = frozenset(["temporary", "temp"]) @@ -1033,8 +1035,22 @@ def create_or_replace_dynamic_table_statement( def pivot_statement( - pivot_column: str, pivot_values: List[str], aggregate: str, child: str + pivot_column: str, + pivot_values: Optional[Union[str, List[str]]], + aggregate: str, + default_on_null: Optional[str], + child: str, ) -> str: + if isinstance(pivot_values, str): + # The subexpression in this case already includes parenthesis. + values_str = pivot_values + else: + values_str = ( + LEFT_PARENTHESIS + + (ANY if pivot_values is None else COMMA.join(pivot_values)) + + RIGHT_PARENTHESIS + ) + return ( SELECT + STAR @@ -1048,9 +1064,12 @@ def pivot_statement( + FOR + pivot_column + IN - + LEFT_PARENTHESIS - + COMMA.join(pivot_values) - + RIGHT_PARENTHESIS + + values_str + + ( + (DEFAULT_ON_NULL + LEFT_PARENTHESIS + default_on_null + RIGHT_PARENTHESIS) + if default_on_null + else EMPTY_STRING + ) + RIGHT_PARENTHESIS ) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 637a9462288..df4ec602e6e 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -817,13 +817,16 @@ def limit( def pivot( self, pivot_column: str, - pivot_values: List[str], + pivot_values: Optional[Union[str, List[str]]], aggregate: str, + default_on_null: Optional[str], child: SnowflakePlan, source_plan: Optional[LogicalPlan], ) -> SnowflakePlan: return self.build( - lambda x: pivot_statement(pivot_column, pivot_values, aggregate, x), + lambda x: pivot_statement( + pivot_column, pivot_values, aggregate, default_on_null, x + ), child, source_plan, ) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py index adf507703c9..076aa80a74c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_plan_node.py @@ -2,7 +2,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from snowflake.snowpark._internal.analyzer.expression import Expression, NamedExpression from snowflake.snowpark._internal.analyzer.snowflake_plan import LogicalPlan @@ -53,8 +53,9 @@ def __init__( self, grouping_columns: List[Expression], pivot_column: Expression, - pivot_values: List[Expression], + pivot_values: Optional[Union[List[Expression], LogicalPlan]], aggregates: List[Expression], + default_on_null: Optional[Expression], child: LogicalPlan, ) -> None: super().__init__(child) @@ -62,6 +63,7 @@ def __init__( self.pivot_column = pivot_column self.pivot_values = pivot_values self.aggregates = aggregates + self.default_on_null = default_on_null class Unpivot(UnaryNode): diff --git a/src/snowflake/snowpark/_internal/utils.py b/src/snowflake/snowpark/_internal/utils.py index 211f2268065..161eeeb9b86 100644 --- a/src/snowflake/snowpark/_internal/utils.py +++ b/src/snowflake/snowpark/_internal/utils.py @@ -28,6 +28,7 @@ Any, Callable, Dict, + Iterable, Iterator, List, Literal, @@ -870,3 +871,66 @@ def validate_quoted_name(name: str) -> str: def escape_quotes(unescaped: str) -> str: return unescaped.replace(DOUBLE_QUOTE, DOUBLE_QUOTE + DOUBLE_QUOTE) + + +def prepare_pivot_arguments( + df: "snowflake.snowpark.DataFrame", + df_name: str, + pivot_col: "snowflake.snowpark._internal.type_utils.ColumnOrName", + values: Optional[ + Union[ + Iterable["snowflake.snowpark._internal.type_utils.LiteralType"], + "snowflake.snowpark.DataFrame", + ] + ], + default_on_null: Optional["snowflake.snowpark._internal.type_utils.LiteralType"], +): + """ + Prepare dataframe pivot arguments to use in the underlying pivot call. This includes issuing any applicable + warnings, ensuring column types and valid arguments. + Returns: + DateFrame, pivot column, pivot_values and default_on_null value. + """ + from snowflake.snowpark.dataframe import DataFrame + + if values is None or isinstance(values, DataFrame): + warning( + df_name, + "Parameter values is Optional or DataFrame is in private preview since v1.15.0. Do not use it in production.", + ) + + if default_on_null is not None: + warning( + df_name, + "Parameter default_on_null is not None is in private preview since v1.15.0. Do not use it in production.", + ) + + if values is not None and not values: + raise ValueError("values cannot be empty") + + pc = df._convert_cols_to_exprs(f"{df_name}()", pivot_col) + + from snowflake.snowpark._internal.analyzer.expression import Literal, ScalarSubquery + from snowflake.snowpark.column import Column + + if isinstance(values, Iterable): + pivot_values = [ + v._expression if isinstance(v, Column) else Literal(v) for v in values + ] + else: + if isinstance(values, DataFrame): + pivot_values = ScalarSubquery(values._plan) + else: + pivot_values = None + + if len(df.queries.get("post_actions", [])) > 0: + df = df.cache_result() + + if default_on_null is not None: + default_on_null = ( + default_on_null._expression + if isinstance(default_on_null, Column) + else Literal(default_on_null) + ) + + return df, pc, pivot_values, default_on_null diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index a992b43c1f8..8caa7c4d178 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -116,6 +116,7 @@ is_sql_select_statement, parse_positional_args_to_list, parse_table_name, + prepare_pivot_arguments, private_preview, quote_name, random_name_for_temp_object, @@ -1620,7 +1621,10 @@ def drop_duplicates(self, *subset: Union[str, Iterable[str]]) -> "DataFrame": def pivot( self, pivot_col: ColumnOrName, - values: Iterable[LiteralType], + values: Optional[ + Union[Iterable[LiteralType], "snowflake.snowpark.DataFrame"] + ] = None, + default_on_null: Optional[LiteralType] = None, ) -> "snowflake.snowpark.RelationalGroupedDataFrame": """Rotates this DataFrame by turning the unique values from one column in the input expression into multiple columns and aggregating results where required on any @@ -1649,21 +1653,43 @@ def pivot( ------------------------------- + >>> df = session.table("monthly_sales") + >>> df.pivot("month").sum("amount").sort("empid").show() + ------------------------------- + |"EMPID" |"'FEB'" |"'JAN'" | + ------------------------------- + |1 |8000 |10400 | + |2 |200 |39500 | + ------------------------------- + + + >>> subquery_df = session.table("monthly_sales").select(col("month")).filter(col("month") == "JAN") + >>> df = session.table("monthly_sales") + >>> df.pivot("month", values=subquery_df).sum("amount").sort("empid").show() + --------------------- + |"EMPID" |"'JAN'" | + --------------------- + |1 |10400 | + |2 |39500 | + --------------------- + + Args: pivot_col: The column or name of the column to use. - values: A list of values in the column. + values: A list of values in the column, + or dynamic based on the DataFrame query, + or None (default) will use all values of the pivot column. + default_on_null: Expression to replace empty result values. """ - if not values: - raise ValueError("values cannot be empty") - pc = self._convert_cols_to_exprs("pivot()", pivot_col) - value_exprs = [ - v._expression if isinstance(v, Column) else Literal(v) for v in values - ] + target_df, pc, pivot_values, default_on_null = prepare_pivot_arguments( + self, "DataFrame.pivot", pivot_col, values, default_on_null + ) + return snowflake.snowpark.RelationalGroupedDataFrame( - self, + target_df, [], snowflake.snowpark.relational_grouped_dataframe._PivotType( - pc[0], value_exprs + pc[0], pivot_values, default_on_null ), ) diff --git a/src/snowflake/snowpark/relational_grouped_dataframe.py b/src/snowflake/snowpark/relational_grouped_dataframe.py index c38e70625af..7b525ee9cd5 100644 --- a/src/snowflake/snowpark/relational_grouped_dataframe.py +++ b/src/snowflake/snowpark/relational_grouped_dataframe.py @@ -2,7 +2,7 @@ # # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import Callable, Dict, Iterable, List, Tuple, Union +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union from snowflake.connector.options import pandas from snowflake.snowpark import functions @@ -10,6 +10,7 @@ Expression, Literal, NamedExpression, + ScalarSubquery, SnowflakeUDF, UnresolvedAttribute, ) @@ -26,7 +27,10 @@ from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages from snowflake.snowpark._internal.telemetry import relational_group_df_api_usage from snowflake.snowpark._internal.type_utils import ColumnOrName, LiteralType -from snowflake.snowpark._internal.utils import parse_positional_args_to_list +from snowflake.snowpark._internal.utils import ( + parse_positional_args_to_list, + prepare_pivot_arguments, +) from snowflake.snowpark.column import Column from snowflake.snowpark.dataframe import DataFrame from snowflake.snowpark.types import StructType @@ -75,9 +79,15 @@ class _RollupType(_GroupType): class _PivotType(_GroupType): - def __init__(self, pivot_col: Expression, values: List[Expression]) -> None: + def __init__( + self, + pivot_col: Expression, + values: Optional[Union[List[Expression], ScalarSubquery]], + default_on_null: Optional[Expression], + ) -> None: self.pivot_col = pivot_col self.values = values + self.default_on_null = default_on_null class GroupingSets: @@ -173,6 +183,7 @@ def _to_df(self, agg_exprs: List[Expression]) -> DataFrame: self._group_type.pivot_col, self._group_type.values, agg_exprs, + self._group_type.default_on_null, self._df._select_statement or self._df._plan, ) else: # pragma: no cover @@ -363,7 +374,10 @@ def end_partition(self, pdf: pandas.DataFrame) -> pandas.DataFrame: applyInPandas = apply_in_pandas def pivot( - self, pivot_col: ColumnOrName, values: Iterable[LiteralType] + self, + pivot_col: ColumnOrName, + values: Optional[Union[Iterable[LiteralType], DataFrame]] = None, + default_on_null: Optional[LiteralType] = None, ) -> "RelationalGroupedDataFrame": """Rotates this DataFrame by turning unique values from one column in the input expression into multiple columns and aggregating results where required on any @@ -373,7 +387,10 @@ def pivot( Args: pivot_col: The column or name of the column to use. - values: A list of values in the column. + values: A list of values in the column, + or dynamic based on the DataFrame query, + or None (default) will use all values of the pivot column. + default_on_null: Expression to replace empty result values. Example:: @@ -406,16 +423,42 @@ def pivot( |2 |B |NULL |200 | ---------------------------------------- + + >>> df = session.table("monthly_sales") + >>> df.group_by(["empid", "team"]).pivot("month").sum("amount").sort("empid", "team").show() + ---------------------------------------- + |"EMPID" |"TEAM" |"'FEB'" |"'JAN'" | + ---------------------------------------- + |1 |A |3000 |10000 | + |1 |B |5000 |400 | + |2 |A |NULL |39500 | + |2 |B |200 |NULL | + ---------------------------------------- + + + >>> from snowflake.snowpark.functions import col + >>> subquery_df = session.table("monthly_sales").select("month").filter(col("month") == "JAN") + >>> df = session.table("monthly_sales") + >>> df.group_by(["empid", "team"]).pivot("month", values=subquery_df, default_on_null=999).sum("amount").sort("empid", "team").show() + ------------------------------ + |"EMPID" |"TEAM" |"'JAN'" | + ------------------------------ + |1 |A |10000 | + |1 |B |400 | + |2 |A |39500 | + |2 |B |999 | + ------------------------------ + """ - if not values: - raise ValueError("values cannot be empty") - pc = self._df._convert_cols_to_exprs( - "RelationalGroupedDataFrame.pivot()", pivot_col + self._df, pc, pivot_values, default_on_null = prepare_pivot_arguments( + self._df, + "RelationalGroupedDataFrame.pivot", + pivot_col, + values, + default_on_null, ) - value_exprs = [ - v._expression if isinstance(v, Column) else Literal(v) for v in values - ] - self._group_type = _PivotType(pc[0], value_exprs) + + self._group_type = _PivotType(pc[0], pivot_values, default_on_null) return self @relational_group_df_api_usage diff --git a/tests/integ/scala/test_column_suite.py b/tests/integ/scala/test_column_suite.py index 5e547a95678..64d613bdac5 100644 --- a/tests/integ/scala/test_column_suite.py +++ b/tests/integ/scala/test_column_suite.py @@ -979,3 +979,20 @@ def test_in_expression_with_multiple_queries(session): Utils.check_answer( df2.select(col("a").in_(df1.select("a"))), [Row(True), Row(False)] ) + + +def test_pivot_with_multiple_queries(session): + from snowflake.snowpark._internal.analyzer import analyzer + + original_value = analyzer.ARRAY_BIND_THRESHOLD + try: + analyzer.ARRAY_BIND_THRESHOLD = 2 + df1 = session.create_dataframe([[1, "one"], [2, "two"]], schema=["a", "b"]) + finally: + analyzer.ARRAY_BIND_THRESHOLD = original_value + df2 = session.create_dataframe( + [[1, "one"], [11, "one"], [3, "three"]], schema=["a", "b"] + ) + Utils.check_answer( + df2.pivot(col("b"), df1.select(col("b"))).agg(avg(col("a"))), [Row(6, None)] + ) diff --git a/tests/integ/scala/test_dataframe_aggregate_suite.py b/tests/integ/scala/test_dataframe_aggregate_suite.py index 766a00cb474..55cc05914c6 100644 --- a/tests/integ/scala/test_dataframe_aggregate_suite.py +++ b/tests/integ/scala/test_dataframe_aggregate_suite.py @@ -5,11 +5,13 @@ from decimal import Decimal from math import sqrt +from typing import NamedTuple import pytest from snowflake.snowpark import GroupingSets, Row from snowflake.snowpark._internal.utils import TempObjectType +from snowflake.snowpark.column import Column from snowflake.snowpark.exceptions import ( SnowparkDataframeException, SnowparkSQLException, @@ -95,6 +97,61 @@ def test_group_by_pivot(session): ).agg([sum(col("amount")), avg(col("amount"))]) +def test_group_by_pivot_dynamic_any(session): + Utils.check_answer( + TestData.monthly_sales_with_team(session) + .group_by("empid") + .pivot("month") + .agg(sum(col("amount"))) + .sort(col("empid")), + [ + Row(1, 18000, 8000, 10400, 11000), + Row(2, 5300, 90700, 39500, 12000), + ], + sort=False, + ) + + Utils.check_answer( + TestData.monthly_sales_with_team(session) + .group_by(["empid", "team"]) + .pivot("month") + .agg(sum(col("amount"))) + .sort(col("empid"), col("team")), + [ + Row(1, "A", 10000, None, 10400, 5000), + Row(1, "B", 8000, 8000, None, 6000), + Row(2, "A", 5300, 90700, 4500, None), + Row(2, "B", None, None, 35000, 12000), + ], + sort=False, + ) + + +def test_group_by_pivot_dynamic_subquery(session): + src = TestData.monthly_sales(session) + subquery_df = src.select(col("month")).filter(col("month") == "JAN") + + Utils.check_answer( + TestData.monthly_sales_with_team(session) + .group_by("empid") + .pivot("month", subquery_df) + .agg(sum(col("amount"))) + .sort(col("empid")), + [Row(1, 10400), Row(2, 39500)], + sort=False, + ) + + Utils.check_answer( + TestData.monthly_sales_with_team(session) + .group_by(["empid", "team"]) + .pivot("month", subquery_df, 999) + .agg(sum(col("amount"))) + .sort(col("empid"), col("team")), + [Row(1, "A", 10400), Row(1, "B", 999), Row(2, "A", 4500), Row(2, "B", 35000)], + sort=False, + ) + + def test_join_on_pivot(session): df1 = ( TestData.monthly_sales(session) @@ -132,6 +189,143 @@ def test_pivot_on_join(session): ) +# TODO (SNOW-916206) If the source is a temp table with inlined data, then we need to validate that +# pivot will materialize the data before executing pivot, otherwise would fail with not finding the +# data when doing a later schema call. +def test_pivot_dynamic_any_with_temp_table_inlined_data(session): + original_df = session.create_dataframe( + [tuple(range(26)) for r in range(20)], schema=list("ABCDEFGHIJKLMNOPQRSTUVWXYZ") + ) + + # Validate the data is backed by a temporary table + assert len(original_df.queries.get("post_actions", [])) > 0 + + pivot_op_df = original_df.pivot("a").agg(sum(col("b"))).sort(col("c")) + + # Query and ensure the schema matches as expected, this would fail with an exception if the data is not + # materialized (happens internally) first. + assert {f.column_identifier.name for f in pivot_op_df.schema.fields} == set( + list("CDEFGHIJKLMNOPQRSTUVWXYZ") + ['"0"'] + ) + + assert pivot_op_df.count() == 1 + + +def test_pivot_dynamic_any(session): + Utils.check_answer( + TestData.monthly_sales(session) + .pivot("month") + .agg(sum(col("amount"))) + .sort(col("empid")), + [ + Row(1, 18000, 8000, 10400, 11000), + Row(2, 5300, 90700, 39500, 12000), + ], + sort=False, + ) + + +def test_pivot_dynamic_subquery(session): + src = TestData.monthly_sales(session) + subquery_df = src.select(col("month")).filter(col("month") == "JAN") + + Utils.check_answer( + TestData.monthly_sales(session) + .pivot("month", subquery_df) + .agg(sum(col("amount"))) + .sort(col("empid")), + [Row(1, 10400), Row(2, 39500)], + sort=False, + ) + + +@pytest.mark.skip( + "SNOW-847500: Currently fails because of snowpark is not using DISTINCT keyword expected by server" +) +@pytest.mark.parametrize("is_ascending", [True, False]) +def test_pivot_dynamic_subquery_with_sort(session, is_ascending): + src = TestData.monthly_sales(session) + subquery_df = src.select(col("month")).filter( + (col("month") == "JAN") | (col("month") == "APR") + ) + + Utils.check_answer( + TestData.monthly_sales(session) + .pivot( + "month", + subquery_df.select("month") + .distinct() + .sort("month", ascending=is_ascending), + ) + .agg(sum(col("amount"))) + .sort(col("empid")), + [ + Row(1, 18000, 10400) if is_ascending else Row(1, 10400, 18000), + Row(2, 5300, 39500) if is_ascending else Row(2, 39500, 5300), + ], + sort=False, + ) + + +@pytest.mark.skip( + "SNOW-848987: Requires server changes in 7.22 so can unskip once sfctest0 is on >= 7.22" +) +def test_pivot_dynamic_subquery_with_bad_subquery(session): + src = TestData.monthly_sales(session) + subquery_df = src.select(col("month")).filter( + (col("month") == "JAN") | (col("month") == "APR") + ) + + with pytest.raises(SnowparkSQLException) as ex_info: + TestData.monthly_sales(session).pivot( + "month", subquery_df.select("month").sort("month") + ).agg(sum(col("amount"))).collect() + + assert "Invalid subquery pivot order by must be distinct query" in str(ex_info) + + with pytest.raises(SnowparkSQLException) as ex_info: + TestData.monthly_sales(session).pivot( + "month", subquery_df.select(["month", "empid"]) + ).agg(sum(col("amount"))).collect() + + assert "Pivot subquery must select single column" in str(ex_info.value) + + +def test_pivot_default_on_none(session): + class MonthlySales(NamedTuple): + empid: int + amount: int + month: str + + src = session.create_dataframe( + [ + MonthlySales(1, 10000, "JAN"), + MonthlySales(1, 400, "JAN"), + MonthlySales(1, None, "FEB"), + MonthlySales(1, 6000, "MAR"), + MonthlySales(2, 9000, "MAR"), + MonthlySales(2, None, "MAR"), + ] + ) + + for default_on_null in [Decimal(1.5), lit(9999), 9999, 0, None]: + default_value = ( + default_on_null._expression.value + if isinstance(default_on_null, Column) + else default_on_null + ) + Utils.check_answer( + src.pivot("month", ["JAN", "FEB", "MAR"], default_on_null=default_on_null) + .agg(sum(col("amount"))) + .sort(col("empid")), + [ + Row(1, 10400, default_value, 6000), + Row(2, default_value, default_value, 9000), + ], + sort=False, + ) + + @pytest.mark.localtest def test_rel_grouped_dataframe_agg(session): df = ( diff --git a/tests/unit/scala/test_df_suite.py b/tests/unit/scala/test_df_suite.py index 2707465bfec..fb1175bff38 100644 --- a/tests/unit/scala/test_df_suite.py +++ b/tests/unit/scala/test_df_suite.py @@ -15,4 +15,4 @@ def test_to_string_of_relational_grouped_dataframe(): assert _GroupByType().to_string() == "GroupBy" assert _CubeType().to_string() == "Cube" assert _RollupType().to_string() == "Rollup" - assert _PivotType(None, []).to_string() == "Pivot" + assert _PivotType(None, [], None).to_string() == "Pivot"