diff --git a/src/snowflake/snowpark/modin/plugin/_internal/groupby_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/groupby_utils.py index 09572a16d87..44e0a562812 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/groupby_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/groupby_utils.py @@ -41,6 +41,25 @@ ] NO_GROUPKEY_ERROR = ValueError("No group keys passed!") +GROUPBY_AGG_SAME_INPUT_AND_OUTPUT_DATA_TYPES = [ + "min", + "max", + "sum", + "mean", + "median", + "std", + "first", + "last", +] +GROUPBY_AGG_DIFFERENT_INPUT_AND_OUTPUT_DATA_TYPES = [ + "any", + "all", + "count", + "idxmax", + "idxmin", + "size", + "nunique", +] def is_groupby_value_label_like(val: Any) -> bool: diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 7fbc54cd314..79e6433c2fa 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -198,6 +198,8 @@ ) from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame from snowflake.snowpark.modin.plugin._internal.groupby_utils import ( + GROUPBY_AGG_DIFFERENT_INPUT_AND_OUTPUT_DATA_TYPES, + GROUPBY_AGG_SAME_INPUT_AND_OUTPUT_DATA_TYPES, check_is_groupby_supported_by_snowflake, extract_groupby_column_pandas_labels, get_frame_with_groupby_columns_as_index, @@ -3426,30 +3428,13 @@ def convert_func_to_agg_func_info( new_data_column_quoted_identifiers.append( col_agg_op.agg_snowflake_quoted_identifier ) - if agg_func in ( - "min", - "max", - "sum", - "mean", - "median", - "std", - "first", - "last", - ): + if agg_func in GROUPBY_AGG_SAME_INPUT_AND_OUTPUT_DATA_TYPES: new_data_column_snowpark_pandas_types.append( col_agg_op.data_type if isinstance(col_agg_op.data_type, SnowparkPandasType) else None ) - elif agg_func in ( - "any", - "all", - "count", - "idxmax", - "idxmin", - "size", - "nunique", - ): + elif agg_func in GROUPBY_AGG_DIFFERENT_INPUT_AND_OUTPUT_DATA_TYPES: # In the case where the aggregation overrides the type of the output data column # (e.g. any always returns boolean data columns), set the output Snowpark pandas type to None new_data_column_snowpark_pandas_types = None # type: ignore diff --git a/tests/integ/modin/groupby/test_groupby_negative.py b/tests/integ/modin/groupby/test_groupby_negative.py index 31af6b0eaf0..4ac8a813755 100644 --- a/tests/integ/modin/groupby/test_groupby_negative.py +++ b/tests/integ/modin/groupby/test_groupby_negative.py @@ -558,6 +558,7 @@ def test_groupby_agg_invalid_min_count( ) +@sql_count_checker(query_count=0) def test_groupby_negative_var(): native_df = native_pd.DataFrame( { diff --git a/tests/integ/modin/types/test_timedelta.py b/tests/integ/modin/types/test_timedelta.py index bcae016cbf0..a85d149a690 100644 --- a/tests/integ/modin/types/test_timedelta.py +++ b/tests/integ/modin/types/test_timedelta.py @@ -88,23 +88,3 @@ def test_timedelta_precision_insufficient_with_nulls_SNOW_1628925(): eval_snowpark_pandas_result( pd, native_pd, lambda lib: lib.Series([None, timedelta]) ) - - -@sql_count_checker(query_count=0) -def test_timedelta_not_supported(): - df = pd.DataFrame( - { - "a": ["one", "two", "three"], - "b": ["abc", "pqr", "xyz"], - "dt": [ - pd.Timedelta("1 days"), - pd.Timedelta("2 days"), - pd.Timedelta("3 days"), - ], - } - ) - with pytest.raises( - NotImplementedError, - match="validate_groupby is not yet implemented for Timedelta Type", - ): - df.groupby("a").count()