Skip to content

Commit

Permalink
SNOW-1489361: Add support for GroupBy.all/any (#1804)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-joshi authored Jun 20, 2024
1 parent 356bee1 commit 79b05c4
Show file tree
Hide file tree
Showing 10 changed files with 267 additions and 29 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@
- Added support for `Series.dt.dayofweek`, `Series.dt.day_of_week`, `Series.dt.dayofyear`, and `Series.dt.day_of_year`.
- Added support for `Series.str.__getitem__` (`Series.str[...]`).
- Added support for `Series.str.lstrip` and `Series.str.rstrip`.
- Added support for `DataFrameGroupby.size` and `SeriesGroupby.size`.
- Added support for `DataFrameGroupBy.size` and `SeriesGroupBy.size`.
- Added support for `DataFrame.expanding` and `Series.expanding` for aggregations `count`, `sum`, `min`, `max`, `mean`, `std`, and `var` with `axis=0`.
- Added support for `DataFrame.rolling` and `Series.rolling` for aggregation `count` with `axis=0`.
- Added support for `Series.str.match`.
- Added support for `DataFrame.resample` and `Series.resample` for aggregation `size`.
- Added support for `DataFrameGroupBy.all`, `SeriesGroupBy.all`, `DataFrameGroupBy.any`, and `SeriesGroupBy.any`.

#### Bug Fixes

Expand Down
4 changes: 4 additions & 0 deletions docs/source/modin/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ GroupBy
.. autosummary::
:toctree: pandas_api/

DataFrameGroupBy.all
DataFrameGroupBy.any
DataFrameGroupBy.count
DataFrameGroupBy.cumcount
DataFrameGroupBy.cummax
Expand Down Expand Up @@ -61,6 +63,8 @@ GroupBy
.. autosummary::
:toctree: pandas_api/

SeriesGroupBy.all
SeriesGroupBy.any
SeriesGroupBy.count
SeriesGroupBy.cumcount
SeriesGroupBy.cummax
Expand Down
4 changes: 2 additions & 2 deletions docs/source/modin/supported/groupby_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ Computations/descriptive stats
+-----------------------------+---------------------------------+----------------------------------------------------+
| GroupBy method | Snowpark implemented? (Y/N/P/D) | Notes for current implementation |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``all`` | N | |
| ``all`` | P | ``N`` for non-integer/boolean types |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``any`` | N | |
| ``any`` | P | ``N`` for non-integer/boolean types |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``bfill`` | N | |
+-----------------------------+---------------------------------+----------------------------------------------------+
Expand Down
16 changes: 12 additions & 4 deletions src/snowflake/snowpark/modin/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,13 @@ def mean(
agg_kwargs=dict(numeric_only=numeric_only),
)

def any(self, skipna=True):
def any(self, skipna: bool = True):
# TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions
ErrorMessage.method_not_implemented_error(name="any", class_="GroupBy")
return self._wrap_aggregation(
type(self._query_compiler).groupby_any,
numeric_only=False,
agg_kwargs=dict(skipna=skipna),
)

@property
def plot(self): # pragma: no cover
Expand Down Expand Up @@ -731,9 +735,13 @@ def __len__(self):
# TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions
ErrorMessage.method_not_implemented_error(name="__len__", class_="GroupBy")

def all(self, skipna=True):
def all(self, skipna: bool = True):
# TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions
ErrorMessage.method_not_implemented_error(name="all", class_="GroupBy")
return self._wrap_aggregation(
type(self._query_compiler).groupby_all,
numeric_only=False,
agg_kwargs=dict(skipna=skipna),
)

def size(self):
# TODO: SNOW-1063349: Modin upgrade - modin.pandas.groupby.DataFrameGroupBy functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@ def _columns_coalescing_idxmax_idxmin_helper(
"skew": skew,
"std": stddev,
"var": variance,
"booland_agg": builtin("booland_agg"),
"boolor_agg": builtin("boolor_agg"),
"all": builtin("booland_agg"),
"any": builtin("boolor_agg"),
np.max: max_,
np.min: min_,
np.sum: sum_,
Expand Down Expand Up @@ -695,8 +695,8 @@ def generate_aggregation_column(
snowflake_agg_func(snowpark_column), pandas_lit(0)
)
elif snowflake_agg_func in (
SNOWFLAKE_BUILTIN_AGG_FUNC_MAP["booland_agg"],
SNOWFLAKE_BUILTIN_AGG_FUNC_MAP["boolor_agg"],
SNOWFLAKE_BUILTIN_AGG_FUNC_MAP["all"],
SNOWFLAKE_BUILTIN_AGG_FUNC_MAP["any"],
):
# Need to wrap column name in IDENTIFIER, or else bool agg function will treat the name as a string literal
agg_snowpark_column = snowflake_agg_func(builtin("identifier")(snowpark_column))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2007,7 +2007,7 @@ def binary_op(
def _bool_reduce_helper(
self,
empty_value: bool,
reduce_op: Literal["and", "or"],
agg_func: Literal["all", "any"],
axis: int,
_bool_only: Optional[bool],
skipna: Optional[bool],
Expand All @@ -2017,33 +2017,32 @@ def _bool_reduce_helper(

empty_value: bool
The value returned for an empty dataframe.
reduce_op: {"and", "or"}
The name of the boolean operation to apply.
agg_func: {"all", "any"}
The name of the aggregation to apply.
_bool_only: Optional[bool]
Unused, accepted for compatibility with modin frontend. If true, only boolean columns are included
in the result; this filtering is already performed on the frontend.
skipna: Optional[bool]
Exclude NA/null values. If the entire row/column is NA and skipna is True, then the result will be False,
as for an empty row/column. If skipna is False, then NA are treated as True, because these are not equal to zero.
"""
assert reduce_op in ("and", "or")
assert agg_func in ("all", "any")

frame = self._modin_frame
empty_columns = len(frame.data_columns_index) == 0
if not empty_columns and not all(
is_bool_dtype(t) or is_integer_dtype(t) for t in self.dtypes
):
api_name = "all" if reduce_op == "and" else "any"
# Raise error if columns are non-integer/boolean
ErrorMessage.not_implemented(
f"Snowpark pandas {api_name} API doesn't yet support non-integer/boolean columns"
f"Snowpark pandas {agg_func} API doesn't yet support non-integer/boolean columns"
)

if axis == 1:
# append a new column representing the reduction of all the columns
reduce_expr = pandas_lit(empty_value)
for col_name in frame.data_column_snowflake_quoted_identifiers:
if reduce_op == "and":
if agg_func == "all":
reduce_expr = col(col_name).cast(BooleanType()) & reduce_expr
else:
reduce_expr = col(col_name).cast(BooleanType()) | reduce_expr
Expand Down Expand Up @@ -2076,7 +2075,6 @@ def _bool_reduce_helper(
}
).frame
)
agg_func = "booland_agg" if reduce_op == "and" else "boolor_agg"
# The resulting DF is transposed so will have string 'NULL' as a column name,
# so we need to manually remove it
return self.agg(
Expand All @@ -2093,7 +2091,7 @@ def all(
skipna: Optional[bool],
) -> "SnowflakeQueryCompiler":
return self._bool_reduce_helper(
True, "and", axis=axis, _bool_only=bool_only, skipna=skipna
True, "all", axis=axis, _bool_only=bool_only, skipna=skipna
)

def any(
Expand All @@ -2103,7 +2101,7 @@ def any(
skipna: Optional[bool],
) -> "SnowflakeQueryCompiler":
return self._bool_reduce_helper(
False, "or", axis=axis, _bool_only=bool_only, skipna=skipna
False, "any", axis=axis, _bool_only=bool_only, skipna=skipna
)

def _parse_names_arguments_from_reset_index(
Expand Down Expand Up @@ -4062,6 +4060,52 @@ def groupby_nunique(
drop=drop,
)

def groupby_any(
self,
by: Any,
axis: int,
groupby_kwargs: dict[str, Any],
agg_args: Any,
agg_kwargs: dict[str, Any],
drop: bool = False,
**kwargs: Any,
) -> "SnowflakeQueryCompiler":
# We have to override the Modin version of this function because our groupby frontend passes the
# ignored numeric_only argument to this query compiler method, and BaseQueryCompiler
# does not have **kwargs.
return self.groupby_agg(
by=by,
agg_func="any",
axis=axis,
groupby_kwargs=groupby_kwargs,
agg_args=agg_args,
agg_kwargs=agg_kwargs,
drop=drop,
)

def groupby_all(
self,
by: Any,
axis: int,
groupby_kwargs: dict[str, Any],
agg_args: Any,
agg_kwargs: dict[str, Any],
drop: bool = False,
**kwargs: Any,
) -> "SnowflakeQueryCompiler":
# We have to override the Modin version of this function because our groupby frontend passes the
# ignored numeric_only argument to this query compiler method, and BaseQueryCompiler
# does not have **kwargs.
return self.groupby_agg(
by=by,
agg_func="all",
axis=axis,
groupby_kwargs=groupby_kwargs,
agg_args=agg_args,
agg_kwargs=agg_kwargs,
drop=drop,
)

def _get_dummies_helper(
self,
column: Hashable,
Expand Down
99 changes: 95 additions & 4 deletions src/snowflake/snowpark/modin/plugin/docstrings/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,6 @@ def mean():
Name: B, dtype: float64
"""

def any():
pass

@property
def plot():
pass
Expand Down Expand Up @@ -1325,7 +1322,101 @@ def __len__():
pass

def all():
pass
"""
Return True if all values in the group are truthful, else False.
Parameters
----------
skipna : bool, default True
Flag to ignore nan values during truth testing.
Returns
-------
Series or DataFrame
DataFrame or Series of boolean values, where a value is True if all elements
are True within its respective group, False otherwise.
Examples
--------
For SeriesGroupBy:
>>> lst = ['a', 'a', 'b']
>>> ser = pd.Series([1, 2, 0], index=lst)
>>> ser # doctest: +NORMALIZE_WHITESPACE
a 1
a 2
b 0
dtype: int64
>>> ser.groupby(level=0).all() # doctest: +NORMALIZE_WHITESPACE
a True
b False
dtype: bool
For DataFrameGroupBy:
>>> data = [[1, 0, 3], [1, 5, 6], [7, 8, 9]]
>>> df = pd.DataFrame(data, columns=["a", "b", "c"],
... index=["ostrich", "penguin", "parrot"])
>>> df # doctest: +NORMALIZE_WHITESPACE
a b c
ostrich 1 0 3
penguin 1 5 6
parrot 7 8 9
>>> df.groupby(by=["a"]).all() # doctest: +NORMALIZE_WHITESPACE
b c
a
1 False True
7 True True
"""

def any():
"""
Return True if any value in the group is truthful, else False.
Parameters
----------
skipna : bool, default True
Flag to ignore nan values during truth testing.
Returns
-------
Series or DataFrame
DataFrame or Series of boolean values, where a value is True if any element
is True within its respective group, False otherwise.
Examples
--------
For SeriesGroupBy:
>>> lst = ['a', 'a', 'b']
>>> ser = pd.Series([1, 2, 0], index=lst)
>>> ser # doctest: +NORMALIZE_WHITESPACE
a 1
a 2
b 0
dtype: int64
>>> ser.groupby(level=0).any() # doctest: +NORMALIZE_WHITESPACE
a True
b False
dtype: bool
For DataFrameGroupBy:
>>> data = [[1, 0, 3], [1, 0, 6], [7, 1, 9]]
>>> df = pd.DataFrame(data, columns=["a", "b", "c"],
... index=["ostrich", "penguin", "parrot"])
>>> df # doctest: +NORMALIZE_WHITESPACE
a b c
ostrich 1 0 3
penguin 1 0 6
parrot 7 1 9
>>> df.groupby(by=["a"]).any() # doctest: +NORMALIZE_WHITESPACE
b c
a
1 False True
7 True True
"""

def size():
"""
Expand Down
1 change: 1 addition & 0 deletions tests/integ/modin/frame/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def native_df_multiindex() -> native_pd.DataFrame:
(lambda df: df.aggregate(x=("A", "max")), 0),
(lambda df: df.aggregate(x=("A", "max"), y=pd.NamedAgg("A", "max")), 1),
(lambda df: df.aggregate(x=("A", "max"), y=("C", "min"), z=("A", "min")), 2),
(lambda df: df.aggregate(x=("B", "all"), y=("B", "any")), 1),
# note following aggregation requires transpose
(lambda df: df.aggregate(max), 0),
(lambda df: df.min(), 0),
Expand Down
Loading

0 comments on commit 79b05c4

Please sign in to comment.