diff --git a/CHANGELOG.md b/CHANGELOG.md index e0589d4a358..d436e0d596c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ #### Improvements - Improved `to_pandas` to persist the original timezone offset for TIMESTAMP_TZ type. +- Improved `dtype` results for TIMESTAMP_TZ type to show correct timezone offset. #### New Features diff --git a/src/snowflake/snowpark/modin/pandas/indexing.py b/src/snowflake/snowpark/modin/pandas/indexing.py index c672f04da63..c69d349636a 100644 --- a/src/snowflake/snowpark/modin/pandas/indexing.py +++ b/src/snowflake/snowpark/modin/pandas/indexing.py @@ -50,11 +50,9 @@ from pandas.api.types import is_bool, is_list_like from pandas.core.dtypes.common import ( is_bool_dtype, - is_datetime64_any_dtype, is_integer, is_integer_dtype, is_numeric_dtype, - is_timedelta64_dtype, pandas_dtype, ) from pandas.core.indexing import IndexingError @@ -846,7 +844,7 @@ def _try_partial_string_indexing_for_string( period = pd.Period(parsed, freq=reso.attr_abbrev) # partial string indexing only works for DatetimeIndex - if is_datetime64_any_dtype(self.df._query_compiler.index_dtypes[0]): + if self.df._query_compiler.is_datetime64_any_dtype(idx=0, is_index=True): return slice( pd.Timestamp(period.start_time, tzinfo=tzinfo), pd.Timestamp(period.end_time, tzinfo=tzinfo), @@ -927,7 +925,7 @@ def __getitem__( row_loc = self._try_partial_string_indexing(row_loc) # Check if self or its index is a TimedeltaIndex. `index_dtypes` retrieves the dtypes of the index columns. - if is_timedelta64_dtype(self.df._query_compiler.index_dtypes[0]): + if self.df._query_compiler.is_timedelta64_dtype(idx=0, is_index=True): # Convert row_loc to timedelta format to perform exact matching for TimedeltaIndex. row_loc = self._convert_to_timedelta(row_loc) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/frame.py b/src/snowflake/snowpark/modin/plugin/_internal/frame.py index 25ca2fb8d23..a80f9c65687 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/frame.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/frame.py @@ -9,6 +9,7 @@ from typing import Any, Callable, NamedTuple, Optional, Union import pandas as native_pd +from pandas import DatetimeTZDtype from pandas._typing import IndexLabel from snowflake.snowpark._internal.analyzer.analyzer_utils import ( @@ -31,6 +32,9 @@ from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import ( SnowparkPandasType, ) +from snowflake.snowpark.modin.plugin._internal.type_utils import ( + _get_timezone_from_timestamp_tz, +) from snowflake.snowpark.modin.plugin._internal.utils import ( DEFAULT_DATA_COLUMN_LABEL, INDEX_LABEL, @@ -1473,5 +1477,17 @@ def is_quoted_identifier_normalized( ) return self.rename_snowflake_identifiers(renamed_quoted_identifier_mapping) + def get_datetime64tz_from_timestamp_tz( + self, timestamp_tz_snowfalke_quoted_identifier: str + ) -> DatetimeTZDtype: + """ + map a snowpark timestamp type to datetime64 type. + """ + + return _get_timezone_from_timestamp_tz( + self.ordered_dataframe._dataframe_ref.snowpark_dataframe, + timestamp_tz_snowfalke_quoted_identifier, + ) + # END: Internal Frame mutation APIs. ########################################################################### diff --git a/src/snowflake/snowpark/modin/plugin/_internal/type_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/type_utils.py index 67ddfe3abde..ee4dbe78f1d 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/type_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/type_utils.py @@ -1,7 +1,7 @@ # # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from functools import reduce +from functools import lru_cache, reduce from typing import Any, Callable, Union import numpy as np @@ -31,6 +31,7 @@ from snowflake.snowpark import Column from snowflake.snowpark._internal.type_utils import infer_type, merge_type +from snowflake.snowpark.dataframe import DataFrame as SnowparkDataFrame from snowflake.snowpark.functions import ( builtin, cast, @@ -39,6 +40,7 @@ floor, iff, length, + to_char, to_varchar, to_variant, ) @@ -446,3 +448,23 @@ def is_compatible_snowpark_types(sp_type_1: DataType, sp_type_2: DataType) -> bo if isinstance(sp_type_1, StringType) and isinstance(sp_type_2, StringType): return True return False + + +@lru_cache +def _get_timezone_from_timestamp_tz( + snowpark_dataframe: SnowparkDataFrame, snowflake_quoted_identifier: str +) -> Union[str, DatetimeTZDtype]: + tz_df = ( + snowpark_dataframe.filter(col(snowflake_quoted_identifier).is_not_null()) + .select(to_char(col(snowflake_quoted_identifier), format="TZHTZM").as_("tz")) + .group_by(["tz"]) + .agg() + .limit(2) + .to_pandas() + ) + assert ( + len(tz_df) > 0 + ), f"col {snowflake_quoted_identifier} does not contain valid timezone offset" + if len(tz_df) == 2: # multi timezone cases + return "object" + return DatetimeTZDtype(tz="UTC" + tz_df.iloc[0, 0].replace("Z", "")) diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 2f6ff69be6c..b630978445c 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -365,6 +365,7 @@ PandasDataFrameType, PandasSeriesType, StringType, + TimestampTimeZone, TimestampType, TimeType, VariantType, @@ -487,18 +488,19 @@ def dtypes(self) -> native_pd.Series: pandas.Series Series with dtypes of each column. """ + type_map = self._modin_frame.quoted_identifier_to_snowflake_type( + self._modin_frame.data_column_snowflake_quoted_identifiers + ) types = [ - TypeMapper.to_pandas(t) - for t in self._modin_frame.get_snowflake_type( - self._modin_frame.data_column_snowflake_quoted_identifiers - ) + self._modin_frame.get_datetime64tz_from_timestamp_tz(i) + if t == TimestampType(TimestampTimeZone.TZ) + else TypeMapper.to_pandas(t) + for i, t in type_map.items() ] - from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native - return native_pd.Series( data=types, - index=try_convert_index_to_native(self._modin_frame.data_columns_index), + index=self._modin_frame.data_columns_index, dtype=object, ) @@ -512,13 +514,65 @@ def index_dtypes(self) -> list[Union[np.dtype, ExtensionDtype]]: pandas.Series Series with dtypes of each column. """ + type_map = self._modin_frame.quoted_identifier_to_snowflake_type( + self._modin_frame.index_column_snowflake_quoted_identifiers + ) return [ - TypeMapper.to_pandas(t) - for t in self._modin_frame.get_snowflake_type( - self._modin_frame.index_column_snowflake_quoted_identifiers - ) + self._modin_frame.get_datetime64tz_from_timestamp_tz(i) + if t == TimestampType(TimestampTimeZone.TZ) + else TypeMapper.to_pandas(t) + for i, t in type_map.items() ] + def is_timestamp_type(self, idx: int, is_index: bool = True) -> bool: + """Return True if index is TIMESTAMP TYPE. + + Args: + idx: the index of the column + is_index: whether it is an index or data column + """ + return isinstance( + self._modin_frame.get_snowflake_type( + self._modin_frame.index_column_snowflake_quoted_identifiers + if is_index + else self._modin_frame.data_column_snowflake_quoted_identifiers + )[idx], + TimestampType, + ) + + def is_datetime64_any_dtype(self, idx: int, is_index: bool = True) -> bool: + """Helper method similar to is_datetime64_any_dtype, but it avoids extra query for DatetimeTZDtype. + + Args: + idx: the index of the column + is_index: whether it is an index or data column + """ + return self.is_timestamp_type(idx, is_index) or is_datetime64_any_dtype( + self.index_dtypes[idx] if is_index else self.dtypes[idx] + ) + + def is_timedelta64_dtype(self, idx: int, is_index: bool = True) -> bool: + """Helper method similar to is_timedelta_dtype, but it avoids extra query for DatetimeTZDtype. + + Args: + idx: the index of the column + is_index: whether it is an index or data column + """ + return not self.is_timestamp_type(idx, is_index) and is_timedelta64_dtype( + self.index_dtypes[idx] if is_index else self.dtypes[idx] + ) + + def is_string_dtype(self, idx: int, is_index: bool = True) -> bool: + """Helper method similar to is_timedelta_dtype, but it avoids extra query for DatetimeTZDtype + + Args: + idx: the index of the column + is_index: whether it is an index or data column + """ + return not self.is_timestamp_type(idx, is_index) and is_string_dtype( + self.index_dtypes[idx] if is_index else self.dtypes[idx] + ) + @classmethod def from_pandas( cls, df: native_pd.DataFrame, *args: Any, **kwargs: Any @@ -11154,7 +11208,7 @@ def dt_property( """ if not include_index: assert len(self.columns) == 1, "dt only works for series" - if not is_datetime64_any_dtype(self.dtypes[0]): + if not self.is_datetime64_any_dtype(idx=0, is_index=False): raise AttributeError( f"'TimedeltaProperties' object has no attribute '{property_name}'" ) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py index df136af1a34..25c0316539c 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py @@ -41,7 +41,6 @@ TimeAmbiguous, TimeNonexistent, ) -from pandas.core.dtypes.common import is_datetime64_any_dtype from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( SnowflakeQueryCompiler, @@ -136,8 +135,7 @@ def __new__( """ if query_compiler: # Raise error if underlying type is not a TimestampType. - current_dtype = query_compiler.index_dtypes[0] - if not current_dtype == np.dtype("datetime64[ns]"): + if not query_compiler.is_datetime64_any_dtype(idx=0, is_index=True): raise ValueError( "DatetimeIndex can only be created from a query compiler with TimestampType." ) @@ -158,7 +156,7 @@ def __new__( data, _CONSTRUCTOR_DEFAULTS, query_compiler, **kwargs ) # Convert to datetime64 if not already. - if not is_datetime64_any_dtype(query_compiler.index_dtypes[0]): + if not query_compiler.is_datetime64_any_dtype(idx=0, is_index=True): query_compiler = query_compiler.series_to_datetime(include_index=True) index._query_compiler = query_compiler # `_parent` keeps track of any Series or DataFrame that this Index is a part of. diff --git a/src/snowflake/snowpark/modin/plugin/extensions/index.py b/src/snowflake/snowpark/modin/plugin/extensions/index.py index 12710224de7..98d54595e27 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/index.py @@ -44,7 +44,6 @@ is_integer_dtype, is_numeric_dtype, is_object_dtype, - is_timedelta64_dtype, pandas_dtype, ) from pandas.core.dtypes.inference import is_hashable @@ -127,10 +126,9 @@ def __new__( query_compiler = cls._init_query_compiler( data, _CONSTRUCTOR_DEFAULTS, query_compiler, **kwargs ) - dtype = query_compiler.index_dtypes[0] - if is_datetime64_any_dtype(dtype): + if query_compiler.is_datetime64_any_dtype(idx=0, is_index=True): return DatetimeIndex(query_compiler=query_compiler) - if is_timedelta64_dtype(dtype): + if query_compiler.is_timedelta64_dtype(idx=0, is_index=True): return TimedeltaIndex(query_compiler=query_compiler) index = object.__new__(cls) # Initialize the Index diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py index 5011defa685..cf5328c3b45 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py @@ -39,11 +39,6 @@ Renamer, Scalar, ) -from pandas.api.types import ( - is_datetime64_any_dtype, - is_string_dtype, - is_timedelta64_dtype, -) from pandas.core.common import apply_if_callable, is_bool_indexer from pandas.core.dtypes.common import is_bool_dtype, is_dict_like, is_list_like from pandas.util._validators import validate_bool_kwarg @@ -1190,10 +1185,9 @@ def dt(self): # noqa: RT01, D200 Accessor object for datetimelike properties of the Series values. """ # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions - current_dtype = self.dtype - if not is_datetime64_any_dtype(current_dtype) and not is_timedelta64_dtype( - current_dtype - ): + if not self._query_compiler.is_datetime64_any_dtype( + idx=0, is_index=False + ) and not self._query_compiler.is_timedelta64_dtype(idx=0, is_index=False): raise AttributeError("Can only use .dt accessor with datetimelike values") from modin.pandas.series_utils import DatetimeProperties @@ -1210,8 +1204,7 @@ def _str(self): # noqa: RT01, D200 Vectorized string functions for Series and Index. """ # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions - current_dtype = self.dtype - if not is_string_dtype(current_dtype): + if not self._query_compiler.is_string_dtype(idx=0, is_index=False): raise AttributeError("Can only use .str accessor with string values!") from modin.pandas.series_utils import StringMethods diff --git a/tests/integ/modin/frame/test_dtypes.py b/tests/integ/modin/frame/test_dtypes.py index b078b31f6c5..56bb30d0a15 100644 --- a/tests/integ/modin/frame/test_dtypes.py +++ b/tests/integ/modin/frame/test_dtypes.py @@ -6,7 +6,7 @@ import numpy as np import pandas as native_pd import pytest -from pandas.core.dtypes.common import is_integer_dtype +from pandas.core.dtypes.common import is_datetime64_any_dtype, is_integer_dtype import snowflake.snowpark.modin.plugin # noqa: F401 from snowflake.snowpark.types import ( @@ -18,7 +18,7 @@ StringType, VariantType, ) -from tests.integ.modin.sql_counter import sql_count_checker +from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker from tests.integ.modin.utils import ( assert_frame_equal, assert_series_equal, @@ -352,7 +352,7 @@ def test_insert_multiindex_multi_label(label1, label2): ], "datetime64[ns, America/Los_Angeles]", "datetime64[ns, UTC-08:00]", - "datetime64[ns]", + "datetime64[ns, UTC-08:00]", ), ( [ @@ -373,21 +373,27 @@ def test_insert_multiindex_multi_label(label1, label2): ], "object", "datetime64[ns, UTC-08:00]", - "datetime64[ns]", + "datetime64[ns, UTC-08:00]", ), ], ) -@sql_count_checker(query_count=1) def test_time(dataframe_input, input_dtype, expected_dtype, logical_dtype): expected = native_pd.Series(dataframe_input, dtype=expected_dtype) - created = pd.Series(dataframe_input, dtype=input_dtype) - # For snowpark pandas type mapping - assert created.dtype == logical_dtype - roundtripped = created.to_pandas() - assert_series_equal( - roundtripped, expected, check_dtype=False, check_index_type=False + qc = ( + 2 + if is_datetime64_any_dtype(expected.dtype) + and getattr(expected.dtype, "tz", None) is not None + else 1 ) - assert roundtripped.dtype == expected.dtype + with SqlCounter(query_count=qc): + created = pd.Series(dataframe_input, dtype=input_dtype) + # For snowpark pandas type mapping + assert created.dtype == logical_dtype + roundtripped = created.to_pandas() + assert_series_equal( + roundtripped, expected, check_dtype=False, check_index_type=False + ) + assert roundtripped.dtype == expected.dtype @pytest.mark.parametrize( @@ -528,3 +534,33 @@ def test_str_float_type_with_nan( assert native_se.dtype == to_pandas_dtype expected = native_pd.Series(input_data, dtype=to_pandas_dtype) assert_series_equal(native_se, expected, check_index_type=False) + + +@pytest.mark.parametrize( + "ts_data", + [ + native_pd.date_range("2020-01-01", periods=10), + native_pd.date_range("2020-01-01", periods=10, tz="US/Pacific"), + native_pd.date_range("2020-01-01", periods=10, tz="UTC"), + native_pd.date_range("2020-01-01", periods=10, tz="Asia/Tokyo"), + native_pd.date_range("2020-01-01", periods=10, tz="UTC+1000"), + native_pd.date_range("2020-01-01", periods=10, tz="UTC+1000").append( + native_pd.date_range("2020-01-01", periods=10, tz="UTC") + ), + ], +) +def test_tz_dtype(ts_data): + with SqlCounter( + query_count=1 + if is_datetime64_any_dtype(ts_data.dtype) and ts_data.tz is None + else 2 + ): + s = pd.Series(ts_data) + assert s.dtype == s.to_pandas().dtype + + +@sql_count_checker(query_count=1) +def test_tz_dtype_cache(): + s = pd.Series(native_pd.date_range("2020-10-01", periods=5, tz="UTC")) + for _ in range(50): + assert s.dtype == "datetime64[ns, UTC]" diff --git a/tests/integ/modin/index/test_index_methods.py b/tests/integ/modin/index/test_index_methods.py index 8d0434915ac..bbef57eca43 100644 --- a/tests/integ/modin/index/test_index_methods.py +++ b/tests/integ/modin/index/test_index_methods.py @@ -347,14 +347,12 @@ def test_df_index_to_frame(native_df, index, name): ) -@sql_count_checker(query_count=0) @pytest.mark.parametrize("native_index", NATIVE_INDEX_TEST_DATA) def test_index_dtype(native_index): - snow_index = pd.Index(native_index) - if isinstance(native_index, native_pd.DatetimeIndex): - # Snowpark pandas does not include timezone info in dtype datetime64[ns], - assert snow_index.dtype == "datetime64[ns]" - else: + with SqlCounter( + query_count=1 if getattr(native_index.dtype, "tz", None) is not None else 0 + ): + snow_index = pd.Index(native_index) assert snow_index.dtype == native_index.dtype diff --git a/tests/integ/modin/series/test_astype.py b/tests/integ/modin/series/test_astype.py index 030416d65c5..26a29f4ed85 100644 --- a/tests/integ/modin/series/test_astype.py +++ b/tests/integ/modin/series/test_astype.py @@ -10,6 +10,7 @@ import numpy as np import pandas as native_pd import pytest +from pandas import DatetimeTZDtype from pandas.core.arrays.boolean import BooleanDtype from pandas.core.arrays.floating import Float32Dtype, Float64Dtype from pandas.core.arrays.integer import ( @@ -218,11 +219,11 @@ def test_astype_to_DatetimeTZDtype(from_dtype, to_tz): # use dypte=str instead of StringDType native_pd.Series(seed, dtype=str).astype(to_dtype) else: - with SqlCounter(query_count=1): + with SqlCounter(query_count=2): s = pd.Series(seed, dtype=from_dtype).astype(to_dtype) # Snowflake timestamp_tz column's metadata does not contain the tzinfo so it cannot provide dtype as # datetime64[ns, UTC], so its dtype returns datetime64 or 1 else 4 + expected_query_count = ( + 9 + if isinstance(samples, list) and len(samples) > 1 + else 5 + if "timestamp_tz" in col_name_type + else 4 + ) with SqlCounter(query_count=expected_query_count): Utils.create_table(session, test_table_name, col_name_type, is_temporary=True) if not isinstance(samples, list): diff --git a/tests/integ/modin/test_from_pandas_to_pandas.py b/tests/integ/modin/test_from_pandas_to_pandas.py index ceef588410d..c5edfffe754 100644 --- a/tests/integ/modin/test_from_pandas_to_pandas.py +++ b/tests/integ/modin/test_from_pandas_to_pandas.py @@ -308,7 +308,7 @@ def test_from_to_pandas_datetime64_support(): ) -@sql_count_checker(query_count=3) +@sql_count_checker(query_count=4) def test_rw_datetimeindex(): test_datetime_index = native_pd.DatetimeIndex( ["2017-12-31 16:00:00", "2017-12-31 17:00:00", "2017-12-31 18:00:00"], @@ -327,7 +327,9 @@ def test_rw_datetimeindex(): df = pd.DataFrame({"ntz": test_datetime_index, "tz": test_datetime_index_tz}) assert_series_equal( df.dtypes, - native_pd.Series(["datetime64[ns]", "datetime64[ns]"], index=["ntz", "tz"]), + native_pd.Series( + ["datetime64[ns]", "datetime64[ns, UTC-08:00]"], index=["ntz", "tz"] + ), ) assert_series_equal( df.to_pandas().dtypes, diff --git a/tests/unit/modin/conftest.py b/tests/unit/modin/conftest.py index 042a0ec0d46..4ed9ddf242f 100644 --- a/tests/unit/modin/conftest.py +++ b/tests/unit/modin/conftest.py @@ -25,7 +25,9 @@ def mock_single_col_query_compiler() -> SnowflakeQueryCompiler: '"A"': None } mock_internal_frame.get_snowflake_type.return_value = [StringType()] - + mock_internal_frame.quoted_identifier_to_snowflake_type.return_value = { + '"A"': StringType() + } fake_query_compiler = SnowflakeQueryCompiler(mock_internal_frame) return fake_query_compiler diff --git a/tests/unit/modin/test_series_dt.py b/tests/unit/modin/test_series_dt.py index 0b5572f0592..d3265c5554a 100644 --- a/tests/unit/modin/test_series_dt.py +++ b/tests/unit/modin/test_series_dt.py @@ -22,6 +22,9 @@ def mock_query_compiler_for_dt_series() -> SnowflakeQueryCompiler: mock_internal_frame.data_columns_index = native_pd.Index(["A"], name="B") mock_internal_frame.data_column_snowflake_quoted_identifiers = ['"A"'] mock_internal_frame.get_snowflake_type.return_value = [TimestampType()] + mock_internal_frame.quoted_identifier_to_snowflake_type.return_value = { + '"A"': TimestampType() + } fake_query_compiler = SnowflakeQueryCompiler(mock_internal_frame) return fake_query_compiler