diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f72f90dbd1..1f242c881d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,11 +41,14 @@ - Added support for `Series.dt.dayofweek`, `Series.dt.day_of_week`, `Series.dt.dayofyear`, and `Series.dt.day_of_year`. - Added support for `Series.str.__getitem__` (`Series.str[...]`). - Added support for `Series.str.lstrip` and `Series.str.rstrip`. +- Added support for `DataFrame.expanding` and `Series.expanding` for aggregations `count`, `sum`, `min`, `max`, `mean`, `std`, and `var` with `axis=0`. +- Added support for `DataFrame.rolling` and `Series.rolling` for aggregation `count` with `axis=0`. #### Bug Fixes - Fixed a bug that causes output of GroupBy.aggregate's columns to be ordered incorrectly. - Fixed a bug where `DataFrame.describe` on a frame with duplicate columns of differing dtypes could cause an error or incorrect results. +- Fixed a bug in `DataFrame.rolling` and `Series.rolling` so `window=0` now throws `NotImplementedError` instead of `ValueError` #### Improvements diff --git a/docs/source/modin/supported/dataframe_supported.rst b/docs/source/modin/supported/dataframe_supported.rst index 928a12bf541..a89e38b1e73 100644 --- a/docs/source/modin/supported/dataframe_supported.rst +++ b/docs/source/modin/supported/dataframe_supported.rst @@ -169,7 +169,7 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``ewm`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``expanding`` | N | | | +| ``expanding`` | P | ``method`` is ignored | ``N`` if ``axis = 1`` | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``explode`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ @@ -351,8 +351,8 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``rmul`` | P | ``level`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``rolling`` | P | | Supports integer ``window``, ``min_periods >= 1``, | -| | | | and ``center`` for ``axis = 0`` | +| ``rolling`` | P | ``method`` is ignored, ``step``, | ``N`` for non-integer ``window``, ``axis = 1``, | +| | | ``win_type``, ``closed``, ``on`` | or ``min_periods = 0`` | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``round`` | P | | ``N`` if ``decimals`` is Series | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/docs/source/modin/supported/series_supported.rst b/docs/source/modin/supported/series_supported.rst index ccd77769e69..c96a03dafb1 100644 --- a/docs/source/modin/supported/series_supported.rst +++ b/docs/source/modin/supported/series_supported.rst @@ -180,7 +180,7 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``ewm`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``expanding`` | N | | | +| ``expanding`` | P | ``method`` is ignored | ``N`` if ``axis = 1`` | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``explode`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ @@ -344,8 +344,8 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``rmul`` | P | ``level`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``rolling`` | P | | Supports integer ``window``, ``min_periods >= 1``, | -| | | | and ``center`` | +| ``rolling`` | P | ``method`` is ignored, ``step``, | ``N`` for non-integer ``window``, ``axis = 1``, | +| | | ``win_type``, ``closed``, ``on`` | or ``min_periods = 0`` | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``round`` | Y | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/docs/source/modin/supported/window_supported.rst b/docs/source/modin/supported/window_supported.rst index f86e9d242a2..a3844915e89 100644 --- a/docs/source/modin/supported/window_supported.rst +++ b/docs/source/modin/supported/window_supported.rst @@ -22,22 +22,23 @@ Rolling window functions +-----------------------------+---------------------------------+----------------------------------------------------+ | ``corr`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``count`` | N | | +| ``count`` | P | ``N`` for non-integer ``window``, ``axis = 1``, | +| | | or ``min_periods = 0`` | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``cov`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``kurt`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``max`` | P | Supports integer ``window``, ``min_periods >= 1``, | -| | | and ``center`` for ``axis = 0`` | +| ``max`` | P | ``N`` for non-integer ``window``, ``axis = 1``, | +| | | or ``min_periods = 0`` | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``mean`` | P | Supports integer ``window``, ``min_periods >= 1``, | -| | | and ``center`` for ``axis = 0`` | +| ``mean`` | P | ``N`` for non-integer ``window``, ``axis = 1``, | +| | | or ``min_periods = 0`` | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``median`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``min`` | P | Supports integer ``window``, ``min_periods >= 1``, | -| | | and ``center`` for ``axis = 0`` | +| ``min`` | P | ``N`` for non-integer ``window``, ``axis = 1``, | +| | | or ``min_periods = 0`` | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``quantile`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ @@ -47,14 +48,14 @@ Rolling window functions +-----------------------------+---------------------------------+----------------------------------------------------+ | ``skew`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``std`` | P | Supports integer ``window``, ``min_periods >= 1``, | -| | | and ``center`` for ``axis = 0`` | +| ``std`` | P | ``N`` for non-integer ``window``, ``axis = 1``, | +| | | or ``min_periods = 0`` | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``sum`` | P | Supports integer ``window``, ``min_periods >= 1``, | -| | | and ``center`` for ``axis = 0`` | +| ``sum`` | P | ``N`` for non-integer ``window``, ``axis = 1``, | +| | | or ``min_periods = 0`` | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``var`` | P | Supports integer ``window``, ``min_periods >= 1``, | -| | | and ``center`` for ``axis = 0`` | +| ``var`` | P | ``N`` for non-integer ``window``, ``axis = 1``, | +| | | or ``min_periods = 0`` | +-----------------------------+---------------------------------+----------------------------------------------------+ Weighted window functions @@ -82,19 +83,19 @@ Expanding window functions +-----------------------------+---------------------------------+----------------------------------------------------+ | ``corr`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``count`` | N | | +| ``count`` | P | ``N`` if ``axis = 1`` | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``cov`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``kurt`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``max`` | N | | +| ``max`` | P | ``N`` if ``axis = 1`` | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``mean`` | N | | +| ``mean`` | P | ``N`` if ``axis = 1`` | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``median`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``min`` | N | | +| ``min`` | P | ``N`` if ``axis = 1`` | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``quantile`` | N | | | | | | @@ -107,11 +108,11 @@ Expanding window functions | ``skew`` | N | | | | | | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``std`` | N | | +| ``std`` | N | ``N`` if ``axis = 1`` | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``sum`` | N | | +| ``sum`` | P | ``N`` if ``axis = 1`` | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``var`` | N | | +| ``var`` | N | ``N`` if ``axis = 1`` | +-----------------------------+---------------------------------+----------------------------------------------------+ Exponentially-weighted window functions diff --git a/docs/source/modin/window.rst b/docs/source/modin/window.rst index fb7d301ec6f..ee90e2352a3 100644 --- a/docs/source/modin/window.rst +++ b/docs/source/modin/window.rst @@ -5,11 +5,25 @@ Window .. currentmodule:: snowflake.snowpark.modin.pandas.window .. rubric:: :doc:`All supported window APIs ` +.. rubric:: Expanding window functions + +.. autosummary:: + :toctree: pandas_api/ + + Expanding.count + Expanding.max + Expanding.mean + Expanding.min + Expanding.std + Expanding.sum + Expanding.var + .. rubric:: Rolling window functions .. autosummary:: :toctree: pandas_api/ + Rolling.count Rolling.max Rolling.mean Rolling.min diff --git a/src/snowflake/snowpark/modin/pandas/window.py b/src/snowflake/snowpark/modin/pandas/window.py index 54f2e6d2a78..159f2ccf828 100644 --- a/src/snowflake/snowpark/modin/pandas/window.py +++ b/src/snowflake/snowpark/modin/pandas/window.py @@ -30,6 +30,7 @@ # add these two lines to enable doc tests to run from snowflake.snowpark.modin import pandas as pd # noqa: F401 from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta +from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage from snowflake.snowpark.modin.utils import ( _inherit_docstrings, doc_replace_dataframe_with_link, @@ -121,7 +122,7 @@ def __init__( ) -> None: # TODO: SNOW-1063358: Modin upgrade - modin.pandas.window.Rolling # Raise ValueError when invalid parameter values/combinations - if (isinstance(window, int) and window <= 0) or window is None: + if (isinstance(window, int) and window < 0) or window is None: raise ValueError("window must be an integer 0 or greater") if not isinstance(center, bool): raise ValueError("center must be a boolean") @@ -150,6 +151,12 @@ def __init__( "method": method, } self.axis = axis + if method != "single": + WarningMessage.ignored_argument( + operation="Rolling", + argument="method", + message="Snowpark pandas API executes on Snowflake. Ignoring engine related arguments to select a different execution engine.", + ) def _call_qc_method(self, method_name, *args, **kwargs): """ @@ -471,6 +478,11 @@ def __init__( method: str = "single", ) -> None: # TODO: SNOW-1063366: Modin upgrade - modin.pandas.window.Expanding + if min_periods is not None and not isinstance(min_periods, int): + raise ValueError("min_periods must be an integer") + if isinstance(min_periods, int) and min_periods < 0: + raise ValueError("min_periods must be >= 0") + self._dataframe = dataframe self._query_compiler = dataframe._query_compiler self.expanding_kwargs = { @@ -479,6 +491,12 @@ def __init__( "method": method, } self.axis = axis + if method != "single": + WarningMessage.ignored_argument( + operation="Expanding", + argument="method", + message="Snowpark pandas API executes on Snowflake. Ignoring engine related arguments to select a different execution engine.", + ) def count( self, diff --git a/src/snowflake/snowpark/modin/plugin/_internal/window_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/window_utils.py index ba3b5a2b04b..a4c6e33f5c6 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/window_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/window_utils.py @@ -5,18 +5,31 @@ # This file contains utils functions used by the groupby functionalities. # # +from enum import Enum from typing import Any from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage -IMPLEMENTED_ROLLING_AGG_FUNCS = ["sum", "mean", "var", "std", "min", "max"] +class WindowFunction(Enum): + """ + Type of window function. + + Attributes: + EXPANDING (str): Represents the expanding window. + ROLLING (str): Represents the rolling window. + """ + + EXPANDING = "expanding" + ROLLING = "rolling" -def check_is_rolling_window_supported_by_snowflake( + +def check_and_raise_error_rolling_window_supported_by_snowflake( rolling_kwargs: dict[str, Any] ) -> None: """ Check if execution with snowflake engine is available for the rolling window operation. + If not, raise NotImplementedError. Parameters ---------- @@ -53,13 +66,7 @@ def check_is_rolling_window_supported_by_snowflake( step: int, default None Evaluate the window at every step result, equivalent to slicing as [::step]. window must be an integer. Using a step argument other than None or 1 will produce a result with a different shape than the input. method: str {‘single’, ‘table’}, default ‘single’ - Execute the rolling operation per single column or row ('single') or over the entire object ('table'). - This argument is only implemented when specifying engine='numba' in the method call. - - Returns - ------- - bool - Whether operations can be executed with snowflake sql engine. + **This parameter is ignored in Snowpark pandas since the execution engine will always be Snowflake.** """ # Snowflake pandas implementation only supports integer window_size, min_periods >= 1, and center on axis = 0 window = rolling_kwargs.get("window") @@ -69,7 +76,6 @@ def check_is_rolling_window_supported_by_snowflake( axis = rolling_kwargs.get("axis", 0) closed = rolling_kwargs.get("closed") step = rolling_kwargs.get("step") - # Method is only used for the numba engine, so no need to check the param/raise a warning to the user. # Raise not implemented error for unsupported params if not isinstance(window, int): @@ -101,3 +107,33 @@ def check_is_rolling_window_supported_by_snowflake( ErrorMessage.method_not_implemented_error( name="step", class_="Rolling" ) # pragma: no cover + + +def check_and_raise_error_expanding_window_supported_by_snowflake( + expanding_kwargs: dict[str, Any] +) -> None: + """ + Check if execution with snowflake engine is available for the expanding window operation. + If not, raise NotImplementedError. + + Parameters + ---------- + expanding_kwargs: keyword arguments passed to expanding. The expanding keywords handled in the + function contains: + min_periods: int, default 1. + Minimum number of observations in window required to have a value; otherwise, result is np.nan. + axis: int or str, default 0 + If 0 or 'index', roll across the rows. + If 1 or 'columns', roll across the columns. + For Series this parameter is unused and defaults to 0. + method: str {‘single’, ‘table’}, default ‘single’ + **This parameter is ignored in Snowpark pandas since the execution engine will always be Snowflake.** + """ + + axis = expanding_kwargs.get("axis", 0) + + if axis not in (0, "index"): + # Note that this is deprecated since pandas 2.1.0 + ErrorMessage.method_not_implemented_error( + name="axis = 1", class_="Expanding" + ) # pragma: no cover diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 241e503124c..b3f00bf4e7f 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -306,7 +306,9 @@ validate_expected_boolean_data_columns, ) from snowflake.snowpark.modin.plugin._internal.window_utils import ( - check_is_rolling_window_supported_by_snowflake, + WindowFunction, + check_and_raise_error_expanding_window_supported_by_snowflake, + check_and_raise_error_rolling_window_supported_by_snowflake, ) from snowflake.snowpark.modin.plugin._typing import ( DropKeep, @@ -10540,8 +10542,13 @@ def rolling_count( numeric_only: bool = False, *args: Any, **kwargs: Any, - ) -> None: - ErrorMessage.method_not_implemented_error(name="count", class_="Rolling") + ) -> "SnowflakeQueryCompiler": + return self._window_agg( + window_func=WindowFunction.ROLLING, + agg_func="count", + window_kwargs=rolling_kwargs, + agg_kwargs=dict(numeric_only=numeric_only), + ) def rolling_sum( self, @@ -10556,9 +10563,10 @@ def rolling_sum( WarningMessage.warning_if_engine_args_is_set( "rolling_sum", engine, engine_kwargs ) - return self._rolling_agg( + return self._window_agg( + window_func=WindowFunction.ROLLING, agg_func="sum", - rolling_kwargs=rolling_kwargs, + window_kwargs=rolling_kwargs, agg_kwargs=dict(numeric_only=numeric_only), ) @@ -10575,9 +10583,10 @@ def rolling_mean( WarningMessage.warning_if_engine_args_is_set( "rolling_mean", engine, engine_kwargs ) - return self._rolling_agg( + return self._window_agg( + window_func=WindowFunction.ROLLING, agg_func="mean", - rolling_kwargs=rolling_kwargs, + window_kwargs=rolling_kwargs, agg_kwargs=dict(numeric_only=numeric_only), ) @@ -10606,9 +10615,10 @@ def rolling_var( WarningMessage.warning_if_engine_args_is_set( "rolling_var", engine, engine_kwargs ) - return self._rolling_agg( + return self._window_agg( + window_func=WindowFunction.ROLLING, agg_func="var", - rolling_kwargs=rolling_kwargs, + window_kwargs=rolling_kwargs, agg_kwargs=dict(ddof=ddof, numeric_only=numeric_only), ) @@ -10626,9 +10636,10 @@ def rolling_std( WarningMessage.warning_if_engine_args_is_set( "rolling_var", engine, engine_kwargs ) - return self._rolling_agg( + return self._window_agg( + window_func=WindowFunction.ROLLING, agg_func="std", - rolling_kwargs=rolling_kwargs, + window_kwargs=rolling_kwargs, agg_kwargs=dict(ddof=ddof, numeric_only=numeric_only), ) @@ -10645,9 +10656,10 @@ def rolling_min( WarningMessage.warning_if_engine_args_is_set( "rolling_min", engine, engine_kwargs ) - return self._rolling_agg( + return self._window_agg( + window_func=WindowFunction.ROLLING, agg_func="min", - rolling_kwargs=rolling_kwargs, + window_kwargs=rolling_kwargs, agg_kwargs=dict(numeric_only=numeric_only), ) @@ -10664,9 +10676,10 @@ def rolling_max( WarningMessage.warning_if_engine_args_is_set( "rolling_max", engine, engine_kwargs ) - return self._rolling_agg( + return self._window_agg( + window_func=WindowFunction.ROLLING, agg_func="max", - rolling_kwargs=rolling_kwargs, + window_kwargs=rolling_kwargs, agg_kwargs=dict(numeric_only=numeric_only), ) @@ -10769,15 +10782,17 @@ def rolling_rank( ) -> None: ErrorMessage.method_not_implemented_error(name="rank", class_="Rolling") - def _rolling_agg( + def _window_agg( self, + window_func: WindowFunction, agg_func: AggFuncType, - rolling_kwargs: dict[str, Any], + window_kwargs: dict[str, Any], agg_kwargs: dict[str, Any], ) -> "SnowflakeQueryCompiler": """ Compute rolling window with given aggregation. Args: + window_func: the type of window function to apply. agg_func: callable, str, list or dict. the aggregation function used. rolling_kwargs: keyword arguments passed to rolling. agg_kwargs: keyword arguments passed for the aggregation function. @@ -10785,9 +10800,9 @@ def _rolling_agg( SnowflakeQueryCompiler: with a newly constructed internal dataframe """ - window = rolling_kwargs.get("window") - min_periods = rolling_kwargs.get("min_periods") - center = rolling_kwargs.get("center") + window = window_kwargs.get("window") + min_periods = window_kwargs.get("min_periods") + center = window_kwargs.get("center") numeric_only = agg_kwargs.get("numeric_only", False) query_compiler = self if numeric_only: @@ -10797,7 +10812,11 @@ def _rolling_agg( ) # Throw NotImplementedError if any parameter is unsupported - check_is_rolling_window_supported_by_snowflake(rolling_kwargs) + if window_func == WindowFunction.ROLLING: + check_and_raise_error_rolling_window_supported_by_snowflake(window_kwargs) + elif window_func == WindowFunction.EXPANDING: + check_and_raise_error_expanding_window_supported_by_snowflake(window_kwargs) + frame = query_compiler._modin_frame.ensure_row_position_column() row_position_quoted_identifier = frame.row_position_snowflake_quoted_identifier if center: @@ -10805,21 +10824,34 @@ def _rolling_agg( rows_between_start = -(window // 2) # type: ignore rows_between_end = (window - 1) // 2 # type: ignore else: - # 1 - window is equivalent to window - 1 PRECEDING - rows_between_start = 1 - window # type: ignore + if window_func == WindowFunction.ROLLING: + # 1 - window is equivalent to window - 1 PRECEDING + rows_between_start = 1 - window # type: ignore + else: + rows_between_start = Window.UNBOUNDED_PRECEDING rows_between_end = Window.CURRENT_ROW window_expr = Window.orderBy(col(row_position_quoted_identifier)).rows_between( rows_between_start, rows_between_end ) + # Handle case where min_periods = None + min_periods = 0 if min_periods is None else min_periods # Perform Aggregation over the window_expr new_frame = frame.update_snowflake_quoted_identifiers_with_expressions( { + # If aggregation is count use count on row_position_quoted_identifier + # to include NULL values for min_periods comparison quoted_identifier: iff( - count(col(quoted_identifier)).over(window_expr) >= min_periods, + count(col(row_position_quoted_identifier)).over(window_expr) + >= min_periods + if agg_func == "count" + else count(col(quoted_identifier)).over(window_expr) >= min_periods, get_snowflake_agg_func(agg_func, agg_kwargs)( - col(quoted_identifier) + # Expanding is cumulative so replace NULL with 0 for sum aggregation + builtin("zeroifnull")(col(quoted_identifier)) + if window_func == WindowFunction.EXPANDING and agg_func == "sum" + else col(quoted_identifier) ).over(window_expr), pandas_lit(None), ) @@ -10833,8 +10865,13 @@ def expanding_count( fold_axis: Union[int, str], expanding_kwargs: dict, numeric_only: bool = False, - ) -> None: - ErrorMessage.method_not_implemented_error(name="count", class_="Expanding") + ) -> "SnowflakeQueryCompiler": + return self._window_agg( + window_func=WindowFunction.EXPANDING, + agg_func="count", + window_kwargs=expanding_kwargs, + agg_kwargs=dict(numeric_only=numeric_only), + ) def expanding_sum( self, @@ -10843,8 +10880,16 @@ def expanding_sum( numeric_only: bool = False, engine: Optional[Literal["cython", "numba"]] = None, engine_kwargs: Optional[dict[str, bool]] = None, - ) -> None: - ErrorMessage.method_not_implemented_error(name="sum", class_="Expanding") + ) -> "SnowflakeQueryCompiler": + WarningMessage.warning_if_engine_args_is_set( + "expanding_sum", engine, engine_kwargs + ) + return self._window_agg( + window_func=WindowFunction.EXPANDING, + agg_func="sum", + window_kwargs=expanding_kwargs, + agg_kwargs=dict(numeric_only=numeric_only), + ) def expanding_mean( self, @@ -10853,8 +10898,16 @@ def expanding_mean( numeric_only: bool = False, engine: Optional[Literal["cython", "numba"]] = None, engine_kwargs: Optional[dict[str, bool]] = None, - ) -> None: - ErrorMessage.method_not_implemented_error(name="mean", class_="Expanding") + ) -> "SnowflakeQueryCompiler": + WarningMessage.warning_if_engine_args_is_set( + "expanding_mean", engine, engine_kwargs + ) + return self._window_agg( + window_func=WindowFunction.EXPANDING, + agg_func="mean", + window_kwargs=expanding_kwargs, + agg_kwargs=dict(numeric_only=numeric_only), + ) def expanding_median( self, @@ -10874,8 +10927,16 @@ def expanding_var( numeric_only: bool = False, engine: Optional[Literal["cython", "numba"]] = None, engine_kwargs: Optional[dict[str, bool]] = None, - ) -> None: - ErrorMessage.method_not_implemented_error(name="var", class_="Expanding") + ) -> "SnowflakeQueryCompiler": + WarningMessage.warning_if_engine_args_is_set( + "rolling_var", engine, engine_kwargs + ) + return self._window_agg( + window_func=WindowFunction.EXPANDING, + agg_func="var", + window_kwargs=expanding_kwargs, + agg_kwargs=dict(ddof=ddof, numeric_only=numeric_only), + ) def expanding_std( self, @@ -10885,8 +10946,16 @@ def expanding_std( numeric_only: bool = False, engine: Optional[Literal["cython", "numba"]] = None, engine_kwargs: Optional[dict[str, bool]] = None, - ) -> None: - ErrorMessage.method_not_implemented_error(name="std", class_="Expanding") + ) -> "SnowflakeQueryCompiler": + WarningMessage.warning_if_engine_args_is_set( + "rolling_std", engine, engine_kwargs + ) + return self._window_agg( + window_func=WindowFunction.EXPANDING, + agg_func="std", + window_kwargs=expanding_kwargs, + agg_kwargs=dict(ddof=ddof, numeric_only=numeric_only), + ) def expanding_min( self, @@ -10895,8 +10964,16 @@ def expanding_min( numeric_only: bool = False, engine: Optional[Literal["cython", "numba"]] = None, engine_kwargs: Optional[dict[str, bool]] = None, - ) -> None: - ErrorMessage.method_not_implemented_error(name="min", class_="Expanding") + ) -> "SnowflakeQueryCompiler": + WarningMessage.warning_if_engine_args_is_set( + "expanding_min", engine, engine_kwargs + ) + return self._window_agg( + window_func=WindowFunction.EXPANDING, + agg_func="min", + window_kwargs=expanding_kwargs, + agg_kwargs=dict(numeric_only=numeric_only), + ) def expanding_max( self, @@ -10905,8 +10982,16 @@ def expanding_max( numeric_only: bool = False, engine: Optional[Literal["cython", "numba"]] = None, engine_kwargs: Optional[dict[str, bool]] = None, - ) -> None: - ErrorMessage.method_not_implemented_error(name="max", class_="Expanding") + ) -> "SnowflakeQueryCompiler": + WarningMessage.warning_if_engine_args_is_set( + "expanding_max", engine, engine_kwargs + ) + return self._window_agg( + window_func=WindowFunction.EXPANDING, + agg_func="max", + window_kwargs=expanding_kwargs, + agg_kwargs=dict(numeric_only=numeric_only), + ) def expanding_corr( self, diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/base.py b/src/snowflake/snowpark/modin/plugin/docstrings/base.py index 15ebea3f7de..8e219d682be 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/base.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/base.py @@ -1053,6 +1053,18 @@ def ewm(): def expanding(): """ Provide expanding window calculations. + Currently, ``axis = 1`` is not supported. + + Parameters + ---------- + min_periods: int, default 1. + Minimum number of observations in window required to have a value; otherwise, result is np.nan. + axis: int or str, default 0 + If 0 or 'index', roll across the rows. + If 1 or 'columns', roll across the columns. + For Series this parameter is unused and defaults to 0. + method: str {‘single’, ‘table’}, default ‘single’ + **This parameter is ignored in Snowpark pandas since the execution engine will always be Snowflake.** """ def ffill(): @@ -2558,8 +2570,7 @@ def rolling(): step: int, default None Evaluate the window at every step result, equivalent to slicing as [::step]. window must be an integer. Using a step argument other than None or 1 will produce a result with a different shape than the input. method: str {‘single’, ‘table’}, default ‘single’ - Execute the rolling operation per single column or row ('single') or over the entire object ('table'). - This argument is only implemented when specifying engine='numba' in the method call. + **This parameter is ignored in Snowpark pandas since the execution engine will always be Snowflake.** """ def round(): diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/window.py b/src/snowflake/snowpark/modin/plugin/docstrings/window.py index 6991b7a143c..499b8867597 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/window.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/window.py @@ -9,8 +9,8 @@ from pandas.util._decorators import doc -_rolling_agg_method_engine_template = """ -Compute the rolling {fname}. +_window_agg_method_engine_template = """ +Compute the {win_type} {fname}. Parameters ---------- @@ -35,13 +35,12 @@ **This parameter is ignored in Snowpark pandas. The execution engine will always be Snowflake.** -**kwargs - Keyword arguments to be passed into func. +{kwargs} Returns ------- :class:`~snowflake.snowpark.modin.pandas.Series` or :class:`~snowflake.snowpark.modin.pandas.DataFrame` - Computed rolling {fname} of values. + Computed {win_type} {fname} of values. Examples -------- @@ -61,10 +60,10 @@ - list of functions and/or function names, e.g. ``[np.sum, 'mean']`` - dict of axis labels -> functions, function names or list of such. -*args +*args : tuple Positional arguments to pass to func. -**kwargs +**kwargs : dict Keyword arguments to be passed into func. Returns @@ -112,18 +111,77 @@ class Rolling: + + """ + Compute the rolling count. + + Parameters + ---------- + numeric_only : bool, default False + Include only float, int, boolean columns. + + *args : tuple + Positional arguments to pass to func. + + **kwargs : dict + Keyword arguments to be passed into func. + + Returns + ------- + :class:`~snowflake.snowpark.modin.pandas.Series` or :class:`~snowflake.snowpark.modin.pandas.DataFrame` + Computed rolling count of values. + + Examples + -------- + >>> df = pd.DataFrame({'B': [0, 1, 2, np.nan, 4]}) + >>> df + B + 0 0.0 + 1 1.0 + 2 2.0 + 3 NaN + 4 4.0 + >>> df.rolling(2, min_periods=1).count() + B + 0 1 + 1 2 + 2 2 + 3 1 + 4 1 + >>> df.rolling(2, min_periods=2).count() + B + 0 NaN + 1 2.0 + 2 2.0 + 3 1.0 + 4 1.0 + >>> df.rolling(3, min_periods=1, center=True).count() + B + 0 2 + 1 3 + 2 2 + 3 2 + 4 1 + """ + def count(): pass @doc( - _rolling_agg_method_engine_template, + _window_agg_method_engine_template, + win_type="rolling", fname="sum", no=False, args=dedent( """\ - *args + *args : tuple Positional arguments to pass to func.""" ), + kwargs=dedent( + """\ + **kwargs : dict + Keyword arguments to be passed into func.""" + ), e=None, ek=None, example=dedent( @@ -163,13 +221,19 @@ def sum(): pass @doc( - _rolling_agg_method_engine_template, + _window_agg_method_engine_template, + win_type="rolling", fname="mean", args=dedent( """\ - *args + *args : tuple Positional arguments to pass to func.""" ), + kwargs=dedent( + """\ + **kwargs : dict + Keyword arguments to be passed into func.""" + ), no=False, e=None, ek=None, @@ -211,9 +275,15 @@ def mean(): # TODO: SNOW-1419071 API not implemented - uncomment when done. # @doc( - # _rolling_agg_method_engine_template, + # _window_agg_method_engine_template, + # win_type="rolling", # fname="median", # args=None, + # kwargs=dedent( + # """\ + # **kwargs : dict + # Keyword arguments to be passed into func.""" + # ), # no=False, # e=None, # ek=None, @@ -240,13 +310,19 @@ def median(): pass @doc( - _rolling_agg_method_engine_template, + _window_agg_method_engine_template, + win_type="rolling", fname="var", args=dedent( """\ - *args + *args : tuple Positional arguments to pass to func.""" ), + kwargs=dedent( + """\ + **kwargs : dict + Keyword arguments to be passed into func.""" + ), no=False, e=None, ek=None, @@ -287,13 +363,19 @@ def var(): pass @doc( - _rolling_agg_method_engine_template, + _window_agg_method_engine_template, + win_type="rolling", fname="std", args=dedent( """\ - *args + *args : tuple Positional arguments to pass to func.""" ), + kwargs=dedent( + """\ + **kwargs : dict + Keyword arguments to be passed into func.""" + ), no=False, e=None, ek=None, @@ -334,13 +416,19 @@ def std(): pass @doc( - _rolling_agg_method_engine_template, + _window_agg_method_engine_template, + win_type="rolling", fname="min", args=dedent( """\ - *args + *args : tuple Positional arguments to pass to func.""" ), + kwargs=dedent( + """\ + **kwargs : dict + Keyword arguments to be passed into func.""" + ), no=False, e=None, ek=None, @@ -367,13 +455,19 @@ def min(): pass @doc( - _rolling_agg_method_engine_template, + _window_agg_method_engine_template, + win_type="rolling", fname="max", args=dedent( """\ - *args + *args : tuple Positional arguments to pass to func.""" ), + kwargs=dedent( + """\ + **kwargs : dict + Keyword arguments to be passed into func.""" + ), no=False, e=None, ek=None, @@ -434,27 +528,228 @@ def rank(): class Expanding: + + """ + Compute the expanding count. + + Parameters + ---------- + numeric_only : bool, default False + Include only float, int, boolean columns. + + Returns + ------- + :class:`~snowflake.snowpark.modin.pandas.Series` or :class:`~snowflake.snowpark.modin.pandas.DataFrame` + Computed expanding count of values. + + Examples + -------- + >>> df = pd.DataFrame({'B': [0, 1, 2, np.nan, 4]}) + >>> df + B + 0 0.0 + 1 1.0 + 2 2.0 + 3 NaN + 4 4.0 + >>> df.expanding(2).count() + B + 0 NaN + 1 2.0 + 2 3.0 + 3 3.0 + 4 4.0 + """ + def count(): pass + @doc( + _window_agg_method_engine_template, + win_type="expanding", + fname="sum", + no=False, + args=None, + kwargs=None, + e=None, + ek=None, + example=dedent( + """\ + >>> df = pd.DataFrame({'B': [0, 1, 2, np.nan, 4]}) + >>> df + B + 0 0.0 + 1 1.0 + 2 2.0 + 3 NaN + 4 4.0 + >>> df.expanding(2).sum() + B + 0 NaN + 1 1.0 + 2 3.0 + 3 3.0 + 4 7.0""" + ), + ) def sum(): pass + @doc( + _window_agg_method_engine_template, + win_type="expanding", + fname="mean", + no=False, + args=None, + kwargs=None, + e=None, + ek=None, + example=dedent( + """\ + >>> df = pd.DataFrame({'B': [0, 1, 2, np.nan, 4]}) + >>> df + B + 0 0.0 + 1 1.0 + 2 2.0 + 3 NaN + 4 4.0 + >>> df.expanding(2).mean() + B + 0 NaN + 1 0.50 + 2 1.00 + 3 1.00 + 4 1.75""" + ), + ) def mean(): pass def median(): pass + @doc( + _window_agg_method_engine_template, + win_type="expanding", + fname="var", + no=False, + args=None, + kwargs=None, + e=None, + ek=None, + example=dedent( + """\ + >>> df = pd.DataFrame({'B': [0, 1, 2, np.nan, 4]}) + >>> df + B + 0 0.0 + 1 1.0 + 2 2.0 + 3 NaN + 4 4.0 + >>> df.expanding(2).var() + B + 0 NaN + 1 0.500000 + 2 1.000000 + 3 1.000000 + 4 2.916667""" + ), + ) def var(): pass + @doc( + _window_agg_method_engine_template, + win_type="expanding", + fname="std", + no=False, + args=None, + kwargs=None, + e=None, + ek=None, + example=dedent( + """\ + >>> df = pd.DataFrame({'B': [0, 1, 2, np.nan, 4]}) + >>> df + B + 0 0.0 + 1 1.0 + 2 2.0 + 3 NaN + 4 4.0 + >>> df.expanding(2).std() + B + 0 NaN + 1 0.707107 + 2 1.000000 + 3 1.000000 + 4 1.707825""" + ), + ) def std(): pass + @doc( + _window_agg_method_engine_template, + win_type="expanding", + fname="std", + no=False, + args=None, + kwargs=None, + e=None, + ek=None, + example=dedent( + """\ + >>> df = pd.DataFrame({'B': [0, 1, 2, np.nan, 4]}) + >>> df + B + 0 0.0 + 1 1.0 + 2 2.0 + 3 NaN + 4 4.0 + >>> df.expanding(2).min() + B + 0 NaN + 1 0.0 + 2 0.0 + 3 0.0 + 4 0.0""" + ), + ) def min(): pass + @doc( + _window_agg_method_engine_template, + win_type="expanding", + fname="std", + no=False, + args=None, + kwargs=None, + e=None, + ek=None, + example=dedent( + """\ + >>> df = pd.DataFrame({'B': [0, 1, 2, np.nan, 4]}) + >>> df + B + 0 0.0 + 1 1.0 + 2 2.0 + 3 NaN + 4 4.0 + >>> df.expanding(2).max() + B + 0 NaN + 1 1.0 + 2 2.0 + 3 2.0 + 4 4.0""" + ), + ) def max(): pass diff --git a/tests/integ/modin/window/test_expanding.py b/tests/integ/modin/window/test_expanding.py index 11cf8ff2af4..12b84c5da67 100644 --- a/tests/integ/modin/window/test_expanding.py +++ b/tests/integ/modin/window/test_expanding.py @@ -4,22 +4,113 @@ import modin.pandas as pd import numpy as np +import pandas as native_pd import pytest +import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.sql_counter import sql_count_checker +from tests.integ.modin.utils import eval_snowpark_pandas_result + +agg_func = pytest.mark.parametrize( + "agg_func", ["count", "sum", "mean", "var", "std", "min", "max"] +) +min_periods = pytest.mark.parametrize("min_periods", [None, 0, 1, 2, 10]) + + +@agg_func +@min_periods +@sql_count_checker(query_count=1) +def test_expanding_dataframe(agg_func, min_periods): + native_df = native_pd.DataFrame( + {"A": ["h", "e", "l", "l", "o"], "B": [0, -1, 2.5, np.nan, 4]} + ) + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: getattr( + df.expanding(min_periods), + agg_func, + )(numeric_only=True), + ) + + +@agg_func +@min_periods +@sql_count_checker(query_count=1) +def test_expanding_null_dataframe(agg_func, min_periods): + native_df = native_pd.DataFrame( + { + "A": ["h", np.nan, "l", "l", "o"], + "B": [np.nan, np.nan, np.nan, np.nan, np.nan], + } + ) + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: getattr( + df.expanding(min_periods), + agg_func, + )(numeric_only=True), + ) + + +@agg_func +@min_periods +@sql_count_checker(query_count=1) +def test_expanding_series(agg_func, min_periods): + native_series = native_pd.Series([0, -1, 2.5, np.nan, 4]) + snow_series = pd.Series(native_series) + eval_snowpark_pandas_result( + snow_series, + native_series, + lambda df: getattr( + df.expanding(min_periods), + agg_func, + )(), + ) + + +@sql_count_checker(query_count=1) +def test_expanding_min_periods_default(): + native_df = native_pd.DataFrame( + {"A": ["h", "e", "l", "l", "o"], "B": [0, -1, 2.5, np.nan, 4]} + ) + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.expanding().min(numeric_only=True), + ) + + +@sql_count_checker(query_count=0) +def test_expanding_min_periods_negative(): + native_df = native_pd.DataFrame({"B": [0, 1, 2, np.nan, 4]}) + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.expanding("invalid_value").min(), + expect_exception=True, + expect_exception_type=ValueError, + expect_exception_match="min_periods must be an integer", + ) + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.expanding(-2).min(), + expect_exception=True, + expect_exception_type=ValueError, + expect_exception_match="min_periods must be >= 0", + ) @pytest.mark.parametrize( "agg_func, agg_func_kwargs", [ - ("count", None), - ("sum", None), - ("mean", None), ("median", None), - ("var", 1), - ("std", 1), - ("min", None), - ("max", None), ("corr", None), ("cov", None), ("skew", None), @@ -41,14 +132,7 @@ def test_expanding_aggregation_dataframe_unsupported(agg_func, agg_func_kwargs): @pytest.mark.parametrize( "agg_func, agg_func_kwargs", [ - ("count", None), - ("sum", None), - ("mean", None), ("median", None), - ("var", 1), - ("std", 1), - ("min", None), - ("max", None), ("corr", None), ("cov", None), ("skew", None), diff --git a/tests/integ/modin/window/test_rolling.py b/tests/integ/modin/window/test_rolling.py index b6537e74f63..958de2ecc37 100644 --- a/tests/integ/modin/window/test_rolling.py +++ b/tests/integ/modin/window/test_rolling.py @@ -8,13 +8,12 @@ import pytest import snowflake.snowpark.modin.plugin # noqa: F401 -from snowflake.snowpark.modin.plugin._internal.window_utils import ( - IMPLEMENTED_ROLLING_AGG_FUNCS, -) from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker from tests.integ.modin.utils import eval_snowpark_pandas_result -agg_func = pytest.mark.parametrize("agg_func", IMPLEMENTED_ROLLING_AGG_FUNCS) +agg_func = pytest.mark.parametrize( + "agg_func", ["count", "sum", "mean", "var", "std", "min", "max"] +) window = pytest.mark.parametrize("window", [1, 2, 3, 4, 6]) min_periods = pytest.mark.parametrize("min_periods", [1, 2]) center = pytest.mark.parametrize("center", [True, False]) @@ -211,6 +210,7 @@ def test_rolling_window_unsupported(): lambda df: df.rolling(2, axis=1).sum(), lambda df: df.rolling(2, closed="left").sum(), lambda df: df.rolling(2, step=2).sum(), + lambda df: df.rolling(0, min_periods=0).sum(), ], ) @sql_count_checker(query_count=0) @@ -223,7 +223,6 @@ def test_rolling_params_unsupported(function): @pytest.mark.parametrize( "agg_func, agg_func_kwargs", [ - ("count", None), ("sem", None), ("median", None), ("corr", None),