From d13f06055215e6ad0c52f41cb046f32e8e96e213 Mon Sep 17 00:00:00 2001 From: Naresh Kumar Date: Wed, 28 Aug 2024 22:48:41 -0700 Subject: [PATCH] SNOW-1637945: Add support for TimedeltaIndex attributes --- CHANGELOG.md | 1 + .../supported/timedelta_index_supported.rst | 8 +-- .../modin/plugin/_internal/timestamp_utils.py | 4 +- .../compiler/snowflake_query_compiler.py | 45 ++++++++++++++ .../plugin/extensions/timedelta_index.py | 60 ++++++++++++------- .../index/test_timedelta_index_methods.py | 17 +++++- 6 files changed, 103 insertions(+), 32 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e7a827cb1e..d0d2ccd7589 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -63,6 +63,7 @@ - support for lazy `TimedeltaIndex`. - support for `pd.to_timedelta`. - support for `GroupBy` aggregations `min`, `max`, `mean`, `idxmax`, `idxmin`, `std`, `sum`, `median`, `count`, `any`, `all`, `size`, `nunique`. + - support for `TimedeltaIndex` attributes: `days`, `seconds`, `microseconds` and `nanoseconds`. - Added support for index's arithmetic and comparison operators. - Added support for `Series.dt.round`. - Added documentation pages for `DatetimeIndex`. diff --git a/docs/source/modin/supported/timedelta_index_supported.rst b/docs/source/modin/supported/timedelta_index_supported.rst index 73abe530fd7..cd5e64b8c98 100644 --- a/docs/source/modin/supported/timedelta_index_supported.rst +++ b/docs/source/modin/supported/timedelta_index_supported.rst @@ -15,13 +15,13 @@ Attributes +-----------------------------+---------------------------------+----------------------------------------------------+ | TimedeltaIndex attribute | Snowpark implemented? (Y/N/P/D) | Notes for current implementation | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``days`` | N | | +| ``days`` | Y | | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``seconds`` | N | | +| ``seconds`` | Y | | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``microseconds`` | N | | +| ``microseconds`` | Y | | +-----------------------------+---------------------------------+----------------------------------------------------+ -| ``nanoseconds`` | N | | +| ``nanoseconds`` | Y | | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``components`` | N | | +-----------------------------+---------------------------------+----------------------------------------------------+ diff --git a/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py index 4860baf4acb..8c53f88049e 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/timestamp_utils.py @@ -21,9 +21,9 @@ cast, convert_timezone, date_part, - floor, iff, to_decimal, + trunc, ) from snowflake.snowpark.modin.plugin._internal.utils import pandas_lit from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage @@ -171,7 +171,7 @@ def col_to_timedelta(col: Column, unit: str) -> Column: if not td_unit: # Same error as native pandas. raise ValueError(f"invalid unit abbreviation: {unit}") - return cast(floor(col * TIMEDELTA_UNIT_MULTIPLIER[td_unit]), LongType()) + return trunc(col * TIMEDELTA_UNIT_MULTIPLIER[td_unit]) PANDAS_DATETIME_FORMAT_TO_SNOWFLAKE_MAPPING = { diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index bbebbec1783..195c6914da7 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -135,6 +135,7 @@ to_variant, translate, trim, + trunc, uniform, upper, when, @@ -382,6 +383,12 @@ SUPPORTED_DT_FLOOR_CEIL_FREQS = ["day", "hour", "minute", "second"] +SECONDS_PER_DAY = 86400 +NANOSECONDS_PER_SECOND = 10**9 +NANOSECONDS_PER_MICROSECOND = 10**3 +MICROSECONDS_PER_SECOND = 10**6 +NANOSECONDS_PER_DAY = SECONDS_PER_DAY * NANOSECONDS_PER_SECOND + class SnowflakeQueryCompiler(BaseQueryCompiler): """based on: https://modin.readthedocs.io/en/0.11.0/flow/modin/backends/base/query_compiler.html @@ -17498,3 +17505,41 @@ def tz_convert(self, *args: Any, **kwargs: Any) -> None: def tz_localize(self, *args: Any, **kwargs: Any) -> None: ErrorMessage.method_not_implemented_error("tz_convert", "BasePandasDataset") + + def timedelta_property( + self, property_name: str, include_index: bool = False + ) -> "SnowflakeQueryCompiler": + """ + Extract a specified component of from Timedelta. + + Parameters + ---------- + property : {'days', 'seconds', 'microseconds', 'nanoseconds'} + The component to extract. + include_index: Whether to include the index columns in the operation. + + Returns + ------- + A new SnowflakeQueryCompiler with the extracted component. + """ + if not include_index: + assert len(self.columns) == 1, "dt only works for series" + + # mapping from the property name to the corresponding snowpark function + property_to_func_map = { + "days": lambda column: trunc(column / NANOSECONDS_PER_DAY), + "seconds": lambda column: trunc(column / NANOSECONDS_PER_SECOND) + % SECONDS_PER_DAY, + "microseconds": lambda column: trunc(column / NANOSECONDS_PER_MICROSECOND) + % MICROSECONDS_PER_SECOND, + "nanoseconds": lambda column: column % NANOSECONDS_PER_MICROSECOND, + } + func = property_to_func_map.get(property_name) + if not func: + class_prefix = "TimedeltaIndex" if include_index else "Series.dt" + raise ErrorMessage.not_implemented( + f"Snowpark pandas doesn't yet support the property '{class_prefix}.{property_name}'" + ) + return SnowflakeQueryCompiler( + self._modin_frame.apply_snowpark_function_to_columns(func, include_index) + ) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py index 86ed2a5ded4..dac1a78f740 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py @@ -130,7 +130,6 @@ def __init__( } self._init_index(data, _CONSTRUCTOR_DEFAULTS, query_compiler, **kwargs) - @timedelta_index_not_implemented() @property def days(self) -> Index: """ @@ -142,15 +141,18 @@ def days(self) -> Index: Examples -------- - >>> idx = pd.to_timedelta(["0 days", "10 days", "20 days"]) # doctest: +SKIP - >>> idx # doctest: +SKIP - TimedeltaIndex(['0 days', '10 days', '20 days'], - dtype='timedelta64[ns]', freq=None) - >>> idx.days # doctest: +SKIP + >>> idx = pd.to_timedelta(["0 days", "10 days", "20 days"]) + >>> idx + TimedeltaIndex(['0 days', '10 days', '20 days'], dtype='timedelta64[ns]', freq=None) + >>> idx.days Index([0, 10, 20], dtype='int64') """ + return Index( + query_compiler=self._query_compiler.timedelta_property( + "days", include_index=True + ) + ) - @timedelta_index_not_implemented() @property def seconds(self) -> Index: """ @@ -162,15 +164,18 @@ def seconds(self) -> Index: Examples -------- - >>> idx = pd.to_timedelta([1, 2, 3], unit='s') # doctest: +SKIP - >>> idx # doctest: +SKIP - TimedeltaIndex(['0 days 00:00:01', '0 days 00:00:02', '0 days 00:00:03'], - dtype='timedelta64[ns]', freq=None) - >>> idx.seconds # doctest: +SKIP - Index([1, 2, 3], dtype='int32') + >>> idx = pd.to_timedelta([1, 2, 3], unit='s') + >>> idx + TimedeltaIndex(['0 days 00:00:01', '0 days 00:00:02', '0 days 00:00:03'], dtype='timedelta64[ns]', freq=None) + >>> idx.seconds + Index([1, 2, 3], dtype='int64') """ + return Index( + query_compiler=self._query_compiler.timedelta_property( + "seconds", include_index=True + ) + ) - @timedelta_index_not_implemented() @property def microseconds(self) -> Index: """ @@ -182,16 +187,20 @@ def microseconds(self) -> Index: Examples -------- - >>> idx = pd.to_timedelta([1, 2, 3], unit='us') # doctest: +SKIP - >>> idx # doctest: +SKIP + >>> idx = pd.to_timedelta([1, 2, 3], unit='us') + >>> idx TimedeltaIndex(['0 days 00:00:00.000001', '0 days 00:00:00.000002', '0 days 00:00:00.000003'], dtype='timedelta64[ns]', freq=None) - >>> idx.microseconds # doctest: +SKIP - Index([1, 2, 3], dtype='int32') + >>> idx.microseconds + Index([1, 2, 3], dtype='int64') """ + return Index( + query_compiler=self._query_compiler.timedelta_property( + "microseconds", include_index=True + ) + ) - @timedelta_index_not_implemented() @property def nanoseconds(self) -> Index: """ @@ -203,14 +212,19 @@ def nanoseconds(self) -> Index: Examples -------- - >>> idx = pd.to_timedelta([1, 2, 3], unit='ns') # doctest: +SKIP - >>> idx # doctest: +SKIP + >>> idx = pd.to_timedelta([1, 2, 3], unit='ns') + >>> idx TimedeltaIndex(['0 days 00:00:00.000000001', '0 days 00:00:00.000000002', '0 days 00:00:00.000000003'], dtype='timedelta64[ns]', freq=None) - >>> idx.nanoseconds # doctest: +SKIP - Index([1, 2, 3], dtype='int32') + >>> idx.nanoseconds + Index([1, 2, 3], dtype='int64') """ + return Index( + query_compiler=self._query_compiler.timedelta_property( + "nanoseconds", include_index=True + ) + ) @timedelta_index_not_implemented() @property diff --git a/tests/integ/modin/index/test_timedelta_index_methods.py b/tests/integ/modin/index/test_timedelta_index_methods.py index 1baafed24d2..c68ab9653fa 100644 --- a/tests/integ/modin/index/test_timedelta_index_methods.py +++ b/tests/integ/modin/index/test_timedelta_index_methods.py @@ -8,6 +8,7 @@ import snowflake.snowpark.modin.plugin # noqa: F401 from tests.integ.modin.sql_counter import sql_count_checker +from tests.integ.modin.utils import assert_index_equal @sql_count_checker(query_count=3) @@ -54,12 +55,22 @@ def test_non_default_args(kwargs): pd.TimedeltaIndex(query_compiler=idx._query_compiler, **kwargs) -@pytest.mark.parametrize( - "property", ["days", "seconds", "microseconds", "nanoseconds", "inferred_freq"] -) +@pytest.mark.parametrize("property", ["components", "inferred_freq"]) @sql_count_checker(query_count=0) def test_property_not_implemented(property): snow_index = pd.TimedeltaIndex(["1 days", "2 days"]) msg = f"Snowpark pandas does not yet support the property TimedeltaIndex.{property}" with pytest.raises(NotImplementedError, match=msg): getattr(snow_index, property) + + +@pytest.mark.parametrize("attr", ["days", "seconds", "microseconds", "nanoseconds"]) +@sql_count_checker(query_count=1) +def test_timedelta_index_properties(attr): + native_index = native_pd.TimedeltaIndex( + ["1d", "1h", "60s", "1s", "800ms", "5us", "6ns", "1d 3s", "9m 15s 8us"] + ) + snow_index = pd.Index(native_index) + assert_index_equal( + getattr(snow_index, attr), getattr(native_index, attr), exact=False + )