Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-976701: Adding moving aggregation dataframe function #1145

Merged
merged 30 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
6840bfd
changes
sfc-gh-rsureshbabu Nov 23, 2023
266a281
updating changelog
sfc-gh-rsureshbabu Nov 23, 2023
2499a6d
fixing comment
sfc-gh-rsureshbabu Nov 23, 2023
4ba1a64
generalizing default formatter
sfc-gh-rsureshbabu Nov 23, 2023
810bfb8
generalizing default formatter 2
sfc-gh-rsureshbabu Nov 23, 2023
b43ffec
fix comment
sfc-gh-rsureshbabu Nov 23, 2023
5a007b2
cleaning argument checks
sfc-gh-rsureshbabu Nov 23, 2023
91ff442
cleaning argument checks
sfc-gh-rsureshbabu Nov 23, 2023
faccd9b
refactor
sfc-gh-rsureshbabu Nov 23, 2023
e27f0ea
changes
sfc-gh-rsureshbabu Nov 23, 2023
e8d4345
fix test
sfc-gh-rsureshbabu Nov 23, 2023
ccf6655
changes
sfc-gh-rsureshbabu Nov 23, 2023
8e7adad
changes
sfc-gh-rsureshbabu Nov 23, 2023
9a9a045
changes
sfc-gh-rsureshbabu Nov 23, 2023
706f15f
changes
sfc-gh-rsureshbabu Nov 23, 2023
39a3e8a
fix comment
sfc-gh-rsureshbabu Dec 18, 2023
b0c2c85
Merge branch 'main' into rsureshbabu-SNOW-SNOW-976701-movingagg
sfc-gh-rsureshbabu Dec 18, 2023
2599167
skip tests when pandas are not available
sfc-gh-rsureshbabu Jan 17, 2024
c4416ad
Merge branch 'main' into rsureshbabu-SNOW-SNOW-976701-movingagg
sfc-gh-rsureshbabu Jan 17, 2024
ce061bd
update change log
sfc-gh-rsureshbabu Jan 17, 2024
5c00182
changes
sfc-gh-rsureshbabu Jan 17, 2024
e0e18bf
changes
sfc-gh-rsureshbabu Jan 17, 2024
2d05cb8
adding doctest
sfc-gh-rsureshbabu Jan 18, 2024
9423599
renaming
sfc-gh-rsureshbabu Jan 23, 2024
721383f
renaming 2
sfc-gh-rsureshbabu Jan 23, 2024
27df8d2
Merge branch 'main' into rsureshbabu-SNOW-SNOW-976701-movingagg
sfc-gh-rsureshbabu Jan 23, 2024
cfb5e4a
fix comments
sfc-gh-rsureshbabu Jan 23, 2024
9583161
fix error message
sfc-gh-rsureshbabu Jan 24, 2024
045f216
fix merge
sfc-gh-rsureshbabu Jan 24, 2024
05a2da5
fix code coverage
sfc-gh-rsureshbabu Jan 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

- Add the `conn_error` attribute to `SnowflakeSQLException` that stores the whole underlying exception from `snowflake-connector-python`

### New Features

- Added moving_agg function in DataFrame.transform for time series analysis, enabling moving aggregations like sums and averages with multiple window sizes.

### Bug Fixes

