From 75141219af7b04d74b74edd74e83d1eaffcd7ef1 Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Fri, 26 Jul 2024 10:27:48 -0700 Subject: [PATCH] reconcile value_counts in query compiler --- .../compiler/snowflake_query_compiler.py | 38 ++++++++++--------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index de2a80d5f24..4a088493ea8 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -11,7 +11,7 @@ import uuid from collections.abc import Hashable, Iterable, Mapping, Sequence from datetime import timedelta, tzinfo -from typing import Any, Callable, List, Literal, Optional, Tuple, Union, get_args +from typing import Any, Callable, List, Literal, Optional, Union, get_args import numpy as np import numpy.typing as npt @@ -4705,6 +4705,8 @@ def groupby_value_counts( ErrorMessage.not_implemented( f"Snowpark pandas GroupBy.value_counts {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}" ) + if bins is not None: + raise ErrorMessage.not_implemented("bins argument is not yet supported") if not is_list_like(by): by = [by] if len(set(by) & set(subset or [])): @@ -4713,34 +4715,36 @@ def groupby_value_counts( raise ValueError("Keys in subset cannot be in the groupby column keys") if subset is not None: if not isinstance(subset, (list, tuple)): - subset = [subset] + subset_list = [subset] + else: + subset_list = subset else: # If subset is unspecified, then all columns should be included. - subset = self._modin_frame.data_column_pandas_labels + subset_list = self._modin_frame.data_column_pandas_labels # The grouping columns are always included in the subset. # Furthermore, the columns of the output must have the grouping columns first, in the order # that they were specified. - subset = by + list(filter(lambda label: label not in by, subset)) + subset_list = by + list(filter(lambda label: label not in by, subset_list)) if as_index: # When as_index=True, the result is a Series with a MultiIndex index. - result = self.value_counts( + result = self._value_counts_groupby( + by=subset_list, # Use sort=False to preserve the original order sort=False, - subset=subset, normalize=normalize, - bins=bins, + ascending=False, dropna=dropna, normalize_within_groups=by, ) else: # When as_index=False, the result is a DataFrame where count/proportion is appended as a new named column. - result = self.value_counts( + result = self._value_counts_groupby( + by=subset_list, # Use sort=False to preserve the original order sort=False, - subset=subset, normalize=normalize, - bins=bins, + ascending=False, dropna=dropna, normalize_within_groups=by, ).reset_index() @@ -10924,8 +10928,6 @@ def value_counts( ascending: bool = False, bins: Optional[int] = None, dropna: bool = True, - *, - normalize_within_groups: Optional[list[str]] = None, ) -> "SnowflakeQueryCompiler": """ Counts the frequency or number of unique values of SnowflakeQueryCompiler. @@ -10950,10 +10952,6 @@ def value_counts( This argument is not supported yet. dropna : bool, default True Don't include counts of NaN. - normalize_within_groups : list[str], optional - If set, the normalize parameter will normalize based on the specified groups - rather than the entire dataset. This parameter is exclusive to the Snowpark pandas - query compiler and is only used internally to implement groupby_value_counts. """ # TODO: SNOW-924742 Support bins in Series.value_counts if bins is not None: @@ -10970,11 +10968,13 @@ def value_counts( def _value_counts_groupby( self, - by: Union[List[Hashable], Tuple[Hashable, ...]], + by: Sequence[Hashable], normalize: bool, sort: bool, ascending: bool, dropna: bool, + *, + normalize_within_groups: Optional[list[str]] = None, ) -> "SnowflakeQueryCompiler": """ Helper method to obtain the frequency or number of unique values @@ -10996,6 +10996,10 @@ def _value_counts_groupby( Sort in ascending order. dropna : bool Don't include counts of NaN. + normalize_within_groups : list[str], optional + If set, the normalize parameter will normalize based on the specified groups + rather than the entire dataset. This parameter is exclusive to the Snowpark pandas + query compiler and is only used internally to implement groupby_value_counts. """ # validate whether by is valid (e.g., contains duplicates or non-existing labels) self.validate_groupby(by=by, axis=0, level=None)