Skip to content

Commit

Permalink
SNOW-1558919: Add support for DatetimeIndex ceil, floor and round met…
Browse files Browse the repository at this point in the history
…hods (#2135)

Fixes SNOW-1558919

Added support for DatetimeIndex ceil, floor and round methods. Raise not
implemented error if ambiguous or nonexistent parameter is set.
  • Loading branch information
sfc-gh-nkumar authored Aug 23, 2024
1 parent 7cbad6f commit a43f398
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 187 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
- Added support for `pd.merge_asof`.
- Added support for `Series.dt.normalize` and `DatetimeIndex.normalize`.
- Added support for `Index.is_boolean`, `Index.is_integer`, `Index.is_floating`, `Index.is_numeric`, and `Index.is_object`.
- Added support for `DatetimeIndex.round`, `DatetimeIndex.floor` and `DatetimeIndex.ceil`.

#### Bug Fixes

Expand Down
6 changes: 3 additions & 3 deletions docs/source/modin/supported/datetime_index_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ Methods
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``tz_localize`` | N | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``round`` | N | | |
| ``round`` | P | ``ambiguous``, ``nonexistent`` | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``floor`` | N | | |
| ``floor`` | P | ``ambiguous``, ``nonexistent`` | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``ceil`` | N | | |
| ``ceil`` | P | ``ambiguous``, ``nonexistent`` | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``to_period`` | N | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
Expand Down
230 changes: 104 additions & 126 deletions src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16087,7 +16087,11 @@ def dt_tz_convert(self, tz: Union[str, tzinfo]) -> None:
)

def dt_ceil(
self, freq: Frequency, ambiguous: str = "raise", nonexistent: str = "raise"
self,
freq: Frequency,
ambiguous: str = "raise",
nonexistent: str = "raise",
include_index: bool = False,
) -> "SnowflakeQueryCompiler":
"""
Args:
Expand All @@ -16105,62 +16109,51 @@ def dt_ceil(
- 'NaT' will return NaT where there are nonexistent times
- timedelta objects will shift nonexistent times by the timedelta
- 'raise' will raise an NonExistentTimeError if there are nonexistent times.
include_index: Whether to include the index columns in the operation.
Returns:
A new QueryCompiler with ceil values.

"""
method_name = "DatetimeIndex.ceil" if include_index else "Series.dt.ceil"
if ambiguous != "raise":
ErrorMessage.not_implemented(
"Snowpark pandas 'Series.dt.ceil' method doesn't yet support 'ambiguous' parameter"
)
ErrorMessage.parameter_not_implemented_error("ambiguous", method_name)
if nonexistent != "raise":
ErrorMessage.not_implemented(
"Snowpark pandas 'Series.dt.ceil' method doesn't yet support 'nonexistent' parameter"
)
internal_frame = self._modin_frame
ErrorMessage.parameter_not_implemented_error("nonexistent", method_name)

slice_length, slice_unit = rule_to_snowflake_width_and_slice_unit(
rule=freq # type: ignore[arg-type]
)

