Skip to content

Commit

Permalink
fix subset checks
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-joshi committed Aug 6, 2024
1 parent 13b3f52 commit fce9765
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4707,14 +4707,20 @@ def groupby_value_counts(
)
if not is_list_like(by):
by = [by]
if len(set(by) & set(subset or [])):
# Check for overlap between by and subset. Since column names may contain customer data,
# unlike pandas, we do not include the offending labels in the error message.
raise ValueError("Keys in subset cannot be in the groupby column keys")
if subset is not None:
if not isinstance(subset, (list, tuple)):
subset = [subset]
# All "by" columns must be included in the subset list passed to value_counts
subset = [by_label for by_label in by if by_label not in subset] + subset
else:
# If subset is unspecified, then all columns are part of it
# If subset is unspecified, then all columns should be included.
subset = self._modin_frame.data_column_pandas_labels
# The grouping columns are always included in the subset.
# Furthermore, the columns of the output must have the grouping columns first, in the order
# that they were specified.
subset = by + list(filter(lambda label: label not in by, subset))

if as_index:
# When as_index=True, the result is a Series with a MultiIndex index.
Expand Down
102 changes: 64 additions & 38 deletions tests/integ/modin/groupby/test_value_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

import snowflake.snowpark.modin.plugin # noqa: F401
from tests.integ.modin.sql_counter import sql_count_checker
from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker
from tests.integ.modin.utils import (
assert_snowpark_pandas_equal_to_pandas,
create_test_dfs,
Expand Down Expand Up @@ -35,15 +35,14 @@


@pytest.mark.parametrize("test_data", TEST_DATA)
@pytest.mark.parametrize("by", ["by"]) # , ["by", "value1"], ["by", "value2"]])
@pytest.mark.parametrize("by", ["by", ["value1", "by"], ["by", "value2"]])
@pytest.mark.parametrize("groupby_sort", [True, False])
@pytest.mark.parametrize(
"subset",
[None, ["value1"], ["value2"], ["value1", "value2"]],
)
@pytest.mark.parametrize("normalize", [True, False])
@pytest.mark.parametrize("dropna", [True, False])
@sql_count_checker(query_count=1)
def test_value_counts_basic(test_data, by, groupby_sort, subset, normalize, dropna):
# In this test, we use check_like=True because Snowpark pandas will always preserve the original
# order of rows within groups, while native pandas provides this guarantee only for the grouping
Expand All @@ -58,42 +57,62 @@ def test_value_counts_basic(test_data, by, groupby_sort, subset, normalize, drop
# https://github.com/snowflakedb/snowpark-python/pull/1909
# https://github.com/pandas-dev/pandas/issues/15833
by_list = by if isinstance(by, list) else [by]
none_in_by_col = any(None in test_data[col] for col in by_list)
if not dropna and none_in_by_col:
# when dropna is False, pandas gives a different result because it drops all NaN
# keys in the multiindex
# https://github.com/pandas-dev/pandas/issues/56366
# as a workaround, replace all Nones in the pandas frame with a sentinel value
# since NaNs are sorted last, we want the sentinel to sort to the end as well
VALUE_COUNTS_TEST_SENTINEL = "zzzzzz"
snow_df, native_df = create_test_dfs(test_data)
snow_result = snow_df.groupby(by=by, sort=groupby_sort).value_counts(
subset=subset,
normalize=normalize,
dropna=dropna,
)
native_df = native_df.fillna(value=VALUE_COUNTS_TEST_SENTINEL)
native_result = native_df.groupby(by=by, sort=groupby_sort).value_counts(
subset=subset,
normalize=normalize,
dropna=dropna,
)
native_result.index = native_result.index.map(
lambda x: tuple(None if i == VALUE_COUNTS_TEST_SENTINEL else i for i in x)
)
assert_snowpark_pandas_equal_to_pandas(
snow_result, native_result, check_like=True
)
else:
eval_snowpark_pandas_result(
*create_test_dfs(test_data),
lambda df: df.groupby(by=by, sort=groupby_sort).value_counts(
if len(set(by_list) & set(subset or [])):
# If subset and by overlap, check for ValueError
# Unlike pandas, we do not surface label names in the error message
with SqlCounter(query_count=0):
eval_snowpark_pandas_result(
*create_test_dfs(test_data),
lambda df: df.groupby(by=by, sort=groupby_sort).value_counts(
subset=subset,
normalize=normalize,
dropna=dropna,
),
expect_exception=True,
expect_exception_type=ValueError,
expect_exception_match="in subset cannot be in the groupby column keys",
assert_exception_equal=False,
)
return
with SqlCounter(query_count=1):
none_in_by_col = any(None in test_data[col] for col in by_list)
if not dropna and none_in_by_col:
# when dropna is False, pandas gives a different result because it drops all NaN
# keys in the multiindex
# https://github.com/pandas-dev/pandas/issues/56366
# as a workaround, replace all Nones in the pandas frame with a sentinel value
# since NaNs are sorted last, we want the sentinel to sort to the end as well
VALUE_COUNTS_TEST_SENTINEL = "zzzzzz"
snow_df, native_df = create_test_dfs(test_data)
snow_result = snow_df.groupby(by=by, sort=groupby_sort).value_counts(
subset=subset,
normalize=normalize,
dropna=dropna,
),
check_like=True,
)
)
native_df = native_df.fillna(value=VALUE_COUNTS_TEST_SENTINEL)
native_result = native_df.groupby(by=by, sort=groupby_sort).value_counts(
subset=subset,
normalize=normalize,
dropna=dropna,
)
native_result.index = native_result.index.map(
lambda x: tuple(
None if i == VALUE_COUNTS_TEST_SENTINEL else i for i in x
)
)
assert_snowpark_pandas_equal_to_pandas(
snow_result, native_result, check_like=True
)
else:
eval_snowpark_pandas_result(
*create_test_dfs(test_data),
lambda df: df.groupby(by=by, sort=groupby_sort).value_counts(
subset=subset,
normalize=normalize,
dropna=dropna,
),
check_like=True,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -391,10 +410,17 @@ def test_value_counts_pandas_issue_59307(


@pytest.mark.parametrize("test_data", TEST_DATA)
@pytest.mark.parametrize("subset", [["bad_key"], ["by", "bad_key"]])
@pytest.mark.parametrize(
"subset, exception_cls",
[
(["bad_key"], KeyError), # key not in frame
(["by", "bad_key"], KeyError), # key not in frame
(["by"], ValueError), # subset cannot overlap with grouping columns
],
)
# 1 query always runs to validate the length of the by list
@sql_count_checker(query_count=1)
def test_value_counts_bad_subset(test_data, subset):
def test_value_counts_bad_subset(test_data, subset, exception_cls):
eval_snowpark_pandas_result(
*create_test_dfs(test_data),
lambda x: x.groupby(by=["by"]).value_counts(subset=subset),
Expand Down

0 comments on commit fce9765

Please sign in to comment.