Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1662105, SNOW-1662657: Support by, left_by, right_by for pd.merge_asof #2284

Merged
merged 7 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#### New Features

- Added support for `TimedeltaIndex.mean` method.
- Added support for `by`, `left_by`, and `right_by` for `pd.merge_asof`.

## 1.22.0 (2024-09-10)

Expand Down
3 changes: 1 addition & 2 deletions docs/source/modin/supported/general_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ Data manipulations
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``merge`` | P | ``validate`` | ``N`` if param ``validate`` is given |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``merge_asof`` | P | ``by``, ``left_by``, ``right_by``| ``N`` if param ``direction`` is ``nearest``. |
| | | , ``left_index``, ``right_index``| |
| ``merge_asof`` | P | ``left_index``, ``right_index``, | ``N`` if param ``direction`` is ``nearest``. |
| | | , ``suffixes``, ``tolerance`` | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``merge_ordered`` | N | | |
Expand Down
2 changes: 0 additions & 2 deletions src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,6 @@ def compute_bin_indices(
values_frame,
cuts_frame,
how="asof",
left_on=[],
right_on=[],
left_match_col=values_frame.data_column_snowflake_quoted_identifiers[0],
right_match_col=cuts_frame.data_column_snowflake_quoted_identifiers[0],
match_comparator=MatchComparator.LESS_THAN_OR_EQUAL_TO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -584,8 +584,6 @@ def _get_adjusted_key_frame_by_row_pos_int_frame(
key,
count_frame,
"cross",
left_on=[],
right_on=[],
inherit_join_index=InheritJoinIndex.FROM_LEFT,
)

Expand Down
123 changes: 89 additions & 34 deletions src/snowflake/snowpark/modin/plugin/_internal/join_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,57 @@ class JoinOrAlignInternalFrameResult(NamedTuple):
result_column_mapper: JoinOrAlignResultColumnMapper


def assert_snowpark_pandas_types_match(
left: InternalFrame,
right: InternalFrame,
left_join_identifiers: list[str],
right_join_identifiers: list[str],
) -> None:
"""
sfc-gh-nkrishna marked this conversation as resolved.
Show resolved Hide resolved
If Snowpark pandas types do not match for the given identifiers, then a ValueError will be raised.

Args:
left: An internal frame to use on left side of join.
right: An internal frame to use on right side of join.
left_join_identifiers: List of snowflake identifiers to check types from 'left' frame.
right_join_identifiers: List of snowflake identifiers to check types from 'right' frame.
left_identifiers and right_identifiers must be lists of equal length.

Returns: None

Raises: ValueError
"""
left_types = [
left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
for id in left_join_identifiers
]
right_types = [
right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
for id in right_join_identifiers
]
for i, (lt, rt) in enumerate(zip(left_types, right_types)):
if lt != rt:
left_on_id = left_join_identifiers[i]
idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id)
key = left.data_column_pandas_labels[idx]
lt = lt if lt is not None else left.get_snowflake_type(left_on_id)
rt = (
rt
if rt is not None
else right.get_snowflake_type(right_join_identifiers[i])
)
raise ValueError(
f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. "
f"If you wish to proceed you should use pd.concat"
)


