diff --git a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py index 852a8d1a536..4fd4d2a16ba 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py @@ -9,7 +9,7 @@ from collections.abc import Hashable, Iterable from functools import partial from inspect import getmembers -from types import BuiltinFunctionType +from types import BuiltinFunctionType, MappingProxyType from typing import ( Any, Callable, @@ -429,121 +429,123 @@ def _columns_coalescing_sum(*cols: SnowparkColumn) -> Callable: # Map between the pandas input aggregation function (str or numpy function) and # _SnowparkPandasAggregation representing information about applying the # aggregation in Snowpark pandas. -_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION: dict[ +_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION: MappingProxyType[ AggFuncTypeBase, _SnowparkPandasAggregation -] = { - "count": _SnowparkPandasAggregation( - axis_0_aggregation=count, - axis_1_aggregation_skipna=_columns_count, - preserves_snowpark_pandas_types=False, - ), - **{ - k: _SnowparkPandasAggregation( - axis_0_aggregation=mean, - preserves_snowpark_pandas_types=True, - ) - for k in python_cast(tuple[AggFuncTypeBase], ("mean", np.mean)) - }, - **{ - k: _SnowparkPandasAggregation( - axis_0_aggregation=min_, - axis_1_aggregation_keepna=least, - axis_1_aggregation_skipna=_columns_coalescing_min, - preserves_snowpark_pandas_types=True, - ) - for k in python_cast(tuple[AggFuncTypeBase], ("min", np.min)) - }, - **{ - k: _SnowparkPandasAggregation( - axis_0_aggregation=max_, - axis_1_aggregation_keepna=greatest, - axis_1_aggregation_skipna=_columns_coalescing_max, - preserves_snowpark_pandas_types=True, - ) - for k in python_cast(tuple[AggFuncTypeBase], ("max", np.max)) - }, - **{ - k: _SnowparkPandasAggregation( - axis_0_aggregation=sum_, - # IMPORTANT: count and sum use python builtin sum to invoke - # __add__ on each column rather than Snowpark sum_, since - # Snowpark sum_ gets the sum of all rows within a single column. - axis_1_aggregation_keepna=lambda *cols: sum(cols), - axis_1_aggregation_skipna=_columns_coalescing_sum, - preserves_snowpark_pandas_types=True, - ) - for k in python_cast(tuple[AggFuncTypeBase], ("sum", np.sum)) - }, - **{ - k: _SnowparkPandasAggregation( - axis_0_aggregation=median, +] = MappingProxyType( + { + "count": _SnowparkPandasAggregation( + axis_0_aggregation=count, + axis_1_aggregation_skipna=_columns_count, + preserves_snowpark_pandas_types=False, + ), + **{ + k: _SnowparkPandasAggregation( + axis_0_aggregation=mean, + preserves_snowpark_pandas_types=True, + ) + for k in python_cast(tuple[AggFuncTypeBase], ("mean", np.mean)) + }, + **{ + k: _SnowparkPandasAggregation( + axis_0_aggregation=min_, + axis_1_aggregation_keepna=least, + axis_1_aggregation_skipna=_columns_coalescing_min, + preserves_snowpark_pandas_types=True, + ) + for k in python_cast(tuple[AggFuncTypeBase], ("min", np.min)) + }, + **{ + k: _SnowparkPandasAggregation( + axis_0_aggregation=max_, + axis_1_aggregation_keepna=greatest, + axis_1_aggregation_skipna=_columns_coalescing_max, + preserves_snowpark_pandas_types=True, + ) + for k in python_cast(tuple[AggFuncTypeBase], ("max", np.max)) + }, + **{ + k: _SnowparkPandasAggregation( + axis_0_aggregation=sum_, + # IMPORTANT: count and sum use python builtin sum to invoke + # __add__ on each column rather than Snowpark sum_, since + # Snowpark sum_ gets the sum of all rows within a single column. + axis_1_aggregation_keepna=lambda *cols: sum(cols), + axis_1_aggregation_skipna=_columns_coalescing_sum, + preserves_snowpark_pandas_types=True, + ) + for k in python_cast(tuple[AggFuncTypeBase], ("sum", np.sum)) + }, + **{ + k: _SnowparkPandasAggregation( + axis_0_aggregation=median, + preserves_snowpark_pandas_types=True, + ) + for k in python_cast(tuple[AggFuncTypeBase], ("median", np.median)) + }, + "idxmax": _SnowparkPandasAggregation( + axis_0_aggregation=functools.partial( + _columns_coalescing_idxmax_idxmin_helper, func="idxmax" + ), + axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper, + axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper, + preserves_snowpark_pandas_types=False, + ), + "idxmin": _SnowparkPandasAggregation( + axis_0_aggregation=functools.partial( + _columns_coalescing_idxmax_idxmin_helper, func="idxmin" + ), + axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper, + axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper, + preserves_snowpark_pandas_types=False, + ), + "skew": _SnowparkPandasAggregation( + axis_0_aggregation=skew, preserves_snowpark_pandas_types=True, - ) - for k in python_cast(tuple[AggFuncTypeBase], ("median", np.median)) - }, - "idxmax": _SnowparkPandasAggregation( - axis_0_aggregation=functools.partial( - _columns_coalescing_idxmax_idxmin_helper, func="idxmax" ), - axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper, - axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper, - preserves_snowpark_pandas_types=False, - ), - "idxmin": _SnowparkPandasAggregation( - axis_0_aggregation=functools.partial( - _columns_coalescing_idxmax_idxmin_helper, func="idxmin" + "all": _SnowparkPandasAggregation( + # all() for a column with no non-null values is NULL in Snowflake, but True in pandas. + axis_0_aggregation=lambda c: coalesce( + builtin("booland_agg")(col(c)), pandas_lit(True) + ), + preserves_snowpark_pandas_types=False, ), - axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper, - axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper, - preserves_snowpark_pandas_types=False, - ), - "skew": _SnowparkPandasAggregation( - axis_0_aggregation=skew, - preserves_snowpark_pandas_types=True, - ), - "all": _SnowparkPandasAggregation( - # all() for a column with no non-null values is NULL in Snowflake, but True in pandas. - axis_0_aggregation=lambda c: coalesce( - builtin("booland_agg")(col(c)), pandas_lit(True) + "any": _SnowparkPandasAggregation( + # any() for a column with no non-null values is NULL in Snowflake, but False in pandas. + axis_0_aggregation=lambda c: coalesce( + builtin("boolor_agg")(col(c)), pandas_lit(False) + ), + preserves_snowpark_pandas_types=False, ), - preserves_snowpark_pandas_types=False, - ), - "any": _SnowparkPandasAggregation( - # any() for a column with no non-null values is NULL in Snowflake, but False in pandas. - axis_0_aggregation=lambda c: coalesce( - builtin("boolor_agg")(col(c)), pandas_lit(False) + **{ + k: _SnowparkPandasAggregation( + axis_0_aggregation=stddev, + preserves_snowpark_pandas_types=True, + ) + for k in python_cast(tuple[AggFuncTypeBase], ("std", np.std)) + }, + **{ + k: _SnowparkPandasAggregation( + axis_0_aggregation=variance, + # variance units are the square of the input column units, so + # variance does not preserve types. + preserves_snowpark_pandas_types=False, + ) + for k in python_cast(tuple[AggFuncTypeBase], ("var", np.var)) + }, + "array_agg": _SnowparkPandasAggregation( + axis_0_aggregation=array_agg, + preserves_snowpark_pandas_types=False, ), - preserves_snowpark_pandas_types=False, - ), - **{ - k: _SnowparkPandasAggregation( - axis_0_aggregation=stddev, + "quantile": _SnowparkPandasAggregation( + axis_0_aggregation=column_quantile, preserves_snowpark_pandas_types=True, - ) - for k in python_cast(tuple[AggFuncTypeBase], ("std", np.std)) - }, - **{ - k: _SnowparkPandasAggregation( - axis_0_aggregation=variance, - # variance units are the square of the input column units, so - # variance does not preserve types. + ), + "nunique": _SnowparkPandasAggregation( + axis_0_aggregation=count_distinct, preserves_snowpark_pandas_types=False, - ) - for k in python_cast(tuple[AggFuncTypeBase], ("var", np.var)) - }, - "array_agg": _SnowparkPandasAggregation( - axis_0_aggregation=array_agg, - preserves_snowpark_pandas_types=False, - ), - "quantile": _SnowparkPandasAggregation( - axis_0_aggregation=column_quantile, - preserves_snowpark_pandas_types=True, - ), - "nunique": _SnowparkPandasAggregation( - axis_0_aggregation=count_distinct, - preserves_snowpark_pandas_types=False, - ), -} + ), + } +) class AggregateColumnOpParameters(NamedTuple): @@ -591,12 +593,14 @@ def get_snowflake_agg_func( ) if snowpark_pandas_aggregation is None: + # We don't have any implementation at all for this aggregation. return None snowpark_aggregation = snowpark_pandas_aggregation.axis_0_aggregation if snowpark_aggregation is None: - return None # pragma: no cover + # We don't have an implementation on axis=0 for this aggregation. + return None # Rewrite some aggregations according to `agg_kwargs.` if snowpark_aggregation == stddev or snowpark_aggregation == variance: @@ -624,7 +628,9 @@ def get_snowflake_agg_func( def snowpark_aggregation(col: SnowparkColumn) -> SnowparkColumn: return column_quantile(col, interpolation, q) - assert snowpark_aggregation is not None + assert ( + snowpark_aggregation is not None + ), "Internal error: Snowpark pandas should have identified a Snowpark aggregation." return SnowflakeAggFunc( snowpark_aggregation=snowpark_aggregation, preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types, diff --git a/tests/unit/modin/test_aggregation_utils.py b/tests/unit/modin/test_aggregation_utils.py index 3ebedf72010..6c9edfd024f 100644 --- a/tests/unit/modin/test_aggregation_utils.py +++ b/tests/unit/modin/test_aggregation_utils.py @@ -2,12 +2,20 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from types import MappingProxyType +from unittest import mock + import numpy as np import pytest +import snowflake.snowpark.modin.plugin._internal.aggregation_utils as aggregation_utils +from snowflake.snowpark.functions import greatest, sum as sum_ from snowflake.snowpark.modin.plugin._internal.aggregation_utils import ( + SnowflakeAggFunc, _is_supported_snowflake_agg_func, + _SnowparkPandasAggregation, check_is_aggregation_supported_in_snowflake, + get_snowflake_agg_func, ) @@ -103,3 +111,40 @@ def test_check_aggregation_snowflake_execution_capability_by_args( agg_func=agg_func, agg_kwargs=agg_kwargs, axis=0 ) assert can_be_distributed == expected_result + + +@pytest.mark.parametrize( + "agg_func, agg_kwargs, axis, expected", + [ + (np.sum, {}, 0, SnowflakeAggFunc(sum_, True)), + ("max", {"skipna": False}, 1, SnowflakeAggFunc(greatest, True)), + ("test", {}, 0, None), + ], +) +def test_get_snowflake_agg_func(agg_func, agg_kwargs, axis, expected): + result = get_snowflake_agg_func(agg_func, agg_kwargs, axis) + if expected is None: + assert result is None + else: + assert result == expected + + +def test_get_snowflake_agg_func_with_no_implementation_on_axis_0(): + """Test get_snowflake_agg_func for a function that we support on axis=1 but not on axis=0.""" + # We have to patch the internal dictionary + # _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION here because there is + # no real function that we support on axis=1 but not on axis=0. + with mock.patch.object( + aggregation_utils, + "_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION", + MappingProxyType( + { + "max": _SnowparkPandasAggregation( + preserves_snowpark_pandas_types=True, + axis_1_aggregation_keepna=greatest, + axis_1_aggregation_skipna=greatest, + ) + } + ), + ): + assert get_snowflake_agg_func(agg_func="max", agg_kwargs={}, axis=0) is None