if slice_unit not in SUPPORTED_DT_FLOOR_CEIL_FREQS:
ErrorMessage.not_implemented(
f"Snowpark pandas 'Series.dt.ceil' method doesn't support setting 'freq' parameter to '{freq}'"
ErrorMessage.parameter_not_implemented_error(f"freq='{freq}'", method_name)

def ceil_func(col_id: str) -> SnowparkColumn:
base_column = col(col_id)
floor_column = builtin("time_slice")(
base_column, slice_length, slice_unit, "START"
)
base_column = col(internal_frame.data_column_snowflake_quoted_identifiers[0])
floor_column = builtin("time_slice")(
base_column, slice_length, slice_unit, "START"
)
ceil_column = builtin("time_slice")(
base_column, slice_length, slice_unit, "END"
)
ceil_column = iff(
base_column.equal_null(floor_column), base_column, ceil_column
)
ceil_column = builtin("time_slice")(
base_column, slice_length, slice_unit, "END"
)
return iff(base_column.equal_null(floor_column), base_column, ceil_column)

internal_frame = internal_frame.append_column(
internal_frame.data_column_pandas_labels[0], ceil_column
)
frame = self._modin_frame
snowflake_ids = frame.data_column_snowflake_quoted_identifiers[0:1]
if include_index:
snowflake_ids.extend(frame.index_column_snowflake_quoted_identifiers)

return SnowflakeQueryCompiler(
InternalFrame.create(
ordered_dataframe=internal_frame.ordered_dataframe,
data_column_pandas_labels=[None],
data_column_pandas_index_names=internal_frame.data_column_pandas_index_names,
data_column_snowflake_quoted_identifiers=internal_frame.data_column_snowflake_quoted_identifiers[
-1:
],
index_column_pandas_labels=internal_frame.index_column_pandas_labels,
index_column_snowflake_quoted_identifiers=internal_frame.index_column_snowflake_quoted_identifiers,
data_column_types=internal_frame.cached_data_column_snowpark_pandas_types[
-1:
],
index_column_types=internal_frame.cached_index_column_snowpark_pandas_types,
)
frame.update_snowflake_quoted_identifiers_with_expressions(
{col_id: ceil_func(col_id) for col_id in snowflake_ids}
).frame
)

def dt_round(
self, freq: Frequency, ambiguous: str = "raise", nonexistent: str = "raise"
self,
freq: Frequency,
ambiguous: str = "raise",
nonexistent: str = "raise",
include_index: bool = False,
) -> "SnowflakeQueryCompiler":
"""
Args:
Expand All @@ -16178,28 +16171,23 @@ def dt_round(
- 'NaT' will return NaT where there are nonexistent times
- timedelta objects will shift nonexistent times by the timedelta
- 'raise' will raise an NonExistentTimeError if there are nonexistent times.
include_index: Whether to include the index columns in the operation.
Returns:
A new QueryCompiler with round values.

"""
method_name = "DatetimeIndex.round" if include_index else "Series.dt.round"
if ambiguous != "raise":
ErrorMessage.not_implemented(
"Snowpark pandas 'Series.dt.round' method doesn't yet support 'ambiguous' parameter"
)
ErrorMessage.parameter_not_implemented_error("ambiguous", method_name)
if nonexistent != "raise":
ErrorMessage.not_implemented(
"Snowpark pandas 'Series.dt.round' method doesn't yet support 'nonexistent' parameter"
)
internal_frame = self._modin_frame
ErrorMessage.parameter_not_implemented_error("nonexistent", method_name)

slice_length, slice_unit = rule_to_snowflake_width_and_slice_unit(
rule=freq # type: ignore[arg-type]
)

if slice_unit not in SUPPORTED_DT_FLOOR_CEIL_FREQS or slice_unit == "second":
ErrorMessage.not_implemented(
f"Snowpark pandas 'Series.dt.round' method doesn't support setting 'freq' parameter to '{freq}'"
)
ErrorMessage.parameter_not_implemented_error(f"freq={freq}", method_name)

# We need to implement the algorithm for rounding half to even whenever
# the date value is at half point of the slice:
Expand Down Expand Up @@ -16227,72 +16215,79 @@ def down_level_freq(slice_length: int, slice_unit: str) -> tuple[int, str]:
if slice_length % 2 == 1:
slice_length, slice_unit = down_level_freq(slice_length, slice_unit)
half_slice_length = int(slice_length / 2)
base_column = col(internal_frame.data_column_snowflake_quoted_identifiers[0])

# Second, we determine whether floor represents an even number of slices.
# To do so, we must divide the number of epoch seconds in it over the number
# of epoch seconds in one slice. This way, we can get the number of slices.

floor_column = builtin("time_slice")(
base_column, slice_length, slice_unit, "START"
)
ceil_column = builtin("time_slice")(
base_column, slice_length, slice_unit, "END"
)

def slice_length_when_unit_is_second(slice_length: int, slice_unit: str) -> int:
while slice_unit != "second":
slice_length, slice_unit = down_level_freq(slice_length, slice_unit)
return slice_length

floor_epoch_seconds_column = builtin("extract")("epoch_second", floor_column)
floor_num_slices_column = cast(
floor_epoch_seconds_column
/ pandas_lit(slice_length_when_unit_is_second(slice_length, slice_unit)),
IntegerType(),
)
def round_func(col_id: str) -> SnowparkColumn:
base_column = col(col_id)

# Now that we know the number of slices, we can check if they are even or odd.
# Second, we determine whether floor represents an even number of slices.
# To do so, we must divide the number of epoch seconds in it over the number
# of epoch seconds in one slice. This way, we can get the number of slices.

floor_is_even = (floor_num_slices_column % pandas_lit(2)).equal_null(
pandas_lit(0)
)
floor_column = builtin("time_slice")(
base_column, slice_length, slice_unit, "START"
)
ceil_column = builtin("time_slice")(
base_column, slice_length, slice_unit, "END"
)

# Accordingly, we can decide if the round column should be the floor or ceil
# of the slice.
floor_epoch_seconds_column = builtin("extract")(
"epoch_second", floor_column
)
floor_num_slices_column = cast(
floor_epoch_seconds_column
/ pandas_lit(
slice_length_when_unit_is_second(slice_length, slice_unit)
),
IntegerType(),
)

round_column_if_half_point = iff(floor_is_even, floor_column, ceil_column)
# Now that we know the number of slices, we can check if they are even or odd.
floor_is_even = (floor_num_slices_column % pandas_lit(2)).equal_null(
pandas_lit(0)
)

# In case the date value is not at half point of the slice, then we shift it
# by half a slice, and take the floor from there.
# Accordingly, we can decide if the round column should be the floor or ceil
# of the slice.
round_column_if_half_point = iff(floor_is_even, floor_column, ceil_column)

base_plus_half_slice_column = dateadd(
slice_unit, pandas_lit(half_slice_length), base_column
)
round_column_if_not_half_point = builtin("time_slice")(
base_plus_half_slice_column, slice_length, slice_unit, "START"
)
# In case the date value is not at half point of the slice, then we shift it
# by half a slice, and take the floor from there.
base_plus_half_slice_column = dateadd(
slice_unit, pandas_lit(half_slice_length), base_column
)
round_column_if_not_half_point = builtin("time_slice")(
base_plus_half_slice_column, slice_length, slice_unit, "START"
)

# The final expression for the round column.
# The final expression for the round column.
return iff(
base_plus_half_slice_column.equal_null(ceil_column),
round_column_if_half_point,
round_column_if_not_half_point,
)

round_column = iff(
base_plus_half_slice_column.equal_null(ceil_column),
round_column_if_half_point,
round_column_if_not_half_point,
)
frame = self._modin_frame
snowflake_ids = frame.data_column_snowflake_quoted_identifiers[0:1]
if include_index:
snowflake_ids.extend(frame.index_column_snowflake_quoted_identifiers)

return SnowflakeQueryCompiler(
internal_frame.update_snowflake_quoted_identifiers_with_expressions(
{
internal_frame.data_column_snowflake_quoted_identifiers[
0
]: round_column
}
frame.update_snowflake_quoted_identifiers_with_expressions(
{col_id: round_func(col_id) for col_id in snowflake_ids}
).frame
)

def dt_floor(
self, freq: Frequency, ambiguous: str = "raise", nonexistent: str = "raise"
self,
freq: Frequency,
ambiguous: str = "raise",
nonexistent: str = "raise",
include_index: bool = False,
) -> "SnowflakeQueryCompiler":
"""
Args:
Expand All @@ -16310,52 +16305,35 @@ def dt_floor(
- 'NaT' will return NaT where there are nonexistent times
- timedelta objects will shift nonexistent times by the timedelta
- 'raise' will raise an NonExistentTimeError if there are nonexistent times.
include_index: Whether to include the index columns in the operation.
Returns:
A new QueryCompiler with floor values.
"""
method_name = "DatetimeIndex.floor" if include_index else "Series.dt.floor"
if ambiguous != "raise":
ErrorMessage.not_implemented(
"Snowpark pandas 'Series.dt.floor' method doesn't yet support 'ambiguous' parameter"
)
ErrorMessage.parameter_not_implemented_error("ambiguous", method_name)
if nonexistent != "raise":
ErrorMessage.not_implemented(
"Snowpark pandas 'Series.dt.floor' method doesn't yet support 'nonexistent' parameter"
)
internal_frame = self._modin_frame
ErrorMessage.parameter_not_implemented_error("nonexistent", method_name)

slice_length, slice_unit = rule_to_snowflake_width_and_slice_unit(
rule=freq # type: ignore[arg-type]
)

if slice_unit not in SUPPORTED_DT_FLOOR_CEIL_FREQS:
ErrorMessage.not_implemented(
f"Snowpark pandas 'Series.dt.floor' method doesn't support setting 'freq' parameter to '{freq}'"
)
snowpark_column = builtin("time_slice")(
col(internal_frame.data_column_snowflake_quoted_identifiers[0]),
slice_length,
slice_unit,
)
ErrorMessage.parameter_not_implemented_error(f"freq='{freq}'", method_name)

internal_frame = internal_frame.append_column(
internal_frame.data_column_pandas_labels[0], snowpark_column
)
frame = self._modin_frame
snowflake_ids = frame.data_column_snowflake_quoted_identifiers[0:1]
if include_index:
snowflake_ids.extend(frame.index_column_snowflake_quoted_identifiers)

return SnowflakeQueryCompiler(
InternalFrame.create(
ordered_dataframe=internal_frame.ordered_dataframe,
data_column_pandas_labels=[None],
data_column_pandas_index_names=internal_frame.data_column_pandas_index_names,
data_column_snowflake_quoted_identifiers=internal_frame.data_column_snowflake_quoted_identifiers[
-1:
],
index_column_pandas_labels=internal_frame.index_column_pandas_labels,
index_column_snowflake_quoted_identifiers=internal_frame.index_column_snowflake_quoted_identifiers,
data_column_types=internal_frame.cached_data_column_snowpark_pandas_types[
-1:
],
index_column_types=internal_frame.cached_index_column_snowpark_pandas_types,
)
frame.update_snowflake_quoted_identifiers_with_expressions(
{
col_id: builtin("time_slice")(col(col_id), slice_length, slice_unit)
for col_id in snowflake_ids
}
).frame
)

def dt_normalize(self, include_index: bool = False) -> "SnowflakeQueryCompiler":
Expand Down
Loading

0 comments on commit a43f398

Please sign in to comment.