Skip to content

Commit

Permalink
SNOW-1662105, SNOW-1662657: Support by, left_by, right_by for `…
Browse files Browse the repository at this point in the history
…pd.merge_asof` (#2284)

SNOW-1662105, SNOW-1662657

This PR refactors `join_utils.py` to make `left_on` and `right_on`
optional arguments, as they are not required for "cross" or "asof"
joins. It also support `by`, `left_by`, `right_by` for `pd.merge_asof`.

---------

Signed-off-by: Naren Krishna <[email protected]>
  • Loading branch information
sfc-gh-nkrishna authored Sep 13, 2024
1 parent 08ff293 commit a3586c8
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 111 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

- Added support for `TimedeltaIndex.mean` method.
- Added support for some cases of aggregating `Timedelta` columns on `axis=0` with `agg` or `aggregate`.
- Added support for `by`, `left_by`, and `right_by` for `pd.merge_asof`.


## 1.22.1 (2024-09-11)
Expand Down
3 changes: 1 addition & 2 deletions docs/source/modin/supported/general_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ Data manipulations
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``merge`` | P | ``validate`` | ``N`` if param ``validate`` is given |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``merge_asof`` | P | ``by``, ``left_by``, ``right_by``| ``N`` if param ``direction`` is ``nearest``. |
| | | , ``left_index``, ``right_index``| |
| ``merge_asof`` | P | ``left_index``, ``right_index``, | ``N`` if param ``direction`` is ``nearest``. |
| | | , ``suffixes``, ``tolerance`` | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``merge_ordered`` | N | | |
Expand Down
2 changes: 0 additions & 2 deletions src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,6 @@ def compute_bin_indices(
values_frame,
cuts_frame,
how="asof",
left_on=[],
right_on=[],
left_match_col=values_frame.data_column_snowflake_quoted_identifiers[0],
right_match_col=cuts_frame.data_column_snowflake_quoted_identifiers[0],
match_comparator=MatchComparator.LESS_THAN_OR_EQUAL_TO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -584,8 +584,6 @@ def _get_adjusted_key_frame_by_row_pos_int_frame(
key,
count_frame,
"cross",
left_on=[],
right_on=[],
inherit_join_index=InheritJoinIndex.FROM_LEFT,
)

Expand Down
123 changes: 89 additions & 34 deletions src/snowflake/snowpark/modin/plugin/_internal/join_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,57 @@ class JoinOrAlignInternalFrameResult(NamedTuple):
result_column_mapper: JoinOrAlignResultColumnMapper


def assert_snowpark_pandas_types_match(
left: InternalFrame,
right: InternalFrame,
left_join_identifiers: list[str],
right_join_identifiers: list[str],
) -> None:
"""
If Snowpark pandas types do not match for the given identifiers, then a ValueError will be raised.
Args:
left: An internal frame to use on left side of join.
right: An internal frame to use on right side of join.
left_join_identifiers: List of snowflake identifiers to check types from 'left' frame.
right_join_identifiers: List of snowflake identifiers to check types from 'right' frame.
left_identifiers and right_identifiers must be lists of equal length.
Returns: None
Raises: ValueError
"""
left_types = [
left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
for id in left_join_identifiers
]
right_types = [
right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
for id in right_join_identifiers
]
for i, (lt, rt) in enumerate(zip(left_types, right_types)):
if lt != rt:
left_on_id = left_join_identifiers[i]
idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id)
key = left.data_column_pandas_labels[idx]
lt = lt if lt is not None else left.get_snowflake_type(left_on_id)
rt = (
rt
if rt is not None
else right.get_snowflake_type(right_join_identifiers[i])
)
raise ValueError(
f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. "
f"If you wish to proceed you should use pd.concat"
)


def join(
left: InternalFrame,
right: InternalFrame,
how: JoinTypeLit,
left_on: list[str],
right_on: list[str],
left_on: Optional[list[str]] = None,
right_on: Optional[list[str]] = None,
left_match_col: Optional[str] = None,
right_match_col: Optional[str] = None,
match_comparator: Optional[MatchComparator] = None,
Expand Down Expand Up @@ -161,40 +206,48 @@ def join(
include mapping for index + data columns, ordering columns and row position column
if exists.
"""
assert len(left_on) == len(
right_on
), "left_on and right_on must be of same length or both be None"
if join_key_coalesce_config is not None:
assert len(join_key_coalesce_config) == len(
left_on
), "join_key_coalesce_config must be of same length as left_on and right_on"
assert how in get_args(
JoinTypeLit
), f"Invalid join type: {how}. Allowed values are {get_args(JoinTypeLit)}"

def assert_snowpark_pandas_types_match() -> None:
"""If Snowpark pandas types do not match, then a ValueError will be raised."""
left_types = [
left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
for id in left_on
]
right_types = [
right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
for id in right_on
]
for i, (lt, rt) in enumerate(zip(left_types, right_types)):
if lt != rt:
left_on_id = left_on[i]
idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id)
key = left.data_column_pandas_labels[idx]
lt = lt if lt is not None else left.get_snowflake_type(left_on_id)
rt = rt if rt is not None else right.get_snowflake_type(right_on[i])
raise ValueError(
f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. "
f"If you wish to proceed you should use pd.concat"
)
left_on = left_on or []
right_on = right_on or []
assert len(left_on) == len(
right_on
), "left_on and right_on must be of same length or both be None"

assert_snowpark_pandas_types_match()
if how == "asof":
assert (
left_match_col
), "ASOF join was not provided a column identifier to match on for the left table"
assert (
right_match_col
), "ASOF join was not provided a column identifier to match on for the right table"
assert (
match_comparator
), "ASOF join was not provided a comparator for the match condition"
left_join_key = [left_match_col]
right_join_key = [right_match_col]
left_join_key.extend(left_on)
right_join_key.extend(right_on)
if join_key_coalesce_config is not None:
assert len(join_key_coalesce_config) == len(
left_join_key
), "ASOF join join_key_coalesce_config must be of same length as left_join_key and right_join_key"
else:
left_join_key = left_on
right_join_key = right_on
assert (
left_match_col is None
and right_match_col is None
and match_comparator is None
), f"match condition should not be provided for {how} join"
if join_key_coalesce_config is not None:
assert len(join_key_coalesce_config) == len(
left_join_key
), "join_key_coalesce_config must be of same length as left_on and right_on"

assert_snowpark_pandas_types_match(left, right, left_join_key, right_join_key)

# Re-project the active columns to make sure all active columns of the internal frame participate
# in the join operation, and unnecessary columns are dropped from the projected columns.
Expand All @@ -210,14 +263,13 @@ def assert_snowpark_pandas_types_match() -> None:
match_comparator=match_comparator,
how=how,
)

return _create_internal_frame_with_join_or_align_result(
joined_ordered_dataframe,
left,
right,
how,
left_on,
right_on,
left_join_key,
right_join_key,
sort,
join_key_coalesce_config,
inherit_join_index,
Expand Down Expand Up @@ -1402,6 +1454,9 @@ def _sort_on_join_keys(self) -> None:
)
elif self._how == "right":
ordering_column_identifiers = mapped_right_on
elif self._how == "asof":
# Order only by the left match_condition column
ordering_column_identifiers = [mapped_left_on[0]]
else: # left join, inner join, left align, coalesce align
ordering_column_identifiers = mapped_left_on

Expand Down
31 changes: 22 additions & 9 deletions src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,22 +1197,29 @@ def join(
# get the new mapped right on identifier
right_on_cols = [right_identifiers_rename_map[key] for key in right_on_cols]

# Generate sql ON clause 'EQUAL_NULL(col1, col2) and EQUAL_NULL(col3, col4) ...'
on = None
for left_col, right_col in zip(left_on_cols, right_on_cols):
eq = Column(left_col).equal_null(Column(right_col))
on = eq if on is None else on & eq

if how == "asof":
assert left_match_col, "left_match_col was not provided to ASOF Join"
assert (
left_match_col
), "ASOF join was not provided a column identifier to match on for the left table"
left_match_col = Column(left_match_col)
# Get the new mapped right match condition identifier
assert right_match_col, "right_match_col was not provided to ASOF Join"
assert (
right_match_col
), "ASOF join was not provided a column identifier to match on for the right table"
right_match_col = Column(right_identifiers_rename_map[right_match_col])
# ASOF Join requires the use of match_condition
assert match_comparator, "match_comparator was not provided to ASOF Join"
assert (
match_comparator
), "ASOF join was not provided a comparator for the match condition"

on = None
for left_col, right_col in zip(left_on_cols, right_on_cols):
eq = Column(left_col).__eq__(Column(right_col))
on = eq if on is None else on & eq

snowpark_dataframe = left_snowpark_dataframe_ref.snowpark_dataframe.join(
right=right_snowpark_dataframe_ref.snowpark_dataframe,
on=on,
how=how,
match_condition=getattr(left_match_col, match_comparator.value)(
right_match_col
Expand All @@ -1224,6 +1231,12 @@ def join(
right_snowpark_dataframe_ref.snowpark_dataframe, how=how
)
else:
# Generate sql ON clause 'EQUAL_NULL(col1, col2) and EQUAL_NULL(col3, col4) ...'
on = None
for left_col, right_col in zip(left_on_cols, right_on_cols):
eq = Column(left_col).equal_null(Column(right_col))
on = eq if on is None else on & eq

snowpark_dataframe = left_snowpark_dataframe_ref.snowpark_dataframe.join(
right_snowpark_dataframe_ref.snowpark_dataframe, on, how
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,6 @@ def perform_asof_join_on_frame(
left=preserving_frame,
right=referenced_frame,
how="asof",
left_on=[],
right_on=[],
left_match_col=left_timecol_snowflake_quoted_identifier,
right_match_col=right_timecol_snowflake_quoted_identifier,
match_comparator=(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7381,28 +7381,34 @@ def merge_asof(
SnowflakeQueryCompiler
"""
# TODO: SNOW-1634547: Implement remaining parameters by leveraging `merge` implementation
if (
by
or left_by
or right_by
or left_index
or right_index
or tolerance
or suffixes != ("_x", "_y")
):
if left_index or right_index or tolerance or suffixes != ("_x", "_y"):
ErrorMessage.not_implemented(
"Snowpark pandas merge_asof method does not currently support parameters "
+ "'by', 'left_by', 'right_by', 'left_index', 'right_index', "
+ "'suffixes', or 'tolerance'"
+ "'left_index', 'right_index', 'suffixes', or 'tolerance'"
)
if direction not in ("backward", "forward"):
ErrorMessage.not_implemented(
"Snowpark pandas merge_asof method only supports directions 'forward' and 'backward'"
)

if direction == "backward":
match_comparator = (
MatchComparator.GREATER_THAN_OR_EQUAL_TO
if allow_exact_matches
else MatchComparator.GREATER_THAN
)
else:
match_comparator = (
MatchComparator.LESS_THAN_OR_EQUAL_TO
if allow_exact_matches
else MatchComparator.LESS_THAN
)

left_frame = self._modin_frame
right_frame = right._modin_frame
left_keys, right_keys = join_utils.get_join_keys(
# Get the left and right matching key and quoted identifier corresponding to the match_condition
# There will only be matching key/identifier for each table as there is only a single match condition
left_match_keys, right_match_keys = join_utils.get_join_keys(
left=left_frame,
right=right_frame,
on=on,
Expand All @@ -7411,42 +7417,62 @@ def merge_asof(
left_index=left_index,
right_index=right_index,
)
left_match_col = (
left_match_identifier = (
left_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels(
left_keys
left_match_keys
)[0][0]
)
right_match_col = (
right_match_identifier = (
right_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels(
right_keys
right_match_keys
)[0][0]
)

if direction == "backward":
match_comparator = (
MatchComparator.GREATER_THAN_OR_EQUAL_TO
if allow_exact_matches
else MatchComparator.GREATER_THAN
coalesce_config = join_utils.get_coalesce_config(
left_keys=left_match_keys,
right_keys=right_match_keys,
external_join_keys=[],
)

# Get the left and right matching keys and quoted identifiers corresponding to the 'on' condition
if by or (left_by and right_by):
left_on_keys, right_on_keys = join_utils.get_join_keys(
left=left_frame,
right=right_frame,
on=by,
left_on=left_by,
right_on=right_by,
)
left_on_identifiers = [
ids[0]
for ids in left_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels(
left_on_keys
)
]
right_on_identifiers = [
ids[0]
for ids in right_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels(
right_on_keys
)
]
coalesce_config.extend(
join_utils.get_coalesce_config(
left_keys=left_on_keys,
right_keys=right_on_keys,
external_join_keys=[],
)
)
else:
match_comparator = (
MatchComparator.LESS_THAN_OR_EQUAL_TO
if allow_exact_matches
else MatchComparator.LESS_THAN
)

coalesce_config = join_utils.get_coalesce_config(
left_keys=left_keys, right_keys=right_keys, external_join_keys=[]
)
left_on_identifiers = []
right_on_identifiers = []

joined_frame, _ = join_utils.join(
left=left_frame,
right=right_frame,
left_on=left_on_identifiers,
right_on=right_on_identifiers,
how="asof",
left_on=[left_match_col],
right_on=[right_match_col],
left_match_col=left_match_col,
right_match_col=right_match_col,
left_match_col=left_match_identifier,
right_match_col=right_match_identifier,
match_comparator=match_comparator,
join_key_coalesce_config=coalesce_config,
sort=True,
Expand Down
Loading

0 comments on commit a3586c8

Please sign in to comment.