- DataFrame column names qouting check now supports newline characters.
Expand Down
2 changes: 2 additions & 0 deletions src/snowflake/snowpark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"GetResult",
"DataFrame",
"DataFrameStatFunctions",
"DataFrameTransformFunctions",
"DataFrameNaFunctions",
"DataFrameWriter",
"DataFrameReader",
Expand Down Expand Up @@ -49,6 +50,7 @@
from snowflake.snowpark.dataframe_na_functions import DataFrameNaFunctions
from snowflake.snowpark.dataframe_reader import DataFrameReader
from snowflake.snowpark.dataframe_stat_functions import DataFrameStatFunctions
from snowflake.snowpark.dataframe_transform_functions import DataFrameTransformFunctions
from snowflake.snowpark.dataframe_writer import DataFrameWriter
from snowflake.snowpark.file_operation import FileOperation, GetResult, PutResult
from snowflake.snowpark.query_history import QueryHistory, QueryRecord
Expand Down
6 changes: 6 additions & 0 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
from snowflake.snowpark.column import Column, _to_col_if_sql_expr, _to_col_if_str
from snowflake.snowpark.dataframe_na_functions import DataFrameNaFunctions
from snowflake.snowpark.dataframe_stat_functions import DataFrameStatFunctions
from snowflake.snowpark.dataframe_transform_functions import DataFrameTransformFunctions
from snowflake.snowpark.dataframe_writer import DataFrameWriter
from snowflake.snowpark.exceptions import SnowparkDataframeException
from snowflake.snowpark.functions import (
Expand Down Expand Up @@ -521,6 +522,7 @@ def __init__(
self._writer = DataFrameWriter(self)

self._stat = DataFrameStatFunctions(self)
self._transform = DataFrameTransformFunctions(self)
self.approxQuantile = self.approx_quantile = self._stat.approx_quantile
self.corr = self._stat.corr
self.cov = self._stat.cov
Expand All @@ -538,6 +540,10 @@ def __init__(
def stat(self) -> DataFrameStatFunctions:
return self._stat

@property
def transform(self) -> DataFrameTransformFunctions:
sfc-gh-rsureshbabu marked this conversation as resolved.
Show resolved Hide resolved
return self._transform

@overload
def collect(
self,
Expand Down
122 changes: 122 additions & 0 deletions src/snowflake/snowpark/dataframe_transform_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from typing import Callable, Dict, List

import snowflake.snowpark
from snowflake.snowpark.functions import expr
from snowflake.snowpark.window import Window


class DataFrameTransformFunctions:
"""Provides data transformation functions for DataFrames.
To access an object of this class, use :attr:`DataFrame.transform`.
"""

def __init__(self, df: "snowflake.snowpark.DataFrame") -> None:
self._df = df
sfc-gh-rsureshbabu marked this conversation as resolved.
Show resolved Hide resolved

def _default_col_formatter(input_col: str, operation: str, *args) -> str:
args_str = "_".join(map(str, args))
formatted_name = f"{input_col}_{operation}"
if args_str:
formatted_name += f"_{args_str}"
return formatted_name

def _validate_aggs_argument(self, aggs):
if not isinstance(aggs, dict):
raise TypeError("aggs must be a dictionary")
if not aggs:
raise ValueError("aggs must not be empty")
if not all(
isinstance(key, str) and isinstance(val, list) and val
for key, val in aggs.items()
):
raise ValueError(
"aggs must have strings as keys and non-empty lists of strings as values"
)
sfc-gh-rsureshbabu marked this conversation as resolved.
Show resolved Hide resolved

def _validate_string_list_argument(self, data, argument_name):
if not isinstance(data, list):
raise TypeError(f"{argument_name} must be a list")
if not data:
raise ValueError(f"{argument_name} must not be empty")
if not all(isinstance(item, str) for item in data):
raise ValueError(f"{argument_name} must be a list of strings")

def _validate_positive_integer_list_argument(self, data, argument_name):
if not isinstance(data, list):
raise TypeError(f"{argument_name} must be a list")
if not data:
raise ValueError(f"{argument_name} must not be empty")
if not all(isinstance(item, int) and item > 0 for item in data):
raise ValueError(f"{argument_name} must be a list of integers > 0")
sfc-gh-rsureshbabu marked this conversation as resolved.
Show resolved Hide resolved

def _validate_formatter_argument(self, fromatter):
if not callable(fromatter):
raise TypeError("formatter must be a callable function")

def moving_agg(
self,
aggs: Dict[str, List[str]],
window_sizes: List[int],
order_by: List[str],
group_by: List[str],
col_formatter: Callable[[str, str, int], str] = _default_col_formatter,
) -> "snowflake.snowpark.dataframe.DataFrame":
"""
Applies moving aggregations to the specified columns of the DataFrame using defined window sizes,
and grouping and ordering criteria.

Args:
aggs: A dictionary where keys are column names and values are lists of the desired aggregation functions.
sfc-gh-rsureshbabu marked this conversation as resolved.
Show resolved Hide resolved
window_sizes: A list of positive integers, each representing the size of the window for which to
calculate the moving aggregate.
order_by: A list of column names that specify the order in which rows are processed.
group_by: A list of column names on which the DataFrame is partitioned for separate window calculations.
col_formatter: An optional function for formatting output column names, defaulting to the format '<input_col>_<agg>_<window>'.
This function takes three arguments: 'input_col' (str) for the column name, 'operation' (str) for the applied operation,
and 'value' (int) for the window size, and returns a formatted string for the column name.

Returns:
A Snowflake DataFrame with additional columns corresponding to each specified moving aggregation.

Raises:
ValueError: If an unsupported value is specified in arguments.
TypeError: If an unsupported type is specified in arguments.
SnowparkSQLException: If an unsupported aggregration is specified.

Example:
aggregated_df = moving_agg(
aggs={"SALESAMOUNT": ['SUM', 'AVG']},
window_sizes=[1, 2, 3, 7],
order_by=['ORDERDATE'],
group_by=['PRODUCTKEY']
)
sfc-gh-rsureshbabu marked this conversation as resolved.
Show resolved Hide resolved
"""
# Validate input arguments
self._validate_aggs_argument(aggs)
self._validate_string_list_argument(order_by, "order_by")
self._validate_string_list_argument(group_by, "group_by")
self._validate_positive_integer_list_argument(window_sizes, "window_sizes")
self._validate_formatter_argument(col_formatter)

# Perform window aggregation
agg_df = self._df
for column, agg_funcs in aggs.items():
for window_size in window_sizes:
for agg_func in agg_funcs:
window_spec = (
Window.partition_by(group_by)
.order_by(order_by)
.rows_between(-window_size + 1, 0)
)

# Apply the user-specified aggregation function directly. Snowflake will handle any errors for invalid functions.
agg_col = expr(f"{agg_func}({column})").over(window_spec)

formatted_col_name = col_formatter(column, agg_func, window_size)
agg_df = agg_df.with_column(formatted_col_name, agg_col)

return agg_df
135 changes: 135 additions & 0 deletions tests/integ/test_df_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#!/usr/bin/env python3
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

import pandas as pd
import pytest
from pandas.testing import assert_frame_equal

from snowflake.snowpark.exceptions import SnowparkSQLException


def get_sample_dataframe(session):
data = [
["2023-01-01", 101, 200],
["2023-01-02", 101, 100],
["2023-01-03", 101, 300],
["2023-01-04", 102, 250],
]
return session.create_dataframe(data).to_df(
"ORDERDATE", "PRODUCTKEY", "SALESAMOUNT"
)


def test_moving_agg(session):
"""Tests df.transform.moving_agg() happy path."""

df = get_sample_dataframe(session)

res = df.transform.moving_agg(
aggs={"SALESAMOUNT": ["SUM", "AVG"]},
window_sizes=[2, 3],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
)

expected_data = {
"ORDERDATE": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"],
"PRODUCTKEY": [101, 101, 101, 102],
"SALESAMOUNT": [200, 100, 300, 250],
"SALESAMOUNT_SUM_2": [200, 300, 400, 250],
"SALESAMOUNT_AVG_2": [200.0, 150.0, 200.0, 250.0],
"SALESAMOUNT_SUM_3": [200, 300, 600, 250],
"SALESAMOUNT_AVG_3": [200.0, 150.0, 200.0, 250.0],
}
expected_df = pd.DataFrame(expected_data)
assert_frame_equal(
res.order_by("ORDERDATE").to_pandas(), expected_df, check_dtype=False, atol=1e-1
)


def test_moving_agg_custom_formatting(session):
"""Tests df.transform.moving_agg() with custom formatting of output columns."""

df = get_sample_dataframe(session)

def custom_formatter(input_col, agg, window):
return f"{window}_{agg}_{input_col}"

res = df.transform.moving_agg(
aggs={"SALESAMOUNT": ["SUM", "AVG"]},
window_sizes=[2, 3],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
col_formatter=custom_formatter,
)

expected_data = {
"ORDERDATE": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"],
"PRODUCTKEY": [101, 101, 101, 102],
"SALESAMOUNT": [200, 100, 300, 250],
"2_SUM_SALESAMOUNT": [200, 300, 400, 250],
"2_AVG_SALESAMOUNT": [200.0, 150.0, 200.0, 250.0],
"3_SUM_SALESAMOUNT": [200, 300, 600, 250],
"3_AVG_SALESAMOUNT": [200.0, 150.0, 200.0, 250.0],
}
expected_df = pd.DataFrame(expected_data)
assert_frame_equal(
res.order_by("ORDERDATE").to_pandas(), expected_df, check_dtype=False, atol=1e-1
)


def test_moving_agg_invalid_inputs(session):
"""Tests df.transform.moving_agg() with invalid window sizes."""

df = get_sample_dataframe(session)

with pytest.raises(ValueError) as exc:
df.transform.moving_agg(
aggs={"SALESAMOUNT": ["AVG"]},
window_sizes=[-1, 2, 3],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
).collect()
assert "window_sizes must be a list of integers > 0" in str(exc)

with pytest.raises(ValueError) as exc:
df.transform.moving_agg(
aggs={"SALESAMOUNT": ["AVG"]},
window_sizes=[0, 2, 3],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
).collect()
assert "window_sizes must be a list of integers > 0" in str(exc)

with pytest.raises(ValueError) as exc:
df.transform.moving_agg(
aggs={"SALESAMOUNT": []},
window_sizes=[0, 2, 3],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
).collect()
assert "non-empty lists of strings as values" in str(exc)

with pytest.raises(SnowparkSQLException) as exc:
df.transform.moving_agg(
aggs={"SALESAMOUNT": ["INVALID_FUNC"]},
window_sizes=[1],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
).collect()
assert "Sliding window frame unsupported for function" in str(exc)

def bad_formatter(input_col, agg):
return f"{agg}_{input_col}"

with pytest.raises(TypeError) as exc:
df.transform.moving_agg(
aggs={"SALESAMOUNT": ["SUM"]},
window_sizes=[1],
order_by=["ORDERDATE"],
group_by=["PRODUCTKEY"],
col_formatter=bad_formatter,
).collect()
assert "positional arguments but 3 were given" in str(exc)
Loading