diff --git a/CHANGELOG.md b/CHANGELOG.md index d8efe82596f..0bd719dcb8c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ - Added support for `TimedeltaIndex.mean` method. - Added support for some cases of aggregating `Timedelta` columns on `axis=0` with `agg` or `aggregate`. +- Added support for `by`, `left_by`, and `right_by` for `pd.merge_asof`. ## 1.22.1 (2024-09-11) diff --git a/docs/source/modin/supported/general_supported.rst b/docs/source/modin/supported/general_supported.rst index 797ef3bbd59..95d9610202b 100644 --- a/docs/source/modin/supported/general_supported.rst +++ b/docs/source/modin/supported/general_supported.rst @@ -38,8 +38,7 @@ Data manipulations +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``merge`` | P | ``validate`` | ``N`` if param ``validate`` is given | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``merge_asof`` | P | ``by``, ``left_by``, ``right_by``| ``N`` if param ``direction`` is ``nearest``. | -| | | , ``left_index``, ``right_index``| | +| ``merge_asof`` | P | ``left_index``, ``right_index``, | ``N`` if param ``direction`` is ``nearest``. | | | | , ``suffixes``, ``tolerance`` | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``merge_ordered`` | N | | | diff --git a/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py index 882dc79d2a8..4eaf98d9b29 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/cut_utils.py @@ -189,8 +189,6 @@ def compute_bin_indices( values_frame, cuts_frame, how="asof", - left_on=[], - right_on=[], left_match_col=values_frame.data_column_snowflake_quoted_identifiers[0], right_match_col=cuts_frame.data_column_snowflake_quoted_identifiers[0], match_comparator=MatchComparator.LESS_THAN_OR_EQUAL_TO diff --git a/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py index c2c224e404c..6207bd2399a 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py @@ -584,8 +584,6 @@ def _get_adjusted_key_frame_by_row_pos_int_frame( key, count_frame, "cross", - left_on=[], - right_on=[], inherit_join_index=InheritJoinIndex.FROM_LEFT, ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py index 79f063b9ece..d07211dbcf5 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py @@ -103,12 +103,57 @@ class JoinOrAlignInternalFrameResult(NamedTuple): result_column_mapper: JoinOrAlignResultColumnMapper +def assert_snowpark_pandas_types_match( + left: InternalFrame, + right: InternalFrame, + left_join_identifiers: list[str], + right_join_identifiers: list[str], +) -> None: + """ + If Snowpark pandas types do not match for the given identifiers, then a ValueError will be raised. + + Args: + left: An internal frame to use on left side of join. + right: An internal frame to use on right side of join. + left_join_identifiers: List of snowflake identifiers to check types from 'left' frame. + right_join_identifiers: List of snowflake identifiers to check types from 'right' frame. + left_identifiers and right_identifiers must be lists of equal length. + + Returns: None + + Raises: ValueError + """ + left_types = [ + left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) + for id in left_join_identifiers + ] + right_types = [ + right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) + for id in right_join_identifiers + ] + for i, (lt, rt) in enumerate(zip(left_types, right_types)): + if lt != rt: + left_on_id = left_join_identifiers[i] + idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id) + key = left.data_column_pandas_labels[idx] + lt = lt if lt is not None else left.get_snowflake_type(left_on_id) + rt = ( + rt + if rt is not None + else right.get_snowflake_type(right_join_identifiers[i]) + ) + raise ValueError( + f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. " + f"If you wish to proceed you should use pd.concat" + ) + + def join( left: InternalFrame, right: InternalFrame, how: JoinTypeLit, - left_on: list[str], - right_on: list[str], + left_on: Optional[list[str]] = None, + right_on: Optional[list[str]] = None, left_match_col: Optional[str] = None, right_match_col: Optional[str] = None, match_comparator: Optional[MatchComparator] = None, @@ -161,40 +206,48 @@ def join( include mapping for index + data columns, ordering columns and row position column if exists. """ - assert len(left_on) == len( - right_on - ), "left_on and right_on must be of same length or both be None" - if join_key_coalesce_config is not None: - assert len(join_key_coalesce_config) == len( - left_on - ), "join_key_coalesce_config must be of same length as left_on and right_on" assert how in get_args( JoinTypeLit ), f"Invalid join type: {how}. Allowed values are {get_args(JoinTypeLit)}" - def assert_snowpark_pandas_types_match() -> None: - """If Snowpark pandas types do not match, then a ValueError will be raised.""" - left_types = [ - left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) - for id in left_on - ] - right_types = [ - right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) - for id in right_on - ] - for i, (lt, rt) in enumerate(zip(left_types, right_types)): - if lt != rt: - left_on_id = left_on[i] - idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id) - key = left.data_column_pandas_labels[idx] - lt = lt if lt is not None else left.get_snowflake_type(left_on_id) - rt = rt if rt is not None else right.get_snowflake_type(right_on[i]) - raise ValueError( - f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. " - f"If you wish to proceed you should use pd.concat" - ) + left_on = left_on or [] + right_on = right_on or [] + assert len(left_on) == len( + right_on + ), "left_on and right_on must be of same length or both be None" - assert_snowpark_pandas_types_match() + if how == "asof": + assert ( + left_match_col + ), "ASOF join was not provided a column identifier to match on for the left table" + assert ( + right_match_col + ), "ASOF join was not provided a column identifier to match on for the right table" + assert ( + match_comparator + ), "ASOF join was not provided a comparator for the match condition" + left_join_key = [left_match_col] + right_join_key = [right_match_col] + left_join_key.extend(left_on) + right_join_key.extend(right_on) + if join_key_coalesce_config is not None: + assert len(join_key_coalesce_config) == len( + left_join_key + ), "ASOF join join_key_coalesce_config must be of same length as left_join_key and right_join_key" + else: + left_join_key = left_on + right_join_key = right_on + assert ( + left_match_col is None + and right_match_col is None + and match_comparator is None + ), f"match condition should not be provided for {how} join" + if join_key_coalesce_config is not None: + assert len(join_key_coalesce_config) == len( + left_join_key + ), "join_key_coalesce_config must be of same length as left_on and right_on" + + assert_snowpark_pandas_types_match(left, right, left_join_key, right_join_key) # Re-project the active columns to make sure all active columns of the internal frame participate # in the join operation, and unnecessary columns are dropped from the projected columns. @@ -210,14 +263,13 @@ def assert_snowpark_pandas_types_match() -> None: match_comparator=match_comparator, how=how, ) - return _create_internal_frame_with_join_or_align_result( joined_ordered_dataframe, left, right, how, - left_on, - right_on, + left_join_key, + right_join_key, sort, join_key_coalesce_config, inherit_join_index, @@ -1402,6 +1454,9 @@ def _sort_on_join_keys(self) -> None: ) elif self._how == "right": ordering_column_identifiers = mapped_right_on + elif self._how == "asof": + # Order only by the left match_condition column + ordering_column_identifiers = [mapped_left_on[0]] else: # left join, inner join, left align, coalesce align ordering_column_identifiers = mapped_left_on diff --git a/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py b/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py index f7ae87c2a5d..91537d98e30 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py @@ -1197,22 +1197,29 @@ def join( # get the new mapped right on identifier right_on_cols = [right_identifiers_rename_map[key] for key in right_on_cols] - # Generate sql ON clause 'EQUAL_NULL(col1, col2) and EQUAL_NULL(col3, col4) ...' - on = None - for left_col, right_col in zip(left_on_cols, right_on_cols): - eq = Column(left_col).equal_null(Column(right_col)) - on = eq if on is None else on & eq - if how == "asof": - assert left_match_col, "left_match_col was not provided to ASOF Join" + assert ( + left_match_col + ), "ASOF join was not provided a column identifier to match on for the left table" left_match_col = Column(left_match_col) # Get the new mapped right match condition identifier - assert right_match_col, "right_match_col was not provided to ASOF Join" + assert ( + right_match_col + ), "ASOF join was not provided a column identifier to match on for the right table" right_match_col = Column(right_identifiers_rename_map[right_match_col]) # ASOF Join requires the use of match_condition - assert match_comparator, "match_comparator was not provided to ASOF Join" + assert ( + match_comparator + ), "ASOF join was not provided a comparator for the match condition" + + on = None + for left_col, right_col in zip(left_on_cols, right_on_cols): + eq = Column(left_col).__eq__(Column(right_col)) + on = eq if on is None else on & eq + snowpark_dataframe = left_snowpark_dataframe_ref.snowpark_dataframe.join( right=right_snowpark_dataframe_ref.snowpark_dataframe, + on=on, how=how, match_condition=getattr(left_match_col, match_comparator.value)( right_match_col @@ -1224,6 +1231,12 @@ def join( right_snowpark_dataframe_ref.snowpark_dataframe, how=how ) else: + # Generate sql ON clause 'EQUAL_NULL(col1, col2) and EQUAL_NULL(col3, col4) ...' + on = None + for left_col, right_col in zip(left_on_cols, right_on_cols): + eq = Column(left_col).equal_null(Column(right_col)) + on = eq if on is None else on & eq + snowpark_dataframe = left_snowpark_dataframe_ref.snowpark_dataframe.join( right_snowpark_dataframe_ref.snowpark_dataframe, on, how ) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py index de83e0429bf..ba8ceedec5e 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/resample_utils.py @@ -649,8 +649,6 @@ def perform_asof_join_on_frame( left=preserving_frame, right=referenced_frame, how="asof", - left_on=[], - right_on=[], left_match_col=left_timecol_snowflake_quoted_identifier, right_match_col=right_timecol_snowflake_quoted_identifier, match_comparator=( diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 8ef3bdf9bee..48f91ab40dd 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -7381,28 +7381,34 @@ def merge_asof( SnowflakeQueryCompiler """ # TODO: SNOW-1634547: Implement remaining parameters by leveraging `merge` implementation - if ( - by - or left_by - or right_by - or left_index - or right_index - or tolerance - or suffixes != ("_x", "_y") - ): + if left_index or right_index or tolerance or suffixes != ("_x", "_y"): ErrorMessage.not_implemented( "Snowpark pandas merge_asof method does not currently support parameters " - + "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - + "'suffixes', or 'tolerance'" + + "'left_index', 'right_index', 'suffixes', or 'tolerance'" ) if direction not in ("backward", "forward"): ErrorMessage.not_implemented( "Snowpark pandas merge_asof method only supports directions 'forward' and 'backward'" ) + if direction == "backward": + match_comparator = ( + MatchComparator.GREATER_THAN_OR_EQUAL_TO + if allow_exact_matches + else MatchComparator.GREATER_THAN + ) + else: + match_comparator = ( + MatchComparator.LESS_THAN_OR_EQUAL_TO + if allow_exact_matches + else MatchComparator.LESS_THAN + ) + left_frame = self._modin_frame right_frame = right._modin_frame - left_keys, right_keys = join_utils.get_join_keys( + # Get the left and right matching key and quoted identifier corresponding to the match_condition + # There will only be matching key/identifier for each table as there is only a single match condition + left_match_keys, right_match_keys = join_utils.get_join_keys( left=left_frame, right=right_frame, on=on, @@ -7411,42 +7417,62 @@ def merge_asof( left_index=left_index, right_index=right_index, ) - left_match_col = ( + left_match_identifier = ( left_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( - left_keys + left_match_keys )[0][0] ) - right_match_col = ( + right_match_identifier = ( right_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( - right_keys + right_match_keys )[0][0] ) - - if direction == "backward": - match_comparator = ( - MatchComparator.GREATER_THAN_OR_EQUAL_TO - if allow_exact_matches - else MatchComparator.GREATER_THAN + coalesce_config = join_utils.get_coalesce_config( + left_keys=left_match_keys, + right_keys=right_match_keys, + external_join_keys=[], + ) + + # Get the left and right matching keys and quoted identifiers corresponding to the 'on' condition + if by or (left_by and right_by): + left_on_keys, right_on_keys = join_utils.get_join_keys( + left=left_frame, + right=right_frame, + on=by, + left_on=left_by, + right_on=right_by, + ) + left_on_identifiers = [ + ids[0] + for ids in left_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( + left_on_keys + ) + ] + right_on_identifiers = [ + ids[0] + for ids in right_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( + right_on_keys + ) + ] + coalesce_config.extend( + join_utils.get_coalesce_config( + left_keys=left_on_keys, + right_keys=right_on_keys, + external_join_keys=[], + ) ) else: - match_comparator = ( - MatchComparator.LESS_THAN_OR_EQUAL_TO - if allow_exact_matches - else MatchComparator.LESS_THAN - ) - - coalesce_config = join_utils.get_coalesce_config( - left_keys=left_keys, right_keys=right_keys, external_join_keys=[] - ) + left_on_identifiers = [] + right_on_identifiers = [] joined_frame, _ = join_utils.join( left=left_frame, right=right_frame, + left_on=left_on_identifiers, + right_on=right_on_identifiers, how="asof", - left_on=[left_match_col], - right_on=[right_match_col], - left_match_col=left_match_col, - right_match_col=right_match_col, + left_match_col=left_match_identifier, + right_match_col=right_match_identifier, match_comparator=match_comparator, join_key_coalesce_config=coalesce_config, sort=True, diff --git a/tests/integ/modin/test_merge_asof.py b/tests/integ/modin/test_merge_asof.py index 681d339da90..51dda7889e7 100644 --- a/tests/integ/modin/test_merge_asof.py +++ b/tests/integ/modin/test_merge_asof.py @@ -105,6 +105,7 @@ def left_right_timestamp_data(): pd.Timestamp("2016-05-25 13:30:00.072"), pd.Timestamp("2016-05-25 13:30:00.075"), ], + "ticker": ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"], "bid": [720.50, 51.95, 51.97, 51.99, 720.50, 97.99, 720.50, 52.01], "ask": [720.93, 51.96, 51.98, 52.00, 720.93, 98.01, 720.88, 52.03], } @@ -118,6 +119,7 @@ def left_right_timestamp_data(): pd.Timestamp("2016-05-25 13:30:00.048"), pd.Timestamp("2016-05-25 13:30:00.048"), ], + "ticker": ["MSFT", "MSFT", "GOOG", "GOOG", "AAPL"], "price": [51.95, 51.95, 720.77, 720.92, 98.0], "quantity": [75, 155, 100, 100, 100], } @@ -229,14 +231,39 @@ def test_merge_asof_left_right_on( assert_snowpark_pandas_equal_to_pandas(snow_output, native_output) +@pytest.mark.parametrize("by", ["ticker", ["ticker"]]) @sql_count_checker(query_count=1, join_count=1) -def test_merge_asof_timestamps(left_right_timestamp_data): +def test_merge_asof_by(left_right_timestamp_data, by): left_native_df, right_native_df = left_right_timestamp_data left_snow_df, right_snow_df = pd.DataFrame(left_native_df), pd.DataFrame( right_native_df ) - native_output = native_pd.merge_asof(left_native_df, right_native_df, on="time") - snow_output = pd.merge_asof(left_snow_df, right_snow_df, on="time") + native_output = native_pd.merge_asof( + left_native_df, right_native_df, on="time", by=by + ) + snow_output = pd.merge_asof(left_snow_df, right_snow_df, on="time", by=by) + assert_snowpark_pandas_equal_to_pandas(snow_output, native_output) + + +@pytest.mark.parametrize( + "left_by, right_by", + [ + ("ticker", "ticker"), + (["ticker", "bid"], ["ticker", "price"]), + ], +) +@sql_count_checker(query_count=1, join_count=1) +def test_merge_asof_left_right_by(left_right_timestamp_data, left_by, right_by): + left_native_df, right_native_df = left_right_timestamp_data + left_snow_df, right_snow_df = pd.DataFrame(left_native_df), pd.DataFrame( + right_native_df + ) + native_output = native_pd.merge_asof( + left_native_df, right_native_df, on="time", left_by=left_by, right_by=right_by + ) + snow_output = pd.merge_asof( + left_snow_df, right_snow_df, on="time", left_by=left_by, right_by=right_by + ) assert_snowpark_pandas_equal_to_pandas(snow_output, native_output) @@ -248,8 +275,10 @@ def test_merge_asof_date(left_right_timestamp_data): left_snow_df, right_snow_df = pd.DataFrame(left_native_df), pd.DataFrame( right_native_df ) - native_output = native_pd.merge_asof(left_native_df, right_native_df, on="time") - snow_output = pd.merge_asof(left_snow_df, right_snow_df, on="time") + native_output = native_pd.merge_asof( + left_native_df, right_native_df, on="time", by="ticker" + ) + snow_output = pd.merge_asof(left_snow_df, right_snow_df, on="time", by="ticker") assert_snowpark_pandas_equal_to_pandas(snow_output, native_output) @@ -360,9 +389,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data): with pytest.raises( NotImplementedError, match=( - "Snowpark pandas merge_asof method does not currently support parameters " - "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - "'suffixes', or 'tolerance'" + "Snowpark pandas merge_asof method only supports directions 'forward' and 'backward'" ), ): pd.merge_asof( @@ -372,19 +399,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data): NotImplementedError, match=( "Snowpark pandas merge_asof method does not currently support parameters " - "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - "'suffixes', or 'tolerance'" - ), - ): - pd.merge_asof( - left_snow_df, right_snow_df, on="time", left_by="price", right_by="quantity" - ) - with pytest.raises( - NotImplementedError, - match=( - "Snowpark pandas merge_asof method does not currently support parameters " - "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - "'suffixes', or 'tolerance'" + + "'left_index', 'right_index', 'suffixes', or 'tolerance'" ), ): pd.merge_asof(left_snow_df, right_snow_df, left_index=True, right_index=True) @@ -392,8 +407,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data): NotImplementedError, match=( "Snowpark pandas merge_asof method does not currently support parameters " - "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - "'suffixes', or 'tolerance'" + + "'left_index', 'right_index', 'suffixes', or 'tolerance'" ), ): pd.merge_asof( @@ -406,8 +420,7 @@ def test_merge_asof_params_unsupported(left_right_timestamp_data): NotImplementedError, match=( "Snowpark pandas merge_asof method does not currently support parameters " - "'by', 'left_by', 'right_by', 'left_index', 'right_index', " - "'suffixes', or 'tolerance'" + + "'left_index', 'right_index', 'suffixes', or 'tolerance'" ), ): pd.merge_asof(