Skip to content

Commit

Permalink
SNOW-989398: Refactor resample.fillna to use ASOF Join (#2196)
Browse files Browse the repository at this point in the history
SNOW-989398

This PR refactors `resample.fillna()` to use the `ASOF` Join.

---------

Signed-off-by: Naren Krishna <[email protected]>
  • Loading branch information
sfc-gh-nkrishna authored Sep 4, 2024
1 parent e99a790 commit b80d0e6
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 121 deletions.
152 changes: 33 additions & 119 deletions src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,25 @@
from snowflake.snowpark.column import Column
from snowflake.snowpark.functions import (
builtin,
coalesce,
col,
dateadd,
datediff,
lag,
lead,
lit,
row_number,
to_timestamp_ntz,
)
from snowflake.snowpark.modin.plugin._internal import join_utils
from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame
from snowflake.snowpark.modin.plugin._internal.join_utils import InheritJoinIndex, join
from snowflake.snowpark.modin.plugin._internal.join_utils import (
InheritJoinIndex,
MatchComparator,
join,
)
from snowflake.snowpark.modin.plugin._internal.ordered_dataframe import (
DataFrameReference,
OrderedDataFrame,
)
from snowflake.snowpark.modin.plugin._internal.utils import (
generate_snowflake_quoted_identifiers_helper,
pandas_lit,
)
from snowflake.snowpark.modin.plugin.utils.error_message import ErrorMessage
from snowflake.snowpark.types import DateType, TimestampType
Expand Down Expand Up @@ -622,7 +622,6 @@ def fill_missing_resample_bins_for_frame(
)


# TODO: SNOW-989398 Migrate function to ASOF join
def perform_asof_join_on_frame(
preserving_frame: InternalFrame, referenced_frame: InternalFrame, fill_method: str
) -> InternalFrame:
Expand Down Expand Up @@ -652,7 +651,9 @@ def perform_asof_join_on_frame(
frame : InternalFrame
A new frame that holds the result of an ASOF join.
"""
# Consider the following example:
# Consider the following example where we want to perform an ASOF JOIN of preserving_frame
# and referenced_frame where __resample_index__ >= __index__ if forward fill
# or __resample_index__ <= __index__ if backward fill:
#
# preserved_frame:
# __resample_index__
Expand All @@ -671,119 +672,32 @@ def perform_asof_join_on_frame(
# 2023-01-07 02:00:00 NaN
# 2023-01-10 00:00:00 6

# We want to perform an ASOF JOIN of preserving_frame and referenced_frame. Here
# are the steps to take:

# 1. Construct right_frame using referenced_frame, which has a
# temporary column, interval_end_col, that holds the closest
# following timestamp to every value in __index__. The last value in
# interval_end_col is dummy value that represents the smallest or largest
# (e.g. bfill or ffill) possible date in Snowflake.
interval_end_pandas_label = "interval_end_col"
interval_start_snowflake_quoted_identifier = (
get_snowflake_quoted_identifier_for_resample_index_col(referenced_frame)
)
if fill_method == "bfill":
# Snowflake recommends using 1582 as the smallest year for date or timestamp type
# due to limits on the Gregorian calendar. See https://docs.snowflake.com/en/sql-reference/data-types-datetime
interval_end_col = coalesce(
lag(col(interval_start_snowflake_quoted_identifier)).over(
Window.order_by(col(interval_start_snowflake_quoted_identifier).asc())
),
pandas_lit("1582-01-01 00:00:00"),
)
else:
# Snowflake recommends using 9999 as the largest year for date or timestamp type
# due to limits on the Gregorian calendar. See https://docs.snowflake.com/en/sql-reference/data-types-datetime
assert fill_method == "ffill", "`fill_method` can only be 'bfill' or 'ffill'"
interval_end_col = coalesce(
lead(col(interval_start_snowflake_quoted_identifier)).over(
Window.order_by(col(interval_start_snowflake_quoted_identifier).asc())
),
pandas_lit("9999-12-31 23:59:59"),
)
right_frame = referenced_frame.append_column(
interval_end_pandas_label, interval_end_col
)
# right_frame:
# a interval_end_col
# __index__
# 2023-01-03 01:00:00 1 2023-01-04 00:00:00
# 2023-01-04 00:00:00 2 2023-01-05 23:00:00
# 2023-01-05 23:00:00 3 2023-01-06 00:00:00
# 2023-01-06 00:00:00 4 2023-01-07 02:00:00
# 2023-01-07 02:00:00 NaN 2023-01-10 00:00:00
# 2023-01-10 00:00:00 6 9999-01-01 00:00:00

# 2. Get the Snowflake identifiers needed for the join condition.
# interval_start_snowflake_quoted_identifier is needed as well,
# but has already been fetched above.
left_timecol_snowflake_quoted_identifier = (
get_snowflake_quoted_identifier_for_resample_index_col(preserving_frame)
)
interval_end_snowflake_quoted_identifier = (
right_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels(
pandas_labels=[interval_end_pandas_label]
)[0][0]
)

# 3. Convert both preserved_frame and right_frame to Snowpark DataFrames to perform
# a non-equi-join.
left_snowpark_df = (
preserving_frame.ordered_dataframe.to_projected_snowpark_dataframe()
)
right_snowpark_df = right_frame.ordered_dataframe.to_projected_snowpark_dataframe()

# 4. Join left_snowpark_df and right_snowpark_df using the following logic:
# For each element left_frame's __resample_index__, join it with a single row
# in right_frame whose __index__ value is less/greater than or equal to it and is closest in time.
# If a row cannot be found, pad the joined columns from right_frame with null.
if fill_method == "bfill":
on_expr = (
left_snowpark_df[left_timecol_snowflake_quoted_identifier]
<= right_snowpark_df[interval_start_snowflake_quoted_identifier]
) & (
left_snowpark_df[left_timecol_snowflake_quoted_identifier]
> right_snowpark_df[interval_end_snowflake_quoted_identifier]
)
else:
assert fill_method == "ffill", f"invalid fill_method {fill_method}"
on_expr = (
left_snowpark_df[left_timecol_snowflake_quoted_identifier]
>= right_snowpark_df[interval_start_snowflake_quoted_identifier]
) & (
left_snowpark_df[left_timecol_snowflake_quoted_identifier]
< right_snowpark_df[interval_end_snowflake_quoted_identifier]
)
joined_snowpark_df = left_snowpark_df.join(
right=right_snowpark_df,
on=on_expr,
how="left",
right_timecol_snowflake_quoted_identifier = (
get_snowflake_quoted_identifier_for_resample_index_col(referenced_frame)
)
# joined_snowpark_df:
#
# __resample_index__ __index__ a interval_end_col
# 2023-01-03 00:00:00 NULL NULL NULL
# 2023-01-05 00:00:00 2023-01-04 00:00:00 2 2023-01-05 23:00:00
# 2023-01-07 00:00:00 2023-01-06 00:00:00 4 2023-01-07 02:00:00
# 2023-01-09 00:00:00 2023-01-07 02:00:00 NULL 2023-01-10 00:00:00

# 5. Construct a final result with correct frame metadata.
# a
# __resample_index__
# 2023-01-03 00:00:00 NaN
# 2023-01-05 00:00:00 2
# 2023-01-07 00:00:00 4
# 2023-01-09 00:00:00 NaN
return InternalFrame.create(
ordered_dataframe=OrderedDataFrame(DataFrameReference(joined_snowpark_df)),
data_column_pandas_labels=referenced_frame.data_column_pandas_labels,
data_column_snowflake_quoted_identifiers=referenced_frame.data_column_snowflake_quoted_identifiers,
index_column_pandas_labels=referenced_frame.index_column_pandas_labels,
index_column_snowflake_quoted_identifiers=[
left_timecol_snowflake_quoted_identifier
],
data_column_pandas_index_names=referenced_frame.data_column_pandas_index_names,
data_column_types=referenced_frame.cached_data_column_snowpark_pandas_types,
index_column_types=referenced_frame.cached_index_column_snowpark_pandas_types,
output_frame, _ = join_utils.join(
left=preserving_frame,
right=referenced_frame,
how="asof",
left_on=[],
right_on=[],
left_match_col=left_timecol_snowflake_quoted_identifier,
right_match_col=right_timecol_snowflake_quoted_identifier,
match_comparator=(
MatchComparator.GREATER_THAN_OR_EQUAL_TO
if fill_method == "ffill"
else MatchComparator.LESS_THAN_OR_EQUAL_TO
),
sort=True,
)
# output_frame:
# a
# __resample_index__
# 2023-01-03 00:00:00 NULL
# 2023-01-05 00:00:00 2
# 2023-01-07 00:00:00 4
# 2023-01-09 00:00:00 NULL
return output_frame
Original file line number Diff line number Diff line change
Expand Up @@ -11628,8 +11628,11 @@ def resample(
# The output frame's DatetimeIndex is identical to expected_frame's. For each date in the DatetimeIndex,
# a single row is selected from the input frame, where its date is the closest match in time based on
# the filling method. We perform an ASOF join to accomplish this.
frame = perform_asof_join_on_frame(expected_frame, frame, resample_method)

index_name = frame.index_column_pandas_labels
output_frame = perform_asof_join_on_frame(
expected_frame, frame, resample_method
)
return SnowflakeQueryCompiler(output_frame).set_index_names(index_name)
elif resample_method in IMPLEMENTED_AGG_METHODS:
frame = perform_resample_binning_on_frame(frame, start_date, rule)
if resample_method == "size":
Expand Down

0 comments on commit b80d0e6

Please sign in to comment.