Skip to content

Commit

Permalink
SNOW-1652274, SNOW-1638408: Add support for Timedelta with GroupBy `f…
Browse files Browse the repository at this point in the history
…irst`, `last`, `head`, `tail`, `aggregate` (#2239)

SNOW-1652274, SNOW-1638408

This PR adds support for Timedelta data types for operations GroupBy
`first`, `last`, `head`, `tail`, and `aggregate`. Additionally, a test
for `to_csv` with TimedeltaIndex is added.

---------

Signed-off-by: Naren Krishna <[email protected]>
  • Loading branch information
sfc-gh-nkrishna authored Sep 5, 2024
1 parent 20be5fb commit 5dcf20b
Show file tree
Hide file tree
Showing 14 changed files with 141 additions and 44 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@
- support for binary arithmetic and comparisons between `Timedelta` values and numeric values.
- support for lazy `TimedeltaIndex`.
- support for `pd.to_timedelta`.
- support for `GroupBy` aggregations `min`, `max`, `mean`, `idxmax`, `idxmin`, `std`, `sum`, `median`, `count`, `any`, `all`, `size`, `nunique`.
- support for `GroupBy` aggregations `min`, `max`, `mean`, `idxmax`, `idxmin`, `std`, `sum`, `median`, `count`, `any`, `all`, `size`, `nunique`, `head`, `tail`, `aggregate`.
- support for `GroupBy` filtrations `first` and `last`.
- support for `TimedeltaIndex` attributes: `days`, `seconds`, `microseconds` and `nanoseconds`.
- support for `diff` with timestamp columns on `axis=0` and `axis=1`
- Added support for index's arithmetic and comparison operators.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,9 @@ def _columns_coalescing_idxmax_idxmin_helper(


# Map between the pandas input aggregation function (str or numpy function) and
# the corresponding snowflake builtin aggregation function for axis=0.
# the corresponding snowflake builtin aggregation function for axis=0. If any change
# is made to this map, ensure GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE and
# GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES are updated accordingly.
SNOWFLAKE_BUILTIN_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = {
"count": count,
"mean": mean,
Expand Down Expand Up @@ -270,6 +272,29 @@ def _columns_coalescing_idxmax_idxmin_helper(
"quantile": column_quantile,
"nunique": count_distinct,
}
GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE = (
"min",
"max",
"sum",
"mean",
"median",
"std",
np.max,
np.min,
np.sum,
np.mean,
np.median,
np.std,
)
GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES = (
"any",
"all",
"count",
"idxmax",
"idxmin",
"size",
"nunique",
)


class AggFuncWithLabel(NamedTuple):
Expand Down
19 changes: 0 additions & 19 deletions src/snowflake/snowpark/modin/plugin/_internal/groupby_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,6 @@
]

NO_GROUPKEY_ERROR = ValueError("No group keys passed!")
GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE = [
"min",
"max",
"sum",
"mean",
"median",
"std",
"first",
"last",
]
GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES = [
"any",
"all",
"count",
"idxmax",
"idxmin",
"size",
"nunique",
]


def is_groupby_value_label_like(val: Any) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@
)
from snowflake.snowpark.modin.plugin._internal.aggregation_utils import (
AGG_NAME_COL_LABEL,
GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE,
GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES,
AggFuncInfo,
AggFuncWithLabel,
AggregateColumnOpParameters,
Expand Down Expand Up @@ -202,8 +204,6 @@
LabelIdentifierPair,
)
from snowflake.snowpark.modin.plugin._internal.groupby_utils import (
GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE,
GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES,
check_is_groupby_supported_by_snowflake,
extract_groupby_column_pandas_labels,
get_frame_with_groupby_columns_as_index,
Expand Down Expand Up @@ -3551,23 +3551,34 @@ def convert_func_to_agg_func_info(
agg_col_ops, new_data_column_index_names = generate_column_agg_info(
internal_frame, column_to_agg_func, agg_kwargs, is_series_groupby
)
# Get the column aggregation functions used to check if the function
# preserves Snowpark pandas types.
agg_col_funcs = []
for _, func in column_to_agg_func.items():
if is_list_like(func) and not is_named_tuple(func):
for fn in func:
agg_col_funcs.append(fn.func)
else:
agg_col_funcs.append(func.func)
# the pandas label and quoted identifier generated for each result column
# after aggregation will be used as new pandas label and quoted identifiers.
new_data_column_pandas_labels = []
new_data_column_quoted_identifiers = []
new_data_column_snowpark_pandas_types = []
for col_agg_op in agg_col_ops:
for i in range(len(agg_col_ops)):
col_agg_op = agg_col_ops[i]
col_agg_func = agg_col_funcs[i]
new_data_column_pandas_labels.append(col_agg_op.agg_pandas_label)
new_data_column_quoted_identifiers.append(
col_agg_op.agg_snowflake_quoted_identifier
)
if agg_func in GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE:
if col_agg_func in GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE:
new_data_column_snowpark_pandas_types.append(
col_agg_op.data_type
if isinstance(col_agg_op.data_type, SnowparkPandasType)
else None
)
elif agg_func in GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES:
elif col_agg_func in GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES:
# In the case where the aggregation overrides the type of the output data column
# (e.g. any always returns boolean data columns), set the output Snowpark pandas type to None
new_data_column_snowpark_pandas_types = None # type: ignore
Expand Down Expand Up @@ -4182,9 +4193,10 @@ def _groupby_first_last(
else:
result = SnowflakeQueryCompiler(
self._modin_frame.update_snowflake_quoted_identifiers_with_expressions(
self._fill_null_values_in_groupby(
quoted_identifier_to_column_map=self._fill_null_values_in_groupby(
fillna_method, by_snowflake_quoted_identifiers_list
)
),
data_column_snowpark_pandas_types=self._modin_frame.cached_data_column_snowpark_pandas_types,
).frame
)
result = result.groupby_agg(
Expand Down Expand Up @@ -4230,8 +4242,6 @@ def groupby_first(
Returns:
SnowflakeQueryCompiler: The result of groupby_first()
"""
self._raise_not_implemented_error_for_timedelta()

return self._groupby_first_last(
"first", by, axis, groupby_kwargs, agg_args, agg_kwargs, drop, **kwargs
)
Expand Down Expand Up @@ -4265,8 +4275,6 @@ def groupby_last(
Returns:
SnowflakeQueryCompiler: The result of groupby_last()
"""
self._raise_not_implemented_error_for_timedelta()

return self._groupby_first_last(
"last", by, axis, groupby_kwargs, agg_args, agg_kwargs, drop, **kwargs
)
Expand Down Expand Up @@ -16105,8 +16113,6 @@ def _groupby_head_tail(
Returns:
A SnowflakeQueryCompiler object representing a DataFrame.
"""
self._raise_not_implemented_error_for_timedelta()

original_frame = self._modin_frame
ordered_dataframe = original_frame.ordered_dataframe

Expand Down Expand Up @@ -16253,8 +16259,8 @@ def _groupby_head_tail(
data_column_snowflake_quoted_identifiers=data_column_snowflake_quoted_identifiers,
index_column_pandas_labels=original_frame.index_column_pandas_labels,
index_column_snowflake_quoted_identifiers=index_column_snowflake_quoted_identifiers,
data_column_types=None,
index_column_types=None,
data_column_types=original_frame.cached_data_column_snowpark_pandas_types,
index_column_types=original_frame.cached_index_column_snowpark_pandas_types,
)

return SnowflakeQueryCompiler(new_modin_frame)
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/modin/groupby/test_all_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_all_any_basic(data):
def test_timedelta(agg_func, by):
native_df = native_pd.DataFrame(
{
"A": native_pd.to_timedelta(["1 days 06:05:01.00003", "15.5us", "10"]),
"A": native_pd.to_timedelta(["1 days 06:05:01.00003", "15.5us", "15.5us"]),
"B": [10, 8, 12],
}
)
Expand Down
27 changes: 26 additions & 1 deletion tests/integ/modin/groupby/test_groupby_basic_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ def test_timedelta(agg_func, by):
native_df = native_pd.DataFrame(
{
"A": native_pd.to_timedelta(
["1 days 06:05:01.00003", "15.5us", "nan", "16us"]
["1 days 06:05:01.00003", "16us", "nan", "16us"]
),
"B": [8, 8, 12, 10],
}
Expand All @@ -1122,3 +1122,28 @@ def test_timedelta(agg_func, by):
eval_snowpark_pandas_result(
snow_df, native_df, lambda df: getattr(df.groupby(by), agg_func)()
)


def test_timedelta_groupby_agg():
native_df = native_pd.DataFrame(
{
"A": native_pd.to_timedelta(
["1 days 06:05:01.00003", "16us", "nan", "16us"]
),
"B": [8, 8, 12, 10],
"C": [True, False, False, True],
}
)
snow_df = pd.DataFrame(native_df)
with SqlCounter(query_count=1):
eval_snowpark_pandas_result(
snow_df,
native_df,
lambda df: df.groupby("A").agg({"B": ["sum", "median"], "C": "min"}),
)
with SqlCounter(query_count=1):
eval_snowpark_pandas_result(
snow_df,
native_df,
lambda df: df.groupby("B").agg({"A": ["sum", "median"], "C": "min"}),
)
20 changes: 20 additions & 0 deletions tests/integ/modin/groupby/test_groupby_first_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
import modin.pandas as pd
import numpy as np
import pandas as native_pd
import pytest

import snowflake.snowpark.modin.plugin # noqa: F401
Expand Down Expand Up @@ -102,3 +103,22 @@ def test_error_checking():

with pytest.raises(NotImplementedError):
s.groupby(s).last()


@pytest.mark.parametrize("agg_func", ["first", "last"])
@pytest.mark.parametrize("by", ["A", "B"])
@sql_count_checker(query_count=1)
def test_timedelta(agg_func, by):
native_df = native_pd.DataFrame(
{
"A": native_pd.to_timedelta(
["1 days 06:05:01.00003", "16us", "nan", "16us"]
),
"B": [8, 8, 12, 10],
}
)
snow_df = pd.DataFrame(native_df)

eval_snowpark_pandas_result(
snow_df, native_df, lambda df: getattr(df.groupby(by), agg_func)()
)
19 changes: 19 additions & 0 deletions tests/integ/modin/groupby/test_groupby_head_tail.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,22 @@ def test_df_groupby_last_chained_pivot_table_SNOW_1628228():
.groupby("A")
.last(),
)


@pytest.mark.parametrize("agg_func", ["head", "tail"])
@pytest.mark.parametrize("by", ["A", "B"])
@sql_count_checker(query_count=1)
def test_timedelta(agg_func, by):
native_df = native_pd.DataFrame(
{
"A": native_pd.to_timedelta(
["1 days 06:05:01.00003", "16us", "nan", "16us"]
),
"B": [8, 8, 12, 10],
}
)
snow_df = pd.DataFrame(native_df)

eval_snowpark_pandas_result(
snow_df, native_df, lambda df: getattr(df.groupby(by), agg_func)()
)
2 changes: 1 addition & 1 deletion tests/integ/modin/groupby/test_groupby_idxmax_idxmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def test_timedelta(agg_func, by):
native_df = native_pd.DataFrame(
{
"A": native_pd.to_timedelta(
["1 days 06:05:01.00003", "15.5us", "nan", "16us"]
["1 days 06:05:01.00003", "16us", "nan", "16us"]
),
"B": [8, 8, 12, 10],
}
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/modin/groupby/test_groupby_nunique.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_timedelta(by):
native_df = native_pd.DataFrame(
{
"A": native_pd.to_timedelta(
["1 days 06:05:01.00003", "15.5us", "nan", "16us"]
["1 days 06:05:01.00003", "16us", "nan", "16us"]
),
"B": [8, 8, 12, 10],
"C": ["the", "name", "is", "bond"],
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/modin/groupby/test_groupby_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_timedelta(by):
native_df = native_pd.DataFrame(
{
"A": native_pd.to_timedelta(
["1 days 06:05:01.00003", "15.5us", "nan", "16us"]
["1 days 06:05:01.00003", "16us", "nan", "16us"]
),
"B": [8, 8, 12, 10],
"C": ["the", "name", "is", "bond"],
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/modin/groupby/test_min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_timedelta(agg_func, by):
native_df = native_pd.DataFrame(
{
"A": native_pd.to_timedelta(
["1 days 06:05:01.00003", "15.5us", "nan", "16us"]
["1 days 06:05:01.00003", "16us", "nan", "16us"]
),
"B": [8, 8, 12, 10],
"C": ["the", "name", "is", "bond"],
Expand Down
20 changes: 20 additions & 0 deletions tests/integ/modin/io/test_to_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,23 @@ def test_timedelta_to_csv_series_local():
pd.Series(native_series).to_csv(snow_path)

assert_file_equal(snow_path, native_path, is_compressed=False)


@sql_count_checker(query_count=1)
def test_timedeltaindex_to_csv_dataframe_local():
native_df = native_pd.DataFrame(
{
"A": native_pd.to_timedelta(["1 days 06:05:01.00003", "15.5us", "nan"]),
"B": [10, 8, 12],
"C": ["bond", "james", "bond"],
}
)
native_df = native_df.groupby("A").min()
native_path, snow_path = get_filepaths(kwargs={}, test_name="series_local")

# Write csv with native pandas.
native_df.to_csv(native_path)
# Write csv with snowpark pandas.
pd.DataFrame(native_df).to_csv(snow_path)

assert_file_equal(snow_path, native_path, is_compressed=False)
4 changes: 2 additions & 2 deletions tests/integ/modin/types/test_timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,6 @@ def test_timedelta_not_supported():
)
with pytest.raises(
NotImplementedError,
match="SnowflakeQueryCompiler::groupby_first is not yet implemented for Timedelta Type",
match="SnowflakeQueryCompiler::groupby_groups is not yet implemented for Timedelta Type",
):
df.groupby("a").first()
df.groupby("a").groups()

0 comments on commit 5dcf20b

Please sign in to comment.