Skip to content

Commit

Permalink
fix as_index and doctests
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-joshi committed Aug 29, 2024
1 parent ebf2be6 commit 8a57df9
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5040,10 +5040,7 @@ def groupby_value_counts(
# unlike pandas, we do not include the offending labels in the error message.
raise ValueError("Keys in subset cannot be in the groupby column keys")
if subset is not None:
if not isinstance(subset, (list, tuple)):
subset_list = [subset]
else:
subset_list = subset
subset_list = subset
else:
# If subset is unspecified, then all columns should be included.
subset_list = self._modin_frame.data_column_pandas_labels
Expand Down Expand Up @@ -5150,10 +5147,14 @@ def groupby_value_counts(
# the count/proportion column. The left-most column (nearest to the grouping columns
# is sorted on last).
# Exclude the grouping columns (always the first) from the sort.
if as_index:
# When as_index is true, the non-grouping columns are part of the index columns
columns_to_filter = result._modin_frame.index_column_pandas_labels
else:
# When as_index is false, the non-grouping columns are part of the data columns
columns_to_filter = result._modin_frame.data_column_pandas_labels
non_grouping_cols = [
col_label
for col_label in result._modin_frame.index_column_pandas_labels
if col_label not in by
col_label for col_label in columns_to_filter if col_label not in by
]
sort_cols.extend(non_grouping_cols)
ascending_cols.extend([True] * len(non_grouping_cols))
Expand Down
12 changes: 6 additions & 6 deletions src/snowflake/snowpark/modin/plugin/docstrings/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,26 +257,26 @@ def value_counts():
>>> df.groupby('gender').value_counts() # doctest: +NORMALIZE_WHITESPACE
gender education country
female high US 1
FR 1
female high FR 1
US 1
male low FR 2
US 1
medium FR 1
Name: count, dtype: int64
>>> df.groupby('gender').value_counts(ascending=True) # doctest: +NORMALIZE_WHITESPACE
gender education country
female high US 1
FR 1
female high FR 1
US 1
male low US 1
medium FR 1
low FR 2
Name: count, dtype: int64
>>> df.groupby('gender').value_counts(normalize=True) # doctest: +NORMALIZE_WHITESPACE
gender education country
female high US 0.50
FR 0.50
female high FR 0.50
US 0.50
male low FR 0.50
US 0.25
medium FR 0.25
Expand Down
15 changes: 15 additions & 0 deletions tests/integ/modin/groupby/test_value_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,21 @@ def test_value_counts_normalize(
)


@pytest.mark.parametrize("test_data", TEST_DATA)
@pytest.mark.parametrize("by", ["by", ["value1", "by"], ["by", "value2"]])
@pytest.mark.parametrize("groupby_sort", [True, False])
@pytest.mark.parametrize("sort", [True, False])
@pytest.mark.parametrize("as_index", [True, False])
@sql_count_checker(query_count=1)
def test_value_counts_as_index(test_data, by, groupby_sort, sort, as_index):
eval_snowpark_pandas_result(
*create_test_dfs(test_data),
lambda df: df.groupby(by=by, sort=groupby_sort, as_index=as_index).value_counts(
sort=sort
),
)


@pytest.mark.parametrize(
"subset, exception_cls",
[
Expand Down

0 comments on commit 8a57df9

Please sign in to comment.