Skip to content

Commit

Permalink
SNOW-1657460 Improved results for TIMESTAMP_TZ type to show correct t…
Browse files Browse the repository at this point in the history
…imezone offset. (#2290)

<!---
Please answer these questions before creating your pull request. Thanks!
--->

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   <!---
   In this section, please add a Snowflake Jira issue number.
   
Note that if a corresponding GitHub issue exists, you should still
include
   the Snowflake Jira issue number. For example, for GitHub issue
#1400, you should
   add "SNOW-1335071" here.
    --->

   Fixes SNOW-1657460

2. Fill out the following pre-review checklist:

- [ ] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.

3. Please describe how your code solves the related issue.

Please write a short description of how your code change solves the
related issue.

Previously we use "datetime64[ns]" as the dtype for datetime with
timezone which is missing the tz info. This change will extract the
timezone from timestamp_tz column and show the timezone offset when
calling dtypes.
  • Loading branch information
sfc-gh-azhan authored Sep 18, 2024
1 parent 6841488 commit 20d9738
Show file tree
Hide file tree
Showing 16 changed files with 211 additions and 77 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#### Improvements

- Improved `to_pandas` to persist the original timezone offset for TIMESTAMP_TZ type.
- Improved `dtype` results for TIMESTAMP_TZ type to show correct timezone offset.

#### New Features

Expand Down
8 changes: 3 additions & 5 deletions src/snowflake/snowpark/modin/pandas/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,9 @@
from pandas.api.types import is_bool, is_list_like
from pandas.core.dtypes.common import (
is_bool_dtype,
is_datetime64_any_dtype,
is_integer,
is_integer_dtype,
is_numeric_dtype,
is_timedelta64_dtype,
pandas_dtype,
)
from pandas.core.indexing import IndexingError
Expand Down Expand Up @@ -846,7 +844,7 @@ def _try_partial_string_indexing_for_string(
period = pd.Period(parsed, freq=reso.attr_abbrev)

# partial string indexing only works for DatetimeIndex
if is_datetime64_any_dtype(self.df._query_compiler.index_dtypes[0]):
if self.df._query_compiler.is_datetime64_any_dtype(idx=0, is_index=True):
return slice(
pd.Timestamp(period.start_time, tzinfo=tzinfo),
pd.Timestamp(period.end_time, tzinfo=tzinfo),
Expand Down Expand Up @@ -926,8 +924,8 @@ def __getitem__(
row_loc, col_loc = self._parse_get_row_and_column_locators(key)
row_loc = self._try_partial_string_indexing(row_loc)

# Check if self or its index is a TimedeltaIndex. `index_dtypes` retrieves the dtypes of the index columns.
if is_timedelta64_dtype(self.df._query_compiler.index_dtypes[0]):
# Check if self or its index is a TimedeltaIndex.
if self.df._query_compiler.is_timedelta64_dtype(idx=0, is_index=True):
# Convert row_loc to timedelta format to perform exact matching for TimedeltaIndex.
row_loc = self._convert_to_timedelta(row_loc)

Expand Down
16 changes: 16 additions & 0 deletions src/snowflake/snowpark/modin/plugin/_internal/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Callable, NamedTuple, Optional, Union

import pandas as native_pd
from pandas import DatetimeTZDtype
from pandas._typing import IndexLabel

from snowflake.snowpark._internal.analyzer.analyzer_utils import (
Expand All @@ -31,6 +32,9 @@
from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import (
SnowparkPandasType,
)
from snowflake.snowpark.modin.plugin._internal.type_utils import (
_get_timezone_from_timestamp_tz,
)
from snowflake.snowpark.modin.plugin._internal.utils import (
DEFAULT_DATA_COLUMN_LABEL,
INDEX_LABEL,
Expand Down Expand Up @@ -1473,5 +1477,17 @@ def is_quoted_identifier_normalized(
)
return self.rename_snowflake_identifiers(renamed_quoted_identifier_mapping)

def get_datetime64tz_from_timestamp_tz(
self, timestamp_tz_snowfalke_quoted_identifier: str
) -> DatetimeTZDtype:
"""
map a snowpark timestamp type to datetime64 type.
"""

return _get_timezone_from_timestamp_tz(
self.ordered_dataframe._dataframe_ref.snowpark_dataframe,
timestamp_tz_snowfalke_quoted_identifier,
)

# END: Internal Frame mutation APIs.
###########################################################################
24 changes: 23 additions & 1 deletion src/snowflake/snowpark/modin/plugin/_internal/type_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
from functools import reduce
from functools import lru_cache, reduce
from typing import Any, Callable, Union

import numpy as np
Expand Down Expand Up @@ -31,6 +31,7 @@

from snowflake.snowpark import Column
from snowflake.snowpark._internal.type_utils import infer_type, merge_type
from snowflake.snowpark.dataframe import DataFrame as SnowparkDataFrame
from snowflake.snowpark.functions import (
builtin,
cast,
Expand All @@ -39,6 +40,7 @@
floor,
iff,
length,
to_char,
to_varchar,
to_variant,
)
Expand Down Expand Up @@ -446,3 +448,23 @@ def is_compatible_snowpark_types(sp_type_1: DataType, sp_type_2: DataType) -> bo
if isinstance(sp_type_1, StringType) and isinstance(sp_type_2, StringType):
return True
return False


@lru_cache
def _get_timezone_from_timestamp_tz(
snowpark_dataframe: SnowparkDataFrame, snowflake_quoted_identifier: str
) -> Union[str, DatetimeTZDtype]:
tz_df = (
snowpark_dataframe.filter(col(snowflake_quoted_identifier).is_not_null())
.select(to_char(col(snowflake_quoted_identifier), format="TZHTZM").as_("tz"))
.group_by(["tz"])
.agg()
.limit(2) # only need 2 to check whether it contains multiple timezones
.to_pandas()
)
assert (
len(tz_df) > 0
), f"col {snowflake_quoted_identifier} does not contain valid timezone offset"
if len(tz_df) == 2: # multi timezone cases
return "object"
return DatetimeTZDtype(tz="UTC" + tz_df.iloc[0, 0].replace("Z", ""))
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@
PandasDataFrameType,
PandasSeriesType,
StringType,
TimestampTimeZone,
TimestampType,
TimeType,
VariantType,
Expand Down Expand Up @@ -477,6 +478,28 @@ def wrap(*args, **kwargs): # type: ignore

return wrap

def _get_dtypes(
self, snowflake_quoted_identifiers: List[str]
) -> List[Union[np.dtype, ExtensionDtype]]:
"""
Get dtypes for the input columns.

Args:
snowflake_quoted_identifiers: input column identifiers

Returns:
a list of the dtypes.
"""
type_map = self._modin_frame.quoted_identifier_to_snowflake_type(
snowflake_quoted_identifiers
)
return [
self._modin_frame.get_datetime64tz_from_timestamp_tz(i)
if t == TimestampType(TimestampTimeZone.TZ)
else TypeMapper.to_pandas(t)
for i, t in type_map.items()
]

@property
def dtypes(self) -> native_pd.Series:
"""
Expand All @@ -487,18 +510,11 @@ def dtypes(self) -> native_pd.Series:
pandas.Series
Series with dtypes of each column.
"""
types = [
TypeMapper.to_pandas(t)
for t in self._modin_frame.get_snowflake_type(
self._modin_frame.data_column_snowflake_quoted_identifiers
)
]

from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native

return native_pd.Series(
data=types,
index=try_convert_index_to_native(self._modin_frame.data_columns_index),
data=self._get_dtypes(
self._modin_frame.data_column_snowflake_quoted_identifiers
),
index=self._modin_frame.data_columns_index,
dtype=object,
)

Expand All @@ -512,12 +528,59 @@ def index_dtypes(self) -> list[Union[np.dtype, ExtensionDtype]]:
pandas.Series
Series with dtypes of each column.
"""
return [
TypeMapper.to_pandas(t)
for t in self._modin_frame.get_snowflake_type(
return self._get_dtypes(
self._modin_frame.index_column_snowflake_quoted_identifiers
)

def is_timestamp_type(self, idx: int, is_index: bool = True) -> bool:
"""Return True if column at the index is TIMESTAMP TYPE.

Args:
idx: the index of the column
is_index: whether it is an index or data column
"""
return isinstance(
self._modin_frame.get_snowflake_type(
self._modin_frame.index_column_snowflake_quoted_identifiers
)
]
if is_index
else self._modin_frame.data_column_snowflake_quoted_identifiers
)[idx],
TimestampType,
)

def is_datetime64_any_dtype(self, idx: int, is_index: bool = True) -> bool:
"""Helper method similar to is_datetime64_any_dtype, but it avoids extra query for DatetimeTZDtype.

Args:
idx: the index of the column
is_index: whether it is an index or data column
"""
return self.is_timestamp_type(idx, is_index)

def is_timedelta64_dtype(self, idx: int, is_index: bool = True) -> bool:
"""Helper method similar to is_timedelta_dtype, but it avoids extra query for DatetimeTZDtype.

Args:
idx: the index of the column
is_index: whether it is an index or data column
"""
id = (
self._modin_frame.index_column_snowflake_quoted_identifiers[idx]
if is_index
else self._modin_frame.data_column_snowflake_quoted_identifiers[idx]
)
return self._modin_frame.get_snowflake_type(id) == TimedeltaType()

def is_string_dtype(self, idx: int, is_index: bool = True) -> bool:
"""Helper method similar to is_timedelta_dtype, but it avoids extra query for DatetimeTZDtype.

Args:
idx: the index of the column
is_index: whether it is an index or data column
"""
return not self.is_timestamp_type(idx, is_index) and is_string_dtype(
self.index_dtypes[idx] if is_index else self.dtypes[idx]
)

@classmethod
def from_pandas(
Expand Down Expand Up @@ -9390,9 +9453,6 @@ def astype(
ErrorMessage.not_implemented(
f"Snowpark pandas astype API doesn't yet support errors == '{errors}'"
)
col_dtypes_curr = {
k: v for k, v in self.dtypes.to_dict().items() if k in col_dtypes_map
}

astype_mapping = {}
labels = list(col_dtypes_map.keys())
Expand All @@ -9411,17 +9471,18 @@ def astype(
for id in ids:
to_dtype = col_dtypes_map[label]
to_sf_type = TypeMapper.to_snowflake(to_dtype)
from_dtype = col_dtypes_curr[label]
from_sf_type = self._modin_frame.get_snowflake_type(id)
if isinstance(from_sf_type, StringType) and isinstance(
to_sf_type, TimedeltaType
):
# Raise NotImplementedError as there is no Snowflake SQL function converting
# string (e.g. 1 day, 3 hours, 2 minutes) to Timedelta
from_dtype = self.dtypes.to_dict()[label]
ErrorMessage.not_implemented(
f"dtype {pandas_dtype(from_dtype)} cannot be converted to {pandas_dtype(to_dtype)}"
)
elif is_astype_type_error(from_sf_type, to_sf_type):
from_dtype = self.dtypes.to_dict()[label]
raise TypeError(
f"dtype {pandas_dtype(from_dtype)} cannot be converted to {pandas_dtype(to_dtype)}"
)
Expand Down Expand Up @@ -11191,7 +11252,7 @@ def dt_property(
"""
if not include_index:
assert len(self.columns) == 1, "dt only works for series"
if not is_datetime64_any_dtype(self.dtypes[0]):
if not self.is_datetime64_any_dtype(idx=0, is_index=False):
raise AttributeError(
f"'TimedeltaProperties' object has no attribute '{property_name}'"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
TimeAmbiguous,
TimeNonexistent,
)
from pandas.core.dtypes.common import is_datetime64_any_dtype

from snowflake.snowpark.modin.pandas import to_datetime, to_timedelta
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
Expand Down Expand Up @@ -137,8 +136,7 @@ def __new__(
"""
if query_compiler:
# Raise error if underlying type is not a TimestampType.
current_dtype = query_compiler.index_dtypes[0]
if not current_dtype == np.dtype("datetime64[ns]"):
if not query_compiler.is_datetime64_any_dtype(idx=0, is_index=True):
raise ValueError(
"DatetimeIndex can only be created from a query compiler with TimestampType."
)
Expand All @@ -159,7 +157,7 @@ def __new__(
data, _CONSTRUCTOR_DEFAULTS, query_compiler, **kwargs
)
# Convert to datetime64 if not already.
if not is_datetime64_any_dtype(query_compiler.index_dtypes[0]):
if not query_compiler.is_datetime64_any_dtype(idx=0, is_index=True):
query_compiler = query_compiler.series_to_datetime(include_index=True)
index._query_compiler = query_compiler
# `_parent` keeps track of any Series or DataFrame that this Index is a part of.
Expand Down
6 changes: 2 additions & 4 deletions src/snowflake/snowpark/modin/plugin/extensions/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
is_integer_dtype,
is_numeric_dtype,
is_object_dtype,
is_timedelta64_dtype,
pandas_dtype,
)
from pandas.core.dtypes.inference import is_hashable
Expand Down Expand Up @@ -156,10 +155,9 @@ def __new__(
query_compiler = cls._init_query_compiler(
data, _CONSTRUCTOR_DEFAULTS, query_compiler, **kwargs
)
dtype = query_compiler.index_dtypes[0]
if is_datetime64_any_dtype(dtype):
if query_compiler.is_datetime64_any_dtype(idx=0, is_index=True):
return DatetimeIndex(query_compiler=query_compiler)
if is_timedelta64_dtype(dtype):
if query_compiler.is_timedelta64_dtype(idx=0, is_index=True):
return TimedeltaIndex(query_compiler=query_compiler)
index = object.__new__(cls)
# Initialize the Index
Expand Down
15 changes: 4 additions & 11 deletions src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@
Renamer,
Scalar,
)
from pandas.api.types import (
is_datetime64_any_dtype,
is_string_dtype,
is_timedelta64_dtype,
)
from pandas.core.common import apply_if_callable, is_bool_indexer
from pandas.core.dtypes.common import is_bool_dtype, is_dict_like, is_list_like
from pandas.util._validators import validate_ascending, validate_bool_kwarg
Expand Down Expand Up @@ -1183,10 +1178,9 @@ def dt(self): # noqa: RT01, D200
Accessor object for datetimelike properties of the Series values.
"""
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
current_dtype = self.dtype
if not is_datetime64_any_dtype(current_dtype) and not is_timedelta64_dtype(
current_dtype
):
if not self._query_compiler.is_datetime64_any_dtype(
idx=0, is_index=False
) and not self._query_compiler.is_timedelta64_dtype(idx=0, is_index=False):
raise AttributeError("Can only use .dt accessor with datetimelike values")

from modin.pandas.series_utils import DatetimeProperties
Expand All @@ -1203,8 +1197,7 @@ def _str(self): # noqa: RT01, D200
Vectorized string functions for Series and Index.
"""
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
current_dtype = self.dtype
if not is_string_dtype(current_dtype):
if not self._query_compiler.is_string_dtype(idx=0, is_index=False):
raise AttributeError("Can only use .str accessor with string values!")

from modin.pandas.series_utils import StringMethods
Expand Down
Loading

0 comments on commit 20d9738

Please sign in to comment.