From a43f398d025e6bb66759b4b3d7bf84be7b94c886 Mon Sep 17 00:00:00 2001 From: Naresh Kumar <113932371+sfc-gh-nkumar@users.noreply.github.com> Date: Fri, 23 Aug 2024 13:48:51 -0700 Subject: [PATCH] SNOW-1558919: Add support for DatetimeIndex ceil, floor and round methods (#2135) Fixes SNOW-1558919 Added support for DatetimeIndex ceil, floor and round methods. Raise not implemented error if ambiguous or nonexistent parameter is set. --- CHANGELOG.md | 1 + .../supported/datetime_index_supported.rst | 6 +- .../compiler/snowflake_query_compiler.py | 230 ++++++++---------- .../modin/plugin/docstrings/series_utils.py | 32 ++- .../modin/plugin/extensions/datetime_index.py | 70 +++--- .../index/test_datetime_index_methods.py | 60 ++++- 6 files changed, 212 insertions(+), 187 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index afe0d3e862b..82e656db449 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,6 +60,7 @@ - Added support for `pd.merge_asof`. - Added support for `Series.dt.normalize` and `DatetimeIndex.normalize`. - Added support for `Index.is_boolean`, `Index.is_integer`, `Index.is_floating`, `Index.is_numeric`, and `Index.is_object`. +- Added support for `DatetimeIndex.round`, `DatetimeIndex.floor` and `DatetimeIndex.ceil`. #### Bug Fixes diff --git a/docs/source/modin/supported/datetime_index_supported.rst b/docs/source/modin/supported/datetime_index_supported.rst index ccc553caaf2..68b1935da96 100644 --- a/docs/source/modin/supported/datetime_index_supported.rst +++ b/docs/source/modin/supported/datetime_index_supported.rst @@ -86,11 +86,11 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``tz_localize`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``round`` | N | | | +| ``round`` | P | ``ambiguous``, ``nonexistent`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``floor`` | N | | | +| ``floor`` | P | ``ambiguous``, ``nonexistent`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``ceil`` | N | | | +| ``ceil`` | P | ``ambiguous``, ``nonexistent`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``to_period`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 18a732c2f0f..aff1c155d6c 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -16087,7 +16087,11 @@ def dt_tz_convert(self, tz: Union[str, tzinfo]) -> None: ) def dt_ceil( - self, freq: Frequency, ambiguous: str = "raise", nonexistent: str = "raise" + self, + freq: Frequency, + ambiguous: str = "raise", + nonexistent: str = "raise", + include_index: bool = False, ) -> "SnowflakeQueryCompiler": """ Args: @@ -16105,62 +16109,51 @@ def dt_ceil( - 'NaT' will return NaT where there are nonexistent times - timedelta objects will shift nonexistent times by the timedelta - 'raise' will raise an NonExistentTimeError if there are nonexistent times. + include_index: Whether to include the index columns in the operation. Returns: A new QueryCompiler with ceil values. """ + method_name = "DatetimeIndex.ceil" if include_index else "Series.dt.ceil" if ambiguous != "raise": - ErrorMessage.not_implemented( - "Snowpark pandas 'Series.dt.ceil' method doesn't yet support 'ambiguous' parameter" - ) + ErrorMessage.parameter_not_implemented_error("ambiguous", method_name) if nonexistent != "raise": - ErrorMessage.not_implemented( - "Snowpark pandas 'Series.dt.ceil' method doesn't yet support 'nonexistent' parameter" - ) - internal_frame = self._modin_frame + ErrorMessage.parameter_not_implemented_error("nonexistent", method_name) slice_length, slice_unit = rule_to_snowflake_width_and_slice_unit( rule=freq # type: ignore[arg-type] ) if slice_unit not in SUPPORTED_DT_FLOOR_CEIL_FREQS: - ErrorMessage.not_implemented( - f"Snowpark pandas 'Series.dt.ceil' method doesn't support setting 'freq' parameter to '{freq}'" + ErrorMessage.parameter_not_implemented_error(f"freq='{freq}'", method_name) + + def ceil_func(col_id: str) -> SnowparkColumn: + base_column = col(col_id) + floor_column = builtin("time_slice")( + base_column, slice_length, slice_unit, "START" ) - base_column = col(internal_frame.data_column_snowflake_quoted_identifiers[0]) - floor_column = builtin("time_slice")( - base_column, slice_length, slice_unit, "START" - ) - ceil_column = builtin("time_slice")( - base_column, slice_length, slice_unit, "END" - ) - ceil_column = iff( - base_column.equal_null(floor_column), base_column, ceil_column - ) + ceil_column = builtin("time_slice")( + base_column, slice_length, slice_unit, "END" + ) + return iff(base_column.equal_null(floor_column), base_column, ceil_column) - internal_frame = internal_frame.append_column( - internal_frame.data_column_pandas_labels[0], ceil_column - ) + frame = self._modin_frame + snowflake_ids = frame.data_column_snowflake_quoted_identifiers[0:1] + if include_index: + snowflake_ids.extend(frame.index_column_snowflake_quoted_identifiers) return SnowflakeQueryCompiler( - InternalFrame.create( - ordered_dataframe=internal_frame.ordered_dataframe, - data_column_pandas_labels=[None], - data_column_pandas_index_names=internal_frame.data_column_pandas_index_names, - data_column_snowflake_quoted_identifiers=internal_frame.data_column_snowflake_quoted_identifiers[ - -1: - ], - index_column_pandas_labels=internal_frame.index_column_pandas_labels, - index_column_snowflake_quoted_identifiers=internal_frame.index_column_snowflake_quoted_identifiers, - data_column_types=internal_frame.cached_data_column_snowpark_pandas_types[ - -1: - ], - index_column_types=internal_frame.cached_index_column_snowpark_pandas_types, - ) + frame.update_snowflake_quoted_identifiers_with_expressions( + {col_id: ceil_func(col_id) for col_id in snowflake_ids} + ).frame ) def dt_round( - self, freq: Frequency, ambiguous: str = "raise", nonexistent: str = "raise" + self, + freq: Frequency, + ambiguous: str = "raise", + nonexistent: str = "raise", + include_index: bool = False, ) -> "SnowflakeQueryCompiler": """ Args: @@ -16178,28 +16171,23 @@ def dt_round( - 'NaT' will return NaT where there are nonexistent times - timedelta objects will shift nonexistent times by the timedelta - 'raise' will raise an NonExistentTimeError if there are nonexistent times. + include_index: Whether to include the index columns in the operation. Returns: A new QueryCompiler with round values. """ + method_name = "DatetimeIndex.round" if include_index else "Series.dt.round" if ambiguous != "raise": - ErrorMessage.not_implemented( - "Snowpark pandas 'Series.dt.round' method doesn't yet support 'ambiguous' parameter" - ) + ErrorMessage.parameter_not_implemented_error("ambiguous", method_name) if nonexistent != "raise": - ErrorMessage.not_implemented( - "Snowpark pandas 'Series.dt.round' method doesn't yet support 'nonexistent' parameter" - ) - internal_frame = self._modin_frame + ErrorMessage.parameter_not_implemented_error("nonexistent", method_name) slice_length, slice_unit = rule_to_snowflake_width_and_slice_unit( rule=freq # type: ignore[arg-type] ) if slice_unit not in SUPPORTED_DT_FLOOR_CEIL_FREQS or slice_unit == "second": - ErrorMessage.not_implemented( - f"Snowpark pandas 'Series.dt.round' method doesn't support setting 'freq' parameter to '{freq}'" - ) + ErrorMessage.parameter_not_implemented_error(f"freq={freq}", method_name) # We need to implement the algorithm for rounding half to even whenever # the date value is at half point of the slice: @@ -16227,72 +16215,79 @@ def down_level_freq(slice_length: int, slice_unit: str) -> tuple[int, str]: if slice_length % 2 == 1: slice_length, slice_unit = down_level_freq(slice_length, slice_unit) half_slice_length = int(slice_length / 2) - base_column = col(internal_frame.data_column_snowflake_quoted_identifiers[0]) - - # Second, we determine whether floor represents an even number of slices. - # To do so, we must divide the number of epoch seconds in it over the number - # of epoch seconds in one slice. This way, we can get the number of slices. - - floor_column = builtin("time_slice")( - base_column, slice_length, slice_unit, "START" - ) - ceil_column = builtin("time_slice")( - base_column, slice_length, slice_unit, "END" - ) def slice_length_when_unit_is_second(slice_length: int, slice_unit: str) -> int: while slice_unit != "second": slice_length, slice_unit = down_level_freq(slice_length, slice_unit) return slice_length - floor_epoch_seconds_column = builtin("extract")("epoch_second", floor_column) - floor_num_slices_column = cast( - floor_epoch_seconds_column - / pandas_lit(slice_length_when_unit_is_second(slice_length, slice_unit)), - IntegerType(), - ) + def round_func(col_id: str) -> SnowparkColumn: + base_column = col(col_id) - # Now that we know the number of slices, we can check if they are even or odd. + # Second, we determine whether floor represents an even number of slices. + # To do so, we must divide the number of epoch seconds in it over the number + # of epoch seconds in one slice. This way, we can get the number of slices. - floor_is_even = (floor_num_slices_column % pandas_lit(2)).equal_null( - pandas_lit(0) - ) + floor_column = builtin("time_slice")( + base_column, slice_length, slice_unit, "START" + ) + ceil_column = builtin("time_slice")( + base_column, slice_length, slice_unit, "END" + ) - # Accordingly, we can decide if the round column should be the floor or ceil - # of the slice. + floor_epoch_seconds_column = builtin("extract")( + "epoch_second", floor_column + ) + floor_num_slices_column = cast( + floor_epoch_seconds_column + / pandas_lit( + slice_length_when_unit_is_second(slice_length, slice_unit) + ), + IntegerType(), + ) - round_column_if_half_point = iff(floor_is_even, floor_column, ceil_column) + # Now that we know the number of slices, we can check if they are even or odd. + floor_is_even = (floor_num_slices_column % pandas_lit(2)).equal_null( + pandas_lit(0) + ) - # In case the date value is not at half point of the slice, then we shift it - # by half a slice, and take the floor from there. + # Accordingly, we can decide if the round column should be the floor or ceil + # of the slice. + round_column_if_half_point = iff(floor_is_even, floor_column, ceil_column) - base_plus_half_slice_column = dateadd( - slice_unit, pandas_lit(half_slice_length), base_column - ) - round_column_if_not_half_point = builtin("time_slice")( - base_plus_half_slice_column, slice_length, slice_unit, "START" - ) + # In case the date value is not at half point of the slice, then we shift it + # by half a slice, and take the floor from there. + base_plus_half_slice_column = dateadd( + slice_unit, pandas_lit(half_slice_length), base_column + ) + round_column_if_not_half_point = builtin("time_slice")( + base_plus_half_slice_column, slice_length, slice_unit, "START" + ) - # The final expression for the round column. + # The final expression for the round column. + return iff( + base_plus_half_slice_column.equal_null(ceil_column), + round_column_if_half_point, + round_column_if_not_half_point, + ) - round_column = iff( - base_plus_half_slice_column.equal_null(ceil_column), - round_column_if_half_point, - round_column_if_not_half_point, - ) + frame = self._modin_frame + snowflake_ids = frame.data_column_snowflake_quoted_identifiers[0:1] + if include_index: + snowflake_ids.extend(frame.index_column_snowflake_quoted_identifiers) return SnowflakeQueryCompiler( - internal_frame.update_snowflake_quoted_identifiers_with_expressions( - { - internal_frame.data_column_snowflake_quoted_identifiers[ - 0 - ]: round_column - } + frame.update_snowflake_quoted_identifiers_with_expressions( + {col_id: round_func(col_id) for col_id in snowflake_ids} ).frame ) def dt_floor( - self, freq: Frequency, ambiguous: str = "raise", nonexistent: str = "raise" + self, + freq: Frequency, + ambiguous: str = "raise", + nonexistent: str = "raise", + include_index: bool = False, ) -> "SnowflakeQueryCompiler": """ Args: @@ -16310,52 +16305,35 @@ def dt_floor( - 'NaT' will return NaT where there are nonexistent times - timedelta objects will shift nonexistent times by the timedelta - 'raise' will raise an NonExistentTimeError if there are nonexistent times. + include_index: Whether to include the index columns in the operation. Returns: A new QueryCompiler with floor values. """ + method_name = "DatetimeIndex.floor" if include_index else "Series.dt.floor" if ambiguous != "raise": - ErrorMessage.not_implemented( - "Snowpark pandas 'Series.dt.floor' method doesn't yet support 'ambiguous' parameter" - ) + ErrorMessage.parameter_not_implemented_error("ambiguous", method_name) if nonexistent != "raise": - ErrorMessage.not_implemented( - "Snowpark pandas 'Series.dt.floor' method doesn't yet support 'nonexistent' parameter" - ) - internal_frame = self._modin_frame + ErrorMessage.parameter_not_implemented_error("nonexistent", method_name) slice_length, slice_unit = rule_to_snowflake_width_and_slice_unit( rule=freq # type: ignore[arg-type] ) if slice_unit not in SUPPORTED_DT_FLOOR_CEIL_FREQS: - ErrorMessage.not_implemented( - f"Snowpark pandas 'Series.dt.floor' method doesn't support setting 'freq' parameter to '{freq}'" - ) - snowpark_column = builtin("time_slice")( - col(internal_frame.data_column_snowflake_quoted_identifiers[0]), - slice_length, - slice_unit, - ) + ErrorMessage.parameter_not_implemented_error(f"freq='{freq}'", method_name) - internal_frame = internal_frame.append_column( - internal_frame.data_column_pandas_labels[0], snowpark_column - ) + frame = self._modin_frame + snowflake_ids = frame.data_column_snowflake_quoted_identifiers[0:1] + if include_index: + snowflake_ids.extend(frame.index_column_snowflake_quoted_identifiers) return SnowflakeQueryCompiler( - InternalFrame.create( - ordered_dataframe=internal_frame.ordered_dataframe, - data_column_pandas_labels=[None], - data_column_pandas_index_names=internal_frame.data_column_pandas_index_names, - data_column_snowflake_quoted_identifiers=internal_frame.data_column_snowflake_quoted_identifiers[ - -1: - ], - index_column_pandas_labels=internal_frame.index_column_pandas_labels, - index_column_snowflake_quoted_identifiers=internal_frame.index_column_snowflake_quoted_identifiers, - data_column_types=internal_frame.cached_data_column_snowpark_pandas_types[ - -1: - ], - index_column_types=internal_frame.cached_index_column_snowpark_pandas_types, - ) + frame.update_snowflake_quoted_identifiers_with_expressions( + { + col_id: builtin("time_slice")(col(col_id), slice_length, slice_unit) + for col_id in snowflake_ids + } + ).frame ) def dt_normalize(self, include_index: bool = False) -> "SnowflakeQueryCompiler": diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py b/src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py index 802bdf16653..1bf38c7a241 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/series_utils.py @@ -1859,11 +1859,11 @@ def round(): DatetimeIndex >>> rng = pd.date_range('1/1/2018 11:59:00', periods=3, freq='min') - >>> rng # doctest: +SKIP + >>> rng DatetimeIndex(['2018-01-01 11:59:00', '2018-01-01 12:00:00', '2018-01-01 12:01:00'], - dtype='datetime64[ns]', freq='min') - >>> rng.round('h') # doctest: +SKIP + dtype='datetime64[ns]', freq=None) + >>> rng.round('h') DatetimeIndex(['2018-01-01 12:00:00', '2018-01-01 12:00:00', '2018-01-01 12:00:00'], dtype='datetime64[ns]', freq=None) @@ -1929,14 +1929,14 @@ def floor(): DatetimeIndex >>> rng = pd.date_range('1/1/2018 11:59:00', periods=3, freq='min') - >>> rng # doctest: +SKIP + >>> rng DatetimeIndex(['2018-01-01 11:59:00', '2018-01-01 12:00:00', - '2018-01-01 12:01:00'], - dtype='datetime64[ns]', freq='min') - >>> rng.floor('h') # doctest: +SKIP + '2018-01-01 12:01:00'], + dtype='datetime64[ns]', freq=None) + >>> rng.floor('h') DatetimeIndex(['2018-01-01 11:00:00', '2018-01-01 12:00:00', - '2018-01-01 12:00:00'], - dtype='datetime64[ns]', freq=None) + '2018-01-01 12:00:00'], + dtype='datetime64[ns]', freq=None) Series @@ -1958,7 +1958,6 @@ def floor(): DatetimeIndex(['2021-10-31 02:00:00+02:00'], dtype='datetime64[ns, Europe/Amsterdam]', freq=None) """ - # TODO(SNOW-1486910): Unskip when date_range returns DatetimeIndex. def ceil(): """ @@ -2000,14 +1999,14 @@ def ceil(): DatetimeIndex >>> rng = pd.date_range('1/1/2018 11:59:00', periods=3, freq='min') - >>> rng # doctest: +SKIP + >>> rng DatetimeIndex(['2018-01-01 11:59:00', '2018-01-01 12:00:00', - '2018-01-01 12:01:00'], - dtype='datetime64[ns]', freq='min') - >>> rng.ceil('h') # doctest: +SKIP + '2018-01-01 12:01:00'], + dtype='datetime64[ns]', freq=None) + >>> rng.ceil('h') DatetimeIndex(['2018-01-01 12:00:00', '2018-01-01 12:00:00', - '2018-01-01 13:00:00'], - dtype='datetime64[ns]', freq=None) + '2018-01-01 13:00:00'], + dtype='datetime64[ns]', freq=None) Series @@ -2029,7 +2028,6 @@ def ceil(): DatetimeIndex(['2021-10-31 02:00:00+02:00'], dtype='datetime64[ns, Europe/Amsterdam]', freq=None) """ - # TODO(SNOW-1486910): Unskip when date_range returns DatetimeIndex. def month_name(): """ diff --git a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py index af12c554c22..0618a9dd648 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py @@ -1055,8 +1055,9 @@ def tz_localize( dtype='datetime64[ns]', freq=None) """ - @datetime_index_not_implemented() - def round(self, *args, **kwargs) -> DatetimeIndex: + def round( + self, freq: Frequency, ambiguous: str = "raise", nonexistent: str = "raise" + ) -> DatetimeIndex: """ Perform round operation on the data to the specified `freq`. @@ -1093,21 +1094,12 @@ def round(self, *args, **kwargs) -> DatetimeIndex: Returns ------- - DatetimeIndex, TimedeltaIndex, or Series - Index of the same type for a DatetimeIndex or TimedeltaIndex, - or a Series with the same index for a Series. + DatetimeIndex with round values. Raises ------ ValueError if the `freq` cannot be converted. - Notes - ----- - If the timestamps have a timezone, {op}ing will take place relative to the - local ("wall") time and re-localized to the same timezone. When {op}ing - near daylight savings time, use ``nonexistent`` and ``ambiguous`` to - control the re-localization behavior. - Examples -------- **DatetimeIndex** @@ -1118,14 +1110,20 @@ def round(self, *args, **kwargs) -> DatetimeIndex: '2018-01-01 12:01:00'], dtype='datetime64[ns]', freq=None) - >>> rng.round('h') # doctest: +SKIP + >>> rng.round('h') DatetimeIndex(['2018-01-01 12:00:00', '2018-01-01 12:00:00', '2018-01-01 12:00:00'], dtype='datetime64[ns]', freq=None) """ + return DatetimeIndex( + query_compiler=self._query_compiler.dt_round( + freq, ambiguous, nonexistent, include_index=True + ) + ) - @datetime_index_not_implemented() - def floor(self, *args, **kwargs) -> DatetimeIndex: + def floor( + self, freq: Frequency, ambiguous: str = "raise", nonexistent: str = "raise" + ) -> DatetimeIndex: """ Perform floor operation on the data to the specified `freq`. @@ -1162,21 +1160,12 @@ def floor(self, *args, **kwargs) -> DatetimeIndex: Returns ------- - DatetimeIndex, TimedeltaIndex, or Series - Index of the same type for a DatetimeIndex or TimedeltaIndex, - or a Series with the same index for a Series. + DatetimeIndex with floor values. Raises ------ ValueError if the `freq` cannot be converted. - Notes - ----- - If the timestamps have a timezone, {op}ing will take place relative to the - local ("wall") time and re-localized to the same timezone. When {op}ing - near daylight savings time, use ``nonexistent`` and ``ambiguous`` to - control the re-localization behavior. - Examples -------- **DatetimeIndex** @@ -1187,14 +1176,20 @@ def floor(self, *args, **kwargs) -> DatetimeIndex: '2018-01-01 12:01:00'], dtype='datetime64[ns]', freq=None) - >>> rng.floor('h') # doctest: +SKIP + >>> rng.floor('h') DatetimeIndex(['2018-01-01 11:00:00', '2018-01-01 12:00:00', '2018-01-01 12:00:00'], dtype='datetime64[ns]', freq=None) """ + return DatetimeIndex( + query_compiler=self._query_compiler.dt_floor( + freq, ambiguous, nonexistent, include_index=True + ) + ) - @datetime_index_not_implemented() - def ceil(self, *args, **kwargs) -> DatetimeIndex: + def ceil( + self, freq: Frequency, ambiguous: str = "raise", nonexistent: str = "raise" + ) -> DatetimeIndex: """ Perform ceil operation on the data to the specified `freq`. @@ -1231,21 +1226,12 @@ def ceil(self, *args, **kwargs) -> DatetimeIndex: Returns ------- - DatetimeIndex, TimedeltaIndex, or Series - Index of the same type for a DatetimeIndex or TimedeltaIndex, - or a Series with the same index for a Series. + DatetimeIndex with ceil values. Raises ------ ValueError if the `freq` cannot be converted. - Notes - ----- - If the timestamps have a timezone, {op}ing will take place relative to the - local ("wall") time and re-localized to the same timezone. When {op}ing - near daylight savings time, use ``nonexistent`` and ``ambiguous`` to - control the re-localization behavior. - Examples -------- **DatetimeIndex** @@ -1256,12 +1242,16 @@ def ceil(self, *args, **kwargs) -> DatetimeIndex: '2018-01-01 12:01:00'], dtype='datetime64[ns]', freq=None) - >>> rng.ceil('h') # doctest: +SKIP + >>> rng.ceil('h') DatetimeIndex(['2018-01-01 12:00:00', '2018-01-01 12:00:00', '2018-01-01 13:00:00'], dtype='datetime64[ns]', freq=None) - """ + return DatetimeIndex( + query_compiler=self._query_compiler.dt_ceil( + freq, ambiguous, nonexistent, include_index=True + ) + ) def month_name(self, locale: str = None) -> Index: """ diff --git a/tests/integ/modin/index/test_datetime_index_methods.py b/tests/integ/modin/index/test_datetime_index_methods.py index d7bfee4d32d..9baa44196ab 100644 --- a/tests/integ/modin/index/test_datetime_index_methods.py +++ b/tests/integ/modin/index/test_datetime_index_methods.py @@ -7,7 +7,7 @@ import pytest import snowflake.snowpark.modin.plugin # noqa: F401 -from tests.integ.modin.sql_counter import sql_count_checker +from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker from tests.integ.modin.utils import ( assert_frame_equal, assert_index_equal, @@ -217,3 +217,61 @@ def test_normalize(): native_index, lambda i: i.normalize(), ) + + +@pytest.mark.parametrize( + "datetime_index_value", + [ + ["2014-04-04 23:56:20", "2014-07-18 21:24:30", "2015-11-22 22:14:40"], + ["04/04/2014", "07/18/2013", "11/22/2015"], + ["2014-04-04 23:56", pd.NaT, "2014-07-18 21:24", "2015-11-22 22:14", pd.NaT], + [ + pd.Timestamp(2017, 1, 1, 12), + pd.Timestamp(2018, 2, 1, 10), + pd.Timestamp(2000, 2, 1, 10), + ], + ], +) +@pytest.mark.parametrize("func", ["round", "floor", "ceil"]) +@pytest.mark.parametrize("freq", ["1d", "2d", "1h", "2h", "1min", "2min", "1s", "2s"]) +def test_floor_ceil_round(datetime_index_value, func, freq): + native_index = native_pd.DatetimeIndex(datetime_index_value) + snow_index = pd.DatetimeIndex(native_index) + if func == "round" and "s" in freq: + with SqlCounter(query_count=0): + msg = f"Snowpark pandas method DatetimeIndex.round does not yet support the 'freq={freq}' parameter" + with pytest.raises(NotImplementedError, match=msg): + snow_index.round(freq=freq) + else: + with SqlCounter(query_count=1): + eval_snowpark_pandas_result( + snow_index, native_index, lambda i: getattr(i, func)(freq) + ) + + +@pytest.mark.parametrize("func", ["floor", "ceil", "round"]) +@pytest.mark.parametrize( + "freq, ambiguous, nonexistent", + [ + ("1w", "raise", "raise"), + ("1h", "infer", "raise"), + ("1h", "raise", "shift_forward"), + ("1w", "infer", "shift_forward"), + ], +) +@sql_count_checker(query_count=0) +def test_floor_ceil_round_negative(func, freq, ambiguous, nonexistent): + datetime_index_value = [ + "2014-04-04 23:56", + pd.NaT, + "2014-07-18 21:24", + "2015-11-22 22:14", + pd.NaT, + ] + native_index = native_pd.DatetimeIndex(datetime_index_value) + snow_index = pd.DatetimeIndex(native_index) + msg = f"Snowpark pandas method DatetimeIndex.{func} does not yet support" + with pytest.raises(NotImplementedError, match=msg): + getattr(snow_index, func)( + freq=freq, ambiguous=ambiguous, nonexistent=nonexistent + )