-
Notifications
You must be signed in to change notification settings - Fork 117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SNOW-976701: Adding moving aggregation dataframe function #1145
Merged
sfc-gh-rsureshbabu
merged 30 commits into
main
from
rsureshbabu-SNOW-SNOW-976701-movingagg
Jan 25, 2024
Merged
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
6840bfd
changes
sfc-gh-rsureshbabu 266a281
updating changelog
sfc-gh-rsureshbabu 2499a6d
fixing comment
sfc-gh-rsureshbabu 4ba1a64
generalizing default formatter
sfc-gh-rsureshbabu 810bfb8
generalizing default formatter 2
sfc-gh-rsureshbabu b43ffec
fix comment
sfc-gh-rsureshbabu 5a007b2
cleaning argument checks
sfc-gh-rsureshbabu 91ff442
cleaning argument checks
sfc-gh-rsureshbabu faccd9b
refactor
sfc-gh-rsureshbabu e27f0ea
changes
sfc-gh-rsureshbabu e8d4345
fix test
sfc-gh-rsureshbabu ccf6655
changes
sfc-gh-rsureshbabu 8e7adad
changes
sfc-gh-rsureshbabu 9a9a045
changes
sfc-gh-rsureshbabu 706f15f
changes
sfc-gh-rsureshbabu 39a3e8a
fix comment
sfc-gh-rsureshbabu b0c2c85
Merge branch 'main' into rsureshbabu-SNOW-SNOW-976701-movingagg
sfc-gh-rsureshbabu 2599167
skip tests when pandas are not available
sfc-gh-rsureshbabu c4416ad
Merge branch 'main' into rsureshbabu-SNOW-SNOW-976701-movingagg
sfc-gh-rsureshbabu ce061bd
update change log
sfc-gh-rsureshbabu 5c00182
changes
sfc-gh-rsureshbabu e0e18bf
changes
sfc-gh-rsureshbabu 2d05cb8
adding doctest
sfc-gh-rsureshbabu 9423599
renaming
sfc-gh-rsureshbabu 721383f
renaming 2
sfc-gh-rsureshbabu 27df8d2
Merge branch 'main' into rsureshbabu-SNOW-SNOW-976701-movingagg
sfc-gh-rsureshbabu cfb5e4a
fix comments
sfc-gh-rsureshbabu 9583161
fix error message
sfc-gh-rsureshbabu 045f216
fix merge
sfc-gh-rsureshbabu 05a2da5
fix code coverage
sfc-gh-rsureshbabu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
170 changes: 170 additions & 0 deletions
170
src/snowflake/snowpark/dataframe_analytics_functions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
# | ||
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. | ||
# | ||
|
||
from typing import Callable, Dict, List | ||
|
||
import snowflake.snowpark | ||
from snowflake.snowpark.functions import expr | ||
from snowflake.snowpark.window import Window | ||
|
||
|
||
class DataFrameAnalyticsFunctions: | ||
"""Provides data analytics functions for DataFrames. | ||
To access an object of this class, use :attr:`DataFrame.analytics`. | ||
""" | ||
|
||
def __init__(self, df: "snowflake.snowpark.DataFrame") -> None: | ||
self._df = df | ||
|
||
def _default_col_formatter(input_col: str, operation: str, *args) -> str: | ||
args_str = "_".join(map(str, args)) | ||
formatted_name = f"{input_col}_{operation}" | ||
if args_str: | ||
formatted_name += f"_{args_str}" | ||
return formatted_name | ||
|
||
def _validate_aggs_argument(self, aggs): | ||
argument_requirements = ( | ||
"The 'aggs' argument must adhere to the following rules: " | ||
"1) It must be a dictionary. " | ||
"2) It must not be empty. " | ||
"3) All keys must be strings. " | ||
"4) All values must be non-empty lists of strings." | ||
) | ||
|
||
if not isinstance(aggs, dict): | ||
raise TypeError(f"aggs must be a dictionary. {argument_requirements}") | ||
if not aggs: | ||
raise ValueError(f"aggs must not be empty. {argument_requirements}") | ||
if not all( | ||
isinstance(key, str) and isinstance(val, list) and val | ||
for key, val in aggs.items() | ||
): | ||
raise ValueError( | ||
f"aggs must have strings as keys and non-empty lists of strings as values. {argument_requirements}" | ||
) | ||
|
||
def _validate_string_list_argument(self, data, argument_name): | ||
argument_requirements = ( | ||
f"The '{argument_name}' argument must adhere to the following rules: " | ||
"1) It must be a list. " | ||
"2) It must not be empty. " | ||
"3) All items in the list must be strings." | ||
) | ||
if not isinstance(data, list): | ||
raise TypeError(f"{argument_name} must be a list. {argument_requirements}") | ||
if not data: | ||
raise ValueError( | ||
f"{argument_name} must not be empty. {argument_requirements}" | ||
) | ||
if not all(isinstance(item, str) for item in data): | ||
raise ValueError( | ||
f"{argument_name} must be a list of strings. {argument_requirements}" | ||
) | ||
|
||
def _validate_positive_integer_list_argument(self, data, argument_name): | ||
argument_requirements = ( | ||
f"The '{argument_name}' argument must adhere to the following criteria: " | ||
"1) It must be a list. " | ||
"2) It must not be empty. " | ||
"3) All items in the list must be positive integers." | ||
) | ||
if not isinstance(data, list): | ||
raise TypeError(f"{argument_name} must be a list. {argument_requirements}") | ||
if not data: | ||
raise ValueError( | ||
f"{argument_name} must not be empty. {argument_requirements}" | ||
) | ||
if not all(isinstance(item, int) and item > 0 for item in data): | ||
raise ValueError( | ||
f"{argument_name} must be a list of integers > 0. {argument_requirements}" | ||
) | ||
|
||
def _validate_formatter_argument(self, fromatter): | ||
if not callable(fromatter): | ||
raise TypeError("formatter must be a callable function") | ||
|
||
def moving_agg( | ||
self, | ||
aggs: Dict[str, List[str]], | ||
window_sizes: List[int], | ||
order_by: List[str], | ||
group_by: List[str], | ||
col_formatter: Callable[[str, str, int], str] = _default_col_formatter, | ||
) -> "snowflake.snowpark.dataframe.DataFrame": | ||
""" | ||
Applies moving aggregations to the specified columns of the DataFrame using defined window sizes, | ||
and grouping and ordering criteria. | ||
|
||
Args: | ||
aggs: A dictionary where keys are column names and values are lists of the desired aggregation functions. | ||
Supported aggregation are listed here https://docs.snowflake.com/en/sql-reference/functions-analytic#list-of-functions-that-support-windows. | ||
window_sizes: A list of positive integers, each representing the size of the window for which to | ||
calculate the moving aggregate. | ||
order_by: A list of column names that specify the order in which rows are processed. | ||
group_by: A list of column names on which the DataFrame is partitioned for separate window calculations. | ||
col_formatter: An optional function for formatting output column names, defaulting to the format '<input_col>_<agg>_<window>'. | ||
This function takes three arguments: 'input_col' (str) for the column name, 'operation' (str) for the applied operation, | ||
and 'value' (int) for the window size, and returns a formatted string for the column name. | ||
|
||
Returns: | ||
A Snowflake DataFrame with additional columns corresponding to each specified moving aggregation. | ||
|
||
Raises: | ||
ValueError: If an unsupported value is specified in arguments. | ||
TypeError: If an unsupported type is specified in arguments. | ||
SnowparkSQLException: If an unsupported aggregration is specified. | ||
|
||
Example: | ||
>>> data = [ | ||
... ["2023-01-01", 101, 200], | ||
... ["2023-01-02", 101, 100], | ||
... ["2023-01-03", 101, 300], | ||
... ["2023-01-04", 102, 250], | ||
... ] | ||
>>> df = session.create_dataframe(data).to_df( | ||
... "ORDERDATE", "PRODUCTKEY", "SALESAMOUNT" | ||
... ) | ||
>>> result = df.analytics.moving_agg( | ||
... aggs={"SALESAMOUNT": ["SUM", "AVG"]}, | ||
... window_sizes=[2, 3], | ||
... order_by=["ORDERDATE"], | ||
... group_by=["PRODUCTKEY"], | ||
... ) | ||
>>> result.show() | ||
-------------------------------------------------------------------------------------------------------------------------------------- | ||
|"ORDERDATE" |"PRODUCTKEY" |"SALESAMOUNT" |"SALESAMOUNT_SUM_2" |"SALESAMOUNT_AVG_2" |"SALESAMOUNT_SUM_3" |"SALESAMOUNT_AVG_3" | | ||
-------------------------------------------------------------------------------------------------------------------------------------- | ||
|2023-01-04 |102 |250 |250 |250.000 |250 |250.000 | | ||
|2023-01-01 |101 |200 |200 |200.000 |200 |200.000 | | ||
|2023-01-02 |101 |100 |300 |150.000 |300 |150.000 | | ||
|2023-01-03 |101 |300 |400 |200.000 |600 |200.000 | | ||
-------------------------------------------------------------------------------------------------------------------------------------- | ||
<BLANKLINE> | ||
""" | ||
# Validate input arguments | ||
self._validate_aggs_argument(aggs) | ||
self._validate_string_list_argument(order_by, "order_by") | ||
self._validate_string_list_argument(group_by, "group_by") | ||
self._validate_positive_integer_list_argument(window_sizes, "window_sizes") | ||
self._validate_formatter_argument(col_formatter) | ||
|
||
# Perform window aggregation | ||
agg_df = self._df | ||
for column, agg_funcs in aggs.items(): | ||
for window_size in window_sizes: | ||
for agg_func in agg_funcs: | ||
window_spec = ( | ||
Window.partition_by(group_by) | ||
.order_by(order_by) | ||
.rows_between(-window_size + 1, 0) | ||
) | ||
|
||
# Apply the user-specified aggregation function directly. Snowflake will handle any errors for invalid functions. | ||
agg_col = expr(f"{agg_func}({column})").over(window_spec) | ||
|
||
formatted_col_name = col_formatter(column, agg_func, window_size) | ||
agg_df = agg_df.with_column(formatted_col_name, agg_col) | ||
|
||
return agg_df |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are they all mandatory arguments?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, except col_formatter