Skip to content

Commit

Permalink
reconcile value_counts in query compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-joshi committed Jul 26, 2024
1 parent 3a7c36e commit 2d23b99
Showing 1 changed file with 21 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import uuid
from collections.abc import Hashable, Iterable, Mapping, Sequence
from datetime import timedelta, tzinfo
from typing import Any, Callable, List, Literal, Optional, Tuple, Union, get_args
from typing import Any, Callable, List, Literal, Optional, Union, get_args

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -4724,6 +4724,8 @@ def groupby_value_counts(
ErrorMessage.not_implemented(
f"Snowpark pandas GroupBy.value_counts {_GROUPBY_UNSUPPORTED_GROUPING_MESSAGE}"
)
if bins is not None:
raise ErrorMessage.not_implemented("bins argument is not yet supported")
if not is_list_like(by):
by = [by]
if len(set(by) & set(subset or [])):
Expand All @@ -4732,34 +4734,36 @@ def groupby_value_counts(
raise ValueError("Keys in subset cannot be in the groupby column keys")
if subset is not None:
if not isinstance(subset, (list, tuple)):
subset = [subset]
subset_list = [subset]
else:
subset_list = subset
else:
# If subset is unspecified, then all columns should be included.
subset = self._modin_frame.data_column_pandas_labels
subset_list = self._modin_frame.data_column_pandas_labels
# The grouping columns are always included in the subset.
# Furthermore, the columns of the output must have the grouping columns first, in the order
# that they were specified.
subset = by + list(filter(lambda label: label not in by, subset))
subset_list = by + list(filter(lambda label: label not in by, subset_list))

if as_index:
# When as_index=True, the result is a Series with a MultiIndex index.
result = self.value_counts(
result = self._value_counts_groupby(
by=subset_list,
# Use sort=False to preserve the original order
sort=False,
subset=subset,
normalize=normalize,
bins=bins,
ascending=False,
dropna=dropna,
normalize_within_groups=by,
)
else:
# When as_index=False, the result is a DataFrame where count/proportion is appended as a new named column.
result = self.value_counts(
result = self._value_counts_groupby(
by=subset_list,
# Use sort=False to preserve the original order
sort=False,
subset=subset,
normalize=normalize,
bins=bins,
ascending=False,
dropna=dropna,
normalize_within_groups=by,
).reset_index()
Expand Down Expand Up @@ -10791,8 +10795,6 @@ def value_counts(
ascending: bool = False,
bins: Optional[int] = None,
dropna: bool = True,
*,
normalize_within_groups: Optional[list[str]] = None,
) -> "SnowflakeQueryCompiler":
"""
Counts the frequency or number of unique values of SnowflakeQueryCompiler.
Expand All @@ -10817,10 +10819,6 @@ def value_counts(
This argument is not supported yet.
dropna : bool, default True
Don't include counts of NaN.
normalize_within_groups : list[str], optional
If set, the normalize parameter will normalize based on the specified groups
rather than the entire dataset. This parameter is exclusive to the Snowpark pandas
query compiler and is only used internally to implement groupby_value_counts.
"""
# TODO: SNOW-924742 Support bins in Series.value_counts
if bins is not None:
Expand All @@ -10837,11 +10835,13 @@ def value_counts(

def _value_counts_groupby(
self,
by: Union[List[Hashable], Tuple[Hashable, ...]],
by: Sequence[Hashable],
normalize: bool,
sort: bool,
ascending: bool,
dropna: bool,
*,
normalize_within_groups: Optional[list[str]] = None,
) -> "SnowflakeQueryCompiler":
"""
Helper method to obtain the frequency or number of unique values
Expand All @@ -10863,6 +10863,10 @@ def _value_counts_groupby(
Sort in ascending order.
dropna : bool
Don't include counts of NaN.
normalize_within_groups : list[str], optional
If set, the normalize parameter will normalize based on the specified groups
rather than the entire dataset. This parameter is exclusive to the Snowpark pandas
query compiler and is only used internally to implement groupby_value_counts.
"""
# validate whether by is valid (e.g., contains duplicates or non-existing labels)
self.validate_groupby(by=by, axis=0, level=None)
Expand Down

0 comments on commit 2d23b99

Please sign in to comment.