diff --git a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py index 331901f1a67..ec09485bd45 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py @@ -172,6 +172,30 @@ def join( JoinTypeLit ), f"Invalid join type: {how}. Allowed values are {get_args(JoinTypeLit)}" + def assert_snowpark_pandas_types_match() -> None: + """If Snowpark pandas types does not match, then a ValueError will be raised.""" + left_types = [ + left.snowflake_quoted_identifier_to_snowpark_pandas_type[id] + for id in left_on + ] + right_types = [ + right.snowflake_quoted_identifier_to_snowpark_pandas_type[id] + for id in right_on + ] + for i, (lt, rt) in enumerate(zip(left_types, right_types)): + if lt != rt: + left_on_id = left_on[i] + idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id) + key = left.data_column_pandas_labels[idx] + lt = lt if lt is not None else left.get_snowflake_type(left_on_id) + rt = rt if rt is not None else right.get_snowflake_type(right_on[i]) + raise ValueError( + f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. " + f"If you wish to proceed you should use pd.concat" + ) + + assert_snowpark_pandas_types_match() + # Re-project the active columns to make sure all active columns of the internal frame participate # in the join operation, and unnecessary columns are dropped from the projected columns. left = left.select_active_columns() diff --git a/tests/integ/modin/frame/test_isin.py b/tests/integ/modin/frame/test_isin.py index 5fb960518a2..cc6113c7466 100644 --- a/tests/integ/modin/frame/test_isin.py +++ b/tests/integ/modin/frame/test_isin.py @@ -250,8 +250,15 @@ def test_isin_dataframe_values_type_negative(): df.isin(values="abcdef") -@sql_count_checker(query_count=6) -def test_isin_timedelta(): +@sql_count_checker(query_count=3) +@pytest.mark.parametrize( + "values", + [ + pytest.param([2, 3], id="integers"), + pytest.param([pd.Timedelta(2), pd.Timedelta(3)], id="timedeltas"), + ], +) +def test_isin_timedelta(values): native_df = native_pd.DataFrame({"a": [1, 2, 3], "b": [None, 4, 2]}).astype( "timedelta64[ns]" ) @@ -260,13 +267,5 @@ def test_isin_timedelta(): eval_snowpark_pandas_result( snow_df, native_df, - lambda df: _test_isin_with_snowflake_logic(df, [2, 3], query_count=1), - ) - - eval_snowpark_pandas_result( - snow_df, - native_df, - lambda df: _test_isin_with_snowflake_logic( - df, [pd.Timedelta(2), pd.Timedelta(3)], query_count=1 - ), + lambda df: _test_isin_with_snowflake_logic(df, values, query_count=1), ) diff --git a/tests/integ/modin/frame/test_melt.py b/tests/integ/modin/frame/test_melt.py index 29728f26956..0812bb2c60c 100644 --- a/tests/integ/modin/frame/test_melt.py +++ b/tests/integ/modin/frame/test_melt.py @@ -305,8 +305,9 @@ def test_everything(): ) -@sql_count_checker(query_count=2) -def test_melt_timedelta(): +@sql_count_checker(query_count=1) +@pytest.mark.parametrize("value_vars", [["B"], ["B", "C"]]) +def test_melt_timedelta(value_vars): native_df = npd.DataFrame( { "A": {0: "a", 1: "b", 2: "c"}, @@ -316,9 +317,5 @@ def test_melt_timedelta(): ).astype({"B": "timedelta64[ns]", "C": "timedelta64[ns]"}) snow_df = pd.DataFrame(native_df) eval_snowpark_pandas_result( - snow_df, native_df, lambda df: df.melt(id_vars=["A"], value_vars=["B"]) - ) - - eval_snowpark_pandas_result( - snow_df, native_df, lambda df: df.melt(id_vars=["A"], value_vars=["B", "C"]) + snow_df, native_df, lambda df: df.melt(id_vars=["A"], value_vars=value_vars) ) diff --git a/tests/integ/modin/frame/test_merge.py b/tests/integ/modin/frame/test_merge.py index e1c75d1d853..c1ced99fc67 100644 --- a/tests/integ/modin/frame/test_merge.py +++ b/tests/integ/modin/frame/test_merge.py @@ -1158,8 +1158,8 @@ def test_merge_validate_negative(lvalues, rvalues, validate): left.merge(right, left_on="A", right_on="B", validate=validate) -@sql_count_checker(query_count=4, join_count=4) -def test_merge_timedelta(): +@sql_count_checker(query_count=1, join_count=1) +def test_merge_timedelta_on(): left_df = native_pd.DataFrame( {"lkey": ["foo", "bar", "baz", "foo"], "value": [1, 2, 3, 5]} ).astype({"value": "timedelta64[ns]"}) @@ -1176,36 +1176,42 @@ def test_merge_timedelta(): ), ) - left_df = native_pd.DataFrame({"a": ["foo", "bar"], "b": [1, 2]}).astype( - {"b": "timedelta64[ns]"} - ) - right_df = native_pd.DataFrame({"a": ["foo", "baz"], "c": [3, 4]}).astype( - {"c": "timedelta64[ns]"} - ) - eval_snowpark_pandas_result( - pd.DataFrame(left_df), - left_df, - lambda df: df.merge( - pd.DataFrame(right_df) if isinstance(df, pd.DataFrame) else right_df, - how="inner", - on="a", - ), - ) - eval_snowpark_pandas_result( - pd.DataFrame(left_df), - left_df, - lambda df: df.merge( - pd.DataFrame(right_df) if isinstance(df, pd.DataFrame) else right_df, - how="right", - on="a", - ), - ) - eval_snowpark_pandas_result( - pd.DataFrame(left_df), - left_df, - lambda df: df.merge( - pd.DataFrame(right_df) if isinstance(df, pd.DataFrame) else right_df, - how="cross", - ), - ) +@pytest.mark.parametrize( + "kwargs", + [ + {"how": "inner", "on": "a"}, + {"how": "right", "on": "a"}, + {"how": "right", "on": "b"}, + {"how": "left", "on": "c"}, + {"how": "cross"}, + ], +) +def test_merge_timedelta_how(kwargs): + left_df = native_pd.DataFrame( + {"a": ["foo", "bar"], "b": [1, 2], "c": [3, 5]} + ).astype({"b": "timedelta64[ns]"}) + right_df = native_pd.DataFrame( + {"a": ["foo", "baz"], "b": [1, 3], "c": [3, 4]} + ).astype({"b": "timedelta64[ns]", "c": "timedelta64[ns]"}) + count = 1 + expect_exception = False + if "c" == kwargs.get("on", None): # merge timedelta with int exception + expect_exception = True + count = 0 + + with SqlCounter(query_count=count, join_count=count): + eval_snowpark_pandas_result( + pd.DataFrame(left_df), + left_df, + lambda df: df.merge( + pd.DataFrame(right_df) if isinstance(df, pd.DataFrame) else right_df, + **kwargs, + ), + expect_exception=expect_exception, + expect_exception_match="You are trying to merge on LongType and TimedeltaType columns for key 'c'. If you " + "wish to proceed you should use pd.concat", + expect_exception_type=ValueError, + assert_exception_equal=False, # pandas exception: You are trying to merge on int64 and timedelta64[ns] + # columns for key 'c'. If you wish to proceed you should use pd.concat + )