From 6840bfdb9de752a09f76f33bec5d39442499f7d9 Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Wed, 22 Nov 2023 21:25:38 -0800 Subject: [PATCH 01/27] changes --- src/snowflake/snowpark/__init__.py | 2 + src/snowflake/snowpark/dataframe.py | 6 + .../snowpark/dataframe_transform_functions.py | 113 ++++++++++++++++ tests/integ/test_df_transform.py | 122 ++++++++++++++++++ 4 files changed, 243 insertions(+) create mode 100644 src/snowflake/snowpark/dataframe_transform_functions.py create mode 100644 tests/integ/test_df_transform.py diff --git a/src/snowflake/snowpark/__init__.py b/src/snowflake/snowpark/__init__.py index d2b28205757..560693d9522 100644 --- a/src/snowflake/snowpark/__init__.py +++ b/src/snowflake/snowpark/__init__.py @@ -19,6 +19,7 @@ "GetResult", "DataFrame", "DataFrameStatFunctions", + "DataFrameTransformFunctions", "DataFrameNaFunctions", "DataFrameWriter", "DataFrameReader", @@ -49,6 +50,7 @@ from snowflake.snowpark.dataframe_na_functions import DataFrameNaFunctions from snowflake.snowpark.dataframe_reader import DataFrameReader from snowflake.snowpark.dataframe_stat_functions import DataFrameStatFunctions +from snowflake.snowpark.dataframe_transform_functions import DataFrameTransformFunctions from snowflake.snowpark.dataframe_writer import DataFrameWriter from snowflake.snowpark.file_operation import FileOperation, GetResult, PutResult from snowflake.snowpark.query_history import QueryHistory, QueryRecord diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index ca5064bf13d..d6a8457d456 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -123,6 +123,7 @@ from snowflake.snowpark.column import Column, _to_col_if_sql_expr, _to_col_if_str from snowflake.snowpark.dataframe_na_functions import DataFrameNaFunctions from snowflake.snowpark.dataframe_stat_functions import DataFrameStatFunctions +from snowflake.snowpark.dataframe_transform_functions import DataFrameTransformFunctions from snowflake.snowpark.dataframe_writer import DataFrameWriter from snowflake.snowpark.exceptions import SnowparkDataframeException from snowflake.snowpark.functions import ( @@ -521,6 +522,7 @@ def __init__( self._writer = DataFrameWriter(self) self._stat = DataFrameStatFunctions(self) + self._transform = DataFrameTransformFunctions(self) self.approxQuantile = self.approx_quantile = self._stat.approx_quantile self.corr = self._stat.corr self.cov = self._stat.cov @@ -538,6 +540,10 @@ def __init__( def stat(self) -> DataFrameStatFunctions: return self._stat + @property + def transform(self) -> DataFrameTransformFunctions: + return self._transform + @overload def collect( self, diff --git a/src/snowflake/snowpark/dataframe_transform_functions.py b/src/snowflake/snowpark/dataframe_transform_functions.py new file mode 100644 index 00000000000..f3382ccaffb --- /dev/null +++ b/src/snowflake/snowpark/dataframe_transform_functions.py @@ -0,0 +1,113 @@ +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +from typing import Callable, Dict, List + +import snowflake.snowpark +from snowflake.snowpark.functions import expr +from snowflake.snowpark.window import Window + + +class DataFrameTransformFunctions: + """Provides data transformation functions for DataFrames. + To access an object of this class, use :attr:`DataFrame.transform`. + """ + + def __init__(self, df: "snowflake.snowpark.DataFrame") -> None: + self._df = df + + def _default_col_formatter(input_col: str, agg: str, window: int) -> str: + return f"{input_col}_{agg}_{window}" + + def _validate_aggs_argument(self, data): + if not isinstance(data, dict): + raise TypeError("aggs must be a dictionary") + if not data or not all( + isinstance(key, str) and isinstance(val, list) and val + for key, val in data.items() + ): + raise ValueError( + "aggs must be a non-empty dictionary with strings as keys and non-empty lists of strings as values" + ) + + def _validate_column_names_argument(self, data, argument_name): + if not isinstance(data, list): + raise TypeError(f"{argument_name} must be a list") + if not data or not all(isinstance(item, str) for item in data): + raise ValueError(f"{argument_name} must be a non-empty list of strings") + + def _validate_formatter_argument(self, data): + if not callable(data): + raise TypeError("formatter must be a callable function") + + def moving_agg( + self, + aggs: Dict[str, List[str]], + window_sizes: List[int], + order_by: List[str], + group_by: List[str], + col_formatter: Callable[[str, str, int], str] = _default_col_formatter, + ) -> "snowflake.snowpark.dataframe.DataFrame": + """ + Applies moving aggregations to the specified columns of the DataFrame using defined window sizes, + and grouping and ordering criteria. + + Args: + df: The Snowflake DataFrame to which the moving aggregations are applied. + aggs: A dictionary where keys are column names and values are lists of the desired aggregation functions. + window_sizes: A list of positive integers, each representing the size of the window for which to + calculate the moving aggregate. + order_by: A list of column names that specify the order in which rows are processed. + group_by: A list of column names on which the DataFrame is partitioned for separate window calculations. + col_formatter: An optional function to format the output column names. Defaults to a built-in formatter + that outputs column names in the format "__". + + Returns: + A Snowflake DataFrame with additional columns corresponding to each specified moving aggregation. + + Raises: + ValueError: If an unsupported aggregation function is specified in 'aggs'. + + Example: + aggregated_df = moving_agg( + df, + aggs={"SALESAMOUNT": ['SUM', 'AVG']}, + window_sizes=[1, 2, 3, 7], + order_by=['ORDERDATE'], + group_by=['PRODUCTKEY'] + ) + """ + # Validate input arguments + self._validate_aggs_argument(aggs) + self._validate_column_names_argument(order_by, "order_by") + self._validate_column_names_argument(group_by, "group_by") + self._validate_formatter_argument(col_formatter) + + if not isinstance(window_sizes, list): + raise TypeError("window_sizes must be a list") + if not window_sizes or not all( + isinstance(item, int) and item > 0 for item in window_sizes + ): + raise ValueError( + "window_sizes must be a non-empty list of positive integers" + ) + + # Perform window aggregation + agg_df = self._df + for column, agg_funcs in aggs.items(): + for window_size in window_sizes: + for agg_func in agg_funcs: + window_spec = ( + Window.partition_by(group_by) + .order_by(order_by) + .rows_between(-window_size + 1, 0) + ) + + # Apply the user-specified aggregation function directly. Snowflake will handle any errors for invalid functions. + agg_col = expr(f"{agg_func}({column})").over(window_spec) + + formatted_col_name = col_formatter(column, agg_func, window_size) + agg_df = agg_df.with_column(formatted_col_name, agg_col) + + return agg_df diff --git a/tests/integ/test_df_transform.py b/tests/integ/test_df_transform.py new file mode 100644 index 00000000000..a578ea7635f --- /dev/null +++ b/tests/integ/test_df_transform.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 +# +# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +# + +import pandas as pd +import pytest +from pandas.testing import assert_frame_equal + +from snowflake.snowpark.exceptions import SnowparkSQLException + + +def get_sample_dataframe(session): + data = [ + ["2023-01-01", 101, 200], + ["2023-01-02", 101, 100], + ["2023-01-03", 101, 300], + ["2023-01-04", 102, 250], + ] + return session.create_dataframe(data).to_df( + "ORDERDATE", "PRODUCTKEY", "SALESAMOUNT" + ) + + +def test_moving_agg(session): + """Tests df.transform.moving_agg() happy path.""" + + df = get_sample_dataframe(session) + + res = df.transform.moving_agg( + aggs={"SALESAMOUNT": ["SUM", "AVG"]}, + window_sizes=[2, 3], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ) + + expected_data = { + "ORDERDATE": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"], + "PRODUCTKEY": [101, 101, 101, 102], + "SALESAMOUNT": [200, 100, 300, 250], + "SALESAMOUNT_SUM_2": [200, 300, 400, 250], + "SALESAMOUNT_AVG_2": [200.0, 150.0, 200.0, 250.0], + "SALESAMOUNT_SUM_3": [200, 300, 600, 250], + "SALESAMOUNT_AVG_3": [200.0, 150.0, 200.0, 250.0], + } + expected_df = pd.DataFrame(expected_data) + assert_frame_equal( + res.order_by("ORDERDATE").to_pandas(), expected_df, check_dtype=False, atol=1e-1 + ) + + +def test_moving_agg_custom_formatting(session): + """Tests df.transform.moving_agg() with custom formatting of output columns.""" + + df = get_sample_dataframe(session) + + def custom_formatter(input_col, agg, window): + return f"{window}_{agg}_{input_col}" + + res = df.transform.moving_agg( + aggs={"SALESAMOUNT": ["SUM", "AVG"]}, + window_sizes=[2, 3], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + col_formatter=custom_formatter, + ) + + expected_data = { + "ORDERDATE": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"], + "PRODUCTKEY": [101, 101, 101, 102], + "SALESAMOUNT": [200, 100, 300, 250], + "2_SUM_SALESAMOUNT": [200, 300, 400, 250], + "2_AVG_SALESAMOUNT": [200.0, 150.0, 200.0, 250.0], + "3_SUM_SALESAMOUNT": [200, 300, 600, 250], + "3_AVG_SALESAMOUNT": [200.0, 150.0, 200.0, 250.0], + } + expected_df = pd.DataFrame(expected_data) + assert_frame_equal( + res.order_by("ORDERDATE").to_pandas(), expected_df, check_dtype=False, atol=1e-1 + ) + + +def test_moving_agg_invalid_inputs(session): + """Tests df.transform.moving_agg() with invalid window sizes.""" + + df = get_sample_dataframe(session) + + with pytest.raises(ValueError) as exc: + df.transform.moving_agg( + aggs={"SALESAMOUNT": ["AVG"]}, + window_sizes=[-1, 2, 3], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ).collect() + assert "window_sizes must be a non-empty list of positive integers" in str(exc) + + with pytest.raises(ValueError) as exc: + df.transform.moving_agg( + aggs={"SALESAMOUNT": ["AVG"]}, + window_sizes=[0, 2, 3], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ).collect() + assert "window_sizes must be a non-empty list of positive integers" in str(exc) + + with pytest.raises(ValueError) as exc: + df.transform.moving_agg( + aggs={"SALESAMOUNT": []}, + window_sizes=[0, 2, 3], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ).collect() + assert "non-empty lists of strings as values" in str(exc) + + with pytest.raises(SnowparkSQLException) as exc: + df.transform.moving_agg( + aggs={"SALESAMOUNT": ["INVALID_FUNC"]}, + window_sizes=[1], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ).collect() + assert "Sliding window frame unsupported for function" in str(exc) From 266a2816782d4476c117e6839c7a2e846d6205f7 Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Thu, 23 Nov 2023 09:16:04 -0800 Subject: [PATCH 02/27] updating changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f489b24fc90..72efc3ac928 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ - Add the `conn_error` attribute to `SnowflakeSQLException` that stores the whole underlying exception from `snowflake-connector-python` +### New Features + +- Added moving_agg function in DataFrame.transform for time series analysis, enabling moving aggregations like sums and averages with multiple window sizes. + ### Bug Fixes - DataFrame column names qouting check now supports newline characters. From 2499a6d70a29c71c008d2e09f956268f83b6b353 Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Thu, 23 Nov 2023 09:32:41 -0800 Subject: [PATCH 03/27] fixing comment --- src/snowflake/snowpark/dataframe_transform_functions.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/dataframe_transform_functions.py b/src/snowflake/snowpark/dataframe_transform_functions.py index f3382ccaffb..2abb312f374 100644 --- a/src/snowflake/snowpark/dataframe_transform_functions.py +++ b/src/snowflake/snowpark/dataframe_transform_functions.py @@ -54,7 +54,6 @@ def moving_agg( and grouping and ordering criteria. Args: - df: The Snowflake DataFrame to which the moving aggregations are applied. aggs: A dictionary where keys are column names and values are lists of the desired aggregation functions. window_sizes: A list of positive integers, each representing the size of the window for which to calculate the moving aggregate. @@ -67,7 +66,9 @@ def moving_agg( A Snowflake DataFrame with additional columns corresponding to each specified moving aggregation. Raises: - ValueError: If an unsupported aggregation function is specified in 'aggs'. + ValueError: If an unsupported value is specified in arguments. + TypeError: If an unsupported type is specified in arguments. + SnowparkSQLException: If an unsupported aggregration is specified. Example: aggregated_df = moving_agg( From 4ba1a64c6e6a71e291a1344435ff865a50a752f3 Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Thu, 23 Nov 2023 10:28:32 -0800 Subject: [PATCH 04/27] generalizing default formatter --- src/snowflake/snowpark/dataframe_transform_functions.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/dataframe_transform_functions.py b/src/snowflake/snowpark/dataframe_transform_functions.py index 2abb312f374..0138b187b19 100644 --- a/src/snowflake/snowpark/dataframe_transform_functions.py +++ b/src/snowflake/snowpark/dataframe_transform_functions.py @@ -17,8 +17,12 @@ class DataFrameTransformFunctions: def __init__(self, df: "snowflake.snowpark.DataFrame") -> None: self._df = df - def _default_col_formatter(input_col: str, agg: str, window: int) -> str: - return f"{input_col}_{agg}_{window}" + def _default_col_formatter(input_col: str, operation: str, *args) -> str: + additional_args_str = "_".join(map(str, args)) + formatted_name = f"{input_col}_{operation}" + if additional_args_str: + formatted_name += f"_{additional_args_str}" + return formatted_name def _validate_aggs_argument(self, data): if not isinstance(data, dict): From 810bfb85f10242a25df5c4709ae2391adfa447d2 Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Thu, 23 Nov 2023 10:29:35 -0800 Subject: [PATCH 05/27] generalizing default formatter 2 --- src/snowflake/snowpark/dataframe_transform_functions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/snowflake/snowpark/dataframe_transform_functions.py b/src/snowflake/snowpark/dataframe_transform_functions.py index 0138b187b19..dedc3fd6f7e 100644 --- a/src/snowflake/snowpark/dataframe_transform_functions.py +++ b/src/snowflake/snowpark/dataframe_transform_functions.py @@ -18,10 +18,10 @@ def __init__(self, df: "snowflake.snowpark.DataFrame") -> None: self._df = df def _default_col_formatter(input_col: str, operation: str, *args) -> str: - additional_args_str = "_".join(map(str, args)) + args_str = "_".join(map(str, args)) formatted_name = f"{input_col}_{operation}" - if additional_args_str: - formatted_name += f"_{additional_args_str}" + if args_str: + formatted_name += f"_{args_str}" return formatted_name def _validate_aggs_argument(self, data): From b43ffec7fe0857a698dffb13d1b3ed76a1f8c706 Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Thu, 23 Nov 2023 10:51:41 -0800 Subject: [PATCH 06/27] fix comment --- src/snowflake/snowpark/dataframe_transform_functions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/snowflake/snowpark/dataframe_transform_functions.py b/src/snowflake/snowpark/dataframe_transform_functions.py index dedc3fd6f7e..148014d41f7 100644 --- a/src/snowflake/snowpark/dataframe_transform_functions.py +++ b/src/snowflake/snowpark/dataframe_transform_functions.py @@ -76,7 +76,6 @@ def moving_agg( Example: aggregated_df = moving_agg( - df, aggs={"SALESAMOUNT": ['SUM', 'AVG']}, window_sizes=[1, 2, 3, 7], order_by=['ORDERDATE'], From 5a007b2e6fd78eab7d9f57a83df4c35088a68eca Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Thu, 23 Nov 2023 12:06:59 -0800 Subject: [PATCH 07/27] cleaning argument checks --- .../snowpark/dataframe_transform_functions.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/snowflake/snowpark/dataframe_transform_functions.py b/src/snowflake/snowpark/dataframe_transform_functions.py index 148014d41f7..5bf6b327482 100644 --- a/src/snowflake/snowpark/dataframe_transform_functions.py +++ b/src/snowflake/snowpark/dataframe_transform_functions.py @@ -27,7 +27,9 @@ def _default_col_formatter(input_col: str, operation: str, *args) -> str: def _validate_aggs_argument(self, data): if not isinstance(data, dict): raise TypeError("aggs must be a dictionary") - if not data or not all( + if not data: + raise ValueError("aggs must not be empty") + if not all( isinstance(key, str) and isinstance(val, list) and val for key, val in data.items() ): @@ -38,7 +40,9 @@ def _validate_aggs_argument(self, data): def _validate_column_names_argument(self, data, argument_name): if not isinstance(data, list): raise TypeError(f"{argument_name} must be a list") - if not data or not all(isinstance(item, str) for item in data): + if not data: + raise ValueError(f"{argument_name} must not be empty") + if not all(isinstance(item, str) for item in data): raise ValueError(f"{argument_name} must be a non-empty list of strings") def _validate_formatter_argument(self, data): @@ -90,9 +94,9 @@ def moving_agg( if not isinstance(window_sizes, list): raise TypeError("window_sizes must be a list") - if not window_sizes or not all( - isinstance(item, int) and item > 0 for item in window_sizes - ): + if not window_sizes: + raise ValueError("window_sizes must not be empty") + if not all(isinstance(item, int) and item > 0 for item in window_sizes): raise ValueError( "window_sizes must be a non-empty list of positive integers" ) From 91ff44240b791cc6f8c927109e52787c46a3a719 Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Thu, 23 Nov 2023 12:10:57 -0800 Subject: [PATCH 08/27] cleaning argument checks --- .../snowpark/dataframe_transform_functions.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/snowflake/snowpark/dataframe_transform_functions.py b/src/snowflake/snowpark/dataframe_transform_functions.py index 5bf6b327482..ca22da730f0 100644 --- a/src/snowflake/snowpark/dataframe_transform_functions.py +++ b/src/snowflake/snowpark/dataframe_transform_functions.py @@ -24,14 +24,14 @@ def _default_col_formatter(input_col: str, operation: str, *args) -> str: formatted_name += f"_{args_str}" return formatted_name - def _validate_aggs_argument(self, data): - if not isinstance(data, dict): + def _validate_aggs_argument(self, aggs): + if not isinstance(aggs, dict): raise TypeError("aggs must be a dictionary") - if not data: + if not aggs: raise ValueError("aggs must not be empty") if not all( isinstance(key, str) and isinstance(val, list) and val - for key, val in data.items() + for key, val in aggs.items() ): raise ValueError( "aggs must be a non-empty dictionary with strings as keys and non-empty lists of strings as values" @@ -45,8 +45,8 @@ def _validate_column_names_argument(self, data, argument_name): if not all(isinstance(item, str) for item in data): raise ValueError(f"{argument_name} must be a non-empty list of strings") - def _validate_formatter_argument(self, data): - if not callable(data): + def _validate_formatter_argument(self, fromatter): + if not callable(fromatter): raise TypeError("formatter must be a callable function") def moving_agg( From faccd9b074de07f135f58ce8e78dfd69d70a984e Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Thu, 23 Nov 2023 12:26:10 -0800 Subject: [PATCH 09/27] refactor --- .../snowpark/dataframe_transform_functions.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/snowflake/snowpark/dataframe_transform_functions.py b/src/snowflake/snowpark/dataframe_transform_functions.py index ca22da730f0..86cf336b7eb 100644 --- a/src/snowflake/snowpark/dataframe_transform_functions.py +++ b/src/snowflake/snowpark/dataframe_transform_functions.py @@ -34,7 +34,7 @@ def _validate_aggs_argument(self, aggs): for key, val in aggs.items() ): raise ValueError( - "aggs must be a non-empty dictionary with strings as keys and non-empty lists of strings as values" + "aggs must have strings as keys and non-empty lists of strings as values" ) def _validate_column_names_argument(self, data, argument_name): @@ -43,7 +43,15 @@ def _validate_column_names_argument(self, data, argument_name): if not data: raise ValueError(f"{argument_name} must not be empty") if not all(isinstance(item, str) for item in data): - raise ValueError(f"{argument_name} must be a non-empty list of strings") + raise ValueError(f"{argument_name} must be a list of strings") + + def _validate_positive_integer_list(self, data, argument_name): + if not isinstance(data, list): + raise TypeError(f"{argument_name} must be a list") + if not data: + raise ValueError(f"{argument_name} must not be empty") + if not all(isinstance(item, int) and item > 0 for item in data): + raise ValueError(f"{argument_name} must be a list of positive integers") def _validate_formatter_argument(self, fromatter): if not callable(fromatter): @@ -90,17 +98,9 @@ def moving_agg( self._validate_aggs_argument(aggs) self._validate_column_names_argument(order_by, "order_by") self._validate_column_names_argument(group_by, "group_by") + self._validate_positive_integer_list(window_sizes, "window_sizes") self._validate_formatter_argument(col_formatter) - if not isinstance(window_sizes, list): - raise TypeError("window_sizes must be a list") - if not window_sizes: - raise ValueError("window_sizes must not be empty") - if not all(isinstance(item, int) and item > 0 for item in window_sizes): - raise ValueError( - "window_sizes must be a non-empty list of positive integers" - ) - # Perform window aggregation agg_df = self._df for column, agg_funcs in aggs.items(): From e27f0eaafd949e4230c2cbb242bb2ee8e48cc6a7 Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Thu, 23 Nov 2023 12:36:35 -0800 Subject: [PATCH 10/27] changes --- .../snowpark/dataframe_transform_functions.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/dataframe_transform_functions.py b/src/snowflake/snowpark/dataframe_transform_functions.py index 86cf336b7eb..be98eb998a2 100644 --- a/src/snowflake/snowpark/dataframe_transform_functions.py +++ b/src/snowflake/snowpark/dataframe_transform_functions.py @@ -45,13 +45,15 @@ def _validate_column_names_argument(self, data, argument_name): if not all(isinstance(item, str) for item in data): raise ValueError(f"{argument_name} must be a list of strings") - def _validate_positive_integer_list(self, data, argument_name): + def _validate_integer_list(self, data, argument_name, min_value=1): if not isinstance(data, list): raise TypeError(f"{argument_name} must be a list") if not data: raise ValueError(f"{argument_name} must not be empty") - if not all(isinstance(item, int) and item > 0 for item in data): - raise ValueError(f"{argument_name} must be a list of positive integers") + if not all(isinstance(item, int) and item >= min_value for item in data): + raise ValueError( + f"{argument_name} must be a list of integers >= {min_value}" + ) def _validate_formatter_argument(self, fromatter): if not callable(fromatter): @@ -98,7 +100,7 @@ def moving_agg( self._validate_aggs_argument(aggs) self._validate_column_names_argument(order_by, "order_by") self._validate_column_names_argument(group_by, "group_by") - self._validate_positive_integer_list(window_sizes, "window_sizes") + self._validate_integer_list(window_sizes, "window_sizes") self._validate_formatter_argument(col_formatter) # Perform window aggregation From e8d4345af423b3311c2c7ffe0bd39ca99db4cad3 Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Thu, 23 Nov 2023 12:58:52 -0800 Subject: [PATCH 11/27] fix test --- tests/integ/test_df_transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integ/test_df_transform.py b/tests/integ/test_df_transform.py index a578ea7635f..43efca28923 100644 --- a/tests/integ/test_df_transform.py +++ b/tests/integ/test_df_transform.py @@ -92,7 +92,7 @@ def test_moving_agg_invalid_inputs(session): order_by=["ORDERDATE"], group_by=["PRODUCTKEY"], ).collect() - assert "window_sizes must be a non-empty list of positive integers" in str(exc) + assert "window_sizes must be a list of integers >= 1" in str(exc) with pytest.raises(ValueError) as exc: df.transform.moving_agg( @@ -101,7 +101,7 @@ def test_moving_agg_invalid_inputs(session): order_by=["ORDERDATE"], group_by=["PRODUCTKEY"], ).collect() - assert "window_sizes must be a non-empty list of positive integers" in str(exc) + assert "window_sizes must be a list of integers >= 1" in str(exc) with pytest.raises(ValueError) as exc: df.transform.moving_agg( From ccf665533492befdc2931621759a2e27a6da67da Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Thu, 23 Nov 2023 13:33:23 -0800 Subject: [PATCH 12/27] changes --- .../snowpark/dataframe_transform_functions.py | 10 ++++------ tests/integ/test_df_transform.py | 4 ++-- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/snowflake/snowpark/dataframe_transform_functions.py b/src/snowflake/snowpark/dataframe_transform_functions.py index be98eb998a2..8e5303a7375 100644 --- a/src/snowflake/snowpark/dataframe_transform_functions.py +++ b/src/snowflake/snowpark/dataframe_transform_functions.py @@ -45,15 +45,13 @@ def _validate_column_names_argument(self, data, argument_name): if not all(isinstance(item, str) for item in data): raise ValueError(f"{argument_name} must be a list of strings") - def _validate_integer_list(self, data, argument_name, min_value=1): + def _validate_positive_integer_list(self, data, argument_name): if not isinstance(data, list): raise TypeError(f"{argument_name} must be a list") if not data: raise ValueError(f"{argument_name} must not be empty") - if not all(isinstance(item, int) and item >= min_value for item in data): - raise ValueError( - f"{argument_name} must be a list of integers >= {min_value}" - ) + if not all(isinstance(item, int) and item > 0 for item in data): + raise ValueError(f"{argument_name} must be a list of integers > 0") def _validate_formatter_argument(self, fromatter): if not callable(fromatter): @@ -100,7 +98,7 @@ def moving_agg( self._validate_aggs_argument(aggs) self._validate_column_names_argument(order_by, "order_by") self._validate_column_names_argument(group_by, "group_by") - self._validate_integer_list(window_sizes, "window_sizes") + self._validate_positive_integer_list(window_sizes, "window_sizes") self._validate_formatter_argument(col_formatter) # Perform window aggregation diff --git a/tests/integ/test_df_transform.py b/tests/integ/test_df_transform.py index 43efca28923..888a412db7d 100644 --- a/tests/integ/test_df_transform.py +++ b/tests/integ/test_df_transform.py @@ -92,7 +92,7 @@ def test_moving_agg_invalid_inputs(session): order_by=["ORDERDATE"], group_by=["PRODUCTKEY"], ).collect() - assert "window_sizes must be a list of integers >= 1" in str(exc) + assert "window_sizes must be a list of integers > 0" in str(exc) with pytest.raises(ValueError) as exc: df.transform.moving_agg( @@ -101,7 +101,7 @@ def test_moving_agg_invalid_inputs(session): order_by=["ORDERDATE"], group_by=["PRODUCTKEY"], ).collect() - assert "window_sizes must be a list of integers >= 1" in str(exc) + assert "window_sizes must be a list of integers > 0" in str(exc) with pytest.raises(ValueError) as exc: df.transform.moving_agg( From 8e7adad9c2099381e393b0cfe4072621b5eb4d6d Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Thu, 23 Nov 2023 13:39:07 -0800 Subject: [PATCH 13/27] changes --- src/snowflake/snowpark/dataframe_transform_functions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/dataframe_transform_functions.py b/src/snowflake/snowpark/dataframe_transform_functions.py index 8e5303a7375..5533f822a53 100644 --- a/src/snowflake/snowpark/dataframe_transform_functions.py +++ b/src/snowflake/snowpark/dataframe_transform_functions.py @@ -37,7 +37,7 @@ def _validate_aggs_argument(self, aggs): "aggs must have strings as keys and non-empty lists of strings as values" ) - def _validate_column_names_argument(self, data, argument_name): + def _validate_string_list_argument(self, data, argument_name): if not isinstance(data, list): raise TypeError(f"{argument_name} must be a list") if not data: @@ -45,7 +45,7 @@ def _validate_column_names_argument(self, data, argument_name): if not all(isinstance(item, str) for item in data): raise ValueError(f"{argument_name} must be a list of strings") - def _validate_positive_integer_list(self, data, argument_name): + def _validate_positive_integer_list_argument(self, data, argument_name): if not isinstance(data, list): raise TypeError(f"{argument_name} must be a list") if not data: @@ -96,8 +96,8 @@ def moving_agg( """ # Validate input arguments self._validate_aggs_argument(aggs) - self._validate_column_names_argument(order_by, "order_by") - self._validate_column_names_argument(group_by, "group_by") + self._validate_string_list_argument(order_by, "order_by") + self._validate_string_list_argument(group_by, "group_by") self._validate_positive_integer_list(window_sizes, "window_sizes") self._validate_formatter_argument(col_formatter) From 9a9a0457c6ba47610a51c1447b238f1cb3d1f81f Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Thu, 23 Nov 2023 14:19:17 -0800 Subject: [PATCH 14/27] changes --- .../snowpark/dataframe_transform_functions.py | 7 ++++--- tests/integ/test_df_transform.py | 13 +++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/snowflake/snowpark/dataframe_transform_functions.py b/src/snowflake/snowpark/dataframe_transform_functions.py index 5533f822a53..ed1787f723c 100644 --- a/src/snowflake/snowpark/dataframe_transform_functions.py +++ b/src/snowflake/snowpark/dataframe_transform_functions.py @@ -75,8 +75,9 @@ def moving_agg( calculate the moving aggregate. order_by: A list of column names that specify the order in which rows are processed. group_by: A list of column names on which the DataFrame is partitioned for separate window calculations. - col_formatter: An optional function to format the output column names. Defaults to a built-in formatter - that outputs column names in the format "__". + col_formatter: An optional function for formatting output column names, defaulting to the format '__'. + This function takes three arguments: 'input_col' (str) for the column name, 'operation' (str) for the applied operation, + and 'value' (int) for the operation's numerical value, and returns a formatted string for the column name. Returns: A Snowflake DataFrame with additional columns corresponding to each specified moving aggregation. @@ -98,7 +99,7 @@ def moving_agg( self._validate_aggs_argument(aggs) self._validate_string_list_argument(order_by, "order_by") self._validate_string_list_argument(group_by, "group_by") - self._validate_positive_integer_list(window_sizes, "window_sizes") + self._validate_positive_integer_list_argument(window_sizes, "window_sizes") self._validate_formatter_argument(col_formatter) # Perform window aggregation diff --git a/tests/integ/test_df_transform.py b/tests/integ/test_df_transform.py index 888a412db7d..916b21b6165 100644 --- a/tests/integ/test_df_transform.py +++ b/tests/integ/test_df_transform.py @@ -120,3 +120,16 @@ def test_moving_agg_invalid_inputs(session): group_by=["PRODUCTKEY"], ).collect() assert "Sliding window frame unsupported for function" in str(exc) + + def bad_formatter(input_col, agg): + return f"{agg}_{input_col}" + + with pytest.raises(TypeError) as exc: + df.transform.moving_agg( + aggs={"SALESAMOUNT": ["SUM"]}, + window_sizes=[1], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + col_formatter=bad_formatter, + ).collect() + assert "positional arguments but 3 were given" in str(exc) From 706f15fafa1b25a8d805b3773c942671b86e5135 Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Thu, 23 Nov 2023 14:26:48 -0800 Subject: [PATCH 15/27] changes --- src/snowflake/snowpark/dataframe_transform_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/dataframe_transform_functions.py b/src/snowflake/snowpark/dataframe_transform_functions.py index ed1787f723c..2ffb0f3a58c 100644 --- a/src/snowflake/snowpark/dataframe_transform_functions.py +++ b/src/snowflake/snowpark/dataframe_transform_functions.py @@ -77,7 +77,7 @@ def moving_agg( group_by: A list of column names on which the DataFrame is partitioned for separate window calculations. col_formatter: An optional function for formatting output column names, defaulting to the format '__'. This function takes three arguments: 'input_col' (str) for the column name, 'operation' (str) for the applied operation, - and 'value' (int) for the operation's numerical value, and returns a formatted string for the column name. + and 'value' (int) for the window size, and returns a formatted string for the column name. Returns: A Snowflake DataFrame with additional columns corresponding to each specified moving aggregation. From 39a3e8acd0f7f25c26195f0fe008e34a5cec474f Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Mon, 18 Dec 2023 18:01:21 +0530 Subject: [PATCH 16/27] fix comment --- src/snowflake/snowpark/dataframe_transform_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/snowflake/snowpark/dataframe_transform_functions.py b/src/snowflake/snowpark/dataframe_transform_functions.py index 2ffb0f3a58c..a498efa894e 100644 --- a/src/snowflake/snowpark/dataframe_transform_functions.py +++ b/src/snowflake/snowpark/dataframe_transform_functions.py @@ -71,6 +71,7 @@ def moving_agg( Args: aggs: A dictionary where keys are column names and values are lists of the desired aggregation functions. + Supported aggregation are listed here https://docs.snowflake.com/en/sql-reference/functions-analytic#list-of-functions-that-support-windows. window_sizes: A list of positive integers, each representing the size of the window for which to calculate the moving aggregate. order_by: A list of column names that specify the order in which rows are processed. From 25991679163c1018733bfca568a1d3f29094b906 Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Tue, 16 Jan 2024 16:45:31 -0800 Subject: [PATCH 17/27] skip tests when pandas are not available --- tests/integ/test_df_transform.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/integ/test_df_transform.py b/tests/integ/test_df_transform.py index 916b21b6165..859b85439e4 100644 --- a/tests/integ/test_df_transform.py +++ b/tests/integ/test_df_transform.py @@ -3,7 +3,13 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # -import pandas as pd +try: + import pandas as pd + + is_pandas_available = True +except ImportError: + is_pandas_available = False + import pytest from pandas.testing import assert_frame_equal @@ -22,6 +28,7 @@ def get_sample_dataframe(session): ) +@pytest.mark.skipif(not is_pandas_available, reason="pandas is required") def test_moving_agg(session): """Tests df.transform.moving_agg() happy path.""" @@ -49,6 +56,7 @@ def test_moving_agg(session): ) +@pytest.mark.skipif(not is_pandas_available, reason="pandas is required") def test_moving_agg_custom_formatting(session): """Tests df.transform.moving_agg() with custom formatting of output columns.""" From ce061bd41d899bea4843c93f56dd8a511fc4cbd2 Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Tue, 16 Jan 2024 16:50:13 -0800 Subject: [PATCH 18/27] update change log --- CHANGELOG.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aa178202e0b..d55c77ed238 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ - `SessionBuilder.getOrCreate` will now attempt to replace the singleton it returns when token expiration has been detected. - Added support for new function(s) in `snowflake.snowpark.functions`: - `array_except` +- Added moving_agg function in DataFrame.transform for time series analysis, enabling moving aggregations like sums and averages with multiple window sizes. ### Bug Fixes @@ -35,8 +36,6 @@ ### New Features -- Added moving_agg function in DataFrame.transform for time series analysis, enabling moving aggregations like sums and averages with multiple window sizes. - ### Bug Fixes - Fixed a bug that numpy should not be imported at the top level of mock module. From 5c00182e829b0d025a60b1f39c2d21344783523b Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Wed, 17 Jan 2024 05:42:06 -0800 Subject: [PATCH 19/27] changes --- tests/integ/test_df_transform.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integ/test_df_transform.py b/tests/integ/test_df_transform.py index 859b85439e4..8a71e0c2fdc 100644 --- a/tests/integ/test_df_transform.py +++ b/tests/integ/test_df_transform.py @@ -88,6 +88,7 @@ def custom_formatter(input_col, agg, window): ) +@pytest.mark.skipif(not is_pandas_available, reason="pandas is required") def test_moving_agg_invalid_inputs(session): """Tests df.transform.moving_agg() with invalid window sizes.""" From e0e18bfd42e326ae98296c3bdc25131b6efcb0be Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Wed, 17 Jan 2024 10:52:40 -0800 Subject: [PATCH 20/27] changes --- tests/integ/test_df_transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integ/test_df_transform.py b/tests/integ/test_df_transform.py index 8a71e0c2fdc..7c4c8d305a8 100644 --- a/tests/integ/test_df_transform.py +++ b/tests/integ/test_df_transform.py @@ -5,13 +5,13 @@ try: import pandas as pd + from pandas.testing import assert_frame_equal is_pandas_available = True except ImportError: is_pandas_available = False import pytest -from pandas.testing import assert_frame_equal from snowflake.snowpark.exceptions import SnowparkSQLException From 2d05cb8c7ffa64a7260dcc21ac33b8ffffacf3ce Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Wed, 17 Jan 2024 19:37:14 -0800 Subject: [PATCH 21/27] adding doctest --- .../snowpark/dataframe_transform_functions.py | 31 +++++++++++++++---- 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/src/snowflake/snowpark/dataframe_transform_functions.py b/src/snowflake/snowpark/dataframe_transform_functions.py index a498efa894e..556e51cfa3b 100644 --- a/src/snowflake/snowpark/dataframe_transform_functions.py +++ b/src/snowflake/snowpark/dataframe_transform_functions.py @@ -89,12 +89,31 @@ def moving_agg( SnowparkSQLException: If an unsupported aggregration is specified. Example: - aggregated_df = moving_agg( - aggs={"SALESAMOUNT": ['SUM', 'AVG']}, - window_sizes=[1, 2, 3, 7], - order_by=['ORDERDATE'], - group_by=['PRODUCTKEY'] - ) + >>> data = [ + ... ["2023-01-01", 101, 200], + ... ["2023-01-02", 101, 100], + ... ["2023-01-03", 101, 300], + ... ["2023-01-04", 102, 250], + ... ] + >>> df = session.create_dataframe(data).to_df( + ... "ORDERDATE", "PRODUCTKEY", "SALESAMOUNT" + ... ) + >>> result = df.transform.moving_agg( + ... aggs={"SALESAMOUNT": ["SUM", "AVG"]}, + ... window_sizes=[2, 3], + ... order_by=["ORDERDATE"], + ... group_by=["PRODUCTKEY"], + ... ) + >>> result.show() + +-----------+-----------+----------------+----------------+-----------------+-----------------+ + | ORDERDATE | PRODUCTKEY| SALESAMOUNT_SUM_2 | SALESAMOUNT_AVG_2 | SALESAMOUNT_SUM_3 | SALESAMOUNT_AVG_3 | + +-----------+-----------+----------------+----------------+-----------------+-----------------+ + | 2023-01-01| 101| 200| 200.0| 200| 200.0| + | 2023-01-02| 101| 300| 150.0| 300| 150.0| + | 2023-01-03| 101| 400| 200.0| 600| 200.0| + | 2023-01-04| 102| 250| 250.0| 250| 250.0| + +-----------+-----------+----------------+----------------+-----------------+-----------------+ + """ # Validate input arguments self._validate_aggs_argument(aggs) From 9423599036cb4efdafa1bd67e045902e408fbfc1 Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Tue, 23 Jan 2024 14:39:45 -0800 Subject: [PATCH 22/27] renaming --- CHANGELOG.md | 2 +- src/snowflake/snowpark/__init__.py | 4 ++-- src/snowflake/snowpark/dataframe.py | 8 ++++---- ...form_functions.py => dataframe_analytics_functions.py} | 8 ++++---- .../integ/{test_df_transform.py => test_df_analytics.py} | 0 5 files changed, 11 insertions(+), 11 deletions(-) rename src/snowflake/snowpark/{dataframe_transform_functions.py => dataframe_analytics_functions.py} (96%) rename tests/integ/{test_df_transform.py => test_df_analytics.py} (100%) diff --git a/CHANGELOG.md b/CHANGELOG.md index d55c77ed238..eee0d43ad3f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ - `SessionBuilder.getOrCreate` will now attempt to replace the singleton it returns when token expiration has been detected. - Added support for new function(s) in `snowflake.snowpark.functions`: - `array_except` -- Added moving_agg function in DataFrame.transform for time series analysis, enabling moving aggregations like sums and averages with multiple window sizes. +- Added moving_agg function in DataFrame.analytics for enabling moving aggregations like sums and averages with multiple window sizes. ### Bug Fixes diff --git a/src/snowflake/snowpark/__init__.py b/src/snowflake/snowpark/__init__.py index 560693d9522..bf049c957a6 100644 --- a/src/snowflake/snowpark/__init__.py +++ b/src/snowflake/snowpark/__init__.py @@ -19,7 +19,7 @@ "GetResult", "DataFrame", "DataFrameStatFunctions", - "DataFrameTransformFunctions", + "DataFrameAnalyticsFunctions", "DataFrameNaFunctions", "DataFrameWriter", "DataFrameReader", @@ -47,10 +47,10 @@ from snowflake.snowpark.async_job import AsyncJob from snowflake.snowpark.column import CaseExpr, Column from snowflake.snowpark.dataframe import DataFrame +from snowflake.snowpark.dataframe_analytics_functions import DataFrameAnalyticsFunctions from snowflake.snowpark.dataframe_na_functions import DataFrameNaFunctions from snowflake.snowpark.dataframe_reader import DataFrameReader from snowflake.snowpark.dataframe_stat_functions import DataFrameStatFunctions -from snowflake.snowpark.dataframe_transform_functions import DataFrameTransformFunctions from snowflake.snowpark.dataframe_writer import DataFrameWriter from snowflake.snowpark.file_operation import FileOperation, GetResult, PutResult from snowflake.snowpark.query_history import QueryHistory, QueryRecord diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index efdee0f6637..2f1d2b208c2 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -121,9 +121,9 @@ ) from snowflake.snowpark.async_job import AsyncJob, _AsyncResultType from snowflake.snowpark.column import Column, _to_col_if_sql_expr, _to_col_if_str +from snowflake.snowpark.dataframe_analytics_functions import DataFrameAnalyticsFunctions from snowflake.snowpark.dataframe_na_functions import DataFrameNaFunctions from snowflake.snowpark.dataframe_stat_functions import DataFrameStatFunctions -from snowflake.snowpark.dataframe_transform_functions import DataFrameTransformFunctions from snowflake.snowpark.dataframe_writer import DataFrameWriter from snowflake.snowpark.exceptions import SnowparkDataframeException from snowflake.snowpark.functions import ( @@ -523,7 +523,7 @@ def __init__( self._writer = DataFrameWriter(self) self._stat = DataFrameStatFunctions(self) - self._transform = DataFrameTransformFunctions(self) + self._analytics = DataFrameAnalyticsFunctions(self) self.approxQuantile = self.approx_quantile = self._stat.approx_quantile self.corr = self._stat.corr self.cov = self._stat.cov @@ -542,8 +542,8 @@ def stat(self) -> DataFrameStatFunctions: return self._stat @property - def transform(self) -> DataFrameTransformFunctions: - return self._transform + def transform(self) -> DataFrameAnalyticsFunctions: + return self._analytics @overload def collect( diff --git a/src/snowflake/snowpark/dataframe_transform_functions.py b/src/snowflake/snowpark/dataframe_analytics_functions.py similarity index 96% rename from src/snowflake/snowpark/dataframe_transform_functions.py rename to src/snowflake/snowpark/dataframe_analytics_functions.py index 556e51cfa3b..fe556dcde90 100644 --- a/src/snowflake/snowpark/dataframe_transform_functions.py +++ b/src/snowflake/snowpark/dataframe_analytics_functions.py @@ -9,9 +9,9 @@ from snowflake.snowpark.window import Window -class DataFrameTransformFunctions: - """Provides data transformation functions for DataFrames. - To access an object of this class, use :attr:`DataFrame.transform`. +class DataFrameAnalyticsFunctions: + """Provides data analytics functions for DataFrames. + To access an object of this class, use :attr:`DataFrame.analytics`. """ def __init__(self, df: "snowflake.snowpark.DataFrame") -> None: @@ -98,7 +98,7 @@ def moving_agg( >>> df = session.create_dataframe(data).to_df( ... "ORDERDATE", "PRODUCTKEY", "SALESAMOUNT" ... ) - >>> result = df.transform.moving_agg( + >>> result = df.analytics.moving_agg( ... aggs={"SALESAMOUNT": ["SUM", "AVG"]}, ... window_sizes=[2, 3], ... order_by=["ORDERDATE"], diff --git a/tests/integ/test_df_transform.py b/tests/integ/test_df_analytics.py similarity index 100% rename from tests/integ/test_df_transform.py rename to tests/integ/test_df_analytics.py From 721383f831a6bc68302bc15b58cc3f798ff4469b Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Tue, 23 Jan 2024 14:42:42 -0800 Subject: [PATCH 23/27] renaming 2 --- src/snowflake/snowpark/dataframe.py | 2 +- tests/integ/test_df_analytics.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index 2f1d2b208c2..1e23487abe5 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -542,7 +542,7 @@ def stat(self) -> DataFrameStatFunctions: return self._stat @property - def transform(self) -> DataFrameAnalyticsFunctions: + def analytics(self) -> DataFrameAnalyticsFunctions: return self._analytics @overload diff --git a/tests/integ/test_df_analytics.py b/tests/integ/test_df_analytics.py index 7c4c8d305a8..ed3208a82e4 100644 --- a/tests/integ/test_df_analytics.py +++ b/tests/integ/test_df_analytics.py @@ -34,7 +34,7 @@ def test_moving_agg(session): df = get_sample_dataframe(session) - res = df.transform.moving_agg( + res = df.analytics.moving_agg( aggs={"SALESAMOUNT": ["SUM", "AVG"]}, window_sizes=[2, 3], order_by=["ORDERDATE"], @@ -65,7 +65,7 @@ def test_moving_agg_custom_formatting(session): def custom_formatter(input_col, agg, window): return f"{window}_{agg}_{input_col}" - res = df.transform.moving_agg( + res = df.analytics.moving_agg( aggs={"SALESAMOUNT": ["SUM", "AVG"]}, window_sizes=[2, 3], order_by=["ORDERDATE"], @@ -95,7 +95,7 @@ def test_moving_agg_invalid_inputs(session): df = get_sample_dataframe(session) with pytest.raises(ValueError) as exc: - df.transform.moving_agg( + df.analytics.moving_agg( aggs={"SALESAMOUNT": ["AVG"]}, window_sizes=[-1, 2, 3], order_by=["ORDERDATE"], @@ -104,7 +104,7 @@ def test_moving_agg_invalid_inputs(session): assert "window_sizes must be a list of integers > 0" in str(exc) with pytest.raises(ValueError) as exc: - df.transform.moving_agg( + df.analytics.moving_agg( aggs={"SALESAMOUNT": ["AVG"]}, window_sizes=[0, 2, 3], order_by=["ORDERDATE"], @@ -113,7 +113,7 @@ def test_moving_agg_invalid_inputs(session): assert "window_sizes must be a list of integers > 0" in str(exc) with pytest.raises(ValueError) as exc: - df.transform.moving_agg( + df.analytics.moving_agg( aggs={"SALESAMOUNT": []}, window_sizes=[0, 2, 3], order_by=["ORDERDATE"], @@ -122,7 +122,7 @@ def test_moving_agg_invalid_inputs(session): assert "non-empty lists of strings as values" in str(exc) with pytest.raises(SnowparkSQLException) as exc: - df.transform.moving_agg( + df.analytics.moving_agg( aggs={"SALESAMOUNT": ["INVALID_FUNC"]}, window_sizes=[1], order_by=["ORDERDATE"], @@ -134,7 +134,7 @@ def bad_formatter(input_col, agg): return f"{agg}_{input_col}" with pytest.raises(TypeError) as exc: - df.transform.moving_agg( + df.analytics.moving_agg( aggs={"SALESAMOUNT": ["SUM"]}, window_sizes=[1], order_by=["ORDERDATE"], From cfb5e4a393594ffa672804a4a7794a99fa822296 Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Tue, 23 Jan 2024 15:09:52 -0800 Subject: [PATCH 24/27] fix comments --- tests/integ/test_df_analytics.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/integ/test_df_analytics.py b/tests/integ/test_df_analytics.py index ed3208a82e4..e304f8c6aa1 100644 --- a/tests/integ/test_df_analytics.py +++ b/tests/integ/test_df_analytics.py @@ -30,7 +30,7 @@ def get_sample_dataframe(session): @pytest.mark.skipif(not is_pandas_available, reason="pandas is required") def test_moving_agg(session): - """Tests df.transform.moving_agg() happy path.""" + """Tests df.analytics.moving_agg() happy path.""" df = get_sample_dataframe(session) @@ -58,7 +58,7 @@ def test_moving_agg(session): @pytest.mark.skipif(not is_pandas_available, reason="pandas is required") def test_moving_agg_custom_formatting(session): - """Tests df.transform.moving_agg() with custom formatting of output columns.""" + """Tests df.analytics.moving_agg() with custom formatting of output columns.""" df = get_sample_dataframe(session) @@ -90,7 +90,7 @@ def custom_formatter(input_col, agg, window): @pytest.mark.skipif(not is_pandas_available, reason="pandas is required") def test_moving_agg_invalid_inputs(session): - """Tests df.transform.moving_agg() with invalid window sizes.""" + """Tests df.analytics.moving_agg() with invalid window sizes.""" df = get_sample_dataframe(session) From 9583161e0d775de95a7948e0fedd8e3763ead1dc Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Tue, 23 Jan 2024 16:11:03 -0800 Subject: [PATCH 25/27] fix error message --- .../snowpark/dataframe_analytics_functions.py | 46 +++++++++++++++---- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/src/snowflake/snowpark/dataframe_analytics_functions.py b/src/snowflake/snowpark/dataframe_analytics_functions.py index fe556dcde90..fbdcf555f7a 100644 --- a/src/snowflake/snowpark/dataframe_analytics_functions.py +++ b/src/snowflake/snowpark/dataframe_analytics_functions.py @@ -25,33 +25,61 @@ def _default_col_formatter(input_col: str, operation: str, *args) -> str: return formatted_name def _validate_aggs_argument(self, aggs): + argument_requirements = ( + "The 'aggs' argument must adhere to the following rules: " + "1) It must be a dictionary. " + "2) It must not be empty. " + "3) All keys must be strings. " + "4) All values must be non-empty lists of strings." + ) + if not isinstance(aggs, dict): - raise TypeError("aggs must be a dictionary") + raise TypeError(f"aggs must be a dictionary. {argument_requirements}") if not aggs: - raise ValueError("aggs must not be empty") + raise ValueError(f"aggs must not be empty. {argument_requirements}") if not all( isinstance(key, str) and isinstance(val, list) and val for key, val in aggs.items() ): raise ValueError( - "aggs must have strings as keys and non-empty lists of strings as values" + f"aggs must have strings as keys and non-empty lists of strings as values. {argument_requirements}" ) def _validate_string_list_argument(self, data, argument_name): + argument_requirements = ( + f"The '{argument_name}' argument must adhere to the following rules: " + "1) It must be a list. " + "2) It must not be empty. " + "3) All items in the list must be strings." + ) if not isinstance(data, list): - raise TypeError(f"{argument_name} must be a list") + raise TypeError(f"{argument_name} must be a list. {argument_requirements}") if not data: - raise ValueError(f"{argument_name} must not be empty") + raise ValueError( + f"{argument_name} must not be empty. {argument_requirements}" + ) if not all(isinstance(item, str) for item in data): - raise ValueError(f"{argument_name} must be a list of strings") + raise ValueError( + f"{argument_name} must be a list of strings. {argument_requirements}" + ) def _validate_positive_integer_list_argument(self, data, argument_name): + argument_requirements = ( + f"The '{argument_name}' argument must adhere to the following criteria: " + "1) It must be a list. " + "2) It must not be empty. " + "3) All items in the list must be positive integers." + ) if not isinstance(data, list): - raise TypeError(f"{argument_name} must be a list") + raise TypeError(f"{argument_name} must be a list. {argument_requirements}") if not data: - raise ValueError(f"{argument_name} must not be empty") + raise ValueError( + f"{argument_name} must not be empty. {argument_requirements}" + ) if not all(isinstance(item, int) and item > 0 for item in data): - raise ValueError(f"{argument_name} must be a list of integers > 0") + raise ValueError( + f"{argument_name} must be a list of integers > 0. {argument_requirements}" + ) def _validate_formatter_argument(self, fromatter): if not callable(fromatter): From 045f2161d57de157839ce8163819900f291a233d Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Wed, 24 Jan 2024 04:29:02 -0800 Subject: [PATCH 26/27] fix merge --- CHANGELOG.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d67e63f39e..d7281d60f6c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,8 +46,6 @@ ## 1.11.1 (2023-12-07) -### New Features - ### Bug Fixes - Fixed a bug that numpy should not be imported at the top level of mock module. From 05a2da5fd6f47658c35b388b77846703fea0ff18 Mon Sep 17 00:00:00 2001 From: Ravi Kumar Suresh Babu Date: Wed, 24 Jan 2024 05:54:53 -0800 Subject: [PATCH 27/27] fix code coverage --- .../snowpark/dataframe_analytics_functions.py | 16 +-- tests/integ/test_df_analytics.py | 102 +++++++++++++++++- 2 files changed, 107 insertions(+), 11 deletions(-) diff --git a/src/snowflake/snowpark/dataframe_analytics_functions.py b/src/snowflake/snowpark/dataframe_analytics_functions.py index fbdcf555f7a..9f69c6c327a 100644 --- a/src/snowflake/snowpark/dataframe_analytics_functions.py +++ b/src/snowflake/snowpark/dataframe_analytics_functions.py @@ -133,14 +133,14 @@ def moving_agg( ... group_by=["PRODUCTKEY"], ... ) >>> result.show() - +-----------+-----------+----------------+----------------+-----------------+-----------------+ - | ORDERDATE | PRODUCTKEY| SALESAMOUNT_SUM_2 | SALESAMOUNT_AVG_2 | SALESAMOUNT_SUM_3 | SALESAMOUNT_AVG_3 | - +-----------+-----------+----------------+----------------+-----------------+-----------------+ - | 2023-01-01| 101| 200| 200.0| 200| 200.0| - | 2023-01-02| 101| 300| 150.0| 300| 150.0| - | 2023-01-03| 101| 400| 200.0| 600| 200.0| - | 2023-01-04| 102| 250| 250.0| 250| 250.0| - +-----------+-----------+----------------+----------------+-----------------+-----------------+ + -------------------------------------------------------------------------------------------------------------------------------------- + |"ORDERDATE" |"PRODUCTKEY" |"SALESAMOUNT" |"SALESAMOUNT_SUM_2" |"SALESAMOUNT_AVG_2" |"SALESAMOUNT_SUM_3" |"SALESAMOUNT_AVG_3" | + -------------------------------------------------------------------------------------------------------------------------------------- + |2023-01-04 |102 |250 |250 |250.000 |250 |250.000 | + |2023-01-01 |101 |200 |200 |200.000 |200 |200.000 | + |2023-01-02 |101 |100 |300 |150.000 |300 |150.000 | + |2023-01-03 |101 |300 |400 |200.000 |600 |200.000 | + -------------------------------------------------------------------------------------------------------------------------------------- """ # Validate input arguments diff --git a/tests/integ/test_df_analytics.py b/tests/integ/test_df_analytics.py index e304f8c6aa1..f747523b28e 100644 --- a/tests/integ/test_df_analytics.py +++ b/tests/integ/test_df_analytics.py @@ -87,6 +87,29 @@ def custom_formatter(input_col, agg, window): res.order_by("ORDERDATE").to_pandas(), expected_df, check_dtype=False, atol=1e-1 ) + # With default formatter + res = df.analytics.moving_agg( + aggs={"SALESAMOUNT": ["SUM", "AVG"]}, + window_sizes=[2, 3], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ) + + expected_data = { + "ORDERDATE": ["2023-01-01", "2023-01-02", "2023-01-03", "2023-01-04"], + "PRODUCTKEY": [101, 101, 101, 102], + "SALESAMOUNT": [200, 100, 300, 250], + "SALESAMOUNT_SUM_2": [200, 300, 400, 250], + "SALESAMOUNT_AVG_2": [200.0, 150.0, 200.0, 250.0], + "SALESAMOUNT_SUM_3": [200, 300, 600, 250], + "SALESAMOUNT_AVG_3": [200.0, 150.0, 200.0, 250.0], + } + + expected_df = pd.DataFrame(expected_data) + assert_frame_equal( + res.order_by("ORDERDATE").to_pandas(), expected_df, check_dtype=False, atol=1e-1 + ) + @pytest.mark.skipif(not is_pandas_available, reason="pandas is required") def test_moving_agg_invalid_inputs(session): @@ -94,6 +117,60 @@ def test_moving_agg_invalid_inputs(session): df = get_sample_dataframe(session) + with pytest.raises(TypeError) as exc: + df.analytics.moving_agg( + aggs=["AVG"], + window_sizes=[1, 2, 3], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ).collect() + assert "aggs must be a dictionary" in str(exc) + + with pytest.raises(ValueError) as exc: + df.analytics.moving_agg( + aggs={}, + window_sizes=[1, 2, 3], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ).collect() + assert "aggs must not be empty" in str(exc) + + with pytest.raises(ValueError) as exc: + df.analytics.moving_agg( + aggs={"SALESAMOUNT": []}, + window_sizes=[1, 2, 3], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ).collect() + assert "non-empty lists of strings as values" in str(exc) + + with pytest.raises(TypeError) as exc: + df.analytics.moving_agg( + aggs={"SALESAMOUNT": ["AVG"]}, + window_sizes=[1, 2, 3], + order_by="ORDERDATE", + group_by=["PRODUCTKEY"], + ).collect() + assert "order_by must be a list" in str(exc) + + with pytest.raises(ValueError) as exc: + df.analytics.moving_agg( + aggs={"SALESAMOUNT": ["AVG"]}, + window_sizes=[1, 2, 3], + order_by=[], + group_by=["PRODUCTKEY"], + ).collect() + assert "order_by must not be empty" in str(exc) + + with pytest.raises(ValueError) as exc: + df.analytics.moving_agg( + aggs={"SALESAMOUNT": ["AVG"]}, + window_sizes=[1, 2, 3], + order_by=[1], + group_by=["PRODUCTKEY"], + ).collect() + assert "order_by must be a list of strings" in str(exc) + with pytest.raises(ValueError) as exc: df.analytics.moving_agg( aggs={"SALESAMOUNT": ["AVG"]}, @@ -112,14 +189,23 @@ def test_moving_agg_invalid_inputs(session): ).collect() assert "window_sizes must be a list of integers > 0" in str(exc) + with pytest.raises(TypeError) as exc: + df.analytics.moving_agg( + aggs={"SALESAMOUNT": ["AVG"]}, + window_sizes=0, + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + ).collect() + assert "window_sizes must be a list" in str(exc) + with pytest.raises(ValueError) as exc: df.analytics.moving_agg( - aggs={"SALESAMOUNT": []}, - window_sizes=[0, 2, 3], + aggs={"SALESAMOUNT": ["AVG"]}, + window_sizes=[], order_by=["ORDERDATE"], group_by=["PRODUCTKEY"], ).collect() - assert "non-empty lists of strings as values" in str(exc) + assert "window_sizes must not be empty" in str(exc) with pytest.raises(SnowparkSQLException) as exc: df.analytics.moving_agg( @@ -142,3 +228,13 @@ def bad_formatter(input_col, agg): col_formatter=bad_formatter, ).collect() assert "positional arguments but 3 were given" in str(exc) + + with pytest.raises(TypeError) as exc: + df.analytics.moving_agg( + aggs={"SALESAMOUNT": ["SUM"]}, + window_sizes=[1], + order_by=["ORDERDATE"], + group_by=["PRODUCTKEY"], + col_formatter="bad_formatter", + ).collect() + assert "formatter must be a callable function" in str(exc)