diff --git a/CHANGELOG.md b/CHANGELOG.md index 75d86cbe90b..d7281d60f6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ - `array_except` - `create_map` - `sign`/`signum` +- Added moving_agg function in DataFrame.analytics for enabling moving aggregations like sums and averages with multiple window sizes. ### Bug Fixes diff --git a/src/snowflake/snowpark/__init__.py b/src/snowflake/snowpark/__init__.py index d2b28205757..bf049c957a6 100644 --- a/src/snowflake/snowpark/__init__.py +++ b/src/snowflake/snowpark/__init__.py @@ -19,6 +19,7 @@ "GetResult", "DataFrame", "DataFrameStatFunctions", + "DataFrameAnalyticsFunctions", "DataFrameNaFunctions", "DataFrameWriter", "DataFrameReader", @@ -46,6 +47,7 @@ from snowflake.snowpark.async_job import AsyncJob from snowflake.snowpark.column import CaseExpr, Column from snowflake.snowpark.dataframe import DataFrame +from snowflake.snowpark.dataframe_analytics_functions import DataFrameAnalyticsFunctions from snowflake.snowpark.dataframe_na_functions import DataFrameNaFunctions from snowflake.snowpark.dataframe_reader import DataFrameReader from snowflake.snowpark.dataframe_stat_functions import DataFrameStatFunctions diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index fbd4b225442..6f4c447675b 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -121,6 +121,7 @@ ) from snowflake.snowpark.async_job import AsyncJob, _AsyncResultType from snowflake.snowpark.column import Column, _to_col_if_sql_expr, _to_col_if_str +from snowflake.snowpark.dataframe_analytics_functions import DataFrameAnalyticsFunctions from snowflake.snowpark.dataframe_na_functions import DataFrameNaFunctions from snowflake.snowpark.dataframe_stat_functions import DataFrameStatFunctions from snowflake.snowpark.dataframe_writer import DataFrameWriter @@ -522,6 +523,7 @@ def __init__( self._writer = DataFrameWriter(self) self._stat = DataFrameStatFunctions(self) + self._analytics = DataFrameAnalyticsFunctions(self) self.approxQuantile = self.approx_quantile = self._stat.approx_quantile self.corr = self._stat.corr self.cov = self._stat.cov @@ -539,6 +541,10 @@ def __init__( def stat(self) -> DataFrameStatFunctions: return self._stat + @property + def analytics(self) -> DataFrameAnalyticsFunctions: + return self._analytics + @overload def collect( self, diff --git a/src/snowflake/snowpark/dataframe_analytics_functions.py b/src/snowflake/snowpark/dataframe_analytics_functions.py new file mode 100644 index 00000000000..9f69c6c327a --- /dev/null +++ b/src/snowflake/snowpark/dataframe_analytics_functions.py @@ -0,0 +1,170 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from typing import Callable, Dict, List + +import snowflake.snowpark +from snowflake.snowpark.functions import expr +from snowflake.snowpark.window import Window + + +class DataFrameAnalyticsFunctions: + """Provides data analytics functions for DataFrames. + To access an object of this class, use :attr:`DataFrame.analytics`. + """ + + def __init__(self, df: "snowflake.snowpark.DataFrame") -> None: + self._df = df + + def _default_col_formatter(input_col: str, operation: str, *args) -> str: + args_str = "_".join(map(str, args)) + formatted_name = f"{input_col}_{operation}" + if args_str: + formatted_name += f"_{args_str}" + return formatted_name + + def _validate_aggs_argument(self, aggs): + argument_requirements = ( + "The 'aggs' argument must adhere to the following rules: " + "1) It must be a dictionary. " + "2) It must not be empty. " + "3) All keys must be strings. " + "4) All values must be non-empty lists of strings." + ) + + if not isinstance(aggs, dict): + raise TypeError(f"aggs must be a dictionary. {argument_requirements}") + if not aggs: + raise ValueError(f"aggs must not be empty. {argument_requirements}") + if not all( + isinstance(key, str) and isinstance(val, list) and val + for key, val in aggs.items() + ): + raise ValueError( + f"aggs must have strings as keys and non-empty lists of strings as values. {argument_requirements}" + ) + + def _validate_string_list_argument(self, data, argument_name): + argument_requirements = ( + f"The '{argument_name}' argument must adhere to the following rules: " + "1) It must be a list. " + "2) It must not be empty. " + "3) All items in the list must be strings." + ) + if not isinstance(data, list): + raise TypeError(f"{argument_name} must be a list. {argument_requirements}") + if not data: + raise ValueError( + f"{argument_name} must not be empty. {argument_requirements}" + ) + if not all(isinstance(item, str) for item in data): + raise ValueError( + f"{argument_name} must be a list of strings. {argument_requirements}" + ) + + def _validate_positive_integer_list_argument(self, data, argument_name): + argument_requirements = ( + f"The '{argument_name}' argument must adhere to the following criteria: " + "1) It must be a list. " + "2) It must not be empty. " + "3) All items in the list must be positive integers." + ) + if not isinstance(data, list): + raise TypeError(f"{argument_name} must be a list. {argument_requirements}") + if not data: + raise ValueError( + f"{argument_name} must not be empty. {argument_requirements}" + ) + if not all(isinstance(item, int) and item > 0 for item in data): + raise ValueError( + f"{argument_name} must be a list of integers > 0. {argument_requirements}" + ) + + def _validate_formatter_argument(self, fromatter): + if not callable(fromatter): + raise TypeError("formatter must be a callable function") + + def moving_agg( + self, + aggs: Dict[str, List[str]], + window_sizes: List[int], + order_by: List[str], + group_by: List[str], + col_formatter: Callable[[str, str, int], str] = _default_col_formatter, + ) -> "snowflake.snowpark.dataframe.DataFrame": + """ + Applies moving aggregations to the specified columns of the DataFrame using defined window sizes, + and grouping and ordering criteria. + + Args: + aggs: A dictionary where keys are column names and values are lists of the desired aggregation functions. + Supported aggregation are listed here https://docs.snowflake.com/en/sql-reference/functions-analytic#list-of-functions-that-support-windows. + window_sizes: A list of positive integers, each representing the size of the window for which to + calculate the moving aggregate. + order_by: A list of column names that specify the order in which rows are processed. + group_by: A list of column names on which the DataFrame is partitioned for separate window calculations. + col_formatter: An optional function for formatting output column names, defaulting to the format '__'. + This function takes three arguments: 'input_col' (str) for the column name, 'operation' (str) for the applied operation, + and 'value' (int) for the window size, and returns a formatted string for the column name. + + Returns: + A Snowflake DataFrame with additional columns corresponding to each specified moving aggregation. + + Raises: + ValueError: If an unsupported value is specified in arguments. + TypeError: If an unsupported type is specified in arguments. + SnowparkSQLException: If an unsupported aggregration is specified. + + Example: + >>> data = [ + ... ["2023-01-01", 101, 200], + ... ["2023-01-02", 101, 100], + ... ["2023-01-03", 101, 300], + ... ["2023-01-04", 102, 250], + ... ] + >>> df = session.create_dataframe(data).to_df( + ... "ORDERDATE", "PRODUCTKEY", "SALESAMOUNT" + ... ) + >>> result = df.analytics.moving_agg( + ... aggs={"SALESAMOUNT": ["SUM", "AVG"]}, + ... window_sizes=[2, 3], + ... order_by=["ORDERDATE"], + ... group_by=["PRODUCTKEY"], + ... ) + >>> result.show() + -------------------------------------------------------------------------------------------------------------------------------------- + |"ORDERDATE" |"PRODUCTKEY" |"SALESAMOUNT" |"SALESAMOUNT_SUM_2" |"SALESAMOUNT_AVG_2" |"SALESAMOUNT_SUM_3" |"SALESAMOUNT_AVG_3" | + -------------------------------------------------------------------------------------------------------------------------------------- + |2023-01-04 |102 |250 |250 |250.000 |250 |250.000 | + |2023-01-01 |101 |200 |200 |200.000 |200 |200.000 | + |2023-01-02 |101 |100 |300 |150.000 |300 |150.000 | + |2023-01-03 |101 |300 |400 |200.000 |600 |200.000 | + -------------------------------------------------------------------------------------------------------------------------------------- + + """ + # Validate input arguments + self._validate_aggs_argument(aggs) + self._validate_string_list_argument(order_by, "order_by") + self._validate_string_list_argument(group_by, "group_by") + self._validate_positive_integer_list_argument(window_sizes, "window_sizes") + self._validate_formatter_argument(col_formatter) + + # Perform window aggregation + agg_df = self._df + for column, agg_funcs in aggs.items(): + for window_size in window_sizes: + for agg_func in agg_funcs: + window_spec = ( + Window.partition_by(group_by) + .order_by(order_by) + .rows_between(-window_size + 1, 0) + ) + + # Apply the user-specified aggregation function directly. Snowflake will handle any errors for invalid functions. + agg_col = expr(f"{agg_func}({column})").over(window_spec) + + formatted_col_name = col_formatter(column, agg_func, window_size) + agg_df = agg_df.with_column(formatted_col_name, agg_col) + + return agg_df diff --git a/tests/integ/test_df_analytics.py b/tests/integ/test_df_analytics.py new file mode 100644 index 00000000000..f747523b28e --- /dev/null +++ b/tests/integ/test_df_analytics.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +try: + import pandas as pd + from pandas.testing import assert_frame_equal + + is_pandas_available = True +except ImportError: + is_pandas_available = False + +import pytest + +from snowflake.snowpark.exceptions import SnowparkSQLException + + +def get_sample_dataframe(session): + data = [ + ["2023-01-01", 101, 200], + ["2023-01-02", 101, 100], + ["2023-01-03", 101, 300], + ["2023-01-04", 102, 250], + ] + return session.create_dataframe(data).to_df( + "ORDERDATE", "PRODUCTKEY", "SALESAMOUNT" + ) + + +@pytest.mark.skipif(not is_pandas_available, reason="pandas is required") +def test_moving_agg(session): + """Tests df.analytics.moving_agg() happy path.""" + + df = get_sample_dataframe(session) + + res = df.analytics.moving_agg( + aggs={"SALESAMOUNT": ["SUM", "AVG"]}, + window_sizes=[2, 3], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ) + + expected_data = { + "ORDERDATE": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"], + "PRODUCTKEY": [101, 101, 101, 102], + "SALESAMOUNT": [200, 100, 300, 250], + "SALESAMOUNT_SUM_2": [200, 300, 400, 250], + "SALESAMOUNT_AVG_2": [200.0, 150.0, 200.0, 250.0], + "SALESAMOUNT_SUM_3": [200, 300, 600, 250], + "SALESAMOUNT_AVG_3": [200.0, 150.0, 200.0, 250.0], + } + expected_df = pd.DataFrame(expected_data) + assert_frame_equal( + res.order_by("ORDERDATE").to_pandas(), expected_df, check_dtype=False, atol=1e-1 + ) + + +@pytest.mark.skipif(not is_pandas_available, reason="pandas is required") +def test_moving_agg_custom_formatting(session): + """Tests df.analytics.moving_agg() with custom formatting of output columns.""" + + df = get_sample_dataframe(session) + + def custom_formatter(input_col, agg, window): + return f"{window}_{agg}_{input_col}" + + res = df.analytics.moving_agg( + aggs={"SALESAMOUNT": ["SUM", "AVG"]}, + window_sizes=[2, 3], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + col_formatter=custom_formatter, + ) + + expected_data = { + "ORDERDATE": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"], + "PRODUCTKEY": [101, 101, 101, 102], + "SALESAMOUNT": [200, 100, 300, 250], + "2_SUM_SALESAMOUNT": [200, 300, 400, 250], + "2_AVG_SALESAMOUNT": [200.0, 150.0, 200.0, 250.0], + "3_SUM_SALESAMOUNT": [200, 300, 600, 250], + "3_AVG_SALESAMOUNT": [200.0, 150.0, 200.0, 250.0], + } + expected_df = pd.DataFrame(expected_data) + assert_frame_equal( + res.order_by("ORDERDATE").to_pandas(), expected_df, check_dtype=False, atol=1e-1 + ) + + # With default formatter + res = df.analytics.moving_agg( + aggs={"SALESAMOUNT": ["SUM", "AVG"]}, + window_sizes=[2, 3], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ) + + expected_data = { + "ORDERDATE": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"], + "PRODUCTKEY": [101, 101, 101, 102], + "SALESAMOUNT": [200, 100, 300, 250], + "SALESAMOUNT_SUM_2": [200, 300, 400, 250], + "SALESAMOUNT_AVG_2": [200.0, 150.0, 200.0, 250.0], + "SALESAMOUNT_SUM_3": [200, 300, 600, 250], + "SALESAMOUNT_AVG_3": [200.0, 150.0, 200.0, 250.0], + } + + expected_df = pd.DataFrame(expected_data) + assert_frame_equal( + res.order_by("ORDERDATE").to_pandas(), expected_df, check_dtype=False, atol=1e-1 + ) + + +@pytest.mark.skipif(not is_pandas_available, reason="pandas is required") +def test_moving_agg_invalid_inputs(session): + """Tests df.analytics.moving_agg() with invalid window sizes.""" + + df = get_sample_dataframe(session) + + with pytest.raises(TypeError) as exc: + df.analytics.moving_agg( + aggs=["AVG"], + window_sizes=[1, 2, 3], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ).collect() + assert "aggs must be a dictionary" in str(exc) + + with pytest.raises(ValueError) as exc: + df.analytics.moving_agg( + aggs={}, + window_sizes=[1, 2, 3], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ).collect() + assert "aggs must not be empty" in str(exc) + + with pytest.raises(ValueError) as exc: + df.analytics.moving_agg( + aggs={"SALESAMOUNT": []}, + window_sizes=[1, 2, 3], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ).collect() + assert "non-empty lists of strings as values" in str(exc) + + with pytest.raises(TypeError) as exc: + df.analytics.moving_agg( + aggs={"SALESAMOUNT": ["AVG"]}, + window_sizes=[1, 2, 3], + order_by="ORDERDATE", + group_by=["PRODUCTKEY"], + ).collect() + assert "order_by must be a list" in str(exc) + + with pytest.raises(ValueError) as exc: + df.analytics.moving_agg( + aggs={"SALESAMOUNT": ["AVG"]}, + window_sizes=[1, 2, 3], + order_by=[], + group_by=["PRODUCTKEY"], + ).collect() + assert "order_by must not be empty" in str(exc) + + with pytest.raises(ValueError) as exc: + df.analytics.moving_agg( + aggs={"SALESAMOUNT": ["AVG"]}, + window_sizes=[1, 2, 3], + order_by=[1], + group_by=["PRODUCTKEY"], + ).collect() + assert "order_by must be a list of strings" in str(exc) + + with pytest.raises(ValueError) as exc: + df.analytics.moving_agg( + aggs={"SALESAMOUNT": ["AVG"]}, + window_sizes=[-1, 2, 3], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ).collect() + assert "window_sizes must be a list of integers > 0" in str(exc) + + with pytest.raises(ValueError) as exc: + df.analytics.moving_agg( + aggs={"SALESAMOUNT": ["AVG"]}, + window_sizes=[0, 2, 3], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ).collect() + assert "window_sizes must be a list of integers > 0" in str(exc) + + with pytest.raises(TypeError) as exc: + df.analytics.moving_agg( + aggs={"SALESAMOUNT": ["AVG"]}, + window_sizes=0, + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ).collect() + assert "window_sizes must be a list" in str(exc) + + with pytest.raises(ValueError) as exc: + df.analytics.moving_agg( + aggs={"SALESAMOUNT": ["AVG"]}, + window_sizes=[], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ).collect() + assert "window_sizes must not be empty" in str(exc) + + with pytest.raises(SnowparkSQLException) as exc: + df.analytics.moving_agg( + aggs={"SALESAMOUNT": ["INVALID_FUNC"]}, + window_sizes=[1], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ).collect() + assert "Sliding window frame unsupported for function" in str(exc) + + def bad_formatter(input_col, agg): + return f"{agg}_{input_col}" + + with pytest.raises(TypeError) as exc: + df.analytics.moving_agg( + aggs={"SALESAMOUNT": ["SUM"]}, + window_sizes=[1], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + col_formatter=bad_formatter, + ).collect() + assert "positional arguments but 3 were given" in str(exc) + + with pytest.raises(TypeError) as exc: + df.analytics.moving_agg( + aggs={"SALESAMOUNT": ["SUM"]}, + window_sizes=[1], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + col_formatter="bad_formatter", + ).collect() + assert "formatter must be a callable function" in str(exc)