Skip to content

Commit

Permalink
SNOW-1526422, SNOW-1524697: Remove axis labels and callable names fro…
Browse files Browse the repository at this point in the history
…m unsupported aggregation messages. (#1915)

Along the way, separate out the NotImplementedError messages about unsupported aggregations from the messages about unsupported groupings (e.g. grouping on axis=1 or grouping by a dataframe).

Signed-off-by: sfc-gh-mvashishtha <[email protected]>
Co-authored-by: Rehan Durrani <[email protected]>
  • Loading branch information
sfc-gh-mvashishtha and sfc-gh-rdurrani authored Jul 22, 2024
1 parent e14b78b commit c5c8004
Show file tree
Hide file tree
Showing 10 changed files with 416 additions and 178 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
#### Improvements
- Removed the public preview warning message upon importing Snowpark pandas.

#### Bug Fixes
- Made passing an unsupported aggregation function to `pivot_table` raise `NotImplementedError` instead of `KeyError`.
- Removed axis labels and callable names from error messages and telemetry about unsupported aggregations.

## 1.20.0 (2024-07-17)

Expand Down
81 changes: 75 additions & 6 deletions src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from collections import defaultdict
from collections.abc import Hashable, Iterable
from functools import partial
from typing import Any, Callable, Literal, NamedTuple, Optional, Union
from inspect import getmembers
from types import BuiltinFunctionType
from typing import Any, Callable, Literal, Mapping, NamedTuple, Optional, Union

import numpy as np
from pandas._typing import AggFuncType, AggFuncTypeBase
Expand Down Expand Up @@ -78,6 +80,9 @@
)

AGG_NAME_COL_LABEL = "AGG_FUNC_NAME"
_NUMPY_FUNCTION_TO_NAME = {
function: name for name, function in getmembers(np) if callable(function)
}


def array_agg_keepna(
Expand Down Expand Up @@ -1127,11 +1132,75 @@ def using_named_aggregations_for_func(func: Any) -> bool:
)


def format_kwargs_for_error_message(kwargs: dict[Any, Any]) -> str:
def repr_aggregate_function(agg_func: AggFuncType, agg_kwargs: Mapping) -> str:
"""
Helper method to format a kwargs dictionary for an error message.
Represent an aggregation function as a string.
Use this function to represent aggregation functions in error message to
the user. This function will hide sensitive information, like axis labels or
names of callables, in the function description.
Returns a string containing the keys + values of kwargs formatted like so:
"key1=value1, key2=value2, ..."
Args:
agg_func: AggFuncType
The aggregation function from the user. This may be a list-like or a
dictionary containing multiple aggregations.
agg_kwargs: Mapping
The keyword arguments for the aggregation function.
Returns:
str
The representation of the aggregation function.
"""
return ", ".join([f"{key}={value}" for key, value in kwargs.items()])
if using_named_aggregations_for_func(agg_func):
# New axis labels are sensitive, so replace them with "new_label."
# Existing axis labels are sensitive, so replace them with "label."
return ", ".join(
f"new_label=(label, {repr_aggregate_function(f, agg_kwargs)})"
for _, f in agg_kwargs.values()
)
if isinstance(agg_func, str):
# Strings functions represent names of pandas functions, e.g.
# "sum" means to aggregate with pandas.Series.sum. string function
# identifiers are not sensitive.
return repr(agg_func)
if is_dict_like(agg_func):
# axis labels in the dictionary keys are sensitive, so replace them with
# "label."
return (
"{"
+ ", ".join(
f"label: {repr_aggregate_function(agg_func[key], agg_kwargs)}"
for key in agg_func.keys()
)
+ "}"
)
if is_list_like(agg_func):
return f"[{', '.join(repr_aggregate_function(func, agg_kwargs) for func in agg_func)}]"
if isinstance(agg_func, BuiltinFunctionType):
return repr(agg_func)

# for built-in classes like `list`, return "list" as opposed to repr(list),
# i.e. <class 'list'>, which would be confusing because the user is using
# `list` as a callable in this context.
if agg_func is list:
return "list"
if agg_func is tuple:
return "tuple"
if agg_func is set:
return "set"
if agg_func is str:
return "str"

