diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 66ed9bf232a..7e03986036c 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -5040,10 +5040,7 @@ def groupby_value_counts( # unlike pandas, we do not include the offending labels in the error message. raise ValueError("Keys in subset cannot be in the groupby column keys") if subset is not None: - if not isinstance(subset, (list, tuple)): - subset_list = [subset] - else: - subset_list = subset + subset_list = subset else: # If subset is unspecified, then all columns should be included. subset_list = self._modin_frame.data_column_pandas_labels @@ -5150,10 +5147,14 @@ def groupby_value_counts( # the count/proportion column. The left-most column (nearest to the grouping columns # is sorted on last). # Exclude the grouping columns (always the first) from the sort. + if as_index: + # When as_index is true, the non-grouping columns are part of the index columns + columns_to_filter = result._modin_frame.index_column_pandas_labels + else: + # When as_index is false, the non-grouping columns are part of the data columns + columns_to_filter = result._modin_frame.data_column_pandas_labels non_grouping_cols = [ - col_label - for col_label in result._modin_frame.index_column_pandas_labels - if col_label not in by + col_label for col_label in columns_to_filter if col_label not in by ] sort_cols.extend(non_grouping_cols) ascending_cols.extend([True] * len(non_grouping_cols)) diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/groupby.py b/src/snowflake/snowpark/modin/plugin/docstrings/groupby.py index 70d8d568493..46fab52f85a 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/groupby.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/groupby.py @@ -257,8 +257,8 @@ def value_counts(): >>> df.groupby('gender').value_counts() # doctest: +NORMALIZE_WHITESPACE gender education country - female high US 1 - FR 1 + female high FR 1 + US 1 male low FR 2 US 1 medium FR 1 @@ -266,8 +266,8 @@ def value_counts(): >>> df.groupby('gender').value_counts(ascending=True) # doctest: +NORMALIZE_WHITESPACE gender education country - female high US 1 - FR 1 + female high FR 1 + US 1 male low US 1 medium FR 1 low FR 2 @@ -275,8 +275,8 @@ def value_counts(): >>> df.groupby('gender').value_counts(normalize=True) # doctest: +NORMALIZE_WHITESPACE gender education country - female high US 0.50 - FR 0.50 + female high FR 0.50 + US 0.50 male low FR 0.50 US 0.25 medium FR 0.25 diff --git a/tests/integ/modin/groupby/test_value_counts.py b/tests/integ/modin/groupby/test_value_counts.py index ae4a14ba88b..1f1b2f5c052 100644 --- a/tests/integ/modin/groupby/test_value_counts.py +++ b/tests/integ/modin/groupby/test_value_counts.py @@ -131,6 +131,21 @@ def test_value_counts_normalize( ) +@pytest.mark.parametrize("test_data", TEST_DATA) +@pytest.mark.parametrize("by", ["by", ["value1", "by"], ["by", "value2"]]) +@pytest.mark.parametrize("groupby_sort", [True, False]) +@pytest.mark.parametrize("sort", [True, False]) +@pytest.mark.parametrize("as_index", [True, False]) +@sql_count_checker(query_count=1) +def test_value_counts_as_index(test_data, by, groupby_sort, sort, as_index): + eval_snowpark_pandas_result( + *create_test_dfs(test_data), + lambda df: df.groupby(by=by, sort=groupby_sort, as_index=as_index).value_counts( + sort=sort + ), + ) + + @pytest.mark.parametrize( "subset, exception_cls", [