diff --git a/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py index 87ac427d6dc..f2ca17d2038 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py @@ -14,25 +14,25 @@ from snowflake.snowpark.column import Column from snowflake.snowpark.functions import ( builtin, - coalesce, - col, dateadd, datediff, - lag, - lead, lit, row_number, to_timestamp_ntz, ) +from snowflake.snowpark.modin.plugin._internal import join_utils from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame -from snowflake.snowpark.modin.plugin._internal.join_utils import InheritJoinIndex, join +from snowflake.snowpark.modin.plugin._internal.join_utils import ( + InheritJoinIndex, + MatchComparator, + join, +) from snowflake.snowpark.modin.plugin._internal.ordered_dataframe import ( DataFrameReference, OrderedDataFrame, ) from snowflake.snowpark.modin.plugin._internal.utils import ( generate_snowflake_quoted_identifiers_helper, - pandas_lit, ) from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage from snowflake.snowpark.types import DateType, TimestampType @@ -622,7 +622,6 @@ def fill_missing_resample_bins_for_frame( ) -# TODO: SNOW-989398 Migrate function to ASOF join def perform_asof_join_on_frame( preserving_frame: InternalFrame, referenced_frame: InternalFrame, fill_method: str ) -> InternalFrame: @@ -652,7 +651,9 @@ def perform_asof_join_on_frame( frame : InternalFrame A new frame that holds the result of an ASOF join. """ - # Consider the following example: + # Consider the following example where we want to perform an ASOF JOIN of preserving_frame + # and referenced_frame where __resample_index__ >= __index__ if forward fill + # or __resample_index__ <= __index__ if backward fill: # # preserved_frame: # __resample_index__ @@ -671,119 +672,32 @@ def perform_asof_join_on_frame( # 2023-01-07 02:00:00 NaN # 2023-01-10 00:00:00 6 - # We want to perform an ASOF JOIN of preserving_frame and referenced_frame. Here - # are the steps to take: - - # 1. Construct right_frame using referenced_frame, which has a - # temporary column, interval_end_col, that holds the closest - # following timestamp to every value in __index__. The last value in - # interval_end_col is dummy value that represents the smallest or largest - # (e.g. bfill or ffill) possible date in Snowflake. - interval_end_pandas_label = "interval_end_col" - interval_start_snowflake_quoted_identifier = ( - get_snowflake_quoted_identifier_for_resample_index_col(referenced_frame) - ) - if fill_method == "bfill": - # Snowflake recommends using 1582 as the smallest year for date or timestamp type - # due to limits on the Gregorian calendar. See https://docs.snowflake.com/en/sql-reference/data-types-datetime - interval_end_col = coalesce( - lag(col(interval_start_snowflake_quoted_identifier)).over( - Window.order_by(col(interval_start_snowflake_quoted_identifier).asc()) - ), - pandas_lit("1582-01-01 00:00:00"), - ) - else: - # Snowflake recommends using 9999 as the largest year for date or timestamp type - # due to limits on the Gregorian calendar. See https://docs.snowflake.com/en/sql-reference/data-types-datetime - assert fill_method == "ffill", "`fill_method` can only be 'bfill' or 'ffill'" - interval_end_col = coalesce( - lead(col(interval_start_snowflake_quoted_identifier)).over( - Window.order_by(col(interval_start_snowflake_quoted_identifier).asc()) - ), - pandas_lit("9999-12-31 23:59:59"), - ) - right_frame = referenced_frame.append_column( - interval_end_pandas_label, interval_end_col - ) - # right_frame: - # a interval_end_col - # __index__ - # 2023-01-03 01:00:00 1 2023-01-04 00:00:00 - # 2023-01-04 00:00:00 2 2023-01-05 23:00:00 - # 2023-01-05 23:00:00 3 2023-01-06 00:00:00 - # 2023-01-06 00:00:00 4 2023-01-07 02:00:00 - # 2023-01-07 02:00:00 NaN 2023-01-10 00:00:00 - # 2023-01-10 00:00:00 6 9999-01-01 00:00:00 - - # 2. Get the Snowflake identifiers needed for the join condition. - # interval_start_snowflake_quoted_identifier is needed as well, - # but has already been fetched above. left_timecol_snowflake_quoted_identifier = ( get_snowflake_quoted_identifier_for_resample_index_col(preserving_frame) ) - interval_end_snowflake_quoted_identifier = ( - right_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( - pandas_labels=[interval_end_pandas_label] - )[0][0] - ) - - # 3. Convert both preserved_frame and right_frame to Snowpark DataFrames to perform - # a non-equi-join. - left_snowpark_df = ( - preserving_frame.ordered_dataframe.to_projected_snowpark_dataframe() - ) - right_snowpark_df = right_frame.ordered_dataframe.to_projected_snowpark_dataframe() - - # 4. Join left_snowpark_df and right_snowpark_df using the following logic: - # For each element left_frame's __resample_index__, join it with a single row - # in right_frame whose __index__ value is less/greater than or equal to it and is closest in time. - # If a row cannot be found, pad the joined columns from right_frame with null. - if fill_method == "bfill": - on_expr = ( - left_snowpark_df[left_timecol_snowflake_quoted_identifier] - <= right_snowpark_df[interval_start_snowflake_quoted_identifier] - ) & ( - left_snowpark_df[left_timecol_snowflake_quoted_identifier] - > right_snowpark_df[interval_end_snowflake_quoted_identifier] - ) - else: - assert fill_method == "ffill", f"invalid fill_method {fill_method}" - on_expr = ( - left_snowpark_df[left_timecol_snowflake_quoted_identifier] - >= right_snowpark_df[interval_start_snowflake_quoted_identifier] - ) & ( - left_snowpark_df[left_timecol_snowflake_quoted_identifier] - < right_snowpark_df[interval_end_snowflake_quoted_identifier] - ) - joined_snowpark_df = left_snowpark_df.join( - right=right_snowpark_df, - on=on_expr, - how="left", + right_timecol_snowflake_quoted_identifier = ( + get_snowflake_quoted_identifier_for_resample_index_col(referenced_frame) ) - # joined_snowpark_df: - # - # __resample_index__ __index__ a interval_end_col - # 2023-01-03 00:00:00 NULL NULL NULL - # 2023-01-05 00:00:00 2023-01-04 00:00:00 2 2023-01-05 23:00:00 - # 2023-01-07 00:00:00 2023-01-06 00:00:00 4 2023-01-07 02:00:00 - # 2023-01-09 00:00:00 2023-01-07 02:00:00 NULL 2023-01-10 00:00:00 - - # 5. Construct a final result with correct frame metadata. - # a - # __resample_index__ - # 2023-01-03 00:00:00 NaN - # 2023-01-05 00:00:00 2 - # 2023-01-07 00:00:00 4 - # 2023-01-09 00:00:00 NaN - return InternalFrame.create( - ordered_dataframe=OrderedDataFrame(DataFrameReference(joined_snowpark_df)), - data_column_pandas_labels=referenced_frame.data_column_pandas_labels, - data_column_snowflake_quoted_identifiers=referenced_frame.data_column_snowflake_quoted_identifiers, - index_column_pandas_labels=referenced_frame.index_column_pandas_labels, - index_column_snowflake_quoted_identifiers=[ - left_timecol_snowflake_quoted_identifier - ], - data_column_pandas_index_names=referenced_frame.data_column_pandas_index_names, - data_column_types=referenced_frame.cached_data_column_snowpark_pandas_types, - index_column_types=referenced_frame.cached_index_column_snowpark_pandas_types, + output_frame, _ = join_utils.join( + left=preserving_frame, + right=referenced_frame, + how="asof", + left_on=[], + right_on=[], + left_match_col=left_timecol_snowflake_quoted_identifier, + right_match_col=right_timecol_snowflake_quoted_identifier, + match_comparator=( + MatchComparator.GREATER_THAN_OR_EQUAL_TO + if fill_method == "ffill" + else MatchComparator.LESS_THAN_OR_EQUAL_TO + ), + sort=True, ) + # output_frame: + # a + # __resample_index__ + # 2023-01-03 00:00:00 NULL + # 2023-01-05 00:00:00 2 + # 2023-01-07 00:00:00 4 + # 2023-01-09 00:00:00 NULL + return output_frame diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 17c0bb2dbe5..e5f514289e2 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -11628,8 +11628,11 @@ def resample( # The output frame's DatetimeIndex is identical to expected_frame's. For each date in the DatetimeIndex, # a single row is selected from the input frame, where its date is the closest match in time based on # the filling method. We perform an ASOF join to accomplish this. - frame = perform_asof_join_on_frame(expected_frame, frame, resample_method) - + index_name = frame.index_column_pandas_labels + output_frame = perform_asof_join_on_frame( + expected_frame, frame, resample_method + ) + return SnowflakeQueryCompiler(output_frame).set_index_names(index_name) elif resample_method in IMPLEMENTED_AGG_METHODS: frame = perform_resample_binning_on_frame(frame, start_date, rule) if resample_method == "size":