From c592a0f43d696d515fd3607fa60a2e767c4b19f0 Mon Sep 17 00:00:00 2001 From: Varnika Budati Date: Fri, 13 Sep 2024 14:49:36 -0700 Subject: [PATCH] support datetimeindex.std and mean --- CHANGELOG.md | 1 + .../supported/datetime_index_supported.rst | 4 +- .../modin/plugin/extensions/datetime_index.py | 51 ++++++---- .../index/test_datetime_index_methods.py | 94 +++++++++++++++++++ 4 files changed, 128 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e0589d4a358..c3da8d2dd0b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ #### New Features - Added support for `TimedeltaIndex.mean` method. +- Added support for `DatetimeIndex.mean` and `DatetimeIndex.std` methods. ## 1.22.1 (2024-09-11) diff --git a/docs/source/modin/supported/datetime_index_supported.rst b/docs/source/modin/supported/datetime_index_supported.rst index 68b1935da96..325da109877 100644 --- a/docs/source/modin/supported/datetime_index_supported.rst +++ b/docs/source/modin/supported/datetime_index_supported.rst @@ -100,7 +100,7 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``day_name`` | P | ``locale`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``mean`` | N | | | +| ``mean`` | Y | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``std`` | N | | | +| ``std`` | Y | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py index e02e02810c8..2ad902e8a4e 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py @@ -43,6 +43,7 @@ ) from pandas.core.dtypes.common import is_datetime64_any_dtype +from snowflake.snowpark.modin.pandas import to_datetime, to_timedelta from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( SnowflakeQueryCompiler, ) @@ -1500,6 +1501,8 @@ def mean( skipna : bool, default True Whether to ignore any NaT elements. axis : int, optional, default 0 + The axis to calculate the mean over. + This parameter is ignored - 0 is the only valid axis. Returns ------- @@ -1522,21 +1525,22 @@ def mean( >>> idx.mean() Timestamp('2001-01-02 00:00:00') """ - return ( - self.to_series() - .agg("mean", axis=axis, skipna=skipna) - .to_pandas() - .squeeze(axis=1) + # Need to convert timestamp to int value (nanoseconds) before aggregating. + # TODO: SNOW-1625233 When `tz` is supported, add a `tz` parameter to `to_datetime` for correct timezone result. + if axis not in [None, 0]: + raise ValueError( + f"axis={axis} is not supported, this parameter is ignored. 0 is the only valid axis." + ) + return to_datetime( + self.to_series().astype("int64").agg("mean", axis=0, skipna=skipna) ) def std( self, - axis=None, - dtype=None, - out=None, + axis: AxisInt | None = None, ddof: int = 1, - keepdims: bool = False, skipna: bool = True, + **kwargs, ) -> timedelta: """ Return sample standard deviation over requested axis. @@ -1546,11 +1550,12 @@ def std( Parameters ---------- axis : int, optional - Axis for the function to be applied on. For :class:`pandas.Series` - this parameter is unused and defaults to ``None``. + The axis to calculate the standard deviation over. + This parameter is ignored - 0 is the only valid axis. ddof : int, default 1 Degrees of Freedom. The divisor used in calculations is `N - ddof`, where `N` represents the number of elements. + This parameter is not yet supported. skipna : bool, default True Exclude NA/null values. If an entire row/column is ``NA``, the result will be ``NA``. @@ -1575,13 +1580,19 @@ def std( >>> idx.std() Timedelta('1 days 00:00:00') """ - kwargs = { - "dtype": dtype, - "out": out, - "ddof": ddof, - "keepdims": keepdims, - "skipna": skipna, - } - return ( - self.to_series().agg("std", axis=axis, **kwargs).to_pandas().squeeze(axis=1) + if axis not in [None, 0]: + raise ValueError( + f"axis={axis} is not supported, this parameter is ignored. 0 is the only valid axis." + ) + if ddof != 1: + raise NotImplementedError( + "`ddof` parameter is not yet supported for `std`." + ) + # Need to convert timestamp to a float type to prevent overflow when aggregating. + # Cannot directly convert a timestamp to a float; therefore, first convert it to an int then a float. + return to_timedelta( + self.to_series() + .astype(int) + .astype(float) + .agg("std", axis=0, ddof=ddof, skipna=skipna, **kwargs) ) diff --git a/tests/integ/modin/index/test_datetime_index_methods.py b/tests/integ/modin/index/test_datetime_index_methods.py index 143e1d74080..a01e740ee84 100644 --- a/tests/integ/modin/index/test_datetime_index_methods.py +++ b/tests/integ/modin/index/test_datetime_index_methods.py @@ -294,3 +294,97 @@ def test_floor_ceil_round_negative(func, freq, ambiguous, nonexistent): getattr(snow_index, func)( freq=freq, ambiguous=ambiguous, nonexistent=nonexistent ) + + +@pytest.mark.parametrize( + "native_index", + [ + native_pd.date_range("2021-01-01", periods=5), + native_pd.date_range("2021-01-01", periods=5, freq="2D"), + pytest.param( + native_pd.DatetimeIndex( + [ + "2014-04-04 23:56:01.000000001", + "2014-07-18 21:24:02.000000002", + "2015-11-22 22:14:03.000000003", + "2015-11-23 20:12:04.1234567890", + pd.NaT, + ], + tz="US/Eastern", + ), + marks=pytest.mark.xfail(reason="Snowpark pandas does not support timezone"), + ), + native_pd.DatetimeIndex( + [ + "2014-04-04 23:56", + pd.NaT, + "2014-07-18 21:24", + "2015-11-22 22:14", + pd.NaT, + ] + ), + ], +) +@pytest.mark.parametrize("skipna", [True, False]) +@sql_count_checker(query_count=1) +def test_datetime_index_mean(native_index, skipna): + snow_index = pd.DatetimeIndex(native_index) + native_res = native_index.mean(skipna=skipna) + snow_res = snow_index.mean(skipna=skipna) + if native_res is pd.NaT: + assert snow_res is pd.NaT + else: + assert snow_res == native_res + + +@pytest.mark.parametrize( + "native_index", + [ + native_pd.date_range("2021-01-01", periods=5), + native_pd.date_range("2021-01-01", periods=5, freq="2D"), + # TODO: SNOW-1625233 Remove xfail when timezone is supported. + pytest.param( + native_pd.DatetimeIndex( + [ + "2014-04-04 23:56:01.000000001", + "2014-07-18 21:24:02.000000002", + "2015-11-22 22:14:03.000000003", + "2015-11-23 20:12:04.1234567890", + pd.NaT, + ], + tz="US/Eastern", + ), + marks=pytest.mark.xfail(reason="Snowpark pandas does not support timezone"), + ), + native_pd.DatetimeIndex( + [ + "2014-04-04 23:56", + pd.NaT, + "2014-07-18 21:24", + "2015-11-22 22:14", + pd.NaT, + ] + ), + ], +) +@pytest.mark.parametrize("ddof", [1]) +@pytest.mark.parametrize("skipna", [True, False]) +@sql_count_checker(query_count=1) +def test_datetime_index_std(native_index, ddof, skipna): + snow_index = pd.DatetimeIndex(native_index) + native_res = native_index.std(ddof=ddof, skipna=skipna) + snow_res = snow_index.std(ddof=ddof, skipna=skipna) + # Since the Snowpark pandas implementation converts timestamp values to float values, + # there is some loss in accuracy. Hence, we use approx to compare the results. + pytest.approx(snow_res, native_res, nan_ok=True) + + +@pytest.mark.parametrize("ops", ["mean", "std"]) +@sql_count_checker(query_count=0) +def test_datetime_index_agg_ops_axis_negative(ops): + snow_index = pd.DatetimeIndex(["2021-01-01", "2021-01-02", "2021-01-03"]) + msg = ( + "axis=1 is not supported, this parameter is ignored. 0 is the only valid axis." + ) + with pytest.raises(ValueError, match=msg): + getattr(snow_index, ops)(axis=1)