Skip to content

Commit

Permalink
add merge tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-azhan committed Aug 27, 2024
1 parent 5163a98 commit 992bee9
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 60 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
#### New Features

- Added limited support for the `Timedelta` type, including the following features. Snowpark pandas will raise `NotImplementedError` for unsupported `Timedelta` use cases.
- supporting tracking the Timedelta type through `copy`, `cache_result`, `shift`, `sort_index`, `assign`, `bfill`, `ffill`, `fillna`, `compare`, `diff`, `drop`, `dropna`, `duplicated`, `empty`, `insert`, `isin`, `isna`, `items`, `iterrows`, `join`, `len`, `melt`, `merge`, `nlargest`, `nsmallest`.
- supporting tracking the Timedelta type through `copy`, `cache_result`, `shift`, `sort_index`, `assign`, `bfill`, `ffill`, `fillna`, `compare`, `diff`, `drop`, `dropna`, `duplicated`, `empty`, `equals`, `insert`, `isin`, `isna`, `items`, `iterrows`, `join`, `len`, `mask`, `melt`, `merge`, `nlargest`, `nsmallest`.
- converting non-timedelta to timedelta via `astype`.
- support for subtracting two timestamps to get a Timedelta.
- support indexing with Timedelta data columns.
Expand Down
24 changes: 24 additions & 0 deletions src/snowflake/snowpark/modin/plugin/_internal/join_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,30 @@ def join(
JoinTypeLit
), f"Invalid join type: {how}. Allowed values are {get_args(JoinTypeLit)}"

def assert_snowpark_pandas_types_match() -> None:
"""If Snowpark pandas types does not match, then a ValueError will be raised."""
left_types = [
left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
for id in left_on
]
right_types = [
right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None)
for id in right_on
]
for i, (lt, rt) in enumerate(zip(left_types, right_types)):
if lt != rt:
left_on_id = left_on[i]
idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id)
key = left.data_column_pandas_labels[idx]
lt = lt if lt is not None else left.get_snowflake_type(left_on_id)
rt = rt if rt is not None else right.get_snowflake_type(right_on[i])
raise ValueError(
f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. "
f"If you wish to proceed you should use pd.concat"
)

assert_snowpark_pandas_types_match()

# Re-project the active columns to make sure all active columns of the internal frame participate
# in the join operation, and unnecessary columns are dropped from the projected columns.
left = left.select_active_columns()
Expand Down
7 changes: 1 addition & 6 deletions tests/integ/modin/frame/test_equals.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,10 @@
([1, 2, None], [1, 2, None], True), # nulls are considered equal
([1, 2, 3], [1.0, 2.0, 3.0], False), # float and integer types are not equal
([1, 2, 3], ["1", "2", "3"], False), # integer and string types are not equal
pytest.param(
(
[1, 2, 3],
pandas.timedelta_range(1, periods=3),
False, # timedelta and integer types are not equal
marks=pytest.mark.xfail(
strict=True,
raises=NotImplementedError,
reason="TODO(SNOW-1637101, SNOW-1637102): Support these cases.",
),
),
],
)
Expand Down
21 changes: 10 additions & 11 deletions tests/integ/modin/frame/test_isin.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,15 @@ def test_isin_dataframe_values_type_negative():
df.isin(values="abcdef")


@sql_count_checker(query_count=6)
def test_isin_timedelta():
@sql_count_checker(query_count=3)
@pytest.mark.parametrize(
"values",
[
pytest.param([2, 3], id="integers"),
pytest.param([pd.Timedelta(2), pd.Timedelta(3)], id="timedeltas"),
],
)
def test_isin_timedelta(values):
native_df = native_pd.DataFrame({"a": [1, 2, 3], "b": [None, 4, 2]}).astype(
"timedelta64[ns]"
)
Expand All @@ -260,13 +267,5 @@ def test_isin_timedelta():
eval_snowpark_pandas_result(
snow_df,
native_df,
lambda df: _test_isin_with_snowflake_logic(df, [2, 3], query_count=1),
)

eval_snowpark_pandas_result(
snow_df,
native_df,
lambda df: _test_isin_with_snowflake_logic(
df, [pd.Timedelta(2), pd.Timedelta(3)], query_count=1
),
lambda df: _test_isin_with_snowflake_logic(df, values, query_count=1),
)
2 changes: 1 addition & 1 deletion tests/integ/modin/frame/test_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,7 @@ def perform_mask(df):
)