# Format numpy aggregations, e.g. np.argmin should become "np.argmin"
if agg_func in _NUMPY_FUNCTION_TO_NAME:
return f"np.{_NUMPY_FUNCTION_TO_NAME[agg_func]}"

# agg_func should be callable at this point. pandas error messages at this
# point are not consistent, so choose one style of error message.
if not callable(agg_func):
raise ValueError("aggregation function is not callable")

# Return a constant string instead of some kind of function name to avoid
# exposing sensitive user input in the NotImplemented error message and
# thus in telemetry.
return "Callable"
30 changes: 23 additions & 7 deletions src/snowflake/snowpark/modin/plugin/_internal/pivot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import pandas as pd
from pandas._typing import AggFuncType, AggFuncTypeBase, Scalar
from pandas.api.types import is_dict_like, is_list_like

from snowflake.snowpark.column import Column as SnowparkColumn
from snowflake.snowpark.functions import (
Expand All @@ -24,6 +25,7 @@
from snowflake.snowpark.modin.plugin._internal.aggregation_utils import (
get_pandas_aggr_func_name,
get_snowflake_agg_func,
repr_aggregate_function,
)
from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame
from snowflake.snowpark.modin.plugin._internal.groupby_utils import (
Expand Down Expand Up @@ -52,6 +54,7 @@
LabelTuple,
PandasLabelToSnowflakeIdentifierPair,
)
from snowflake.snowpark.modin.utils import ErrorMessage
from snowflake.snowpark.types import DoubleType, StringType

TEMP_PIVOT_COLUMN_PREFIX = "PIVOT_"
Expand All @@ -78,6 +81,7 @@ def perform_pivot_and_concatenate(
groupby_snowflake_quoted_identifiers: list[str],
pivot_snowflake_quoted_identifiers: list[str],
should_join_along_columns: bool,
original_aggfunc: AggFuncType,
) -> PivotedOrderedDataFrameResult:
"""
Helper function to perform a full pivot (including joining in the case of multiple aggrs or values) on an OrderedDataFrame.
Expand All @@ -88,6 +92,7 @@ def perform_pivot_and_concatenate(
groupby_snowflake_quoted_identifiers: Group by identifiers
pivot_snowflake_quoted_identifiers: Pivot identifiers
should_join_along_columns: Whether to join along columns, or use union to join along rows instead.
original_aggfunc: The aggregation function that the user provided.
"""
last_ordered_dataframe = None
data_column_pandas_labels: list[Hashable] = []
Expand All @@ -114,6 +119,7 @@ def perform_pivot_and_concatenate(
pivot_aggr_grouping.aggr_label_identifier_pair,
pivot_aggr_grouping.aggfunc,
pivot_aggr_grouping.prefix_label,
original_aggfunc,
)

if last_ordered_dataframe:
Expand Down Expand Up @@ -162,6 +168,7 @@ def pivot_helper(
multiple_aggr_funcs: bool,
multiple_values: bool,
index: Optional[list],
original_aggfunc: AggFuncType,
) -> InternalFrame:
"""
Helper function that that performs a full pivot on an InternalFrame.
Expand All @@ -177,6 +184,7 @@ def pivot_helper(
multiple_aggr_funcs: Whether multiple aggregation functions have been passed in.
multiple_values: Whether multiple values columns have been passed in.
index: The index argument passed to `pivot_table` if specified. Will become the pandas labels for the index column.
original_aggfunc: The aggregation function that the user provided.
Returns:
InternalFrame
The result of performing the pivot.
Expand Down Expand Up @@ -350,6 +358,7 @@ def pivot_helper(
groupby_snowflake_quoted_identifiers,
pivot_snowflake_quoted_identifiers,
True,
original_aggfunc,
)
if last_ordered_dataframe is None:
last_ordered_dataframe = pivot_ordered_dataframe
Expand Down Expand Up @@ -382,6 +391,7 @@ def pivot_helper(
groupby_snowflake_quoted_identifiers,
pivot_snowflake_quoted_identifiers,
should_join_along_columns,
original_aggfunc,
)

# When there are no groupby columns, the index is the first column in the OrderedDataFrame.
Expand Down Expand Up @@ -483,6 +493,7 @@ def single_pivot_helper(
value_label_to_identifier_pair: PandasLabelToSnowflakeIdentifierPair,
pandas_aggr_func_name: str,
prefix_pandas_labels: tuple[LabelComponent],
original_aggfunc: AggFuncType,
) -> tuple[OrderedDataFrame, list[str], list[Hashable]]:
"""
Helper function that is a building block for generating a single pivot, that can be used by other pivot like
Expand All @@ -497,6 +508,7 @@ def single_pivot_helper(
pandas_aggr_func_name: pandas label for aggregation function (since used as a label)
prefix_pandas_labels: Any prefix labels that should be added to the result pivot column name, such as
the aggregation function or other labels.
original_aggfunc: The aggregation function that the user provided.
Returns:
Tuple of:
Expand All @@ -507,7 +519,9 @@ def single_pivot_helper(
snowpark_aggr_func = get_snowflake_agg_func(pandas_aggr_func_name, {})
if not is_supported_snowflake_pivot_agg_func(snowpark_aggr_func):
# TODO: (SNOW-853334) Add support for any non-supported snowflake pivot aggregations
raise KeyError(pandas_aggr_func_name)
raise ErrorMessage.not_implemented(
f"Snowpark pandas DataFrame.pivot_table does not yet support the aggregation {repr_aggregate_function(original_aggfunc, agg_kwargs={})} with the given arguments."
)

pandas_aggr_label, aggr_snowflake_quoted_identifier = value_label_to_identifier_pair

Expand Down Expand Up @@ -870,7 +884,7 @@ def generate_single_pivot_labels(
value_pandas_label_to_identifiers: Aggregation value pandas label to snowflake quoted identifier
pandas_single_aggr_func: pandas aggregation function to apply to pandas aggregation label
"""
if isinstance(aggfunc, list):
if not is_dict_like(aggfunc) and is_list_like(aggfunc):
# Fetch all aggregation functions, it will be the same aggregation function list for each aggregation value.
(
pandas_aggfunc_list,
Expand Down Expand Up @@ -976,7 +990,7 @@ def get_pandas_aggr_func_and_prefix(

include_prefix = any([isinstance(af, list) for af in aggfunc.values()])

elif isinstance(aggfunc, list):
elif is_list_like(aggfunc):
pandas_aggr_func = aggfunc

if len(pandas_aggr_func) == 0:
Expand Down Expand Up @@ -1236,26 +1250,28 @@ def expand_pivot_result_with_pivot_table_margins_no_groupby_columns(
pivot_snowflake_quoted_identifiers: list[str],
values: list[str],
margins_name: str,
original_aggfunc: AggFuncType,
) -> "SnowflakeQueryCompiler": # type: ignore[name-defined] # noqa: F821
names = pivot_qc.columns.names
margins_frame = pivot_helper(
original_modin_frame,
pivot_aggr_groupings,
not dropna,
not isinstance(aggfunc, list),
not is_list_like(aggfunc),
columns[:1],
[], # There are no groupby_snowflake_quoted_identifiers
pivot_snowflake_quoted_identifiers[:1],
(isinstance(aggfunc, list) and len(aggfunc) > 1),
(is_list_like(aggfunc) and len(aggfunc) > 1),
(isinstance(values, list) and len(values) > 1),
None, # There is no index.
original_aggfunc,
)
if len(columns) > 1:
# If there is a multiindex on the pivot result, we need to add the margin_name to the margins frame's data column
# pandas labels, as well as any empty postfixes for the remaining pivot columns if there are more than 2.
new_data_column_pandas_labels = []
for label in margins_frame.data_column_pandas_labels:
if isinstance(aggfunc, list):
if is_list_like(aggfunc):
new_label = label + (margins_name,)
else:
new_label = (label, margins_name) + tuple(
Expand Down Expand Up @@ -1336,7 +1352,7 @@ def expand_pivot_result_with_pivot_table_margins_no_groupby_columns(
# tw"o 2
# If there are multiple columns and multiple aggregation functions, we need to groupby the first two columns instead of just the first one -
# as the first column will be the name of the aggregation function, and the second column will be the values from the first pivot column.
if isinstance(aggfunc, list):
if is_list_like(aggfunc):
groupby_columns = mi_as_frame.columns[:2].tolist()
value_column_index = 2
else:
Expand Down
Loading

0 comments on commit c5c8004

Please sign in to comment.