Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-989398: Refactor resample.fillna to use ASOF Join #2196

Merged
merged 11 commits into from
Sep 4, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@

- Refactored `quoted_identifier_to_snowflake_type` to avoid making metadata queries if the types have been cached locally.
- Improved `pd.to_datetime` to handle all local input cases.
- Refactored `resample.fillna` implementation to use `OrderedDataFrame` join utility.

#### Bug Fixes

Expand Down
8 changes: 7 additions & 1 deletion src/snowflake/snowpark/modin/plugin/_internal/join_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class MatchComparator(Enum):
GREATER_THAN = "__gt__"
LESS_THAN_OR_EQUAL_TO = "__le__"
LESS_THAN = "__lt__"
EQUAL_NULL = "equal_null"


class InheritJoinIndex(IntFlag):
Expand Down Expand Up @@ -109,6 +110,7 @@ def join(
how: JoinTypeLit,
left_on: list[str],
right_on: list[str],
on_comparators: Optional[list[MatchComparator]] = None,
left_match_col: Optional[str] = None,
right_match_col: Optional[str] = None,
match_comparator: Optional[MatchComparator] = None,
Expand All @@ -126,11 +128,14 @@ def join(
left_on: List of snowflake identifiers to join on from 'left' frame.
right_on: List of snowflake identifiers to join on from 'right' frame.
left_on and right_on must be lists of equal length.
on_comparators: list of MatchComparator {"__ge__", "__gt__", "__le__", "__lt__", "equal_null"}
Comparing the 'left_on' and 'right_on' columns. Defaults to list of "equal_null"
of the same length as 'left_on' and 'right_on'.
left_match_col: Snowflake identifier to match condition on from 'left' frame.
Only applicable for 'asof' join.
right_match_col: Snowflake identifier to match condition on from 'right' frame.
Only applicable for 'asof' join.
match_comparator: MatchComparator {"__ge__", "__gt__", "__le__", "__lt__"}
match_comparator: MatchComparator {"__ge__", "__gt__", "__le__", "__lt__", "equal_null"}
Only applicable for 'asof' join, the operation to compare 'left_match_condition'
and 'right_match_condition'.
sort: If True order merged frame on join keys. If False, ordering behavior
Expand Down Expand Up @@ -205,6 +210,7 @@ def assert_snowpark_pandas_types_match() -> None:
right=right.ordered_dataframe,
left_on_cols=left_on,
right_on_cols=right_on,
on_comparators=on_comparators,
left_match_col=left_match_col,
right_match_col=right_match_col,
match_comparator=match_comparator,
Expand Down
27 changes: 22 additions & 5 deletions src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,7 @@ def join(
right: "OrderedDataFrame",
left_on_cols: Optional[list[str]] = None,
right_on_cols: Optional[list[str]] = None,
on_comparators: Optional[list["MatchComparator"]] = None, # type: ignore[name-defined] # noqa: F821
sfc-gh-nkrishna marked this conversation as resolved.
Show resolved Hide resolved
left_match_col: Optional[str] = None,
right_match_col: Optional[str] = None,
match_comparator: Optional[ # type: ignore[name-defined]
Expand All @@ -1073,11 +1074,14 @@ def join(
right: The other OrderedDataFrame to join.
left_on_cols: A list of column names from self OrderedDataFrame to be used for the join.
right_on_cols: A list of column names from right OrderedDataFrame to be used for the join.
on_comparators: list of MatchComparator {"__ge__", "__gt__", "__le__", "__lt__", "equal_null"}
Comparing the 'left_on' and 'right_on' columns. Defaults to list of "equal_null"
of the same length as 'left_on' and 'right_on'.
left_match_col: Snowflake identifier to match condition on from 'left' frame.
Only applicable for 'asof' join.
right_match_col: Snowflake identifier to match condition on from 'right' frame.
Only applicable for 'asof' join.
match_comparator: MatchComparator {"__ge__", "__gt__", "__le__", "__lt__"}
match_comparator: MatchComparator {"__ge__", "__gt__", "__le__", "__lt__", "equal_null"}
Only applicable for 'asof' join, the operation to compare 'left_match_condition'
and 'right_match_condition'.
how: We support the following join types:
Expand Down Expand Up @@ -1197,11 +1201,24 @@ def join(
# get the new mapped right on identifier
right_on_cols = [right_identifiers_rename_map[key] for key in right_on_cols]

# Generate sql ON clause 'EQUAL_NULL(col1, col2) and EQUAL_NULL(col3, col4) ...'
on = None
for left_col, right_col in zip(left_on_cols, right_on_cols):
eq = Column(left_col).equal_null(Column(right_col))
on = eq if on is None else on & eq

from snowflake.snowpark.modin.plugin._internal.join_utils import MatchComparator

# Use EQUAL_NULL as default to compare left and right "on" columns
on_comparators = (
sfc-gh-nkrishna marked this conversation as resolved.
Show resolved Hide resolved
[MatchComparator.EQUAL_NULL] * len(left_on_cols)
if not on_comparators
else on_comparators
)
# Generate sql ON clause comparing left and right columns
for left_col, right_col, on_comparator in zip(
left_on_cols, right_on_cols, on_comparators
):
column_comparison = getattr(Column(left_col), on_comparator.value)(
Column(right_col)
)
on = column_comparison if on is None else on & column_comparison

if how == "asof":
assert left_match_col, "left_match_col was not provided to ASOF Join"
Expand Down
91 changes: 36 additions & 55 deletions src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,14 @@
row_number,
to_timestamp_ntz,
)
from snowflake.snowpark.modin.plugin._internal import join_utils
from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame
from snowflake.snowpark.modin.plugin._internal.join_utils import InheritJoinIndex, join
from snowflake.snowpark.modin.plugin._internal.join_utils import (
InheritJoinIndex,
JoinKeyCoalesceConfig,
MatchComparator,
join,
)
from snowflake.snowpark.modin.plugin._internal.ordered_dataframe import (
DataFrameReference,
OrderedDataFrame,
Expand Down Expand Up @@ -727,63 +733,38 @@ def perform_asof_join_on_frame(
)[0][0]
)

# 3. Convert both preserved_frame and right_frame to Snowpark DataFrames to perform
# a non-equi-join.
left_snowpark_df = (
preserving_frame.ordered_dataframe.to_projected_snowpark_dataframe()
)
right_snowpark_df = right_frame.ordered_dataframe.to_projected_snowpark_dataframe()

# 4. Join left_snowpark_df and right_snowpark_df using the following logic:
# 3. Join left_snowpark_df and right_snowpark_df using the following logic:
# For each element left_frame's __resample_index__, join it with a single row
# in right_frame whose __index__ value is less/greater than or equal to it and is closest in time.
# If a row cannot be found, pad the joined columns from right_frame with null.
if fill_method == "bfill":
on_expr = (
left_snowpark_df[left_timecol_snowflake_quoted_identifier]
sfc-gh-nkrishna marked this conversation as resolved.
Show resolved Hide resolved
<= right_snowpark_df[interval_start_snowflake_quoted_identifier]
) & (
left_snowpark_df[left_timecol_snowflake_quoted_identifier]
> right_snowpark_df[interval_end_snowflake_quoted_identifier]
)
else:
assert fill_method == "ffill", f"invalid fill_method {fill_method}"
on_expr = (
left_snowpark_df[left_timecol_snowflake_quoted_identifier]
>= right_snowpark_df[interval_start_snowflake_quoted_identifier]
) & (
left_snowpark_df[left_timecol_snowflake_quoted_identifier]
< right_snowpark_df[interval_end_snowflake_quoted_identifier]
)
joined_snowpark_df = left_snowpark_df.join(
right=right_snowpark_df,
on=on_expr,
output_frame, _ = join_utils.join(
left=preserving_frame,
right=right_frame,
how="left",
)
# joined_snowpark_df:
#
# __resample_index__ __index__ a interval_end_col
# 2023-01-03 00:00:00 NULL NULL NULL
# 2023-01-05 00:00:00 2023-01-04 00:00:00 2 2023-01-05 23:00:00
# 2023-01-07 00:00:00 2023-01-06 00:00:00 4 2023-01-07 02:00:00
# 2023-01-09 00:00:00 2023-01-07 02:00:00 NULL 2023-01-10 00:00:00

# 5. Construct a final result with correct frame metadata.
# a
# __resample_index__
# 2023-01-03 00:00:00 NaN
# 2023-01-05 00:00:00 2
# 2023-01-07 00:00:00 4
# 2023-01-09 00:00:00 NaN
return InternalFrame.create(
ordered_dataframe=OrderedDataFrame(DataFrameReference(joined_snowpark_df)),
data_column_pandas_labels=referenced_frame.data_column_pandas_labels,
data_column_snowflake_quoted_identifiers=referenced_frame.data_column_snowflake_quoted_identifiers,
index_column_pandas_labels=referenced_frame.index_column_pandas_labels,
index_column_snowflake_quoted_identifiers=[
left_timecol_snowflake_quoted_identifier
left_on=[
left_timecol_snowflake_quoted_identifier,
left_timecol_snowflake_quoted_identifier,
],
right_on=[
interval_start_snowflake_quoted_identifier,
interval_end_snowflake_quoted_identifier,
],
on_comparators=(
[MatchComparator.GREATER_THAN_OR_EQUAL_TO, MatchComparator.LESS_THAN]
if fill_method == "ffill"
else [MatchComparator.LESS_THAN_OR_EQUAL_TO, MatchComparator.GREATER_THAN]
),
sort=True,
join_key_coalesce_config=[
JoinKeyCoalesceConfig.LEFT,
JoinKeyCoalesceConfig.LEFT,
],
data_column_pandas_index_names=referenced_frame.data_column_pandas_index_names,
data_column_types=referenced_frame.cached_data_column_snowpark_pandas_types,
index_column_types=referenced_frame.cached_index_column_snowpark_pandas_types,
)
# output_frame:
# a
# __resample_index__
# 2023-01-03 00:00:00 NULL
# 2023-01-05 00:00:00 2
# 2023-01-07 00:00:00 4
# 2023-01-09 00:00:00 NULL
return output_frame
Original file line number Diff line number Diff line change
Expand Up @@ -11359,8 +11359,11 @@ def resample(
# The output frame's DatetimeIndex is identical to expected_frame's. For each date in the DatetimeIndex,
# a single row is selected from the input frame, where its date is the closest match in time based on
# the filling method. We perform an ASOF join to accomplish this.
frame = perform_asof_join_on_frame(expected_frame, frame, resample_method)

index_name = frame.index_column_pandas_labels
output_frame = perform_asof_join_on_frame(
expected_frame, frame, resample_method
)
return SnowflakeQueryCompiler(output_frame).set_index_names(index_name)
elif resample_method in IMPLEMENTED_AGG_METHODS:
frame = perform_resample_binning_on_frame(frame, start_date, rule)
if resample_method == "size":
Expand Down
Loading