Skip to content

Commit

Permalink
Respond to comments.
Browse files Browse the repository at this point in the history
Signed-off-by: sfc-gh-mvashishtha <[email protected]>
  • Loading branch information
sfc-gh-mvashishtha committed Sep 11, 2024
1 parent 212f663 commit b44a816
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 110 deletions.
226 changes: 116 additions & 110 deletions src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections.abc import Hashable, Iterable
from functools import partial
from inspect import getmembers
from types import BuiltinFunctionType
from types import BuiltinFunctionType, MappingProxyType
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -429,121 +429,123 @@ def _columns_coalescing_sum(*cols: SnowparkColumn) -> Callable:
# Map between the pandas input aggregation function (str or numpy function) and
# _SnowparkPandasAggregation representing information about applying the
# aggregation in Snowpark pandas.
_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION: dict[
_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION: MappingProxyType[
AggFuncTypeBase, _SnowparkPandasAggregation
] = {
"count": _SnowparkPandasAggregation(
axis_0_aggregation=count,
axis_1_aggregation_skipna=_columns_count,
preserves_snowpark_pandas_types=False,
),
**{
k: _SnowparkPandasAggregation(
axis_0_aggregation=mean,
preserves_snowpark_pandas_types=True,
)
for k in python_cast(tuple[AggFuncTypeBase], ("mean", np.mean))
},
**{
k: _SnowparkPandasAggregation(
axis_0_aggregation=min_,
axis_1_aggregation_keepna=least,
axis_1_aggregation_skipna=_columns_coalescing_min,
preserves_snowpark_pandas_types=True,
)
for k in python_cast(tuple[AggFuncTypeBase], ("min", np.min))
},
**{
k: _SnowparkPandasAggregation(
axis_0_aggregation=max_,
axis_1_aggregation_keepna=greatest,
axis_1_aggregation_skipna=_columns_coalescing_max,
preserves_snowpark_pandas_types=True,
)
for k in python_cast(tuple[AggFuncTypeBase], ("max", np.max))
},
**{
k: _SnowparkPandasAggregation(
axis_0_aggregation=sum_,
# IMPORTANT: count and sum use python builtin sum to invoke
# __add__ on each column rather than Snowpark sum_, since
# Snowpark sum_ gets the sum of all rows within a single column.
axis_1_aggregation_keepna=lambda *cols: sum(cols),
axis_1_aggregation_skipna=_columns_coalescing_sum,
preserves_snowpark_pandas_types=True,
)
for k in python_cast(tuple[AggFuncTypeBase], ("sum", np.sum))
},
**{
k: _SnowparkPandasAggregation(
axis_0_aggregation=median,
] = MappingProxyType(
{
"count": _SnowparkPandasAggregation(
axis_0_aggregation=count,
axis_1_aggregation_skipna=_columns_count,
preserves_snowpark_pandas_types=False,
),
**{
k: _SnowparkPandasAggregation(
axis_0_aggregation=mean,
preserves_snowpark_pandas_types=True,
)
for k in python_cast(tuple[AggFuncTypeBase], ("mean", np.mean))
},
**{
k: _SnowparkPandasAggregation(
axis_0_aggregation=min_,
axis_1_aggregation_keepna=least,
axis_1_aggregation_skipna=_columns_coalescing_min,
preserves_snowpark_pandas_types=True,
)
for k in python_cast(tuple[AggFuncTypeBase], ("min", np.min))
},
**{
k: _SnowparkPandasAggregation(
axis_0_aggregation=max_,
axis_1_aggregation_keepna=greatest,
axis_1_aggregation_skipna=_columns_coalescing_max,
preserves_snowpark_pandas_types=True,
)
for k in python_cast(tuple[AggFuncTypeBase], ("max", np.max))
},
**{
k: _SnowparkPandasAggregation(
axis_0_aggregation=sum_,
# IMPORTANT: count and sum use python builtin sum to invoke
# __add__ on each column rather than Snowpark sum_, since
# Snowpark sum_ gets the sum of all rows within a single column.
axis_1_aggregation_keepna=lambda *cols: sum(cols),
axis_1_aggregation_skipna=_columns_coalescing_sum,
preserves_snowpark_pandas_types=True,
)
for k in python_cast(tuple[AggFuncTypeBase], ("sum", np.sum))
},
**{
k: _SnowparkPandasAggregation(
axis_0_aggregation=median,
preserves_snowpark_pandas_types=True,
)
for k in python_cast(tuple[AggFuncTypeBase], ("median", np.median))
},
"idxmax": _SnowparkPandasAggregation(
axis_0_aggregation=functools.partial(
_columns_coalescing_idxmax_idxmin_helper, func="idxmax"
),
axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper,
axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper,
preserves_snowpark_pandas_types=False,
),
"idxmin": _SnowparkPandasAggregation(
axis_0_aggregation=functools.partial(
_columns_coalescing_idxmax_idxmin_helper, func="idxmin"
),
axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper,
axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper,
preserves_snowpark_pandas_types=False,
),
"skew": _SnowparkPandasAggregation(
axis_0_aggregation=skew,
preserves_snowpark_pandas_types=True,
)
for k in python_cast(tuple[AggFuncTypeBase], ("median", np.median))
},
"idxmax": _SnowparkPandasAggregation(
axis_0_aggregation=functools.partial(
_columns_coalescing_idxmax_idxmin_helper, func="idxmax"
),
axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper,
axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper,
preserves_snowpark_pandas_types=False,
),
"idxmin": _SnowparkPandasAggregation(
axis_0_aggregation=functools.partial(
_columns_coalescing_idxmax_idxmin_helper, func="idxmin"
"all": _SnowparkPandasAggregation(
# all() for a column with no non-null values is NULL in Snowflake, but True in pandas.
axis_0_aggregation=lambda c: coalesce(
builtin("booland_agg")(col(c)), pandas_lit(True)
),
preserves_snowpark_pandas_types=False,
),
axis_1_aggregation_skipna=_columns_coalescing_idxmax_idxmin_helper,
axis_1_aggregation_keepna=_columns_coalescing_idxmax_idxmin_helper,
preserves_snowpark_pandas_types=False,
),
"skew": _SnowparkPandasAggregation(
axis_0_aggregation=skew,
preserves_snowpark_pandas_types=True,
),
"all": _SnowparkPandasAggregation(
# all() for a column with no non-null values is NULL in Snowflake, but True in pandas.
axis_0_aggregation=lambda c: coalesce(
builtin("booland_agg")(col(c)), pandas_lit(True)
"any": _SnowparkPandasAggregation(
# any() for a column with no non-null values is NULL in Snowflake, but False in pandas.
axis_0_aggregation=lambda c: coalesce(
builtin("boolor_agg")(col(c)), pandas_lit(False)
),
preserves_snowpark_pandas_types=False,
),
preserves_snowpark_pandas_types=False,
),
"any": _SnowparkPandasAggregation(
# any() for a column with no non-null values is NULL in Snowflake, but False in pandas.
axis_0_aggregation=lambda c: coalesce(
builtin("boolor_agg")(col(c)), pandas_lit(False)
**{
k: _SnowparkPandasAggregation(
axis_0_aggregation=stddev,
preserves_snowpark_pandas_types=True,
)
for k in python_cast(tuple[AggFuncTypeBase], ("std", np.std))
},
**{
k: _SnowparkPandasAggregation(
axis_0_aggregation=variance,
# variance units are the square of the input column units, so
# variance does not preserve types.
preserves_snowpark_pandas_types=False,
)
for k in python_cast(tuple[AggFuncTypeBase], ("var", np.var))
},
"array_agg": _SnowparkPandasAggregation(
axis_0_aggregation=array_agg,
preserves_snowpark_pandas_types=False,
),
preserves_snowpark_pandas_types=False,
),
**{
k: _SnowparkPandasAggregation(
axis_0_aggregation=stddev,
"quantile": _SnowparkPandasAggregation(
axis_0_aggregation=column_quantile,
preserves_snowpark_pandas_types=True,
)
for k in python_cast(tuple[AggFuncTypeBase], ("std", np.std))
},
**{
k: _SnowparkPandasAggregation(
axis_0_aggregation=variance,
# variance units are the square of the input column units, so
# variance does not preserve types.
),
"nunique": _SnowparkPandasAggregation(
axis_0_aggregation=count_distinct,
preserves_snowpark_pandas_types=False,
)
for k in python_cast(tuple[AggFuncTypeBase], ("var", np.var))
},
"array_agg": _SnowparkPandasAggregation(
axis_0_aggregation=array_agg,
preserves_snowpark_pandas_types=False,
),
"quantile": _SnowparkPandasAggregation(
axis_0_aggregation=column_quantile,
preserves_snowpark_pandas_types=True,
),
"nunique": _SnowparkPandasAggregation(
axis_0_aggregation=count_distinct,
preserves_snowpark_pandas_types=False,
),
}
),
}
)


class AggregateColumnOpParameters(NamedTuple):
Expand Down Expand Up @@ -591,12 +593,14 @@ def get_snowflake_agg_func(
)

if snowpark_pandas_aggregation is None:
# We don't have any implementation at all for this aggregation.
return None

snowpark_aggregation = snowpark_pandas_aggregation.axis_0_aggregation

if snowpark_aggregation is None:
return None # pragma: no cover
# We don't have an implementation on axis=0 for this aggregation.
return None

# Rewrite some aggregations according to `agg_kwargs.`
if snowpark_aggregation == stddev or snowpark_aggregation == variance:
Expand Down Expand Up @@ -624,7 +628,9 @@ def get_snowflake_agg_func(
def snowpark_aggregation(col: SnowparkColumn) -> SnowparkColumn:
return column_quantile(col, interpolation, q)

assert snowpark_aggregation is not None
assert (
snowpark_aggregation is not None
), "Internal error: Snowpark pandas should have identified a Snowpark aggregation."
return SnowflakeAggFunc(
snowpark_aggregation=snowpark_aggregation,
preserves_snowpark_pandas_types=snowpark_pandas_aggregation.preserves_snowpark_pandas_types,
Expand Down
45 changes: 45 additions & 0 deletions tests/unit/modin/test_aggregation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,20 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

from types import MappingProxyType
from unittest import mock

import numpy as np
import pytest

import snowflake.snowpark.modin.plugin._internal.aggregation_utils as aggregation_utils
from snowflake.snowpark.functions import greatest, sum as sum_
from snowflake.snowpark.modin.plugin._internal.aggregation_utils import (
SnowflakeAggFunc,
_is_supported_snowflake_agg_func,
_SnowparkPandasAggregation,
check_is_aggregation_supported_in_snowflake,
get_snowflake_agg_func,
)


Expand Down Expand Up @@ -103,3 +111,40 @@ def test_check_aggregation_snowflake_execution_capability_by_args(
agg_func=agg_func, agg_kwargs=agg_kwargs, axis=0
)
assert can_be_distributed == expected_result


@pytest.mark.parametrize(
"agg_func, agg_kwargs, axis, expected",
[
(np.sum, {}, 0, SnowflakeAggFunc(sum_, True)),
("max", {"skipna": False}, 1, SnowflakeAggFunc(greatest, True)),
("test", {}, 0, None),
],
)
def test_get_snowflake_agg_func(agg_func, agg_kwargs, axis, expected):
result = get_snowflake_agg_func(agg_func, agg_kwargs, axis)
if expected is None:
assert result is None
else:
assert result == expected


def test_get_snowflake_agg_func_with_no_implementation_on_axis_0():
"""Test get_snowflake_agg_func for a function that we support on axis=1 but not on axis=0."""
# We have to patch the internal dictionary
# _PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION here because there is
# no real function that we support on axis=1 but not on axis=0.
with mock.patch.object(
aggregation_utils,
"_PANDAS_AGGREGATION_TO_SNOWPARK_PANDAS_AGGREGATION",
MappingProxyType(
{
"max": _SnowparkPandasAggregation(
preserves_snowpark_pandas_types=True,
axis_1_aggregation_keepna=greatest,
axis_1_aggregation_skipna=greatest,
)
}
),
):
assert get_snowflake_agg_func(agg_func="max", agg_kwargs={}, axis=0) is None

0 comments on commit b44a816

Please sign in to comment.