From a07902c0289f956600847855248cdea76619c200 Mon Sep 17 00:00:00 2001 From: Varnika Budati Date: Wed, 18 Sep 2024 13:14:15 -0700 Subject: [PATCH] SNOW-1559025 Implement `DatetimeIndex.mean` and `DatetimeIndex.std` (#2292) 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1559025 2. Fill out the following pre-review checklist: - [x] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. 3. Please describe how your code solves the related issue. Implemented `DatetimeIndex.mean` and `DatetimeIndex.std`. --------- Co-authored-by: Andong Zhan --- CHANGELOG.md | 1 + .../supported/datetime_index_supported.rst | 4 +- .../modin/plugin/extensions/datetime_index.py | 53 +++++++-- .../index/test_datetime_index_methods.py | 107 ++++++++++++++++++ 4 files changed, 151 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c0d0dd643a..7788637a7e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ - Added support for some cases of aggregating `Timedelta` columns on `axis=0` with `agg` or `aggregate`. - Added support for `by`, `left_by`, `right_by`, `left_index`, and `right_index` for `pd.merge_asof`. - Added support for passing parameter `include_describe` to `Session.query_history`. +- Added support for `DatetimeIndex.mean` and `DatetimeIndex.std` methods. #### Bug Fixes diff --git a/docs/source/modin/supported/datetime_index_supported.rst b/docs/source/modin/supported/datetime_index_supported.rst index 3afe671aee7..9ebf6935f77 100644 --- a/docs/source/modin/supported/datetime_index_supported.rst +++ b/docs/source/modin/supported/datetime_index_supported.rst @@ -100,7 +100,7 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``day_name`` | P | ``locale`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``mean`` | N | | | +| ``mean`` | Y | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``std`` | N | | | +| ``std`` | P | ``ddof`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py index 38edb9f7bee..16c6ebdc1d0 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py @@ -26,7 +26,7 @@ from __future__ import annotations -from datetime import tzinfo +from datetime import timedelta, tzinfo import modin import numpy as np @@ -43,6 +43,7 @@ ) from pandas.core.dtypes.common import is_datetime64_any_dtype +from snowflake.snowpark.modin.pandas import to_datetime, to_timedelta from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( SnowflakeQueryCompiler, ) @@ -1502,7 +1503,6 @@ def to_pydatetime(self) -> np.ndarray: datetime.datetime(2018, 3, 1, 0, 0)], dtype=object) """ - @datetime_index_not_implemented() def mean( self, *, skipna: bool = True, axis: AxisInt | None = 0 ) -> native_pd.Timestamp: @@ -1514,6 +1514,8 @@ def mean( skipna : bool, default True Whether to ignore any NaT elements. axis : int, optional, default 0 + The axis to calculate the mean over. + This parameter is ignored - 0 is the only valid axis. Returns ------- @@ -1533,20 +1535,26 @@ def mean( >>> idx = pd.date_range('2001-01-01 00:00', periods=3) >>> idx DatetimeIndex(['2001-01-01', '2001-01-02', '2001-01-03'], dtype='datetime64[ns]', freq=None) - >>> idx.mean() # doctest: +SKIP + >>> idx.mean() Timestamp('2001-01-02 00:00:00') """ + # Need to convert timestamp to int value (nanoseconds) before aggregating. + # TODO: SNOW-1625233 When `tz` is supported, add a `tz` parameter to `to_datetime` for correct timezone result. + if axis not in [None, 0]: + raise ValueError( + f"axis={axis} is not supported, this parameter is ignored. 0 is the only valid axis." + ) + return to_datetime( + self.to_series().astype("int64").agg("mean", axis=0, skipna=skipna) + ) - @datetime_index_not_implemented() def std( self, - axis=None, - dtype=None, - out=None, + axis: AxisInt | None = None, ddof: int = 1, - keepdims: bool = False, skipna: bool = True, - ): + **kwargs, + ) -> timedelta: """ Return sample standard deviation over requested axis. @@ -1555,11 +1563,12 @@ def std( Parameters ---------- axis : int, optional - Axis for the function to be applied on. For :class:`pandas.Series` - this parameter is unused and defaults to ``None``. + The axis to calculate the standard deviation over. + This parameter is ignored - 0 is the only valid axis. ddof : int, default 1 Degrees of Freedom. The divisor used in calculations is `N - ddof`, where `N` represents the number of elements. + This parameter is not yet supported. skipna : bool, default True Exclude NA/null values. If an entire row/column is ``NA``, the result will be ``NA``. @@ -1581,6 +1590,26 @@ def std( >>> idx = pd.date_range('2001-01-01 00:00', periods=3) >>> idx DatetimeIndex(['2001-01-01', '2001-01-02', '2001-01-03'], dtype='datetime64[ns]', freq=None) - >>> idx.std() # doctest: +SKIP + >>> idx.std() Timedelta('1 days 00:00:00') """ + if axis not in [None, 0]: + raise ValueError( + f"axis={axis} is not supported, this parameter is ignored. 0 is the only valid axis." + ) + if ddof != 1: + raise NotImplementedError( + "`ddof` parameter is not yet supported for `std`." + ) + # Snowflake cannot directly perform `std` on a timestamp; therefore, convert the timestamp to an integer. + # By default, the integer version of a timestamp is in nanoseconds. Directly performing computations with + # nanoseconds can lead to results with integer size much larger than the original integer size. Therefore, + # convert the nanoseconds to seconds and then compute the standard deviation. + # The timestamp is converted to seconds instead of the float version of nanoseconds since that can lead to + # floating point precision issues + return to_timedelta( + (self.to_series().astype(int) // 1_000_000_000).agg( + "std", axis=0, ddof=ddof, skipna=skipna, **kwargs + ) + * 1_000_000_000 + ) diff --git a/tests/integ/modin/index/test_datetime_index_methods.py b/tests/integ/modin/index/test_datetime_index_methods.py index 98d1a041c3b..da9294cdf30 100644 --- a/tests/integ/modin/index/test_datetime_index_methods.py +++ b/tests/integ/modin/index/test_datetime_index_methods.py @@ -405,3 +405,110 @@ def test_floor_ceil_round_negative(func, freq, ambiguous, nonexistent): getattr(snow_index, func)( freq=freq, ambiguous=ambiguous, nonexistent=nonexistent ) + + +@pytest.mark.parametrize( + "native_index", + [ + native_pd.date_range("2021-01-01", periods=5), + native_pd.date_range("2021-01-01", periods=5, freq="2D"), + pytest.param( + native_pd.DatetimeIndex( + [ + "2014-04-04 23:56:01.000000001", + "2014-07-18 21:24:02.000000002", + "2015-11-22 22:14:03.000000003", + "2015-11-23 20:12:04.1234567890", + pd.NaT, + ], + tz="US/Eastern", + ), + marks=pytest.mark.xfail( + reason="TODO: SNOW-1625233 Snowpark pandas to_datetime does not support timezone" + ), + ), + native_pd.DatetimeIndex( + [ + "2014-04-04 23:56", + pd.NaT, + "2014-07-18 21:24", + "2015-11-22 22:14", + pd.NaT, + ] + ), + ], +) +@pytest.mark.parametrize("skipna", [True, False]) +@sql_count_checker(query_count=1) +def test_datetime_index_mean(native_index, skipna): + snow_index = pd.DatetimeIndex(native_index) + native_res = native_index.mean(skipna=skipna) + snow_res = snow_index.mean(skipna=skipna) + if native_res is pd.NaT: + assert snow_res is pd.NaT + else: + assert snow_res == native_res + + +@pytest.mark.parametrize( + "native_index", + [ + native_pd.date_range("2021-01-01", periods=5), + native_pd.date_range("2021-01-01", periods=5, freq="2D"), + # TODO: SNOW-1625233 Remove xfail when timezone is supported. + pytest.param( + native_pd.DatetimeIndex( + [ + "2014-04-04 23:56:01.000000001", + "2014-07-18 21:24:02.000000002", + "2015-11-22 22:14:03.000000003", + "2015-11-23 20:12:04.1234567890", + pd.NaT, + ], + tz="US/Eastern", + ), + marks=pytest.mark.xfail( + reason="SNOW-1664175 Snowpark pandas `to_datetime` does not support tz" + ), + ), + native_pd.DatetimeIndex( + [ + "2014-04-04 23:56", + pd.NaT, + "2014-07-18 21:24", + "2015-11-22 22:14", + pd.NaT, + ] + ), + ], +) +@pytest.mark.parametrize("ddof", [1]) +@pytest.mark.parametrize("skipna", [True, False]) +@sql_count_checker(query_count=1) +def test_datetime_index_std(native_index, ddof, skipna): + snow_index = pd.DatetimeIndex(native_index) + native_res = native_index.std(ddof=ddof, skipna=skipna) + snow_res = snow_index.std(ddof=ddof, skipna=skipna) + # Since the Snowpark pandas implementation converts timestamp values to float values, + # there is some loss in accuracy. Hence, we use approx to compare the results. + pytest.approx(snow_res, native_res, nan_ok=True) + + +@pytest.mark.parametrize("ops", ["mean", "std"]) +@sql_count_checker(query_count=0) +def test_datetime_index_agg_ops_axis_negative(ops): + snow_index = pd.DatetimeIndex(["2021-01-01", "2021-01-02", "2021-01-03"]) + with pytest.raises( + ValueError, + match="axis=1 is not supported, this parameter is ignored. 0 is the only valid axis.", + ): + getattr(snow_index, ops)(axis=1) + + +@sql_count_checker(query_count=0) +def test_datetime_index_std_ddof_negative(): + snow_index = pd.DatetimeIndex(["2021-01-01", "2021-01-02", "2021-01-03"]) + with pytest.raises( + NotImplementedError, match="`ddof` parameter is not yet supported for `std`." + ): + snow_index.std(ddof=2)