diff --git a/CHANGELOG.md b/CHANGELOG.md index 67aa350a871..dbdeecb186f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,7 +77,8 @@ - support for binary arithmetic and comparisons between `Timedelta` values and numeric values. - support for lazy `TimedeltaIndex`. - support for `pd.to_timedelta`. - - support for `GroupBy` aggregations `min`, `max`, `mean`, `idxmax`, `idxmin`, `std`, `sum`, `median`, `count`, `any`, `all`, `size`, `nunique`. + - support for `GroupBy` aggregations `min`, `max`, `mean`, `idxmax`, `idxmin`, `std`, `sum`, `median`, `count`, `any`, `all`, `size`, `nunique`, `head`, `tail`, `aggregate`. + - support for `GroupBy` filtrations `first` and `last`. - support for `TimedeltaIndex` attributes: `days`, `seconds`, `microseconds` and `nanoseconds`. - support for `diff` with timestamp columns on `axis=0` and `axis=1` - Added support for index's arithmetic and comparison operators. diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py index 88bfee2b1a5..01ccad8f430 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py @@ -240,7 +240,9 @@ def _columns_coalescing_idxmax_idxmin_helper( # Map between the pandas input aggregation function (str or numpy function) and -# the corresponding snowflake builtin aggregation function for axis=0. +# the corresponding snowflake builtin aggregation function for axis=0. If any change +# is made to this map, ensure GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE and +# GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES are updated accordingly. SNOWFLAKE_BUILTIN_AGG_FUNC_MAP: dict[Union[str, Callable], Callable] = { "count": count, "mean": mean, @@ -270,6 +272,29 @@ def _columns_coalescing_idxmax_idxmin_helper( "quantile": column_quantile, "nunique": count_distinct, } +GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE = ( + "min", + "max", + "sum", + "mean", + "median", + "std", + np.max, + np.min, + np.sum, + np.mean, + np.median, + np.std, +) +GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES = ( + "any", + "all", + "count", + "idxmax", + "idxmin", + "size", + "nunique", +) class AggFuncWithLabel(NamedTuple): diff --git a/src/snowflake/snowpark/modin/plugin/_internal/groupby_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/groupby_utils.py index 369a9bf4ff5..2c50eb23a85 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/groupby_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/groupby_utils.py @@ -41,25 +41,6 @@ ] NO_GROUPKEY_ERROR = ValueError("No group keys passed!") -GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE = [ - "min", - "max", - "sum", - "mean", - "median", - "std", - "first", - "last", -] -GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES = [ - "any", - "all", - "count", - "idxmax", - "idxmin", - "size", - "nunique", -] def is_groupby_value_label_like(val: Any) -> bool: diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index c2b11264aa8..89019555ad7 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -148,6 +148,8 @@ ) from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( AGG_NAME_COL_LABEL, + GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE, + GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES, AggFuncInfo, AggFuncWithLabel, AggregateColumnOpParameters, @@ -202,8 +204,6 @@ LabelIdentifierPair, ) from snowflake.snowpark.modin.plugin._internal.groupby_utils import ( - GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE, - GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES, check_is_groupby_supported_by_snowflake, extract_groupby_column_pandas_labels, get_frame_with_groupby_columns_as_index, @@ -3551,23 +3551,34 @@ def convert_func_to_agg_func_info( agg_col_ops, new_data_column_index_names = generate_column_agg_info( internal_frame, column_to_agg_func, agg_kwargs, is_series_groupby ) + # Get the column aggregation functions used to check if the function + # preserves Snowpark pandas types. + agg_col_funcs = [] + for _, func in column_to_agg_func.items(): + if is_list_like(func) and not is_named_tuple(func): + for fn in func: + agg_col_funcs.append(fn.func) + else: + agg_col_funcs.append(func.func) # the pandas label and quoted identifier generated for each result column # after aggregation will be used as new pandas label and quoted identifiers. new_data_column_pandas_labels = [] new_data_column_quoted_identifiers = [] new_data_column_snowpark_pandas_types = [] - for col_agg_op in agg_col_ops: + for i in range(len(agg_col_ops)): + col_agg_op = agg_col_ops[i] + col_agg_func = agg_col_funcs[i] new_data_column_pandas_labels.append(col_agg_op.agg_pandas_label) new_data_column_quoted_identifiers.append( col_agg_op.agg_snowflake_quoted_identifier ) - if agg_func in GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE: + if col_agg_func in GROUPBY_AGG_PRESERVES_SNOWPARK_PANDAS_TYPE: new_data_column_snowpark_pandas_types.append( col_agg_op.data_type if isinstance(col_agg_op.data_type, SnowparkPandasType) else None ) - elif agg_func in GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES: + elif col_agg_func in GROUPBY_AGG_WITH_NONE_SNOWPARK_PANDAS_TYPES: # In the case where the aggregation overrides the type of the output data column # (e.g. any always returns boolean data columns), set the output Snowpark pandas type to None new_data_column_snowpark_pandas_types = None # type: ignore @@ -4182,9 +4193,10 @@ def _groupby_first_last( else: result = SnowflakeQueryCompiler( self._modin_frame.update_snowflake_quoted_identifiers_with_expressions( - self._fill_null_values_in_groupby( + quoted_identifier_to_column_map=self._fill_null_values_in_groupby( fillna_method, by_snowflake_quoted_identifiers_list - ) + ), + data_column_snowpark_pandas_types=self._modin_frame.cached_data_column_snowpark_pandas_types, ).frame ) result = result.groupby_agg( @@ -4230,8 +4242,6 @@ def groupby_first( Returns: SnowflakeQueryCompiler: The result of groupby_first() """ - self._raise_not_implemented_error_for_timedelta() - return self._groupby_first_last( "first", by, axis, groupby_kwargs, agg_args, agg_kwargs, drop, **kwargs ) @@ -4265,8 +4275,6 @@ def groupby_last( Returns: SnowflakeQueryCompiler: The result of groupby_last() """ - self._raise_not_implemented_error_for_timedelta() - return self._groupby_first_last( "last", by, axis, groupby_kwargs, agg_args, agg_kwargs, drop, **kwargs ) @@ -16105,8 +16113,6 @@ def _groupby_head_tail( Returns: A SnowflakeQueryCompiler object representing a DataFrame. """ - self._raise_not_implemented_error_for_timedelta() - original_frame = self._modin_frame ordered_dataframe = original_frame.ordered_dataframe @@ -16253,8 +16259,8 @@ def _groupby_head_tail( data_column_snowflake_quoted_identifiers=data_column_snowflake_quoted_identifiers, index_column_pandas_labels=original_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=index_column_snowflake_quoted_identifiers, - data_column_types=None, - index_column_types=None, + data_column_types=original_frame.cached_data_column_snowpark_pandas_types, + index_column_types=original_frame.cached_index_column_snowpark_pandas_types, ) return SnowflakeQueryCompiler(new_modin_frame) diff --git a/tests/integ/modin/groupby/test_all_any.py b/tests/integ/modin/groupby/test_all_any.py index e423f28cf0d..d5234dfbdb5 100644 --- a/tests/integ/modin/groupby/test_all_any.py +++ b/tests/integ/modin/groupby/test_all_any.py @@ -44,7 +44,7 @@ def test_all_any_basic(data): def test_timedelta(agg_func, by): native_df = native_pd.DataFrame( { - "A": native_pd.to_timedelta(["1 days 06:05:01.00003", "15.5us", "10"]), + "A": native_pd.to_timedelta(["1 days 06:05:01.00003", "15.5us", "15.5us"]), "B": [10, 8, 12], } ) diff --git a/tests/integ/modin/groupby/test_groupby_basic_agg.py b/tests/integ/modin/groupby/test_groupby_basic_agg.py index 68cd73ad9ed..197c2e2db26 100644 --- a/tests/integ/modin/groupby/test_groupby_basic_agg.py +++ b/tests/integ/modin/groupby/test_groupby_basic_agg.py @@ -1112,7 +1112,7 @@ def test_timedelta(agg_func, by): native_df = native_pd.DataFrame( { "A": native_pd.to_timedelta( - ["1 days 06:05:01.00003", "15.5us", "nan", "16us"] + ["1 days 06:05:01.00003", "16us", "nan", "16us"] ), "B": [8, 8, 12, 10], } @@ -1122,3 +1122,28 @@ def test_timedelta(agg_func, by): eval_snowpark_pandas_result( snow_df, native_df, lambda df: getattr(df.groupby(by), agg_func)() ) + + +def test_timedelta_groupby_agg(): + native_df = native_pd.DataFrame( + { + "A": native_pd.to_timedelta( + ["1 days 06:05:01.00003", "16us", "nan", "16us"] + ), + "B": [8, 8, 12, 10], + "C": [True, False, False, True], + } + ) + snow_df = pd.DataFrame(native_df) + with SqlCounter(query_count=1): + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.groupby("A").agg({"B": ["sum", "median"], "C": "min"}), + ) + with SqlCounter(query_count=1): + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.groupby("B").agg({"A": ["sum", "median"], "C": "min"}), + ) diff --git a/tests/integ/modin/groupby/test_groupby_first_last.py b/tests/integ/modin/groupby/test_groupby_first_last.py index 2c484eab900..5da35806dd1 100644 --- a/tests/integ/modin/groupby/test_groupby_first_last.py +++ b/tests/integ/modin/groupby/test_groupby_first_last.py @@ -3,6 +3,7 @@ # import modin.pandas as pd import numpy as np +import pandas as native_pd import pytest import snowflake.snowpark.modin.plugin # noqa: F401 @@ -102,3 +103,22 @@ def test_error_checking(): with pytest.raises(NotImplementedError): s.groupby(s).last() + + +@pytest.mark.parametrize("agg_func", ["first", "last"]) +@pytest.mark.parametrize("by", ["A", "B"]) +@sql_count_checker(query_count=1) +def test_timedelta(agg_func, by): + native_df = native_pd.DataFrame( + { + "A": native_pd.to_timedelta( + ["1 days 06:05:01.00003", "16us", "nan", "16us"] + ), + "B": [8, 8, 12, 10], + } + ) + snow_df = pd.DataFrame(native_df) + + eval_snowpark_pandas_result( + snow_df, native_df, lambda df: getattr(df.groupby(by), agg_func)() + ) diff --git a/tests/integ/modin/groupby/test_groupby_head_tail.py b/tests/integ/modin/groupby/test_groupby_head_tail.py index c84b88f03cd..90819ec2d68 100644 --- a/tests/integ/modin/groupby/test_groupby_head_tail.py +++ b/tests/integ/modin/groupby/test_groupby_head_tail.py @@ -180,3 +180,22 @@ def test_df_groupby_last_chained_pivot_table_SNOW_1628228(): .groupby("A") .last(), ) + + +@pytest.mark.parametrize("agg_func", ["head", "tail"]) +@pytest.mark.parametrize("by", ["A", "B"]) +@sql_count_checker(query_count=1) +def test_timedelta(agg_func, by): + native_df = native_pd.DataFrame( + { + "A": native_pd.to_timedelta( + ["1 days 06:05:01.00003", "16us", "nan", "16us"] + ), + "B": [8, 8, 12, 10], + } + ) + snow_df = pd.DataFrame(native_df) + + eval_snowpark_pandas_result( + snow_df, native_df, lambda df: getattr(df.groupby(by), agg_func)() + ) diff --git a/tests/integ/modin/groupby/test_groupby_idxmax_idxmin.py b/tests/integ/modin/groupby/test_groupby_idxmax_idxmin.py index bc62278d581..ec1e36d1e38 100644 --- a/tests/integ/modin/groupby/test_groupby_idxmax_idxmin.py +++ b/tests/integ/modin/groupby/test_groupby_idxmax_idxmin.py @@ -167,7 +167,7 @@ def test_timedelta(agg_func, by): native_df = native_pd.DataFrame( { "A": native_pd.to_timedelta( - ["1 days 06:05:01.00003", "15.5us", "nan", "16us"] + ["1 days 06:05:01.00003", "16us", "nan", "16us"] ), "B": [8, 8, 12, 10], } diff --git a/tests/integ/modin/groupby/test_groupby_nunique.py b/tests/integ/modin/groupby/test_groupby_nunique.py index e536696ed0d..345bfe9777a 100644 --- a/tests/integ/modin/groupby/test_groupby_nunique.py +++ b/tests/integ/modin/groupby/test_groupby_nunique.py @@ -88,7 +88,7 @@ def test_timedelta(by): native_df = native_pd.DataFrame( { "A": native_pd.to_timedelta( - ["1 days 06:05:01.00003", "15.5us", "nan", "16us"] + ["1 days 06:05:01.00003", "16us", "nan", "16us"] ), "B": [8, 8, 12, 10], "C": ["the", "name", "is", "bond"], diff --git a/tests/integ/modin/groupby/test_groupby_size.py b/tests/integ/modin/groupby/test_groupby_size.py index 8d2b9226d04..649a3977d86 100644 --- a/tests/integ/modin/groupby/test_groupby_size.py +++ b/tests/integ/modin/groupby/test_groupby_size.py @@ -98,7 +98,7 @@ def test_timedelta(by): native_df = native_pd.DataFrame( { "A": native_pd.to_timedelta( - ["1 days 06:05:01.00003", "15.5us", "nan", "16us"] + ["1 days 06:05:01.00003", "16us", "nan", "16us"] ), "B": [8, 8, 12, 10], "C": ["the", "name", "is", "bond"], diff --git a/tests/integ/modin/groupby/test_min_max.py b/tests/integ/modin/groupby/test_min_max.py index ce116c55b3a..021c83e25b6 100644 --- a/tests/integ/modin/groupby/test_min_max.py +++ b/tests/integ/modin/groupby/test_min_max.py @@ -184,7 +184,7 @@ def test_timedelta(agg_func, by): native_df = native_pd.DataFrame( { "A": native_pd.to_timedelta( - ["1 days 06:05:01.00003", "15.5us", "nan", "16us"] + ["1 days 06:05:01.00003", "16us", "nan", "16us"] ), "B": [8, 8, 12, 10], "C": ["the", "name", "is", "bond"], diff --git a/tests/integ/modin/io/test_to_csv.py b/tests/integ/modin/io/test_to_csv.py index ddd148fe583..abdda12f173 100644 --- a/tests/integ/modin/io/test_to_csv.py +++ b/tests/integ/modin/io/test_to_csv.py @@ -271,3 +271,23 @@ def test_timedelta_to_csv_series_local(): pd.Series(native_series).to_csv(snow_path) assert_file_equal(snow_path, native_path, is_compressed=False) + + +@sql_count_checker(query_count=1) +def test_timedeltaindex_to_csv_dataframe_local(): + native_df = native_pd.DataFrame( + { + "A": native_pd.to_timedelta(["1 days 06:05:01.00003", "15.5us", "nan"]), + "B": [10, 8, 12], + "C": ["bond", "james", "bond"], + } + ) + native_df = native_df.groupby("A").min() + native_path, snow_path = get_filepaths(kwargs={}, test_name="series_local") + + # Write csv with native pandas. + native_df.to_csv(native_path) + # Write csv with snowpark pandas. + pd.DataFrame(native_df).to_csv(snow_path) + + assert_file_equal(snow_path, native_path, is_compressed=False) diff --git a/tests/integ/modin/types/test_timedelta.py b/tests/integ/modin/types/test_timedelta.py index 2e630e09d3d..5575f304efe 100644 --- a/tests/integ/modin/types/test_timedelta.py +++ b/tests/integ/modin/types/test_timedelta.py @@ -105,6 +105,6 @@ def test_timedelta_not_supported(): ) with pytest.raises( NotImplementedError, - match="SnowflakeQueryCompiler::groupby_first is not yet implemented for Timedelta Type", + match="SnowflakeQueryCompiler::groupby_groups is not yet implemented for Timedelta Type", ): - df.groupby("a").first() + df.groupby("a").groups()