Skip to content

Commit

Permalink
support datetimeindex.std and mean
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-vbudati committed Sep 13, 2024
1 parent 82eec96 commit c592a0f
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 22 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#### New Features

- Added support for `TimedeltaIndex.mean` method.
- Added support for `DatetimeIndex.mean` and `DatetimeIndex.std` methods.


## 1.22.1 (2024-09-11)
Expand Down
4 changes: 2 additions & 2 deletions docs/source/modin/supported/datetime_index_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ Methods
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``day_name`` | P | ``locale`` | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``mean`` | N | | |
| ``mean`` | Y | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``std`` | N | | |
| ``std`` | Y | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
51 changes: 31 additions & 20 deletions src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from pandas.core.dtypes.common import is_datetime64_any_dtype

from snowflake.snowpark.modin.pandas import to_datetime, to_timedelta
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
SnowflakeQueryCompiler,
)
Expand Down Expand Up @@ -1500,6 +1501,8 @@ def mean(
skipna : bool, default True
Whether to ignore any NaT elements.
axis : int, optional, default 0
The axis to calculate the mean over.
This parameter is ignored - 0 is the only valid axis.
Returns
-------
Expand All @@ -1522,21 +1525,22 @@ def mean(
>>> idx.mean()
Timestamp('2001-01-02 00:00:00')
"""
return (
self.to_series()
.agg("mean", axis=axis, skipna=skipna)
.to_pandas()
.squeeze(axis=1)
# Need to convert timestamp to int value (nanoseconds) before aggregating.
# TODO: SNOW-1625233 When `tz` is supported, add a `tz` parameter to `to_datetime` for correct timezone result.
if axis not in [None, 0]:
raise ValueError(
f"axis={axis} is not supported, this parameter is ignored. 0 is the only valid axis."
)
return to_datetime(
self.to_series().astype("int64").agg("mean", axis=0, skipna=skipna)
)

def std(
self,
axis=None,
dtype=None,
out=None,
axis: AxisInt | None = None,
ddof: int = 1,
keepdims: bool = False,
skipna: bool = True,
**kwargs,
) -> timedelta:
"""
Return sample standard deviation over requested axis.
Expand All @@ -1546,11 +1550,12 @@ def std(
Parameters
----------
axis : int, optional
Axis for the function to be applied on. For :class:`pandas.Series`
this parameter is unused and defaults to ``None``.
The axis to calculate the standard deviation over.
This parameter is ignored - 0 is the only valid axis.
ddof : int, default 1
Degrees of Freedom. The divisor used in calculations is `N - ddof`,
where `N` represents the number of elements.
This parameter is not yet supported.
skipna : bool, default True
Exclude NA/null values. If an entire row/column is ``NA``, the result
will be ``NA``.
Expand All @@ -1575,13 +1580,19 @@ def std(
>>> idx.std()
Timedelta('1 days 00:00:00')
"""
kwargs = {
"dtype": dtype,
"out": out,
"ddof": ddof,
"keepdims": keepdims,
"skipna": skipna,
}
return (
self.to_series().agg("std", axis=axis, **kwargs).to_pandas().squeeze(axis=1)
if axis not in [None, 0]:
raise ValueError(
f"axis={axis} is not supported, this parameter is ignored. 0 is the only valid axis."
)
if ddof != 1:
raise NotImplementedError(
"`ddof` parameter is not yet supported for `std`."
)
# Need to convert timestamp to a float type to prevent overflow when aggregating.
# Cannot directly convert a timestamp to a float; therefore, first convert it to an int then a float.
return to_timedelta(
self.to_series()
.astype(int)
.astype(float)
.agg("std", axis=0, ddof=ddof, skipna=skipna, **kwargs)
)
94 changes: 94 additions & 0 deletions tests/integ/modin/index/test_datetime_index_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,97 @@ def test_floor_ceil_round_negative(func, freq, ambiguous, nonexistent):
getattr(snow_index, func)(
freq=freq, ambiguous=ambiguous, nonexistent=nonexistent
)


@pytest.mark.parametrize(
"native_index",
[
native_pd.date_range("2021-01-01", periods=5),
native_pd.date_range("2021-01-01", periods=5, freq="2D"),
pytest.param(
native_pd.DatetimeIndex(
[
"2014-04-04 23:56:01.000000001",
"2014-07-18 21:24:02.000000002",
"2015-11-22 22:14:03.000000003",
"2015-11-23 20:12:04.1234567890",
pd.NaT,
],
tz="US/Eastern",
),
marks=pytest.mark.xfail(reason="Snowpark pandas does not support timezone"),
),
native_pd.DatetimeIndex(
[
"2014-04-04 23:56",
pd.NaT,
"2014-07-18 21:24",
"2015-11-22 22:14",
pd.NaT,
]
),
],
)
@pytest.mark.parametrize("skipna", [True, False])
@sql_count_checker(query_count=1)
def test_datetime_index_mean(native_index, skipna):
snow_index = pd.DatetimeIndex(native_index)
native_res = native_index.mean(skipna=skipna)
snow_res = snow_index.mean(skipna=skipna)
if native_res is pd.NaT:
assert snow_res is pd.NaT
else:
assert snow_res == native_res


@pytest.mark.parametrize(
"native_index",
[
native_pd.date_range("2021-01-01", periods=5),
native_pd.date_range("2021-01-01", periods=5, freq="2D"),
# TODO: SNOW-1625233 Remove xfail when timezone is supported.
pytest.param(
native_pd.DatetimeIndex(
[
"2014-04-04 23:56:01.000000001",
"2014-07-18 21:24:02.000000002",
"2015-11-22 22:14:03.000000003",
"2015-11-23 20:12:04.1234567890",
pd.NaT,
],
tz="US/Eastern",
),
marks=pytest.mark.xfail(reason="Snowpark pandas does not support timezone"),
),
native_pd.DatetimeIndex(
[
"2014-04-04 23:56",
pd.NaT,
"2014-07-18 21:24",
"2015-11-22 22:14",
pd.NaT,
]
),
],
)
@pytest.mark.parametrize("ddof", [1])
@pytest.mark.parametrize("skipna", [True, False])
@sql_count_checker(query_count=1)
def test_datetime_index_std(native_index, ddof, skipna):
snow_index = pd.DatetimeIndex(native_index)
native_res = native_index.std(ddof=ddof, skipna=skipna)
snow_res = snow_index.std(ddof=ddof, skipna=skipna)
# Since the Snowpark pandas implementation converts timestamp values to float values,
# there is some loss in accuracy. Hence, we use approx to compare the results.
pytest.approx(snow_res, native_res, nan_ok=True)


@pytest.mark.parametrize("ops", ["mean", "std"])
@sql_count_checker(query_count=0)
def test_datetime_index_agg_ops_axis_negative(ops):
snow_index = pd.DatetimeIndex(["2021-01-01", "2021-01-02", "2021-01-03"])
msg = (
"axis=1 is not supported, this parameter is ignored. 0 is the only valid axis."
)
with pytest.raises(ValueError, match=msg):
getattr(snow_index, ops)(axis=1)

0 comments on commit c592a0f

Please sign in to comment.