Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-976701: Adding moving aggregation dataframe function #1145

Merged
merged 30 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
6840bfd
changes
sfc-gh-rsureshbabu Nov 23, 2023
266a281
updating changelog
sfc-gh-rsureshbabu Nov 23, 2023
2499a6d
fixing comment
sfc-gh-rsureshbabu Nov 23, 2023
4ba1a64
generalizing default formatter
sfc-gh-rsureshbabu Nov 23, 2023
810bfb8
generalizing default formatter 2
sfc-gh-rsureshbabu Nov 23, 2023
b43ffec
fix comment
sfc-gh-rsureshbabu Nov 23, 2023
5a007b2
cleaning argument checks
sfc-gh-rsureshbabu Nov 23, 2023
91ff442
cleaning argument checks
sfc-gh-rsureshbabu Nov 23, 2023
faccd9b
refactor
sfc-gh-rsureshbabu Nov 23, 2023
e27f0ea
changes
sfc-gh-rsureshbabu Nov 23, 2023
e8d4345
fix test
sfc-gh-rsureshbabu Nov 23, 2023
ccf6655
changes
sfc-gh-rsureshbabu Nov 23, 2023
8e7adad
changes
sfc-gh-rsureshbabu Nov 23, 2023
9a9a045
changes
sfc-gh-rsureshbabu Nov 23, 2023
706f15f
changes
sfc-gh-rsureshbabu Nov 23, 2023
39a3e8a
fix comment
sfc-gh-rsureshbabu Dec 18, 2023
b0c2c85
Merge branch 'main' into rsureshbabu-SNOW-SNOW-976701-movingagg
sfc-gh-rsureshbabu Dec 18, 2023
2599167
skip tests when pandas are not available
sfc-gh-rsureshbabu Jan 17, 2024
c4416ad
Merge branch 'main' into rsureshbabu-SNOW-SNOW-976701-movingagg
sfc-gh-rsureshbabu Jan 17, 2024
ce061bd
update change log
sfc-gh-rsureshbabu Jan 17, 2024
5c00182
changes
sfc-gh-rsureshbabu Jan 17, 2024
e0e18bf
changes
sfc-gh-rsureshbabu Jan 17, 2024
2d05cb8
adding doctest
sfc-gh-rsureshbabu Jan 18, 2024
9423599
renaming
sfc-gh-rsureshbabu Jan 23, 2024
721383f
renaming 2
sfc-gh-rsureshbabu Jan 23, 2024
27df8d2
Merge branch 'main' into rsureshbabu-SNOW-SNOW-976701-movingagg
sfc-gh-rsureshbabu Jan 23, 2024
cfb5e4a
fix comments
sfc-gh-rsureshbabu Jan 23, 2024
9583161
fix error message
sfc-gh-rsureshbabu Jan 24, 2024
045f216
fix merge
sfc-gh-rsureshbabu Jan 24, 2024
05a2da5
fix code coverage
sfc-gh-rsureshbabu Jan 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- `array_except`
- `create_map`
- `sign`/`signum`
- Added moving_agg function in DataFrame.analytics for enabling moving aggregations like sums and averages with multiple window sizes.

### Bug Fixes

Expand Down Expand Up @@ -45,6 +46,8 @@

## 1.11.1 (2023-12-07)

### New Features

sfc-gh-rsureshbabu marked this conversation as resolved.
Show resolved Hide resolved
### Bug Fixes

