Skip to content

Commit

Permalink
SNOW-1660952, SNOW-1660954: Add support for DatetimeIndex.tz_localize…
Browse files Browse the repository at this point in the history
…/tz_convert (#2281)

<!---
Please answer these questions before creating your pull request. Thanks!
--->

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   <!---
   In this section, please add a Snowflake Jira issue number.
   
Note that if a corresponding GitHub issue exists, you should still
include
   the Snowflake Jira issue number. For example, for GitHub issue
#1400, you should
   add "SNOW-1335071" here.
    --->

   Fixes SNOW-1660952, SNOW-1660954

2. Fill out the following pre-review checklist:

- [x] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.

3. Please describe how your code solves the related issue.

SNOW-1660952, SNOW-1660954: Add support for
DatetimeIndex.tz_localize/tz_convert.
  • Loading branch information
sfc-gh-helmeleegy authored Sep 13, 2024
1 parent c7be18c commit 5b5c03b
Showing 6 changed files with 157 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -125,6 +125,7 @@ This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for det
- Added support for `Series.dt.total_seconds` method.
- Added support for `DataFrame.apply(axis=0)`.
- Added support for `Series.dt.tz_convert` and `Series.dt.tz_localize`.
- Added support for `DatetimeIndex.tz_convert` and `DatetimeIndex.tz_localize`.

#### Improvements

4 changes: 2 additions & 2 deletions docs/source/modin/supported/datetime_index_supported.rst
Original file line number Diff line number Diff line change
@@ -82,9 +82,9 @@ Methods
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``snap`` | N | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``tz_convert`` | N | | |
| ``tz_convert`` | Y | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``tz_localize`` | N | | |
| ``tz_localize`` | P | ``ambiguous``, ``nonexistent`` | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``round`` | P | ``ambiguous``, ``nonexistent`` | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
Original file line number Diff line number Diff line change
@@ -525,7 +525,7 @@ def tz_convert_column(column: Column, tz: Union[str, dt.tzinfo]) -> Column:
The column after conversion to the specified timezone
"""
if tz is None:
return convert_timezone(pandas_lit("UTC"), column)
return to_timestamp_ntz(convert_timezone(pandas_lit("UTC"), column))
else:
if isinstance(tz, dt.tzinfo):
tz_name = tz.tzname(None)
Original file line number Diff line number Diff line change
@@ -16660,46 +16660,58 @@ def dt_tz_localize(
tz: Union[str, tzinfo],
ambiguous: str = "raise",
nonexistent: str = "raise",
include_index: bool = False,
) -> "SnowflakeQueryCompiler":
"""
Localize tz-naive to tz-aware.
Args:
tz : str, pytz.timezone, optional
ambiguous : {"raise", "inner", "NaT"} or bool mask, default: "raise"
nonexistent : {"raise", "shift_forward", "shift_backward, "NaT"} or pandas.timedelta, default: "raise"
include_index: Whether to include the index columns in the operation.

Returns:
BaseQueryCompiler
New QueryCompiler containing values with localized time zone.
"""
dtype = self.index_dtypes[0] if include_index else self.dtypes[0]
if not include_index:
method_name = "Series.dt.tz_localize"
else:
assert is_datetime64_any_dtype(dtype), "column must be datetime"
method_name = "DatetimeIndex.tz_localize"

if not isinstance(ambiguous, str) or ambiguous != "raise":
ErrorMessage.parameter_not_implemented_error(
"ambiguous", "Series.dt.tz_localize"
)
ErrorMessage.parameter_not_implemented_error("ambiguous", method_name)
if not isinstance(nonexistent, str) or nonexistent != "raise":
ErrorMessage.parameter_not_implemented_error(
"nonexistent", "Series.dt.tz_localize"
)
ErrorMessage.parameter_not_implemented_error("nonexistent", method_name)

return SnowflakeQueryCompiler(
self._modin_frame.apply_snowpark_function_to_columns(
lambda column: tz_localize_column(column, tz)
lambda column: tz_localize_column(column, tz),
include_index,
)
)

def dt_tz_convert(self, tz: Union[str, tzinfo]) -> "SnowflakeQueryCompiler":
def dt_tz_convert(
self,
tz: Union[str, tzinfo],
include_index: bool = False,
) -> "SnowflakeQueryCompiler":
"""
Convert time-series data to the specified time zone.

Args:
tz : str, pytz.timezone
include_index: Whether to include the index columns in the operation.

Returns:
A new QueryCompiler containing values with converted time zone.
"""
return SnowflakeQueryCompiler(
self._modin_frame.apply_snowpark_function_to_columns(
lambda column: tz_convert_column(column, tz)
lambda column: tz_convert_column(column, tz),
include_index,
)
)

29 changes: 21 additions & 8 deletions src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py
Original file line number Diff line number Diff line change
@@ -960,7 +960,6 @@ def snap(self, freq: Frequency = "S") -> DatetimeIndex:
DatetimeIndex(['2023-01-01', '2023-01-01', '2023-02-01', '2023-02-01'], dtype='datetime64[ns]', freq=None)
"""

@datetime_index_not_implemented()
def tz_convert(self, tz) -> DatetimeIndex:
"""
Convert tz-aware Datetime Array/Index from one time zone to another.
@@ -1025,8 +1024,14 @@ def tz_convert(self, tz) -> DatetimeIndex:
'2014-08-01 09:00:00'],
dtype='datetime64[ns]', freq='h')
"""
# TODO (SNOW-1660843): Support tz in pd.date_range and unskip the doctests.
return DatetimeIndex(
query_compiler=self._query_compiler.dt_tz_convert(
tz,
include_index=True,
)
)

@datetime_index_not_implemented()
def tz_localize(
self,
tz,
@@ -1104,21 +1109,29 @@ def tz_localize(
Localize DatetimeIndex in US/Eastern time zone:
>>> tz_aware = tz_naive.tz_localize(tz='US/Eastern') # doctest: +SKIP
>>> tz_aware # doctest: +SKIP
DatetimeIndex(['2018-03-01 09:00:00-05:00',
'2018-03-02 09:00:00-05:00',
>>> tz_aware = tz_naive.tz_localize(tz='US/Eastern')
>>> tz_aware
DatetimeIndex(['2018-03-01 09:00:00-05:00', '2018-03-02 09:00:00-05:00',
'2018-03-03 09:00:00-05:00'],
dtype='datetime64[ns, US/Eastern]', freq=None)
dtype='datetime64[ns, UTC-05:00]', freq=None)
With the ``tz=None``, we can remove the time zone information
while keeping the local time (not converted to UTC):
>>> tz_aware.tz_localize(None) # doctest: +SKIP
>>> tz_aware.tz_localize(None)
DatetimeIndex(['2018-03-01 09:00:00', '2018-03-02 09:00:00',
'2018-03-03 09:00:00'],
dtype='datetime64[ns]', freq=None)
"""
# TODO (SNOW-1660843): Support tz in pd.date_range and unskip the doctests.
return DatetimeIndex(
query_compiler=self._query_compiler.dt_tz_localize(
tz,
ambiguous,
nonexistent,
include_index=True,
)
)

def round(
self, freq: Frequency, ambiguous: str = "raise", nonexistent: str = "raise"
111 changes: 111 additions & 0 deletions tests/integ/modin/index/test_datetime_index_methods.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@
import numpy as np
import pandas as native_pd
import pytest
import pytz

import snowflake.snowpark.modin.plugin # noqa: F401
from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker
@@ -17,6 +18,46 @@
eval_snowpark_pandas_result,
)

timezones = pytest.mark.parametrize(
"tz",
[
None,
# Use a subset of pytz.common_timezones containing a few timezones in each
*[
param_for_one_tz
for tz in [
"Africa/Abidjan",
"Africa/Timbuktu",
"America/Adak",
"America/Yellowknife",
"Antarctica/Casey",
"Asia/Dhaka",
"Asia/Manila",
"Asia/Shanghai",
"Atlantic/Stanley",
"Australia/Sydney",
"Canada/Pacific",
"Europe/Chisinau",
"Europe/Luxembourg",
"Indian/Christmas",
"Pacific/Chatham",
"Pacific/Wake",
"US/Arizona",
"US/Central",
"US/Eastern",
"US/Hawaii",
"US/Mountain",
"US/Pacific",
"UTC",
]
for param_for_one_tz in (
pytz.timezone(tz),
tz,
)
],
],
)


@sql_count_checker(query_count=0)
def test_datetime_index_construction():
@@ -233,6 +274,76 @@ def test_normalize():
)


@sql_count_checker(query_count=1, join_count=1)
@timezones
def test_tz_convert(tz):
native_index = native_pd.date_range(
start="2021-01-01", periods=5, freq="7h", tz="US/Eastern"
)
native_index = native_index.append(
native_pd.DatetimeIndex([pd.NaT], tz="US/Eastern")
)
snow_index = pd.DatetimeIndex(native_index)

# Using eval_snowpark_pandas_result() was not possible because currently
# Snowpark pandas DatetimeIndex only mainains a timzeone-naive dtype
# even if the data contains a timezone.
assert snow_index.tz_convert(tz).equals(
pd.DatetimeIndex(native_index.tz_convert(tz))
)


@sql_count_checker(query_count=1, join_count=1)
@timezones
def test_tz_localize(tz):
native_index = native_pd.DatetimeIndex(
[
"2014-04-04 23:56:01.000000001",
"2014-07-18 21:24:02.000000002",
"2015-11-22 22:14:03.000000003",
"2015-11-23 20:12:04.1234567890",
pd.NaT,
],
)
snow_index = pd.DatetimeIndex(native_index)

# Using eval_snowpark_pandas_result() was not possible because currently
# Snowpark pandas DatetimeIndex only mainains a timzeone-naive dtype
# even if the data contains a timezone.
assert snow_index.tz_localize(tz).equals(
pd.DatetimeIndex(native_index.tz_localize(tz))
)


@pytest.mark.parametrize(
"ambiguous, nonexistent",
[
("infer", "raise"),
("NaT", "raise"),
(np.array([True, True, False]), "raise"),
("raise", "shift_forward"),
("raise", "shift_backward"),
("raise", "NaT"),
("raise", pd.Timedelta("1h")),
("infer", "shift_forward"),
],
)
@sql_count_checker(query_count=0)
def test_tz_localize_negative(ambiguous, nonexistent):
native_index = native_pd.DatetimeIndex(
[
"2014-04-04 23:56:01.000000001",
"2014-07-18 21:24:02.000000002",
"2015-11-22 22:14:03.000000003",
"2015-11-23 20:12:04.1234567890",
pd.NaT,
],
)
snow_index = pd.DatetimeIndex(native_index)
with pytest.raises(NotImplementedError):
snow_index.tz_localize(tz=None, ambiguous=ambiguous, nonexistent=nonexistent)


@pytest.mark.parametrize(
"datetime_index_value",
[

0 comments on commit 5b5c03b

Please sign in to comment.