Skip to content

Commit

Permalink
SNOW-1657460 Improved results for TIMESTAMP_TZ type to show correct t…
Browse files Browse the repository at this point in the history
…z offset
  • Loading branch information
sfc-gh-azhan committed Sep 13, 2024
1 parent 5b5c03b commit 4b901bb
Show file tree
Hide file tree
Showing 16 changed files with 198 additions and 66 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#### Improvements

- Improved `to_pandas` to persist the original timezone offset for TIMESTAMP_TZ type.
- Improved `dtype` results for TIMESTAMP_TZ type to show correct timezone offset.

#### New Features

Expand Down
6 changes: 2 additions & 4 deletions src/snowflake/snowpark/modin/pandas/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,9 @@
from pandas.api.types import is_bool, is_list_like
from pandas.core.dtypes.common import (
is_bool_dtype,
is_datetime64_any_dtype,
is_integer,
is_integer_dtype,
is_numeric_dtype,
is_timedelta64_dtype,
pandas_dtype,
)
from pandas.core.indexing import IndexingError
Expand Down Expand Up @@ -846,7 +844,7 @@ def _try_partial_string_indexing_for_string(
period = pd.Period(parsed, freq=reso.attr_abbrev)

# partial string indexing only works for DatetimeIndex
if is_datetime64_any_dtype(self.df._query_compiler.index_dtypes[0]):
if self.df._query_compiler.is_datetime64_any_dtype(idx=0, is_index=True):
return slice(
pd.Timestamp(period.start_time, tzinfo=tzinfo),
pd.Timestamp(period.end_time, tzinfo=tzinfo),
Expand Down Expand Up @@ -927,7 +925,7 @@ def __getitem__(
row_loc = self._try_partial_string_indexing(row_loc)

# Check if self or its index is a TimedeltaIndex. `index_dtypes` retrieves the dtypes of the index columns.
if is_timedelta64_dtype(self.df._query_compiler.index_dtypes[0]):
if self.df._query_compiler.is_timedelta64_dtype(idx=0, is_index=True):
# Convert row_loc to timedelta format to perform exact matching for TimedeltaIndex.
row_loc = self._convert_to_timedelta(row_loc)

Expand Down
16 changes: 16 additions & 0 deletions src/snowflake/snowpark/modin/plugin/_internal/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Callable, NamedTuple, Optional, Union

import pandas as native_pd
from pandas import DatetimeTZDtype
from pandas._typing import IndexLabel

from snowflake.snowpark._internal.analyzer.analyzer_utils import (
Expand All @@ -31,6 +32,9 @@
from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import (
SnowparkPandasType,
)
from snowflake.snowpark.modin.plugin._internal.type_utils import (
_get_timezone_from_timestamp_tz,
)
from snowflake.snowpark.modin.plugin._internal.utils import (
DEFAULT_DATA_COLUMN_LABEL,
INDEX_LABEL,
Expand Down Expand Up @@ -1473,5 +1477,17 @@ def is_quoted_identifier_normalized(
)
return self.rename_snowflake_identifiers(renamed_quoted_identifier_mapping)

def get_datetime64tz_from_timestamp_tz(
self, timestamp_tz_snowfalke_quoted_identifier: str
) -> DatetimeTZDtype:
"""
map a snowpark timestamp type to datetime64 type.
"""

return _get_timezone_from_timestamp_tz(
self.ordered_dataframe._dataframe_ref.snowpark_dataframe,
timestamp_tz_snowfalke_quoted_identifier,
)

# END: Internal Frame mutation APIs.
###########################################################################
24 changes: 23 additions & 1 deletion src/snowflake/snowpark/modin/plugin/_internal/type_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#
from functools import reduce
from functools import lru_cache, reduce
from typing import Any, Callable, Union

import numpy as np
Expand Down Expand Up @@ -31,6 +31,7 @@

from snowflake.snowpark import Column
from snowflake.snowpark._internal.type_utils import infer_type, merge_type
from snowflake.snowpark.dataframe import DataFrame as SnowparkDataFrame
from snowflake.snowpark.functions import (
builtin,
cast,
Expand All @@ -39,6 +40,7 @@
floor,
iff,
length,
to_char,
to_varchar,
to_variant,
)
Expand Down Expand Up @@ -446,3 +448,23 @@ def is_compatible_snowpark_types(sp_type_1: DataType, sp_type_2: DataType) -> bo
if isinstance(sp_type_1, StringType) and isinstance(sp_type_2, StringType):
return True
return False


@lru_cache
def _get_timezone_from_timestamp_tz(
snowpark_dataframe: SnowparkDataFrame, snowflake_quoted_identifier: str
) -> Union[str, DatetimeTZDtype]:
tz_df = (
snowpark_dataframe.filter(col(snowflake_quoted_identifier).is_not_null())
.select(to_char(col(snowflake_quoted_identifier), format="TZHTZM").as_("tz"))
.group_by(["tz"])
.agg()
.limit(2)
.to_pandas()
)
assert (
len(tz_df) > 0
), f"col {snowflake_quoted_identifier} does not contain valid timezone offset"
if len(tz_df) == 2: # multi timezone cases
return "object"
return DatetimeTZDtype(tz="UTC" + tz_df.iloc[0, 0].replace("Z", ""))
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@
PandasDataFrameType,
PandasSeriesType,
StringType,
TimestampTimeZone,
TimestampType,
TimeType,
VariantType,
Expand Down Expand Up @@ -487,18 +488,19 @@ def dtypes(self) -> native_pd.Series:
pandas.Series
Series with dtypes of each column.
"""
type_map = self._modin_frame.quoted_identifier_to_snowflake_type(
self._modin_frame.data_column_snowflake_quoted_identifiers
)
types = [
TypeMapper.to_pandas(t)
for t in self._modin_frame.get_snowflake_type(
self._modin_frame.data_column_snowflake_quoted_identifiers
)
self._modin_frame.get_datetime64tz_from_timestamp_tz(i)
if t == TimestampType(TimestampTimeZone.TZ)
else TypeMapper.to_pandas(t)
for i, t in type_map.items()
]

from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native

return native_pd.Series(
data=types,
index=try_convert_index_to_native(self._modin_frame.data_columns_index),
index=self._modin_frame.data_columns_index,
dtype=object,
)

Expand All @@ -512,13 +514,65 @@ def index_dtypes(self) -> list[Union[np.dtype, ExtensionDtype]]:
pandas.Series
Series with dtypes of each column.
"""
type_map = self._modin_frame.quoted_identifier_to_snowflake_type(
self._modin_frame.index_column_snowflake_quoted_identifiers
)
return [
TypeMapper.to_pandas(t)
for t in self._modin_frame.get_snowflake_type(
self._modin_frame.index_column_snowflake_quoted_identifiers
)
self._modin_frame.get_datetime64tz_from_timestamp_tz(i)
if t == TimestampType(TimestampTimeZone.TZ)
else TypeMapper.to_pandas(t)
for i, t in type_map.items()
]

def is_timestamp_type(self, idx: int, is_index: bool = True) -> bool:
"""Return True if index is TIMESTAMP TYPE.

Args:
idx: the index of the column
is_index: whether it is an index or data column
"""
return isinstance(
self._modin_frame.get_snowflake_type(
self._modin_frame.index_column_snowflake_quoted_identifiers
if is_index
else self._modin_frame.data_column_snowflake_quoted_identifiers
)[idx],
TimestampType,
)

def is_datetime64_any_dtype(self, idx: int, is_index: bool = True) -> bool:
"""Helper method similar to is_datetime64_any_dtype, but it avoids extra query for DatetimeTZDtype.

Args:
idx: the index of the column
is_index: whether it is an index or data column
"""
return self.is_timestamp_type(idx, is_index) or is_datetime64_any_dtype(
self.index_dtypes[idx] if is_index else self.dtypes[idx]
)

def is_timedelta64_dtype(self, idx: int, is_index: bool = True) -> bool:
"""Helper method similar to is_timedelta_dtype, but it avoids extra query for DatetimeTZDtype.

Args:
idx: the index of the column
is_index: whether it is an index or data column
"""
return not self.is_timestamp_type(idx, is_index) and is_timedelta64_dtype(
self.index_dtypes[idx] if is_index else self.dtypes[idx]
)

def is_string_dtype(self, idx: int, is_index: bool = True) -> bool:
"""Helper method similar to is_timedelta_dtype, but it avoids extra query for DatetimeTZDtype.

Args:
idx: the index of the column
is_index: whether it is an index or data column
"""
return not self.is_timestamp_type(idx, is_index) and is_string_dtype(
self.index_dtypes[idx] if is_index else self.dtypes[idx]
)

@classmethod
def from_pandas(
cls, df: native_pd.DataFrame, *args: Any, **kwargs: Any
Expand Down Expand Up @@ -11163,7 +11217,7 @@ def dt_property(
"""
if not include_index:
assert len(self.columns) == 1, "dt only works for series"
if not is_datetime64_any_dtype(self.dtypes[0]):
if not self.is_datetime64_any_dtype(idx=0, is_index=False):
raise AttributeError(
f"'TimedeltaProperties' object has no attribute '{property_name}'"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
TimeAmbiguous,
TimeNonexistent,
)
from pandas.core.dtypes.common import is_datetime64_any_dtype

from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
SnowflakeQueryCompiler,
Expand Down Expand Up @@ -136,8 +135,7 @@ def __new__(
"""
if query_compiler:
# Raise error if underlying type is not a TimestampType.
current_dtype = query_compiler.index_dtypes[0]
if not current_dtype == np.dtype("datetime64[ns]"):
if not query_compiler.is_datetime64_any_dtype(idx=0, is_index=True):
raise ValueError(
"DatetimeIndex can only be created from a query compiler with TimestampType."
)
Expand All @@ -158,7 +156,7 @@ def __new__(
data, _CONSTRUCTOR_DEFAULTS, query_compiler, **kwargs
)
# Convert to datetime64 if not already.
if not is_datetime64_any_dtype(query_compiler.index_dtypes[0]):
if not query_compiler.is_datetime64_any_dtype(idx=0, is_index=True):
query_compiler = query_compiler.series_to_datetime(include_index=True)
index._query_compiler = query_compiler
# `_parent` keeps track of any Series or DataFrame that this Index is a part of.
Expand Down
6 changes: 2 additions & 4 deletions src/snowflake/snowpark/modin/plugin/extensions/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
is_integer_dtype,
is_numeric_dtype,
is_object_dtype,
is_timedelta64_dtype,
pandas_dtype,
)
from pandas.core.dtypes.inference import is_hashable
Expand Down Expand Up @@ -127,10 +126,9 @@ def __new__(
query_compiler = cls._init_query_compiler(
data, _CONSTRUCTOR_DEFAULTS, query_compiler, **kwargs
)
dtype = query_compiler.index_dtypes[0]
if is_datetime64_any_dtype(dtype):
if query_compiler.is_datetime64_any_dtype(idx=0, is_index=True):
return DatetimeIndex(query_compiler=query_compiler)
if is_timedelta64_dtype(dtype):
if query_compiler.is_timedelta64_dtype(idx=0, is_index=True):
return TimedeltaIndex(query_compiler=query_compiler)
index = object.__new__(cls)
# Initialize the Index
Expand Down
15 changes: 4 additions & 11 deletions src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@
Renamer,
Scalar,
)
from pandas.api.types import (
is_datetime64_any_dtype,
is_string_dtype,
is_timedelta64_dtype,
)
from pandas.core.common import apply_if_callable, is_bool_indexer
from pandas.core.dtypes.common import is_bool_dtype, is_dict_like, is_list_like
from pandas.util._validators import validate_bool_kwarg
Expand Down Expand Up @@ -1190,10 +1185,9 @@ def dt(self): # noqa: RT01, D200
Accessor object for datetimelike properties of the Series values.
"""
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
current_dtype = self.dtype
if not is_datetime64_any_dtype(current_dtype) and not is_timedelta64_dtype(
current_dtype
):
if not self._query_compiler.is_datetime64_any_dtype(
idx=0, is_index=False
) and not self._query_compiler.is_timedelta64_dtype(idx=0, is_index=False):
raise AttributeError("Can only use .dt accessor with datetimelike values")

from modin.pandas.series_utils import DatetimeProperties
Expand All @@ -1210,8 +1204,7 @@ def _str(self): # noqa: RT01, D200
Vectorized string functions for Series and Index.
"""
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
current_dtype = self.dtype
if not is_string_dtype(current_dtype):
if not self._query_compiler.is_string_dtype(idx=0, is_index=False):
raise AttributeError("Can only use .str accessor with string values!")

from modin.pandas.series_utils import StringMethods
Expand Down
Loading

0 comments on commit 4b901bb

Please sign in to comment.