- Fixed a bug that numpy should not be imported at the top level of mock module.
Expand Down
2 changes: 2 additions & 0 deletions src/snowflake/snowpark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"GetResult",
"DataFrame",
"DataFrameStatFunctions",
"DataFrameAnalyticsFunctions",
"DataFrameNaFunctions",
"DataFrameWriter",
"DataFrameReader",
Expand Down Expand Up @@ -46,6 +47,7 @@
from snowflake.snowpark.async_job import AsyncJob
from snowflake.snowpark.column import CaseExpr, Column
from snowflake.snowpark.dataframe import DataFrame
from snowflake.snowpark.dataframe_analytics_functions import DataFrameAnalyticsFunctions
from snowflake.snowpark.dataframe_na_functions import DataFrameNaFunctions
from snowflake.snowpark.dataframe_reader import DataFrameReader
from snowflake.snowpark.dataframe_stat_functions import DataFrameStatFunctions
Expand Down
6 changes: 6 additions & 0 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
)
from snowflake.snowpark.async_job import AsyncJob, _AsyncResultType
from snowflake.snowpark.column import Column, _to_col_if_sql_expr, _to_col_if_str
from snowflake.snowpark.dataframe_analytics_functions import DataFrameAnalyticsFunctions
from snowflake.snowpark.dataframe_na_functions import DataFrameNaFunctions
from snowflake.snowpark.dataframe_stat_functions import DataFrameStatFunctions
from snowflake.snowpark.dataframe_writer import DataFrameWriter
Expand Down Expand Up @@ -522,6 +523,7 @@ def __init__(
self._writer = DataFrameWriter(self)

self._stat = DataFrameStatFunctions(self)
self._analytics = DataFrameAnalyticsFunctions(self)
self.approxQuantile = self.approx_quantile = self._stat.approx_quantile
self.corr = self._stat.corr
self.cov = self._stat.cov
Expand All @@ -539,6 +541,10 @@ def __init__(
def stat(self) -> DataFrameStatFunctions:
return self._stat

@property
def analytics(self) -> DataFrameAnalyticsFunctions:
return self._analytics

@overload
def collect(
self,
Expand Down
170 changes: 170 additions & 0 deletions src/snowflake/snowpark/dataframe_analytics_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from typing import Callable, Dict, List

import snowflake.snowpark
from snowflake.snowpark.functions import expr
from snowflake.snowpark.window import Window


class DataFrameAnalyticsFunctions:
"""Provides data analytics functions for DataFrames.
To access an object of this class, use :attr:`DataFrame.analytics`.
"""

def __init__(self, df: "snowflake.snowpark.DataFrame") -> None:
self._df = df

def _default_col_formatter(input_col: str, operation: str, *args) -> str:
args_str = "_".join(map(str, args))
formatted_name = f"{input_col}_{operation}"
if args_str:
formatted_name += f"_{args_str}"
return formatted_name

def _validate_aggs_argument(self, aggs):
argument_requirements = (
"The 'aggs' argument must adhere to the following rules: "
"1) It must be a dictionary. "
"2) It must not be empty. "
"3) All keys must be strings. "
"4) All values must be non-empty lists of strings."
)

if not isinstance(aggs, dict):
raise TypeError(f"aggs must be a dictionary. {argument_requirements}")
if not aggs:
raise ValueError(f"aggs must not be empty. {argument_requirements}")
if not all(
isinstance(key, str) and isinstance(val, list) and val
for key, val in aggs.items()
):
raise ValueError(
f"aggs must have strings as keys and non-empty lists of strings as values. {argument_requirements}"
)

def _validate_string_list_argument(self, data, argument_name):
argument_requirements = (
f"The '{argument_name}' argument must adhere to the following rules: "
"1) It must be a list. "
"2) It must not be empty. "
"3) All items in the list must be strings."
)
if not isinstance(data, list):
raise TypeError(f"{argument_name} must be a list. {argument_requirements}")
if not data:
raise ValueError(
f"{argument_name} must not be empty. {argument_requirements}"
)
if not all(isinstance(item, str) for item in data):
raise ValueError(
f"{argument_name} must be a list of strings. {argument_requirements}"
)

def _validate_positive_integer_list_argument(self, data, argument_name):
argument_requirements = (
f"The '{argument_name}' argument must adhere to the following criteria: "
"1) It must be a list. "
"2) It must not be empty. "
"3) All items in the list must be positive integers."
)
if not isinstance(data, list):
raise TypeError(f"{argument_name} must be a list. {argument_requirements}")
if not data:
raise ValueError(
f"{argument_name} must not be empty. {argument_requirements}"
)
if not all(isinstance(item, int) and item > 0 for item in data):
raise ValueError(
f"{argument_name} must be a list of integers > 0. {argument_requirements}"
)

def _validate_formatter_argument(self, fromatter):
if not callable(fromatter):
raise TypeError("formatter must be a callable function")

def moving_agg(
self,
aggs: Dict[str, List[str]],
window_sizes: List[int],
order_by: List[str],
group_by: List[str],
col_formatter: Callable[[str, str, int], str] = _default_col_formatter,
Comment on lines +91 to +94
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are they all mandatory arguments?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, except col_formatter

) -> "snowflake.snowpark.dataframe.DataFrame":
"""
Applies moving aggregations to the specified columns of the DataFrame using defined window sizes,
and grouping and ordering criteria.

Args:
aggs: A dictionary where keys are column names and values are lists of the desired aggregation functions.
Supported aggregation are listed here https://docs.snowflake.com/en/sql-reference/functions-analytic#list-of-functions-that-support-windows.
window_sizes: A list of positive integers, each representing the size of the window for which to
calculate the moving aggregate.
order_by: A list of column names that specify the order in which rows are processed.
group_by: A list of column names on which the DataFrame is partitioned for separate window calculations.
col_formatter: An optional function for formatting output column names, defaulting to the format '<input_col>_<agg>_<window>'.
This function takes three arguments: 'input_col' (str) for the column name, 'operation' (str) for the applied operation,
and 'value' (int) for the window size, and returns a formatted string for the column name.

Returns:
A Snowflake DataFrame with additional columns corresponding to each specified moving aggregation.

Raises:
ValueError: If an unsupported value is specified in arguments.
TypeError: If an unsupported type is specified in arguments.
SnowparkSQLException: If an unsupported aggregration is specified.

Example:
>>> data = [
... ["2023-01-01", 101, 200],
... ["2023-01-02", 101, 100],
... ["2023-01-03", 101, 300],
... ["2023-01-04", 102, 250],
... ]
>>> df = session.create_dataframe(data).to_df(
... "ORDERDATE", "PRODUCTKEY", "SALESAMOUNT"
... )
>>> result = df.analytics.moving_agg(
... aggs={"SALESAMOUNT": ["SUM", "AVG"]},
... window_sizes=[2, 3],
... order_by=["ORDERDATE"],
... group_by=["PRODUCTKEY"],
... )
>>> result.show()
+-----------+-----------+----------------+----------------+-----------------+-----------------+
| ORDERDATE | PRODUCTKEY| SALESAMOUNT_SUM_2 | SALESAMOUNT_AVG_2 | SALESAMOUNT_SUM_3 | SALESAMOUNT_AVG_3 |
+-----------+-----------+----------------+----------------+-----------------+-----------------+
| 2023-01-01| 101| 200| 200.0| 200| 200.0|
| 2023-01-02| 101| 300| 150.0| 300| 150.0|
| 2023-01-03| 101| 400| 200.0| 600| 200.0|
| 2023-01-04| 102| 250| 250.0| 250| 250.0|
+-----------+-----------+----------------+----------------+-----------------+-----------------+
<BLANKLINE>
"""
# Validate input arguments
self._validate_aggs_argument(aggs)
self._validate_string_list_argument(order_by, "order_by")
self._validate_string_list_argument(group_by, "group_by")
self._validate_positive_integer_list_argument(window_sizes, "window_sizes")
self._validate_formatter_argument(col_formatter)

# Perform window aggregation
agg_df = self._df
for column, agg_funcs in aggs.items():
for window_size in window_sizes:
for agg_func in agg_funcs:
window_spec = (
Window.partition_by(group_by)
.order_by(order_by)
.rows_between(-window_size + 1, 0)
)

# Apply the user-specified aggregation function directly. Snowflake will handle any errors for invalid functions.
agg_col = expr(f"{agg_func}({column})").over(window_spec)

formatted_col_name = col_formatter(column, agg_func, window_size)
agg_df = agg_df.with_column(formatted_col_name, agg_col)

return agg_df
144 changes: 144 additions & 0 deletions tests/integ/test_df_analytics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#!/usr/bin/env python3
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

try:
import pandas as pd
from pandas.testing import assert_frame_equal

is_pandas_available = True
except ImportError:
is_pandas_available = False

import pytest

from snowflake.snowpark.exceptions import SnowparkSQLException


def get_sample_dataframe(session):
data = [
["2023-01-01", 101, 200],
["2023-01-02", 101, 100],
["2023-01-03", 101, 300],
["2023-01-04", 102, 250],
]
return session.create_dataframe(data).to_df(
"ORDERDATE", "PRODUCTKEY", "SALESAMOUNT"
)


@pytest.mark.skipif(not is_pandas_available, reason="pandas is required")
def test_moving_agg(session):
"""Tests df.analytics.moving_agg() happy path."""

df = get_sample_dataframe(session)

res = df.analytics.moving_agg(
aggs={"SALESAMOUNT": ["SUM", "AVG"]},
window_sizes=[2, 3],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
)

expected_data = {
"ORDERDATE": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"],
"PRODUCTKEY": [101, 101, 101, 102],
"SALESAMOUNT": [200, 100, 300, 250],
"SALESAMOUNT_SUM_2": [200, 300, 400, 250],
"SALESAMOUNT_AVG_2": [200.0, 150.0, 200.0, 250.0],
"SALESAMOUNT_SUM_3": [200, 300, 600, 250],
"SALESAMOUNT_AVG_3": [200.0, 150.0, 200.0, 250.0],
}
expected_df = pd.DataFrame(expected_data)
assert_frame_equal(
res.order_by("ORDERDATE").to_pandas(), expected_df, check_dtype=False, atol=1e-1
)


@pytest.mark.skipif(not is_pandas_available, reason="pandas is required")
def test_moving_agg_custom_formatting(session):
"""Tests df.analytics.moving_agg() with custom formatting of output columns."""

df = get_sample_dataframe(session)

def custom_formatter(input_col, agg, window):
return f"{window}_{agg}_{input_col}"

res = df.analytics.moving_agg(
aggs={"SALESAMOUNT": ["SUM", "AVG"]},
window_sizes=[2, 3],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
col_formatter=custom_formatter,
)

expected_data = {
"ORDERDATE": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"],
"PRODUCTKEY": [101, 101, 101, 102],
"SALESAMOUNT": [200, 100, 300, 250],
"2_SUM_SALESAMOUNT": [200, 300, 400, 250],
"2_AVG_SALESAMOUNT": [200.0, 150.0, 200.0, 250.0],
"3_SUM_SALESAMOUNT": [200, 300, 600, 250],
"3_AVG_SALESAMOUNT": [200.0, 150.0, 200.0, 250.0],
}
expected_df = pd.DataFrame(expected_data)
assert_frame_equal(
res.order_by("ORDERDATE").to_pandas(), expected_df, check_dtype=False, atol=1e-1
)


@pytest.mark.skipif(not is_pandas_available, reason="pandas is required")
def test_moving_agg_invalid_inputs(session):
"""Tests df.analytics.moving_agg() with invalid window sizes."""

df = get_sample_dataframe(session)

with pytest.raises(ValueError) as exc:
df.analytics.moving_agg(
aggs={"SALESAMOUNT": ["AVG"]},
window_sizes=[-1, 2, 3],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
).collect()
assert "window_sizes must be a list of integers > 0" in str(exc)

with pytest.raises(ValueError) as exc:
df.analytics.moving_agg(
aggs={"SALESAMOUNT": ["AVG"]},
window_sizes=[0, 2, 3],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
).collect()
assert "window_sizes must be a list of integers > 0" in str(exc)

with pytest.raises(ValueError) as exc:
df.analytics.moving_agg(
aggs={"SALESAMOUNT": []},
window_sizes=[0, 2, 3],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
).collect()
assert "non-empty lists of strings as values" in str(exc)

with pytest.raises(SnowparkSQLException) as exc:
df.analytics.moving_agg(
aggs={"SALESAMOUNT": ["INVALID_FUNC"]},
window_sizes=[1],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
).collect()
assert "Sliding window frame unsupported for function" in str(exc)

def bad_formatter(input_col, agg):
return f"{agg}_{input_col}"

with pytest.raises(TypeError) as exc:
df.analytics.moving_agg(
aggs={"SALESAMOUNT": ["SUM"]},
window_sizes=[1],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
col_formatter=bad_formatter,
).collect()
assert "positional arguments but 3 were given" in str(exc)
Loading