Skip to content

Commit

Permalink
SNOW-1653121: Support some Timedelta aggregations on axis=0. (#2248)
Browse files Browse the repository at this point in the history
Fixes SNOW-1653121

Test and support aggregation on axis=0. We still raise
`NotImplementedError` if:

1) the aggregation requires concatenating a frame with timedelta types
2) the aggregation requires transposing a row containing a timedelta
type and other types.

This change also fixes the bug that timedelta aggregations like mean
would produce the wrong type (and the wrong result) by truncating the
float result if `preserves_snowpark_pandas_type`.

---------

Signed-off-by: sfc-gh-mvashishtha <[email protected]>
  • Loading branch information
sfc-gh-mvashishtha authored Sep 13, 2024
1 parent f566e25 commit c7be18c
Show file tree
Hide file tree
Showing 23 changed files with 487 additions and 123 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#### New Features

- Added support for `TimedeltaIndex.mean` method.
- Added support for some cases of aggregating `Timedelta` columns on `axis=0` with `agg` or `aggregate`.


## 1.22.1 (2024-09-11)
Expand Down
16 changes: 16 additions & 0 deletions src/snowflake/snowpark/modin/plugin/_internal/aggregation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
stddev,
stddev_pop,
sum as sum_,
trunc,
var_pop,
variance,
when,
Expand Down Expand Up @@ -698,6 +699,8 @@ def _is_supported_snowflake_agg_func(
is_valid: bool. Whether it is valid to implement with snowflake or not.
"""
if isinstance(agg_func, tuple) and len(agg_func) == 2:
# For named aggregations, like `df.agg(new_col=("old_col", "sum"))`,
# take the second part of the named aggregation.
agg_func = agg_func[0]
return get_snowflake_agg_func(agg_func, agg_kwargs, axis) is not None

Expand Down Expand Up @@ -963,6 +966,19 @@ def _generate_aggregation_column(
), f"No case expression is constructed with skipna({skipna}), min_count({min_count})"
agg_snowpark_column = case_expr.otherwise(agg_snowpark_column)

if (
isinstance(agg_column_op_params.data_type, TimedeltaType)
and agg_column_op_params.snowflake_agg_func.preserves_snowpark_pandas_types
):
# timedelta aggregations that produce timedelta results might produce
# a decimal type in snowflake, e.g.
# pd.Series([pd.Timestamp(1), pd.Timestamp(2)]).mean() produces 1.5 in
# Snowflake. We truncate the decimal part of the result, as pandas
# does.
agg_snowpark_column = cast(
trunc(agg_snowpark_column), agg_column_op_params.data_type.snowpark_type
)

# rename the column to agg_column_quoted_identifier
agg_snowpark_column = agg_snowpark_column.as_(
agg_column_op_params.agg_snowflake_quoted_identifier
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3553,7 +3553,6 @@ def convert_func_to_agg_func_info(
agg_col_ops, new_data_column_index_names = generate_column_agg_info(
internal_frame, column_to_agg_func, agg_kwargs, is_series_groupby
)

# the pandas label and quoted identifier generated for each result column
# after aggregation will be used as new pandas label and quoted identifiers.
new_data_column_pandas_labels = []
Expand All @@ -3570,7 +3569,6 @@ def convert_func_to_agg_func_info(
and agg_col_op.snowflake_agg_func.preserves_snowpark_pandas_types
else None
)

# The ordering of the named aggregations is changed by us when we process
# the agg_kwargs into the func dict (named aggregations on the same
# column are moved to be contiguous, see groupby.py::aggregate for an
Expand Down Expand Up @@ -5636,8 +5634,6 @@ def agg(
args: the arguments passed for the aggregation
kwargs: keyword arguments passed for the aggregation function.
"""
self._raise_not_implemented_error_for_timedelta()

numeric_only = kwargs.get("numeric_only", False)
# Call fallback if the aggregation function passed in the arg is currently not supported
# by snowflake engine.
Expand Down Expand Up @@ -5683,6 +5679,11 @@ def agg(
not is_list_like(value) for value in func.values()
)
if axis == 1:
if any(
isinstance(t, TimedeltaType)
for t in internal_frame.snowflake_quoted_identifier_to_snowpark_pandas_type.values()
):
ErrorMessage.not_implemented_for_timedelta("agg(axis=1)")
if self.is_multiindex():
# TODO SNOW-1010307 fix axis=1 behavior with MultiIndex
ErrorMessage.not_implemented(
Expand Down Expand Up @@ -5862,7 +5863,13 @@ def generate_agg_qc(
index_column_snowflake_quoted_identifiers=[
agg_name_col_quoted_identifier
],
data_column_types=None,
data_column_types=[
col.data_type
if isinstance(col.data_type, SnowparkPandasType)
and col.snowflake_agg_func.preserves_snowpark_pandas_types
else None
for col in col_agg_infos
],
index_column_types=None,
)
return SnowflakeQueryCompiler(single_agg_dataframe)
Expand Down Expand Up @@ -9108,7 +9115,9 @@ def transpose_single_row(self) -> "SnowflakeQueryCompiler":
SnowflakeQueryCompiler
Transposed new QueryCompiler object.
"""
self._raise_not_implemented_error_for_timedelta()
if len(set(self._modin_frame.cached_data_column_snowpark_pandas_types)) > 1:
# In this case, transpose may lose types.
self._raise_not_implemented_error_for_timedelta()

frame = self._modin_frame

Expand Down Expand Up @@ -12492,8 +12501,6 @@ def _quantiles_single_col(
column would allow us to create an accurate row position column, but would require a
potentially expensive JOIN operator afterwards to apply the correct index labels.
"""
self._raise_not_implemented_error_for_timedelta()

assert len(self._modin_frame.data_column_pandas_labels) == 1

if index is not None:
Expand Down Expand Up @@ -12558,7 +12565,7 @@ def _quantiles_single_col(
],
index_column_pandas_labels=[None],
index_column_snowflake_quoted_identifiers=[index_identifier],
data_column_types=None,
data_column_types=original_frame.cached_data_column_snowpark_pandas_types,
index_column_types=None,
)
# We cannot call astype() directly to convert an index column, so we replicate
Expand Down Expand Up @@ -14566,8 +14573,6 @@ def idxmax(
Returns:
SnowflakeQueryCompiler
"""
self._raise_not_implemented_error_for_timedelta()

return self._idxmax_idxmin(
func="idxmax", axis=axis, skipna=skipna, numeric_only=numeric_only
)
Expand All @@ -14592,8 +14597,6 @@ def idxmin(
Returns:
SnowflakeQueryCompiler
"""
self._raise_not_implemented_error_for_timedelta()

return self._idxmax_idxmin(
func="idxmin", axis=axis, skipna=skipna, numeric_only=numeric_only
)
Expand Down
44 changes: 18 additions & 26 deletions src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,14 @@
from pandas._typing import ArrayLike, AxisInt, Dtype, Frequency, Hashable
from pandas.core.dtypes.common import is_timedelta64_dtype

from snowflake.snowpark import functions as fn
from snowflake.snowpark.modin.pandas import DataFrame, Series
from snowflake.snowpark.modin.plugin._internal.aggregation_utils import (
AggregateColumnOpParameters,
SnowflakeAggFunc,
aggregate_with_ordered_dataframe,
)
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
SnowflakeQueryCompiler,
)
from snowflake.snowpark.modin.plugin.extensions.index import Index
from snowflake.snowpark.modin.plugin.utils.error_message import (
timedelta_index_not_implemented,
)
from snowflake.snowpark.types import LongType

_CONSTRUCTOR_DEFAULTS = {
"unit": lib.no_default,
Expand Down Expand Up @@ -434,26 +427,25 @@ def mean(
raise ValueError(
f"axis should be 0 for TimedeltaIndex.mean, found '{axis}'"
)
# TODO SNOW-1620439: Reuse code from Series.mean.
frame = self._query_compiler._modin_frame
index_id = frame.index_column_snowflake_quoted_identifiers[0]
new_index_id = frame.ordered_dataframe.generate_snowflake_quoted_identifiers(
pandas_labels=["mean"]
)[0]
agg_column_op_params = AggregateColumnOpParameters(
index_id,
LongType(),
"mean",
new_index_id,
snowflake_agg_func=SnowflakeAggFunc(
preserves_snowpark_pandas_types=True, snowpark_aggregation=fn.mean
),
ordering_columns=[],
pandas_dataframe_result = (
# reset_index(drop=False) copies the index column of
# self._query_compiler into a new data column. Use `drop=False`
# so that we don't have to use SQL row_number() to generate a new
# index column.
self._query_compiler.reset_index(drop=False)
# Aggregate the data column.
.agg("mean", axis=0, args=(), kwargs={"skipna": skipna})
# convert the query compiler to a pandas dataframe with
# dimensions 1x1 (note that the frame has a single row even
# if `self` is empty.)
.to_pandas()
)
mean_value = aggregate_with_ordered_dataframe(
frame.ordered_dataframe, [agg_column_op_params], {"skipna": skipna}
).collect()[0][0]
return native_pd.Timedelta(np.nan if mean_value is None else int(mean_value))
assert pandas_dataframe_result.shape == (
1,
1,
), "Internal error: aggregation result is not 1x1."
# Return the only element in the frame.
return pandas_dataframe_result.iloc[0, 0]

@timedelta_index_not_implemented()
def as_unit(self, unit: str) -> TimedeltaIndex:
Expand Down
27 changes: 27 additions & 0 deletions tests/integ/modin/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,3 +715,30 @@ def numeric_test_data_4x4():
"C": [7, 10, 13, 16],
"D": [8, 11, 14, 17],
}


@pytest.fixture
def timedelta_native_df() -> pandas.DataFrame:
return pandas.DataFrame(
{
"A": [
pd.Timedelta(days=1),
pd.Timedelta(days=2),
pd.Timedelta(days=3),
pd.Timedelta(days=4),
],
"B": [
pd.Timedelta(minutes=-1),
pd.Timedelta(minutes=0),
pd.Timedelta(minutes=5),
pd.Timedelta(minutes=6),
],
"C": [
None,
pd.Timedelta(nanoseconds=5),
pd.Timedelta(nanoseconds=0),
pd.Timedelta(nanoseconds=4),
],
"D": pandas.to_timedelta([pd.NaT] * 4),
}
)
102 changes: 102 additions & 0 deletions tests/integ/modin/frame/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,108 @@ def test_string_sum_with_nulls():
assert_series_equal(snow_result.to_pandas(), native_pd.Series(["ab"]))


class TestTimedelta:
"""Test aggregating dataframes containing timedelta columns."""

@pytest.mark.parametrize(
"func, union_count",
[
param(
lambda df: df.aggregate(["min"]),
0,
id="aggregate_list_with_one_element",
),
param(lambda df: df.aggregate(x=("A", "max")), 0, id="single_named_agg"),
# this works since all results are timedelta and we don't need to do any concats.
param(
lambda df: df.aggregate({"B": "mean", "A": "sum"}),
0,
id="dict_producing_two_timedeltas",
),
# this works since even though we need to do concats, all the results are non-timdelta.
param(
lambda df: df.aggregate(x=("B", "all"), y=("B", "any")),
1,
id="named_agg_producing_two_bools",
),
# note following aggregation requires transpose
param(lambda df: df.aggregate(max), 0, id="aggregate_max"),
param(lambda df: df.min(), 0, id="min"),
param(lambda df: df.max(), 0, id="max"),
param(lambda df: df.count(), 0, id="count"),
param(lambda df: df.sum(), 0, id="sum"),
param(lambda df: df.mean(), 0, id="mean"),
param(lambda df: df.median(), 0, id="median"),
param(lambda df: df.std(), 0, id="std"),
param(lambda df: df.quantile(), 0, id="single_quantile"),
param(lambda df: df.quantile([0.01, 0.99]), 1, id="two_quantiles"),
],
)
def test_supported_axis_0(self, func, union_count, timedelta_native_df):
with SqlCounter(query_count=1, union_count=union_count):
eval_snowpark_pandas_result(
*create_test_dfs(timedelta_native_df),
func,
)

@sql_count_checker(query_count=0)
@pytest.mark.xfail(strict=True, raises=NotImplementedError, reason="SNOW-1653126")
def test_axis_1(self, timedelta_native_df):
eval_snowpark_pandas_result(
*create_test_dfs(timedelta_native_df), lambda df: df.sum(axis=1)
)

@sql_count_checker(query_count=0)
def test_var_invalid(self, timedelta_native_df):
eval_snowpark_pandas_result(
*create_test_dfs(timedelta_native_df),
lambda df: df.var(),
expect_exception=True,
expect_exception_type=TypeError,
assert_exception_equal=False,
expect_exception_match=re.escape(
"timedelta64 type does not support var operations"
),
)

@sql_count_checker(query_count=0)
@pytest.mark.xfail(
strict=True,
raises=NotImplementedError,
reason="requires concat(), which we cannot do with Timedelta.",
)
@pytest.mark.parametrize(
"operation",
[
lambda df: df.aggregate({"A": ["count", "max"], "B": [max, "min"]}),
lambda df: df.aggregate({"B": ["count"], "A": "sum", "C": ["max", "min"]}),
lambda df: df.aggregate(
x=pd.NamedAgg("A", "max"), y=("B", "min"), c=("A", "count")
),
lambda df: df.aggregate(["min", np.max]),
lambda df: df.aggregate(x=("A", "max"), y=("C", "min"), z=("A", "min")),
lambda df: df.aggregate(x=("A", "max"), y=pd.NamedAgg("A", "max")),
lambda df: df.aggregate(
{"B": ["idxmax"], "A": "sum", "C": ["max", "idxmin"]}
),
],
)
def test_agg_requires_concat_with_timedelta(self, timedelta_native_df, operation):
eval_snowpark_pandas_result(*create_test_dfs(timedelta_native_df), operation)

@sql_count_checker(query_count=0)
@pytest.mark.xfail(
strict=True,
raises=NotImplementedError,
reason="requires transposing a one-row frame with integer and timedelta.",
)
def test_agg_produces_timedelta_and_non_timedelta_type(self, timedelta_native_df):
eval_snowpark_pandas_result(
*create_test_dfs(timedelta_native_df),
lambda df: df.aggregate({"B": "idxmax", "A": "sum"}),
)


@pytest.mark.parametrize(
"func, expected_union_count",
[
Expand Down
15 changes: 15 additions & 0 deletions tests/integ/modin/frame/test_describe.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,3 +358,18 @@ def test_describe_object_file(resources_path):
df = pd.read_csv(test_files.test_concat_file1_csv)
native_df = df.to_pandas()
eval_snowpark_pandas_result(df, native_df, lambda x: x.describe(include="O"))


@sql_count_checker(query_count=0)
@pytest.mark.xfail(
strict=True,
raises=NotImplementedError,
reason="requires concat(), which we cannot do with Timedelta.",
)
def test_timedelta(timedelta_native_df):
eval_snowpark_pandas_result(
*create_test_dfs(
timedelta_native_df,
),
lambda df: df.describe(),
)
14 changes: 12 additions & 2 deletions tests/integ/modin/frame/test_idxmax_idxmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,18 @@ def test_idxmax_idxmin_with_dates(func, axis):

@sql_count_checker(query_count=1)
@pytest.mark.parametrize("func", ["idxmax", "idxmin"])
@pytest.mark.parametrize("axis", [0, 1])
@pytest.mark.xfail(reason="SNOW-1625380 TODO")
@pytest.mark.parametrize(
"axis",
[
0,
pytest.param(
1,
marks=pytest.mark.xfail(
strict=True, raises=NotImplementedError, reason="SNOW-1653126"
),
),
],
)
def test_idxmax_idxmin_with_timedelta(func, axis):
native_df = native_pd.DataFrame(
data={
Expand Down
Loading

0 comments on commit c7be18c

Please sign in to comment.