Skip to content

Commit

Permalink
SNOW-1654414: Add support Series.dt.total_seconds and TimedeltaIndex.…
Browse files Browse the repository at this point in the history
…total_seconds (#2253)

SNOW-1654414: Add support Series.dt.total_seconds and
TimedeltaIndex.total_seconds
  • Loading branch information
sfc-gh-nkumar authored Sep 10, 2024
1 parent a03da97 commit 1c85c75
Show file tree
Hide file tree
Showing 9 changed files with 159 additions and 42 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
- support for `TimedeltaIndex` attributes: `days`, `seconds`, `microseconds` and `nanoseconds`.
- support for `diff` with timestamp columns on `axis=0` and `axis=1`
- support for `TimedeltaIndex` methods: `ceil`, `floor` and `round`.
- support for `TimedeltaIndex.total_seconds` method.
- Added support for index's arithmetic and comparison operators.
- Added support for `Series.dt.round`.
- Added documentation pages for `DatetimeIndex`.
Expand All @@ -104,6 +105,7 @@
- Added support for `Series.dt.days`, `Series.dt.seconds`, `Series.dt.microseconds`, and `Series.dt.nanoseconds`.
- Added support for creating a `DatetimeIndex` from an `Index` of numeric or string type.
- Added support for string indexing with `Timedelta` objects.
- Added support for `Series.dt.total_seconds` method.

#### Improvements

Expand Down
2 changes: 1 addition & 1 deletion docs/source/modin/supported/series_dt_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ the method in the left column.
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``day_name`` | P | ``N`` if `locale` is set. |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``total_seconds`` | N | |
| ``total_seconds`` | Y | |
+-----------------------------+---------------------------------+----------------------------------------------------+
| ``to_pytimedelta`` | N | |
+-----------------------------+---------------------------------+----------------------------------------------------+
Expand Down
2 changes: 2 additions & 0 deletions docs/source/modin/supported/timedelta_index_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,5 @@ Methods
+-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+
| ``mean`` | N | | |
+-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+
| ``total_seconds`` | Y | | |
+-----------------------------+---------------------------------+----------------------------------+-------------------------------------------+
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@
BooleanType,
DataType,
DateType,
DecimalType,
DoubleType,
FloatType,
IntegerType,
Expand Down Expand Up @@ -16882,14 +16883,26 @@ def day_name_func(column: SnowparkColumn) -> SnowparkColumn:
)
)