def join(
left: InternalFrame,
right: InternalFrame,
how: JoinTypeLit,
left_on: list[str],
right_on: list[str],
left_on: Optional[list[str]] = None,
right_on: Optional[list[str]] = None,
left_match_col: Optional[str] = None,
right_match_col: Optional[str] = None,
match_comparator: Optional[MatchComparator] = None,
Expand Down Expand Up @@ -161,40 +206,48 @@ def join(
include mapping for index + data columns, ordering columns and row position column
if exists.
"""
assert len(left_on) == len(
right_on
), "left_on and right_on must be of same length or both be None"
if join_key_coalesce_config is not None:
assert len(join_key_coalesce_config) == len(
left_on
), "join_key_coalesce_config must be of same length as left_on and right_on"
assert how in get_args(
JoinTypeLit
), f"Invalid join type: {how}. Allowed values are {get_args(JoinTypeLit)}"

def assert_snowpark_pandas_types_match() -> None:
"""If Snowpark pandas types do not match, then a ValueError will be raised."""
left_types = [
left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
for id in left_on
]
right_types = [
right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
for id in right_on
]
for i, (lt, rt) in enumerate(zip(left_types, right_types)):
if lt != rt:
left_on_id = left_on[i]
idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id)
key = left.data_column_pandas_labels[idx]
lt = lt if lt is not None else left.get_snowflake_type(left_on_id)
rt = rt if rt is not None else right.get_snowflake_type(right_on[i])
raise ValueError(
f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. "
f"If you wish to proceed you should use pd.concat"
)
left_on = left_on or []
right_on = right_on or []
assert len(left_on) == len(
right_on
), "left_on and right_on must be of same length or both be None"

assert_snowpark_pandas_types_match()
if how == "asof":
assert (
left_match_col
sfc-gh-nkrishna marked this conversation as resolved.
Show resolved Hide resolved
), "ASOF join was not provided a column identifier to match on for the left table"
assert (
right_match_col
), "ASOF join was not provided a column identifier to match on for the right table"
assert (
match_comparator
), "ASOF join was not provided a comparator for the match condition"
left_join_key = [left_match_col]
right_join_key = [right_match_col]
left_join_key.extend(left_on)
right_join_key.extend(right_on)
if join_key_coalesce_config is not None:
sfc-gh-nkrishna marked this conversation as resolved.
Show resolved Hide resolved
assert len(join_key_coalesce_config) == len(
left_join_key
), "ASOF join join_key_coalesce_config must be of same length as left_join_key and right_join_key"
assert_snowpark_pandas_types_match(left, right, left_join_key, right_join_key)
else:
left_join_key = left_on
right_join_key = right_on
assert (
left_match_col is None
and right_match_col is None
and match_comparator is None
), f"match condition should not be provided for {how} join"
if join_key_coalesce_config is not None:
assert len(join_key_coalesce_config) == len(
left_join_key
), "join_key_coalesce_config must be of same length as left_on and right_on"
sfc-gh-nkrishna marked this conversation as resolved.
Show resolved Hide resolved
assert_snowpark_pandas_types_match(left, right, left_join_key, right_join_key)
sfc-gh-nkrishna marked this conversation as resolved.
Show resolved Hide resolved

# Re-project the active columns to make sure all active columns of the internal frame participate
# in the join operation, and unnecessary columns are dropped from the projected columns.
Expand All @@ -210,14 +263,13 @@ def assert_snowpark_pandas_types_match() -> None:
match_comparator=match_comparator,
how=how,
)

return _create_internal_frame_with_join_or_align_result(
joined_ordered_dataframe,
left,
right,
how,
left_on,
right_on,
left_join_key,
sfc-gh-nkrishna marked this conversation as resolved.
Show resolved Hide resolved
right_join_key,
sort,
join_key_coalesce_config,
inherit_join_index,
Expand Down Expand Up @@ -1402,6 +1454,9 @@ def _sort_on_join_keys(self) -> None:
)
elif self._how == "right":
ordering_column_identifiers = mapped_right_on
elif self._how == "asof":
# Order only by the left match_condition column
ordering_column_identifiers = [mapped_left_on[0]]
else: # left join, inner join, left align, coalesce align
ordering_column_identifiers = mapped_left_on

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1197,22 +1197,29 @@ def join(
# get the new mapped right on identifier
right_on_cols = [right_identifiers_rename_map[key] for key in right_on_cols]

# Generate sql ON clause 'EQUAL_NULL(col1, col2) and EQUAL_NULL(col3, col4) ...'
on = None
for left_col, right_col in zip(left_on_cols, right_on_cols):
eq = Column(left_col).equal_null(Column(right_col))
on = eq if on is None else on & eq

if how == "asof":
assert left_match_col, "left_match_col was not provided to ASOF Join"
assert (
left_match_col
), "ASOF join was not provided a column identifier to match on for the left table"
left_match_col = Column(left_match_col)
# Get the new mapped right match condition identifier
assert right_match_col, "right_match_col was not provided to ASOF Join"
assert (
right_match_col
), "ASOF join was not provided a column identifier to match on for the right table"
right_match_col = Column(right_identifiers_rename_map[right_match_col])
# ASOF Join requires the use of match_condition
assert match_comparator, "match_comparator was not provided to ASOF Join"
assert (
match_comparator
), "ASOF join was not provided a comparator for the match condition"

on = None
for left_col, right_col in zip(left_on_cols, right_on_cols):
eq = Column(left_col).__eq__(Column(right_col))
sfc-gh-azhan marked this conversation as resolved.
Show resolved Hide resolved
on = eq if on is None else on & eq

snowpark_dataframe = left_snowpark_dataframe_ref.snowpark_dataframe.join(
right=right_snowpark_dataframe_ref.snowpark_dataframe,
on=on,
sfc-gh-nkrishna marked this conversation as resolved.
Show resolved Hide resolved
how=how,
match_condition=getattr(left_match_col, match_comparator.value)(
right_match_col
Expand All @@ -1224,6 +1231,12 @@ def join(
right_snowpark_dataframe_ref.snowpark_dataframe, how=how
)
else:
# Generate sql ON clause 'EQUAL_NULL(col1, col2) and EQUAL_NULL(col3, col4) ...'
on = None
for left_col, right_col in zip(left_on_cols, right_on_cols):
eq = Column(left_col).equal_null(Column(right_col))
on = eq if on is None else on & eq

snowpark_dataframe = left_snowpark_dataframe_ref.snowpark_dataframe.join(
right_snowpark_dataframe_ref.snowpark_dataframe, on, how
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,6 @@ def perform_asof_join_on_frame(
left=preserving_frame,
right=referenced_frame,
how="asof",
left_on=[],
right_on=[],
left_match_col=left_timecol_snowflake_quoted_identifier,
right_match_col=right_timecol_snowflake_quoted_identifier,
match_comparator=(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7395,28 +7395,34 @@ def merge_asof(
SnowflakeQueryCompiler
"""
# TODO: SNOW-1634547: Implement remaining parameters by leveraging `merge` implementation
if (
by
or left_by
or right_by
or left_index
or right_index
or tolerance
or suffixes != ("_x", "_y")
):
if left_index or right_index or tolerance or suffixes != ("_x", "_y"):
ErrorMessage.not_implemented(
"Snowpark pandas merge_asof method does not currently support parameters "
+ "'by', 'left_by', 'right_by', 'left_index', 'right_index', "
+ "'suffixes', or 'tolerance'"
+ "'left_index', 'right_index', 'suffixes', or 'tolerance'"
)
if direction not in ("backward", "forward"):
ErrorMessage.not_implemented(
"Snowpark pandas merge_asof method only supports directions 'forward' and 'backward'"
)

if direction == "backward":
match_comparator = (
MatchComparator.GREATER_THAN_OR_EQUAL_TO
if allow_exact_matches
else MatchComparator.GREATER_THAN
)
else:
match_comparator = (
MatchComparator.LESS_THAN_OR_EQUAL_TO
if allow_exact_matches
else MatchComparator.LESS_THAN
)

left_frame = self._modin_frame
right_frame = right._modin_frame
left_keys, right_keys = join_utils.get_join_keys(
# Get the left and right matching key and quoted identifier corresponding to the match_condition
# There will only be matching key/identifier for each table as there is only a single match condition
left_match_keys, right_match_keys = join_utils.get_join_keys(
left=left_frame,
right=right_frame,
on=on,
Expand All @@ -7425,42 +7431,62 @@ def merge_asof(
left_index=left_index,
right_index=right_index,
)
left_match_col = (
left_match_identifier = (
left_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels(
left_keys
left_match_keys
)[0][0]
)
right_match_col = (
right_match_identifier = (
right_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels(
right_keys
right_match_keys
)[0][0]
)

if direction == "backward":
match_comparator = (
MatchComparator.GREATER_THAN_OR_EQUAL_TO
if allow_exact_matches
else MatchComparator.GREATER_THAN
coalesce_config = join_utils.get_coalesce_config(
left_keys=left_match_keys,
right_keys=right_match_keys,
external_join_keys=[],
)

# Get the left and right matching keys and quoted identifiers corresponding to the 'on' condition
if by or (left_by and right_by):
left_on_keys, right_on_keys = join_utils.get_join_keys(
left=left_frame,
right=right_frame,
on=by,
left_on=left_by,
right_on=right_by,
)
left_on_identifiers = [
ids[0]
for ids in left_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels(
left_on_keys
)
]
right_on_identifiers = [
ids[0]
for ids in right_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels(
right_on_keys
)
]
coalesce_config.extend(
join_utils.get_coalesce_config(
left_keys=left_on_keys,
right_keys=right_on_keys,
external_join_keys=[],
)
)
else:
match_comparator = (
MatchComparator.LESS_THAN_OR_EQUAL_TO
if allow_exact_matches
else MatchComparator.LESS_THAN
)

coalesce_config = join_utils.get_coalesce_config(
left_keys=left_keys, right_keys=right_keys, external_join_keys=[]
)
left_on_identifiers = []
right_on_identifiers = []

joined_frame, _ = join_utils.join(
left=left_frame,
right=right_frame,
left_on=left_on_identifiers,
right_on=right_on_identifiers,
how="asof",
left_on=[left_match_col],
right_on=[right_match_col],
left_match_col=left_match_col,
right_match_col=right_match_col,
left_match_col=left_match_identifier,
right_match_col=right_match_identifier,
match_comparator=match_comparator,
join_key_coalesce_config=coalesce_config,
sort=True,
Expand Down
Loading
Loading