@pytest.mark.xfail(reason="TODO(SNOW-1637101, SNOW-1637102): Support these cases.")
@sql_count_checker(query_count=1)
def test_mask_timedelta(test_data):
native_df = native_pd.DataFrame(test_data, dtype="timedelta64[ns]")
snow_df = pd.DataFrame(native_df)
Expand Down
11 changes: 4 additions & 7 deletions tests/integ/modin/frame/test_melt.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,9 @@ def test_everything():
)


@sql_count_checker(query_count=2)
def test_melt_timedelta():
@sql_count_checker(query_count=1)
@pytest.mark.parametrize("value_vars", [["B"], ["B", "C"]])
def test_melt_timedelta(value_vars):
native_df = npd.DataFrame(
{
"A": {0: "a", 1: "b", 2: "c"},
Expand All @@ -316,9 +317,5 @@ def test_melt_timedelta():
).astype({"B": "timedelta64[ns]", "C": "timedelta64[ns]"})
snow_df = pd.DataFrame(native_df)
eval_snowpark_pandas_result(
snow_df, native_df, lambda df: df.melt(id_vars=["A"], value_vars=["B"])
)

eval_snowpark_pandas_result(
snow_df, native_df, lambda df: df.melt(id_vars=["A"], value_vars=["B", "C"])
snow_df, native_df, lambda df: df.melt(id_vars=["A"], value_vars=value_vars)
)
74 changes: 40 additions & 34 deletions tests/integ/modin/frame/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,8 +1158,8 @@ def test_merge_validate_negative(lvalues, rvalues, validate):
left.merge(right, left_on="A", right_on="B", validate=validate)


@sql_count_checker(query_count=4, join_count=4)
def test_merge_timedelta():
@sql_count_checker(query_count=1, join_count=1)
def test_merge_timedelta_on():
left_df = native_pd.DataFrame(
{"lkey": ["foo", "bar", "baz", "foo"], "value": [1, 2, 3, 5]}
).astype({"value": "timedelta64[ns]"})
Expand All @@ -1176,36 +1176,42 @@ def test_merge_timedelta():
),
)

left_df = native_pd.DataFrame({"a": ["foo", "bar"], "b": [1, 2]}).astype(
{"b": "timedelta64[ns]"}
)
right_df = native_pd.DataFrame({"a": ["foo", "baz"], "c": [3, 4]}).astype(
{"c": "timedelta64[ns]"}
)
eval_snowpark_pandas_result(
pd.DataFrame(left_df),
left_df,
lambda df: df.merge(
pd.DataFrame(right_df) if isinstance(df, pd.DataFrame) else right_df,
how="inner",
on="a",
),
)

eval_snowpark_pandas_result(
pd.DataFrame(left_df),
left_df,
lambda df: df.merge(
pd.DataFrame(right_df) if isinstance(df, pd.DataFrame) else right_df,
how="right",
on="a",
),
)
eval_snowpark_pandas_result(
pd.DataFrame(left_df),
left_df,
lambda df: df.merge(
pd.DataFrame(right_df) if isinstance(df, pd.DataFrame) else right_df,
how="cross",
),
)
@pytest.mark.parametrize(
"kwargs",
[
{"how": "inner", "on": "a"},
{"how": "right", "on": "a"},
{"how": "right", "on": "b"},
{"how": "left", "on": "c"},
{"how": "cross"},
],
)
def test_merge_timedelta_how(kwargs):
left_df = native_pd.DataFrame(
{"a": ["foo", "bar"], "b": [1, 2], "c": [3, 5]}
).astype({"b": "timedelta64[ns]"})
right_df = native_pd.DataFrame(
{"a": ["foo", "baz"], "b": [1, 3], "c": [3, 4]}
).astype({"b": "timedelta64[ns]", "c": "timedelta64[ns]"})
count = 1
expect_exception = False
if "c" == kwargs.get("on", None): # merge timedelta with int exception
expect_exception = True
count = 0

with SqlCounter(query_count=count, join_count=count):
eval_snowpark_pandas_result(
pd.DataFrame(left_df),
left_df,
lambda df: df.merge(
pd.DataFrame(right_df) if isinstance(df, pd.DataFrame) else right_df,
**kwargs,
),
expect_exception=expect_exception,
expect_exception_match="You are trying to merge on LongType and TimedeltaType columns for key 'c'. If you "
"wish to proceed you should use pd.concat",
expect_exception_type=ValueError,
assert_exception_equal=False, # pandas exception: You are trying to merge on int64 and timedelta64[ns]
# columns for key 'c'. If you wish to proceed you should use pd.concat
)

0 comments on commit 992bee9

Please sign in to comment.