def dt_total_seconds(self) -> None:
def dt_total_seconds(self, include_index: bool = False) -> "SnowflakeQueryCompiler":
"""
Return total duration of each element expressed in seconds.
Args:
include_index: Whether to include the index columns in the operation.
Returns:
New QueryCompiler containing total seconds.
"""
ErrorMessage.not_implemented(
"Snowpark pandas doesn't yet support the method 'Series.dt.total_seconds'"
# This method is only applicable to timedelta types.
dtype = self.index_dtypes[0] if include_index else self.dtypes[0]
if not is_timedelta64_dtype(dtype):
raise AttributeError(
"'DatetimeProperties' object has no attribute 'total_seconds'"
)
return SnowflakeQueryCompiler(
self._modin_frame.apply_snowpark_function_to_columns(
# Cast the column to decimal of scale 9 to ensure no precision loss.
lambda x: x.cast(DecimalType(scale=9)) / 1_000_000_000,
include_index,
)
)

def dt_strftime(self, date_format: str) -> None:
Expand Down
53 changes: 52 additions & 1 deletion src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2166,7 +2166,58 @@ def day_name():
"""

def total_seconds():
pass
"""
Return total duration of each element expressed in seconds.
This method is available directly on TimedeltaArray, TimedeltaIndex
and on Series containing timedelta values under the ``.dt`` namespace.
Returns
-------
ndarray, Index or Series
When the calling object is a TimedeltaArray, the return type
is ndarray. When the calling object is a TimedeltaIndex,
the return type is an Index with a float64 dtype. When the calling object
is a Series, the return type is Series of type `float64` whose
index is the same as the original.
See Also
--------
datetime.timedelta.total_seconds : Standard library version
of this method.
TimedeltaIndex.components : Return a DataFrame with components of
each Timedelta.
Examples
--------
**Series**
>>> s = pd.Series(pd.to_timedelta(np.arange(5), unit='d'))
>>> s
0 0 days
1 1 days
2 2 days
3 3 days
4 4 days
dtype: timedelta64[ns]
>>> s.dt.total_seconds()
0 0.0
1 86400.0
2 172800.0
3 259200.0
4 345600.0
dtype: float64
**TimedeltaIndex**
>>> idx = pd.to_timedelta(np.arange(5), unit='d')
>>> idx
TimedeltaIndex(['0 days', '1 days', '2 days', '3 days', '4 days'], dtype='timedelta64[ns]', freq=None)
>>> idx.total_seconds()
Index([0.0, 86400.0, 172800.0, 259200.0, 345600.0], dtype='float64')
"""

def to_pytimedelta():
pass
Expand Down
20 changes: 20 additions & 0 deletions src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,3 +441,23 @@ def as_unit(self, unit: str) -> TimedeltaIndex:
>>> idx.as_unit('s') # doctest: +SKIP
TimedeltaIndex(['1 days 00:03:00'], dtype='timedelta64[s]', freq=None)
"""

def total_seconds(self) -> Index:
"""
Return total duration of each element expressed in seconds.
Returns
-------
An Index with float type.
Examples:
--------
>>> idx = pd.to_timedelta(np.arange(5), unit='d')
>>> idx
TimedeltaIndex(['0 days', '1 days', '2 days', '3 days', '4 days'], dtype='timedelta64[ns]', freq=None)
>>> idx.total_seconds()
Index([0.0, 86400.0, 172800.0, 259200.0, 345600.0], dtype='float64')
"""
return Index(
query_compiler=self._query_compiler.dt_total_seconds(include_index=True)
)
64 changes: 28 additions & 36 deletions tests/integ/modin/index/test_timedelta_index_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,24 @@
from tests.integ.modin.sql_counter import sql_count_checker
from tests.integ.modin.utils import assert_index_equal, eval_snowpark_pandas_result

TIMEDELTA_INDEX_DATA = [
"0ns",
"1d",
"1h",
"5h",
"9h",
"60s",
"1s",
"800ms",
"900ms",
"5us",
"6ns",
"1ns",
"1d 3s",
"9m 15s 8us",
None,
]


@sql_count_checker(query_count=0)
def test_timedelta_index_construction():
Expand Down Expand Up @@ -67,9 +85,7 @@ def test_property_not_implemented(property):
@pytest.mark.parametrize("attr", ["days", "seconds", "microseconds", "nanoseconds"])
@sql_count_checker(query_count=1)
def test_timedelta_index_properties(attr):
native_index = native_pd.TimedeltaIndex(
["1d", "1h", "60s", "1s", "800ms", "5us", "6ns", "1d 3s", "9m 15s 8us", None]
)
native_index = native_pd.TimedeltaIndex(TIMEDELTA_INDEX_DATA)
snow_index = pd.Index(native_index)
assert_index_equal(
getattr(snow_index, attr), getattr(native_index, attr), exact=False
Expand All @@ -82,24 +98,7 @@ def test_timedelta_index_properties(attr):
)
@sql_count_checker(query_count=1)
def test_timedelta_floor_ceil_round(method, freq):
native_index = native_pd.TimedeltaIndex(
[
"0ns" "1d",
"1h",
"5h",
"9h",
"60s",
"1s",
"800ms",
"900ms",
"5us",
"6ns",
"1ns",
"1d 3s",
"9m 15s 8us",
None,
]
)
native_index = native_pd.TimedeltaIndex(TIMEDELTA_INDEX_DATA)
snow_index = pd.Index(native_index)
eval_snowpark_pandas_result(
snow_index, native_index, lambda x: getattr(x, method)(freq)
Expand All @@ -112,21 +111,7 @@ def test_timedelta_floor_ceil_round(method, freq):
)
@sql_count_checker(query_count=0)
def test_timedelta_floor_ceil_round_negative(method, freq):
native_index = native_pd.TimedeltaIndex(
[
"0ns",
"1d",
"5h",
"60s",
"1s",
"900ms",
"5us",
"1ns",
"1d 3s",
"9m 15s 8us",
None,
]
)
native_index = native_pd.TimedeltaIndex(TIMEDELTA_INDEX_DATA)
snow_index = pd.Index(native_index)
eval_snowpark_pandas_result(
snow_index,
Expand All @@ -136,3 +121,10 @@ def test_timedelta_floor_ceil_round_negative(method, freq):
expect_exception_type=ValueError,
expect_exception_match=f"Invalid frequency: {freq}",
)


@sql_count_checker(query_count=1)
def test_timedelta_total_seconds():
native_index = native_pd.TimedeltaIndex(TIMEDELTA_INDEX_DATA)
snow_index = pd.Index(native_index)
eval_snowpark_pandas_result(snow_index, native_index, lambda x: x.total_seconds())
38 changes: 38 additions & 0 deletions tests/integ/modin/series/test_dt_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,41 @@ def test_dt_invalid_dtype_property_combo(data, data_type, property_name):
expect_exception=True,
expect_exception_match="object has no attribute",
)


@sql_count_checker(query_count=1)
def test_dt_total_seconds():
data = [
"0ns",
"1d",
"1h",
"5h",
"9h",
"60s",
"1s",
"800ms",
"900ms",
"5us",
"6ns",
"1ns",
"1d 3s",
"9m 15s 8us",
None,
]
native_ser = native_pd.Series(native_pd.TimedeltaIndex(data))
snow_ser = pd.Series(native_ser)
eval_snowpark_pandas_result(snow_ser, native_ser, lambda x: x.dt.total_seconds())


@sql_count_checker(query_count=0)
def test_timedelta_total_seconds_type_error():
native_ser = native_pd.Series(native_pd.DatetimeIndex(["2024-01-01"]))
snow_ser = pd.Series(native_ser)
eval_snowpark_pandas_result(
snow_ser,
native_ser,
lambda x: x.dt.total_seconds(),
expect_exception=True,
expect_exception_type=AttributeError,
expect_exception_match="'DatetimeProperties' object has no attribute 'total_seconds'",
)
1 change: 0 additions & 1 deletion tests/unit/modin/test_series_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def mock_query_compiler_for_dt_series() -> SnowflakeQueryCompiler:
(lambda s: s.dt.tz_localize(tz="UTC"), "tz_localize"),
(lambda s: s.dt.tz_convert(tz="UTC"), "tz_convert"),
(lambda s: s.dt.strftime(date_format="YY/MM/DD"), "strftime"),
(lambda s: s.dt.total_seconds(), "total_seconds"),
(lambda s: s.dt.qyear, "qyear"),
(lambda s: s.dt.start_time, "start_time"),
(lambda s: s.dt.end_time, "end_time"),
Expand Down

0 comments on commit 1c85c75

Please sign in to comment.