From 3eeda0dbd61304822212d1cb219eadd1fd925a82 Mon Sep 17 00:00:00 2001 From: azhan Date: Tue, 27 Aug 2024 10:27:17 -0700 Subject: [PATCH 1/6] SNOW-1625379 Test coverage for timedelta under modin/integ/frame part 1 --- CHANGELOG.md | 1 + .../modin/plugin/_internal/isin_utils.py | 10 ++++ .../plugin/_internal/snowpark_pandas_types.py | 15 +++++- .../modin/plugin/_internal/unpivot_utils.py | 21 ++++++-- .../compiler/snowflake_query_compiler.py | 30 +++++++---- tests/integ/modin/data.py | 12 ++++- tests/integ/modin/frame/test_assign.py | 21 ++++++++ tests/integ/modin/frame/test_bfill_ffill.py | 23 +++++++- tests/integ/modin/frame/test_compare.py | 15 ++---- tests/integ/modin/frame/test_diff.py | 14 +++++ tests/integ/modin/frame/test_drop.py | 12 +++++ tests/integ/modin/frame/test_dropna.py | 1 + tests/integ/modin/frame/test_duplicated.py | 15 +++++- tests/integ/modin/frame/test_empty.py | 9 +++- tests/integ/modin/frame/test_equals.py | 2 + tests/integ/modin/frame/test_fillna.py | 17 ++++++ tests/integ/modin/frame/test_idxmax_idxmin.py | 20 +++++++ tests/integ/modin/frame/test_insert.py | 31 +++++++++++ tests/integ/modin/frame/test_isin.py | 22 ++++++++ tests/integ/modin/frame/test_items.py | 1 + tests/integ/modin/frame/test_iterrows.py | 1 + tests/integ/modin/frame/test_itertuples.py | 1 + tests/integ/modin/frame/test_join.py | 20 +++++++ tests/integ/modin/frame/test_len.py | 1 + tests/integ/modin/frame/test_mask.py | 9 ++++ tests/integ/modin/frame/test_melt.py | 19 +++++++ tests/integ/modin/frame/test_merge.py | 53 +++++++++++++++++++ .../modin/frame/test_nlargest_nsmallest.py | 18 +++++++ .../unit/modin/test_snowpark_pandas_types.py | 7 +++ 29 files changed, 392 insertions(+), 29 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1cf5845439c..169776af310 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,6 +51,7 @@ #### New Features - Added limited support for the `Timedelta` type, including the following features. Snowpark pandas will raise `NotImplementedError` for unsupported `Timedelta` use cases. + - supporting tracking the Timedelta type through `copy`, `cache_result`, `shift`, `sort_index`, `fillna`, `diff`, `duplicated`, `empty`, `insert`, `isin`, `isna`, `items`, `iterrows`, `join`, `len`, `melt`, `nlargest`, `nsmallest`. - supporting tracking the Timedelta type through `copy`, `cache_result`, `shift`, `sort_index`. - converting non-timedelta to timedelta via `astype`. - `NotImplementedError` will be raised for the rest of methods that do not support `Timedelta`. diff --git a/src/snowflake/snowpark/modin/plugin/_internal/isin_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/isin_utils.py index 26d50a8d53c..48edba416c6 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/isin_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/isin_utils.py @@ -14,6 +14,9 @@ ) from snowflake.snowpark.modin.plugin._internal.frame import InternalFrame from snowflake.snowpark.modin.plugin._internal.indexing_utils import set_frame_2d_labels +from snowflake.snowpark.modin.plugin._internal.snowpark_pandas_types import ( + SnowparkPandasType, +) from snowflake.snowpark.modin.plugin._internal.type_utils import infer_series_type from snowflake.snowpark.modin.plugin._internal.utils import ( append_columns, @@ -100,6 +103,13 @@ def scalar_isin_expression( for literal_expr in values ] + # Case 4: If column's and values' data type differs and any of the type is SnowparkPandasType + elif values_dtype != column_dtype and ( + isinstance(values_dtype, SnowparkPandasType) + or isinstance(column_dtype, SnowparkPandasType) + ): + return pandas_lit(False) + values = array_construct(*values) # to_variant is a requirement for array_contains, else an error is produced. diff --git a/src/snowflake/snowpark/modin/plugin/_internal/snowpark_pandas_types.py b/src/snowflake/snowpark/modin/plugin/_internal/snowpark_pandas_types.py index 20f5d8b61de..a8806d7c90d 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/snowpark_pandas_types.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/snowpark_pandas_types.py @@ -101,6 +101,13 @@ def get_snowpark_pandas_type_for_pandas_type( return _type_to_snowpark_pandas_type[pandas_type]() return None + def type_match(self, value: Any) -> bool: + """Return True if the value's type matches self.""" + val_type = SnowparkPandasType.get_snowpark_pandas_type_for_pandas_type( + type(value) + ) + return self == val_type + class SnowparkPandasColumn(NamedTuple): """A Snowpark Column that has an optional SnowparkPandasType.""" @@ -111,7 +118,7 @@ class SnowparkPandasColumn(NamedTuple): snowpark_pandas_type: Optional[SnowparkPandasType] -class TimedeltaType(SnowparkPandasType, LongType): +class TimedeltaType(SnowparkPandasType): """ Timedelta represents the difference between two times. @@ -133,6 +140,12 @@ def __init__(self) -> None: WarningMessage.single_warning(TIMEDELTA_WARNING_MESSAGE) super().__init__() + def __eq__(self, other: Any) -> bool: + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + def __ne__(self, other: Any) -> bool: + return not self.__eq__(other) + @staticmethod def to_pandas(value: int) -> native_pd.Timedelta: """ diff --git a/src/snowflake/snowpark/modin/plugin/_internal/unpivot_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/unpivot_utils.py index 905f2b23c91..9f1ca22180a 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/unpivot_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/unpivot_utils.py @@ -735,12 +735,16 @@ def _simple_unpivot( # create the initial set of columns to be retained as identifiers and those # which will be unpivoted. Collect data type information. unpivot_quoted_columns = [] + unpivot_quoted_column_types = [] + ordering_decode_conditions = [] id_col_names = [] id_col_quoted_identifiers = [] - for (pandas_label, snowflake_quoted_identifier) in zip( + id_col_types = [] + for (pandas_label, snowflake_quoted_identifier, sp_pandas_type) in zip( frame.data_column_pandas_labels, frame.data_column_snowflake_quoted_identifiers, + frame.cached_data_column_snowpark_pandas_types, ): is_id_col = pandas_label in pandas_id_columns is_var_col = pandas_label in pandas_value_columns @@ -752,9 +756,11 @@ def _simple_unpivot( col(var_quoted) == pandas_lit(pandas_label) ) unpivot_quoted_columns.append(snowflake_quoted_identifier) + unpivot_quoted_column_types.append(sp_pandas_type) if is_id_col: id_col_names.append(pandas_label) id_col_quoted_identifiers.append(snowflake_quoted_identifier) + id_col_types.append(sp_pandas_type) # create the case expressions used for the final result set ordering based # on the column position. This clause will be appled after the unpivot @@ -787,7 +793,7 @@ def _simple_unpivot( pandas_labels=[unquoted_col_name], )[0] ) - # coalese the values to unpivot and preserve null values This code + # coalesce the values to unpivot and preserve null values This code # can be removed when UNPIVOT_INCLUDE_NULLS is enabled unpivot_columns_normalized_types.append( coalesce(to_variant(c), to_variant(pandas_lit(null_replace_value))).alias( @@ -870,6 +876,13 @@ def _simple_unpivot( var_quoted, corrected_value_column_name, ] + corrected_value_column_type = None + if len(set(unpivot_quoted_column_types)) == 1: + corrected_value_column_type = unpivot_quoted_column_types[0] + final_snowflake_quoted_col_types = id_col_types + [ + None, + corrected_value_column_type, + ] # Create the new frame and compiler return InternalFrame.create( @@ -881,8 +894,8 @@ def _simple_unpivot( index_column_snowflake_quoted_identifiers=[ ordered_dataframe.row_position_snowflake_quoted_identifier ], - data_column_types=None, - index_column_types=None, + data_column_types=final_snowflake_quoted_col_types, + index_column_types=[None], ) diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 7e6336c397e..a0efa4ef601 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -5757,8 +5757,6 @@ def insert( Returns: A new SnowflakeQueryCompiler instance with new column. """ - self._raise_not_implemented_error_for_timedelta() - if not isinstance(value, SnowflakeQueryCompiler): # Scalar value new_internal_frame = self._modin_frame.append_column( @@ -5848,7 +5846,9 @@ def move_last_element(arr: list, index: int) -> None: data_column_snowflake_quoted_identifiers = ( new_internal_frame.data_column_snowflake_quoted_identifiers ) + data_column_types = new_internal_frame.cached_data_column_snowpark_pandas_types move_last_element(data_column_snowflake_quoted_identifiers, loc) + move_last_element(data_column_types, loc) new_internal_frame = InternalFrame.create( ordered_dataframe=new_internal_frame.ordered_dataframe, @@ -5857,8 +5857,8 @@ def move_last_element(arr: list, index: int) -> None: data_column_pandas_index_names=new_internal_frame.data_column_pandas_index_names, index_column_pandas_labels=new_internal_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=new_internal_frame.index_column_snowflake_quoted_identifiers, - data_column_types=None, - index_column_types=None, + data_column_types=data_column_types, + index_column_types=new_internal_frame.cached_index_column_snowpark_pandas_types, ) return SnowflakeQueryCompiler(new_internal_frame) @@ -6645,8 +6645,6 @@ def melt( Notes: melt does not yet handle multiindex or ignore index """ - self._raise_not_implemented_error_for_timedelta() - if col_level is not None: raise NotImplementedError( "Snowpark Pandas doesn't support 'col_level' argument in melt API" @@ -6749,8 +6747,6 @@ def merge( Returns: SnowflakeQueryCompiler instance with merged result. """ - self._raise_not_implemented_error_for_timedelta() - if validate: ErrorMessage.not_implemented( "Snowpark pandas merge API doesn't yet support 'validate' parameter" @@ -9815,6 +9811,10 @@ def _fillna_with_masking( # case 2: fillna with a method if method is not None: + # no Snowpark pandas type change in this case + data_column_snowpark_pandas_types = ( + self._modin_frame.cached_data_column_snowpark_pandas_types + ) method = FillNAMethod.get_enum_for_string_method(method) method_is_ffill = method is FillNAMethod.FFILL_METHOD if axis == 0: @@ -9921,6 +9921,7 @@ def fillna_expr(snowflake_quoted_id: str) -> SnowparkColumn: include_index=False, ) fillna_column_map = {} + data_column_snowpark_pandas_types = [] if columns_mask is not None: columns_to_ignore = itertools.compress( self._modin_frame.data_column_pandas_labels, @@ -9940,10 +9941,18 @@ def fillna_expr(snowflake_quoted_id: str) -> SnowparkColumn: col(id), coalesce(id, pandas_lit(val)), ) + col_type = self._modin_frame.get_snowflake_type(id) + col_pandas_type = ( + col_type + if isinstance(col_type, SnowparkPandasType) + and col_type.type_match(val) + else None + ) + data_column_snowpark_pandas_types.append(col_pandas_type) return SnowflakeQueryCompiler( self._modin_frame.update_snowflake_quoted_identifiers_with_expressions( - fillna_column_map + fillna_column_map, data_column_snowpark_pandas_types ).frame ) @@ -10217,7 +10226,8 @@ def diff(self, periods: int, axis: int) -> "SnowflakeQueryCompiler": } return SnowflakeQueryCompiler( self._modin_frame.update_snowflake_quoted_identifiers_with_expressions( - diff_label_to_value_map + diff_label_to_value_map, + self._modin_frame.cached_data_column_snowpark_pandas_types, ).frame ) diff --git a/tests/integ/modin/data.py b/tests/integ/modin/data.py index 653e0037e09..35c4d321787 100644 --- a/tests/integ/modin/data.py +++ b/tests/integ/modin/data.py @@ -1,6 +1,7 @@ # # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import pandas as native_pd RAW_NA_DF_DATA_TEST_CASES = [ ({"A": [1, 2, 3], "B": [4, 5, 6]}, "numeric-no"), @@ -16,9 +17,18 @@ ({"A": [True, 1, "X"], "B": ["Y", 3.14, False]}, "mixed"), ({"A": [True, None, "X"], "B": [None, 3.14, None]}, "mixed-mixed-1"), ({"A": [None, 1, None], "B": ["Y", None, False]}, "mixed-mixed-2"), + ( + { + "A": [None, native_pd.Timedelta(2), None], + "B": [native_pd.Timedelta(4), None, native_pd.Timedelta(6)], + }, + "timedelta-mixed-1", + ), ] RAW_NA_DF_SERIES_TEST_CASES = [ (list(df_data.values()), test_case) - for (df_data, test_case) in RAW_NA_DF_DATA_TEST_CASES + for (df_data, test_case) in RAW_NA_DF_DATA_TEST_CASES[ + :1 + ] # "timedelta-mixed-1" is not json serializable ] diff --git a/tests/integ/modin/frame/test_assign.py b/tests/integ/modin/frame/test_assign.py index b0da2a110bf..b1677deda8f 100644 --- a/tests/integ/modin/frame/test_assign.py +++ b/tests/integ/modin/frame/test_assign.py @@ -238,3 +238,24 @@ def test_overwrite_columns_via_assign(): eval_snowpark_pandas_result( snow_df, native_df, lambda df: df.assign(a=df["b"], last_col=[10, 11, 12]) ) + + +@sql_count_checker(query_count=2, join_count=1) +def test_assign_basic_timedelta_series(): + snow_df, native_df = create_test_dfs( + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], + columns=native_pd.Index(list("abc"), name="columns"), + index=native_pd.Index([0, 1, 2], name="index"), + ) + native_df.columns.names = ["columns"] + native_df.index.names = ["index"] + + native_td = native_pd.timedelta_range("1 day", periods=3) + + def assign_func(df): + if isinstance(df, pd.DataFrame): + return df.assign(new_col=pd.Series(native_td)) + else: + return df.assign(new_col=native_pd.Series(native_td)) + + eval_snowpark_pandas_result(snow_df, native_df, assign_func) diff --git a/tests/integ/modin/frame/test_bfill_ffill.py b/tests/integ/modin/frame/test_bfill_ffill.py index 7938fe4059f..504261b80fe 100644 --- a/tests/integ/modin/frame/test_bfill_ffill.py +++ b/tests/integ/modin/frame/test_bfill_ffill.py @@ -14,7 +14,7 @@ @pytest.mark.parametrize("func", ["backfill", "bfill", "ffill", "pad"]) @sql_count_checker(query_count=1) -def test_df_func(func): +def test_df_fill(func): native_df = native_pd.DataFrame( [ [np.nan, 2, np.nan, 0], @@ -31,3 +31,24 @@ def test_df_func(func): native_df, lambda df: getattr(df, func)(), ) + + +@pytest.mark.parametrize("func", ["backfill", "bfill", "ffill", "pad"]) +@sql_count_checker(query_count=1) +def test_df_timedelta_fill(func): + native_df = native_pd.DataFrame( + [ + [np.nan, 2, np.nan, 0], + [3, 4, np.nan, 1], + [np.nan, np.nan, np.nan, np.nan], + [np.nan, 3, np.nan, 4], + [3, np.nan, 4, np.nan], + ], + columns=list("ABCD"), + ).astype("timedelta64[ns]") + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: getattr(df, func)(), + ) diff --git a/tests/integ/modin/frame/test_compare.py b/tests/integ/modin/frame/test_compare.py index 9a0f7caf88d..c7f0c6f81d4 100644 --- a/tests/integ/modin/frame/test_compare.py +++ b/tests/integ/modin/frame/test_compare.py @@ -35,16 +35,10 @@ def base_df() -> native_pd.DataFrame: return native_pd.DataFrame( [ - [None, None, 3.1, pd.Timestamp("2024-01-01"), [130]], - [ - "a", - 1, - 4.2, - pd.Timestamp("2024-02-01"), - [131], - ], - ["b", 2, 5.3, pd.Timestamp("2024-03-01"), [132]], - [None, 3, 6.4, pd.Timestamp("2024-04-01"), [133]], + [None, None, 3.1, pd.Timestamp("2024-01-01"), [130], pd.Timedelta(1)], + ["a", 1, 4.2, pd.Timestamp("2024-02-01"), [131], pd.Timedelta(11)], + ["b", 2, 5.3, pd.Timestamp("2024-03-01"), [132], pd.Timedelta(21)], + [None, 3, 6.4, pd.Timestamp("2024-04-01"), [133], pd.Timedelta(13)], ], index=pd.MultiIndex.from_tuples( [ @@ -64,6 +58,7 @@ def base_df() -> native_pd.DataFrame: ("group_2", "float_col"), ("group_2", "timestamp_col"), ("group_2", "list_col"), + ("group_2", "timedelta_col"), ], names=["column_level1", "column_level2"], ), diff --git a/tests/integ/modin/frame/test_diff.py b/tests/integ/modin/frame/test_diff.py index 26aa5b74c85..185b2eab89e 100644 --- a/tests/integ/modin/frame/test_diff.py +++ b/tests/integ/modin/frame/test_diff.py @@ -140,6 +140,20 @@ def test_df_diff_bool_df(periods): eval_snowpark_pandas_result(snow_df, native_df, lambda df: df.diff(periods=periods)) +@sql_count_checker(query_count=1) +@pytest.mark.parametrize("periods", [0, 1]) +def test_df_diff_timedelta_df(periods): + native_df = native_pd.DataFrame( + np.arange(NUM_ROWS_TALL_DF * NUM_COLS_TALL_DF).reshape( + (NUM_ROWS_TALL_DF, NUM_COLS_TALL_DF) + ), + columns=["A", "B", "C", "D"], + ) + native_df = native_df.astype({"A": "timedelta64[ns]", "C": "timedelta64[ns]"}) + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result(snow_df, native_df, lambda df: df.diff(periods=periods)) + + @sql_count_checker(query_count=1) @pytest.mark.parametrize("periods", [0, 1]) def test_df_diff_int_and_bool_df(periods): diff --git a/tests/integ/modin/frame/test_drop.py b/tests/integ/modin/frame/test_drop.py index e71999dd28d..cc1a1a203d3 100644 --- a/tests/integ/modin/frame/test_drop.py +++ b/tests/integ/modin/frame/test_drop.py @@ -70,6 +70,18 @@ def test_drop_list_like(native_df, labels): eval_snowpark_pandas_result(snow_df, native_df, lambda df: df.drop(labels, axis=1)) +@pytest.mark.parametrize( + "labels", [Index(["red", "green"]), np.array(["red", "green"])] +) +@sql_count_checker(query_count=1) +def test_drop_timedelta(native_df, labels): + native_df_dt = native_df.astype({"red": "timedelta64[ns]"}) + snow_df = pd.DataFrame(native_df_dt) + eval_snowpark_pandas_result( + snow_df, native_df_dt, lambda df: df.drop(labels, axis=1) + ) + + @pytest.mark.parametrize( "labels, axis, expected_query_count", [ diff --git a/tests/integ/modin/frame/test_dropna.py b/tests/integ/modin/frame/test_dropna.py index e5fb2085417..d77c65d055e 100644 --- a/tests/integ/modin/frame/test_dropna.py +++ b/tests/integ/modin/frame/test_dropna.py @@ -19,6 +19,7 @@ def test_dropna_df(): "name": ["Alfred", "Batman", "Catwoman"], "toy": [np.nan, "Batmobile", "Bullwhip"], "born": [pd.NaT, pd.Timestamp("1940-04-25"), pd.NaT], + "dt": [pd.NaT, pd.Timedelta(1), pd.NaT], } ) diff --git a/tests/integ/modin/frame/test_duplicated.py b/tests/integ/modin/frame/test_duplicated.py index e4c5d594ecc..0eade6af114 100644 --- a/tests/integ/modin/frame/test_duplicated.py +++ b/tests/integ/modin/frame/test_duplicated.py @@ -53,11 +53,24 @@ def test_duplicated_with_misspelled_column_name_or_empty_subset(subset): (["A"], native_pd.Series([False, False, True, False, True])), (["B"], native_pd.Series([False, False, False, True, True])), (["A", "B"], native_pd.Series([False, False, False, False, True])), + ("C", native_pd.Series([False, False, True, False, True])), ], ) @sql_count_checker(query_count=1, join_count=1) def test_duplicated_subset(subset, expected): - df = pd.DataFrame({"A": [0, 1, 1, 2, 0], "B": ["a", "b", "c", "b", "a"]}) + df = pd.DataFrame( + { + "A": [0, 1, 1, 2, 0], + "B": ["a", "b", "c", "b", "a"], + "C": [ + pd.Timedelta(1), + pd.Timedelta(10), + pd.Timedelta(1), + pd.Timedelta(0), + pd.Timedelta(10), + ], + } + ) result = df.duplicated(subset=subset) assert_snowpark_pandas_equal_to_pandas(result, expected) diff --git a/tests/integ/modin/frame/test_empty.py b/tests/integ/modin/frame/test_empty.py index 0ed4d2c9fa9..b39a77eae91 100644 --- a/tests/integ/modin/frame/test_empty.py +++ b/tests/integ/modin/frame/test_empty.py @@ -16,7 +16,14 @@ @pytest.mark.parametrize( "dataframe_input, test_case_name", [ - ({"A": [1, 2, 3], "B": [4, 5, 6]}, "simple non-empty"), + ( + { + "A": [1, 2, 3], + "B": [4, 5, 6], + "C": native_pd.timedelta_range(1, periods=3), + }, + "simple non-empty", + ), ({"A": [], "B": []}, "empty column"), ({"A": [np.nan]}, "np nan column"), ], diff --git a/tests/integ/modin/frame/test_equals.py b/tests/integ/modin/frame/test_equals.py index 95b6b8ffd6f..45b35e3274f 100644 --- a/tests/integ/modin/frame/test_equals.py +++ b/tests/integ/modin/frame/test_equals.py @@ -25,6 +25,8 @@ ([1, 2, None], [1, 2, None], True), # nulls are considered equal ([1, 2, 3], [1.0, 2.0, 3.0], False), # float and integer types are not equal ([1, 2, 3], ["1", "2", "3"], False), # integer and string types are not equal + # TODO(SNOW-1637101, SNOW-1637102): Support these cases. + # ([1, 2, 3], pandas.timedelta_range(1, periods=3), False), # timedelta and integer types are not equal ], ) @sql_count_checker(query_count=2, join_count=2) diff --git a/tests/integ/modin/frame/test_fillna.py b/tests/integ/modin/frame/test_fillna.py index 189e757c8b2..677c8d3ddc5 100644 --- a/tests/integ/modin/frame/test_fillna.py +++ b/tests/integ/modin/frame/test_fillna.py @@ -150,6 +150,23 @@ def test_value_scalar(test_fillna_df): ) +@sql_count_checker(query_count=2) +def test_timedelta_value_scalar(test_fillna_df): + timedelta_df = test_fillna_df.astype("timedelta64[ns]") + eval_snowpark_pandas_result( + pd.DataFrame(timedelta_df), + timedelta_df, + lambda df: df.fillna(pd.Timedelta(1)), # dtype keeps to be timedelta64[ns] + ) + + # Snowpark pandas dtype will be changed to int in this case + eval_snowpark_pandas_result( + pd.DataFrame(timedelta_df), + test_fillna_df, + lambda df: df.fillna(1), + ) + + @sql_count_checker(query_count=1) def test_value_scalar_none_index(test_fillna_df_none_index): # note: none in index should not be filled diff --git a/tests/integ/modin/frame/test_idxmax_idxmin.py b/tests/integ/modin/frame/test_idxmax_idxmin.py index f5a8a6d4b85..72fe88968bc 100644 --- a/tests/integ/modin/frame/test_idxmax_idxmin.py +++ b/tests/integ/modin/frame/test_idxmax_idxmin.py @@ -194,6 +194,26 @@ def test_idxmax_idxmin_with_dates(func, axis): ) +@sql_count_checker(query_count=1) +@pytest.mark.parametrize("func", ["idxmax", "idxmin"]) +@pytest.mark.parametrize("axis", [0, 1]) +@pytest.mark.xfail(reason="SNOW-1625380 TODO") +def test_idxmax_idxmin_with_timedelta(func, axis): + native_df = native_pd.DataFrame( + data={ + "date_1": native_pd.timedelta_range(1, periods=3), + "date_2": [pd.Timedelta(1), pd.Timedelta(-1), pd.Timedelta(0)], + }, + index=[10, 17, 12], + ) + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: getattr(df, func)(axis=axis), + ) + + @sql_count_checker(query_count=1) @pytest.mark.parametrize("func", ["idxmax", "idxmin"]) @pytest.mark.parametrize("axis", [0, 1]) diff --git a/tests/integ/modin/frame/test_insert.py b/tests/integ/modin/frame/test_insert.py index 258d4d2e641..86f5bd8082c 100644 --- a/tests/integ/modin/frame/test_insert.py +++ b/tests/integ/modin/frame/test_insert.py @@ -1,6 +1,8 @@ # # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +import functools + import modin.pandas as pd import numpy as np import pandas as native_pd @@ -768,3 +770,32 @@ def insert_op(df): expected_res = native_df1.join(native_df2["bar"], how="left", sort=False) expected_res = expected_res[["bar", "foo"]] assert_frame_equal(snow_res, expected_res, check_dtype=False) + + +@sql_count_checker(query_count=4, join_count=6) +def test_insert_timedelta(): + native_df = native_pd.DataFrame({"col1": [1, 2], "col2": [3, 4]}) + snow_df = pd.DataFrame(native_df) + + def insert(column, vals, df): + if isinstance(df, pd.DataFrame) and isinstance(vals, native_pd.Series): + values = pd.Series(vals) + else: + values = vals + df.insert(1, column, values) + return df + + vals = native_pd.timedelta_range(1, periods=2) + eval_snowpark_pandas_result( + snow_df, native_df, functools.partial(insert, "td", vals) + ) + + vals = native_pd.Series(native_pd.timedelta_range(1, periods=2)) + eval_snowpark_pandas_result( + snow_df, native_df, functools.partial(insert, "td2", vals) + ) + + vals = native_pd.Series(native_pd.timedelta_range(1, periods=2), index=[0, 2]) + eval_snowpark_pandas_result( + snow_df, native_df, functools.partial(insert, "td3", vals) + ) diff --git a/tests/integ/modin/frame/test_isin.py b/tests/integ/modin/frame/test_isin.py index c0f0a3ce37b..5fb960518a2 100644 --- a/tests/integ/modin/frame/test_isin.py +++ b/tests/integ/modin/frame/test_isin.py @@ -248,3 +248,25 @@ def test_isin_dataframe_values_type_negative(): ): df = pd.DataFrame([1, 2, 3]) df.isin(values="abcdef") + + +@sql_count_checker(query_count=6) +def test_isin_timedelta(): + native_df = native_pd.DataFrame({"a": [1, 2, 3], "b": [None, 4, 2]}).astype( + "timedelta64[ns]" + ) + snow_df = pd.DataFrame(native_df) + + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: _test_isin_with_snowflake_logic(df, [2, 3], query_count=1), + ) + + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: _test_isin_with_snowflake_logic( + df, [pd.Timedelta(2), pd.Timedelta(3)], query_count=1 + ), + ) diff --git a/tests/integ/modin/frame/test_items.py b/tests/integ/modin/frame/test_items.py index d409a0f326a..9cbd4945ee6 100644 --- a/tests/integ/modin/frame/test_items.py +++ b/tests/integ/modin/frame/test_items.py @@ -51,6 +51,7 @@ def assert_items_results_equal(snow_result, pandas_result) -> None: ), native_pd.DataFrame(index=["a"]), native_pd.DataFrame(columns=["a"]), + native_pd.DataFrame({"ts": native_pd.timedelta_range(10, periods=10)}), ], ) def test_items(dataframe): diff --git a/tests/integ/modin/frame/test_iterrows.py b/tests/integ/modin/frame/test_iterrows.py index 700d1b4ec27..843848c6a5c 100644 --- a/tests/integ/modin/frame/test_iterrows.py +++ b/tests/integ/modin/frame/test_iterrows.py @@ -53,6 +53,7 @@ def assert_iterators_equal(snowpark_iterator, native_iterator): ), # empty df native_pd.DataFrame([]), + native_pd.DataFrame({"ts": native_pd.timedelta_range(10, periods=10)}), ], ) def test_df_iterrows(native_df): diff --git a/tests/integ/modin/frame/test_itertuples.py b/tests/integ/modin/frame/test_itertuples.py index c3687a939c7..eed33f9e1a4 100644 --- a/tests/integ/modin/frame/test_itertuples.py +++ b/tests/integ/modin/frame/test_itertuples.py @@ -37,6 +37,7 @@ native_pd.DataFrame([[1, 1.5], [2, 2.5], [3, 7.8]], columns=["i nt", "flo at"]), # empty df native_pd.DataFrame([]), + native_pd.DataFrame({"ts": native_pd.timedelta_range(10, periods=10)}), ] diff --git a/tests/integ/modin/frame/test_join.py b/tests/integ/modin/frame/test_join.py index 91500189d12..964b6f5426b 100644 --- a/tests/integ/modin/frame/test_join.py +++ b/tests/integ/modin/frame/test_join.py @@ -259,3 +259,23 @@ def test_join_validate_negative(lvalues, rvalues, validate): msg = "Snowpark pandas merge API doesn't yet support 'validate' parameter" with pytest.raises(NotImplementedError, match=msg): left.join(right, validate=validate) + + +@sql_count_checker(query_count=6, join_count=2) +def test_join_timedelta(left, right): + right = right.astype("timedelta64[ns]") + eval_snowpark_pandas_result( + left, + left.to_pandas(), + lambda df: df.join( + right if isinstance(df, pd.DataFrame) else right.to_pandas() + ), + ) + left = left.astype("timedelta64[ns]") + eval_snowpark_pandas_result( + left, + left.to_pandas(), + lambda df: df.join( + right if isinstance(df, pd.DataFrame) else right.to_pandas() + ), + ) diff --git a/tests/integ/modin/frame/test_len.py b/tests/integ/modin/frame/test_len.py index 1adeec50caa..d52df4bf567 100644 --- a/tests/integ/modin/frame/test_len.py +++ b/tests/integ/modin/frame/test_len.py @@ -16,6 +16,7 @@ ({"a": []}, 0), ({"a": [1, 2]}, 2), ({"a": [1, 2], "b": [1, 2], "c": [1, 2]}, 2), + ({"td": native_pd.timedelta_range(1, periods=20)}, 20), ], ) @sql_count_checker(query_count=1) diff --git a/tests/integ/modin/frame/test_mask.py b/tests/integ/modin/frame/test_mask.py index 684d8ba4342..2a5c441f539 100644 --- a/tests/integ/modin/frame/test_mask.py +++ b/tests/integ/modin/frame/test_mask.py @@ -954,3 +954,12 @@ def perform_mask(df): native_df, perform_mask, ) + + +@pytest.mark.xfail(reason="TODO(SNOW-1637101, SNOW-1637102): Support these cases.") +def test_mask_timedelta(test_data): + native_df = native_pd.DataFrame(test_data, dtype="timedelta64[ns]") + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result( + snow_df, native_df, lambda df: df.mask(df > pd.Timedelta(1)) + ) diff --git a/tests/integ/modin/frame/test_melt.py b/tests/integ/modin/frame/test_melt.py index 68d25b1e482..29728f26956 100644 --- a/tests/integ/modin/frame/test_melt.py +++ b/tests/integ/modin/frame/test_melt.py @@ -303,3 +303,22 @@ def test_everything(): value_name="dependent", ), ) + + +@sql_count_checker(query_count=2) +def test_melt_timedelta(): + native_df = npd.DataFrame( + { + "A": {0: "a", 1: "b", 2: "c"}, + "B": {0: 1, 1: 3, 2: 5}, + "C": {0: 2, 1: 4, 2: 6}, + } + ).astype({"B": "timedelta64[ns]", "C": "timedelta64[ns]"}) + snow_df = pd.DataFrame(native_df) + eval_snowpark_pandas_result( + snow_df, native_df, lambda df: df.melt(id_vars=["A"], value_vars=["B"]) + ) + + eval_snowpark_pandas_result( + snow_df, native_df, lambda df: df.melt(id_vars=["A"], value_vars=["B", "C"]) + ) diff --git a/tests/integ/modin/frame/test_merge.py b/tests/integ/modin/frame/test_merge.py index 7ac88042e7f..e1c75d1d853 100644 --- a/tests/integ/modin/frame/test_merge.py +++ b/tests/integ/modin/frame/test_merge.py @@ -1156,3 +1156,56 @@ def test_merge_validate_negative(lvalues, rvalues, validate): msg = "Snowpark pandas merge API doesn't yet support 'validate' parameter" with pytest.raises(NotImplementedError, match=msg): left.merge(right, left_on="A", right_on="B", validate=validate) + + +@sql_count_checker(query_count=4, join_count=4) +def test_merge_timedelta(): + left_df = native_pd.DataFrame( + {"lkey": ["foo", "bar", "baz", "foo"], "value": [1, 2, 3, 5]} + ).astype({"value": "timedelta64[ns]"}) + right_df = native_pd.DataFrame( + {"rkey": ["foo", "bar", "baz", "foo"], "value": [5, 6, 7, 8]} + ).astype({"value": "timedelta64[ns]"}) + eval_snowpark_pandas_result( + pd.DataFrame(left_df), + left_df, + lambda df: df.merge( + pd.DataFrame(right_df) if isinstance(df, pd.DataFrame) else right_df, + left_on="lkey", + right_on="rkey", + ), + ) + + left_df = native_pd.DataFrame({"a": ["foo", "bar"], "b": [1, 2]}).astype( + {"b": "timedelta64[ns]"} + ) + right_df = native_pd.DataFrame({"a": ["foo", "baz"], "c": [3, 4]}).astype( + {"c": "timedelta64[ns]"} + ) + eval_snowpark_pandas_result( + pd.DataFrame(left_df), + left_df, + lambda df: df.merge( + pd.DataFrame(right_df) if isinstance(df, pd.DataFrame) else right_df, + how="inner", + on="a", + ), + ) + + eval_snowpark_pandas_result( + pd.DataFrame(left_df), + left_df, + lambda df: df.merge( + pd.DataFrame(right_df) if isinstance(df, pd.DataFrame) else right_df, + how="right", + on="a", + ), + ) + eval_snowpark_pandas_result( + pd.DataFrame(left_df), + left_df, + lambda df: df.merge( + pd.DataFrame(right_df) if isinstance(df, pd.DataFrame) else right_df, + how="cross", + ), + ) diff --git a/tests/integ/modin/frame/test_nlargest_nsmallest.py b/tests/integ/modin/frame/test_nlargest_nsmallest.py index c32fb64a80e..c528c99d1b0 100644 --- a/tests/integ/modin/frame/test_nlargest_nsmallest.py +++ b/tests/integ/modin/frame/test_nlargest_nsmallest.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # import modin.pandas as pd +import pandas as native_pd import pytest import snowflake.snowpark.modin.plugin # noqa: F401 @@ -124,3 +125,20 @@ def test_nlargest_nsmallest_non_numeric_types(method, data): n = 2 expected_df = snow_df.sort_values("A", ascending=(method == "nsmallest")).head(n) assert_frame_equal(getattr(snow_df, method)(n, "A"), expected_df) + + +@pytest.mark.parametrize("n", [1, 2, 4]) +@pytest.mark.parametrize("columns", ["A", "B", ["A", "B"], ["B", "A"]]) +@pytest.mark.parametrize("keep", ["first", "last"]) +@sql_count_checker(query_count=1) +def test_time_delta_nlargest_nsmallest(method, n, columns, keep): + native_df = native_pd.DataFrame( + {"A": [3, 2, 1, 4, 4], "B": [1, 2, 3, 4, 5]} + ).astype("timedelta64[ns]") + snow_df = pd.DataFrame(native_df) + + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: getattr(df, method)(n, columns=columns, keep=keep), + ) diff --git a/tests/unit/modin/test_snowpark_pandas_types.py b/tests/unit/modin/test_snowpark_pandas_types.py index 36d64d164c8..754f031c1a6 100644 --- a/tests/unit/modin/test_snowpark_pandas_types.py +++ b/tests/unit/modin/test_snowpark_pandas_types.py @@ -14,6 +14,7 @@ SnowparkPandasType, TimedeltaType, ) +from snowflake.snowpark.types import LongType def test_timedelta_type_is_immutable(): @@ -68,3 +69,9 @@ def test_get_snowpark_pandas_type_for_pandas_type(pandas_obj, snowpark_pandas_ty ) def test_TimedeltaType_from_pandas(timedelta, snowpark_pandas_value): assert TimedeltaType.from_pandas(timedelta) == snowpark_pandas_value + + +def test_equals(): + assert TimedeltaType() == TimedeltaType() + assert TimedeltaType() != LongType() + assert LongType() != TimedeltaType() From 2082f8631af5322b2ea33327cbe3cb10a0216ee6 Mon Sep 17 00:00:00 2001 From: azhan Date: Tue, 27 Aug 2024 11:40:43 -0700 Subject: [PATCH 2/6] resolve comments --- CHANGELOG.md | 3 +-- tests/integ/modin/frame/test_equals.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 169776af310..6e0b1017830 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,8 +51,7 @@ #### New Features - Added limited support for the `Timedelta` type, including the following features. Snowpark pandas will raise `NotImplementedError` for unsupported `Timedelta` use cases. - - supporting tracking the Timedelta type through `copy`, `cache_result`, `shift`, `sort_index`, `fillna`, `diff`, `duplicated`, `empty`, `insert`, `isin`, `isna`, `items`, `iterrows`, `join`, `len`, `melt`, `nlargest`, `nsmallest`. - - supporting tracking the Timedelta type through `copy`, `cache_result`, `shift`, `sort_index`. + - supporting tracking the Timedelta type through `copy`, `cache_result`, `shift`, `sort_index`, `assign`, `bfill`, `ffill`, `fillna`, `compare`, `diff`, `drop`, `dropna`, `duplicated`, `empty`, `insert`, `isin`, `isna`, `items`, `iterrows`, `join`, `len`, `melt`, `merge`, `nlargest`, `nsmallest`. - converting non-timedelta to timedelta via `astype`. - `NotImplementedError` will be raised for the rest of methods that do not support `Timedelta`. - support for subtracting two timestamps to get a Timedelta. diff --git a/tests/integ/modin/frame/test_equals.py b/tests/integ/modin/frame/test_equals.py index 45b35e3274f..13d3061cd7b 100644 --- a/tests/integ/modin/frame/test_equals.py +++ b/tests/integ/modin/frame/test_equals.py @@ -25,8 +25,16 @@ ([1, 2, None], [1, 2, None], True), # nulls are considered equal ([1, 2, 3], [1.0, 2.0, 3.0], False), # float and integer types are not equal ([1, 2, 3], ["1", "2", "3"], False), # integer and string types are not equal - # TODO(SNOW-1637101, SNOW-1637102): Support these cases. - # ([1, 2, 3], pandas.timedelta_range(1, periods=3), False), # timedelta and integer types are not equal + pytest.param( + [1, 2, 3], + pandas.timedelta_range(1, periods=3), + False, # timedelta and integer types are not equal + marks=pytest.mark.xfail( + strict=True, + raises=NotImplementedError, + reason="TODO(SNOW-1637101, SNOW-1637102): Support these cases.", + ), + ), ], ) @sql_count_checker(query_count=2, join_count=2) From 62e572e5f13f84742af361103ed3ffd919dca0b5 Mon Sep 17 00:00:00 2001 From: azhan Date: Tue, 27 Aug 2024 15:50:38 -0700 Subject: [PATCH 3/6] add merge tests --- CHANGELOG.md | 2 +- .../modin/plugin/_internal/join_utils.py | 24 ++++++ tests/integ/modin/frame/test_equals.py | 7 +- tests/integ/modin/frame/test_isin.py | 21 +++--- tests/integ/modin/frame/test_mask.py | 2 +- tests/integ/modin/frame/test_melt.py | 11 +-- tests/integ/modin/frame/test_merge.py | 74 ++++++++++--------- 7 files changed, 81 insertions(+), 60 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e0b1017830..ad1c8e9cb95 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,7 +51,7 @@ #### New Features - Added limited support for the `Timedelta` type, including the following features. Snowpark pandas will raise `NotImplementedError` for unsupported `Timedelta` use cases. - - supporting tracking the Timedelta type through `copy`, `cache_result`, `shift`, `sort_index`, `assign`, `bfill`, `ffill`, `fillna`, `compare`, `diff`, `drop`, `dropna`, `duplicated`, `empty`, `insert`, `isin`, `isna`, `items`, `iterrows`, `join`, `len`, `melt`, `merge`, `nlargest`, `nsmallest`. + - supporting tracking the Timedelta type through `copy`, `cache_result`, `shift`, `sort_index`, `assign`, `bfill`, `ffill`, `fillna`, `compare`, `diff`, `drop`, `dropna`, `duplicated`, `empty`, `equals`, `insert`, `isin`, `isna`, `items`, `iterrows`, `join`, `len`, `mask`, `melt`, `merge`, `nlargest`, `nsmallest`. - converting non-timedelta to timedelta via `astype`. - `NotImplementedError` will be raised for the rest of methods that do not support `Timedelta`. - support for subtracting two timestamps to get a Timedelta. diff --git a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py index 331901f1a67..be36b004ed4 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py @@ -172,6 +172,30 @@ def join( JoinTypeLit ), f"Invalid join type: {how}. Allowed values are {get_args(JoinTypeLit)}" + def assert_snowpark_pandas_types_match() -> None: + """If Snowpark pandas types does not match, then a ValueError will be raised.""" + left_types = [ + left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) + for id in left_on + ] + right_types = [ + right.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) + for id in right_on + ] + for i, (lt, rt) in enumerate(zip(left_types, right_types)): + if lt != rt: + left_on_id = left_on[i] + idx = left.data_column_snowflake_quoted_identifiers.index(left_on_id) + key = left.data_column_pandas_labels[idx] + lt = lt if lt is not None else left.get_snowflake_type(left_on_id) + rt = rt if rt is not None else right.get_snowflake_type(right_on[i]) + raise ValueError( + f"You are trying to merge on {type(lt).__name__} and {type(rt).__name__} columns for key '{key}'. " + f"If you wish to proceed you should use pd.concat" + ) + + assert_snowpark_pandas_types_match() + # Re-project the active columns to make sure all active columns of the internal frame participate # in the join operation, and unnecessary columns are dropped from the projected columns. left = left.select_active_columns() diff --git a/tests/integ/modin/frame/test_equals.py b/tests/integ/modin/frame/test_equals.py index 13d3061cd7b..2e2dc2fa129 100644 --- a/tests/integ/modin/frame/test_equals.py +++ b/tests/integ/modin/frame/test_equals.py @@ -25,15 +25,10 @@ ([1, 2, None], [1, 2, None], True), # nulls are considered equal ([1, 2, 3], [1.0, 2.0, 3.0], False), # float and integer types are not equal ([1, 2, 3], ["1", "2", "3"], False), # integer and string types are not equal - pytest.param( + ( [1, 2, 3], pandas.timedelta_range(1, periods=3), False, # timedelta and integer types are not equal - marks=pytest.mark.xfail( - strict=True, - raises=NotImplementedError, - reason="TODO(SNOW-1637101, SNOW-1637102): Support these cases.", - ), ), ], ) diff --git a/tests/integ/modin/frame/test_isin.py b/tests/integ/modin/frame/test_isin.py index 5fb960518a2..cc6113c7466 100644 --- a/tests/integ/modin/frame/test_isin.py +++ b/tests/integ/modin/frame/test_isin.py @@ -250,8 +250,15 @@ def test_isin_dataframe_values_type_negative(): df.isin(values="abcdef") -@sql_count_checker(query_count=6) -def test_isin_timedelta(): +@sql_count_checker(query_count=3) +@pytest.mark.parametrize( + "values", + [ + pytest.param([2, 3], id="integers"), + pytest.param([pd.Timedelta(2), pd.Timedelta(3)], id="timedeltas"), + ], +) +def test_isin_timedelta(values): native_df = native_pd.DataFrame({"a": [1, 2, 3], "b": [None, 4, 2]}).astype( "timedelta64[ns]" ) @@ -260,13 +267,5 @@ def test_isin_timedelta(): eval_snowpark_pandas_result( snow_df, native_df, - lambda df: _test_isin_with_snowflake_logic(df, [2, 3], query_count=1), - ) - - eval_snowpark_pandas_result( - snow_df, - native_df, - lambda df: _test_isin_with_snowflake_logic( - df, [pd.Timedelta(2), pd.Timedelta(3)], query_count=1 - ), + lambda df: _test_isin_with_snowflake_logic(df, values, query_count=1), ) diff --git a/tests/integ/modin/frame/test_mask.py b/tests/integ/modin/frame/test_mask.py index 2a5c441f539..e490f34e905 100644 --- a/tests/integ/modin/frame/test_mask.py +++ b/tests/integ/modin/frame/test_mask.py @@ -956,7 +956,7 @@ def perform_mask(df): ) -@pytest.mark.xfail(reason="TODO(SNOW-1637101, SNOW-1637102): Support these cases.") +@sql_count_checker(query_count=1) def test_mask_timedelta(test_data): native_df = native_pd.DataFrame(test_data, dtype="timedelta64[ns]") snow_df = pd.DataFrame(native_df) diff --git a/tests/integ/modin/frame/test_melt.py b/tests/integ/modin/frame/test_melt.py index 29728f26956..0812bb2c60c 100644 --- a/tests/integ/modin/frame/test_melt.py +++ b/tests/integ/modin/frame/test_melt.py @@ -305,8 +305,9 @@ def test_everything(): ) -@sql_count_checker(query_count=2) -def test_melt_timedelta(): +@sql_count_checker(query_count=1) +@pytest.mark.parametrize("value_vars", [["B"], ["B", "C"]]) +def test_melt_timedelta(value_vars): native_df = npd.DataFrame( { "A": {0: "a", 1: "b", 2: "c"}, @@ -316,9 +317,5 @@ def test_melt_timedelta(): ).astype({"B": "timedelta64[ns]", "C": "timedelta64[ns]"}) snow_df = pd.DataFrame(native_df) eval_snowpark_pandas_result( - snow_df, native_df, lambda df: df.melt(id_vars=["A"], value_vars=["B"]) - ) - - eval_snowpark_pandas_result( - snow_df, native_df, lambda df: df.melt(id_vars=["A"], value_vars=["B", "C"]) + snow_df, native_df, lambda df: df.melt(id_vars=["A"], value_vars=value_vars) ) diff --git a/tests/integ/modin/frame/test_merge.py b/tests/integ/modin/frame/test_merge.py index e1c75d1d853..c1ced99fc67 100644 --- a/tests/integ/modin/frame/test_merge.py +++ b/tests/integ/modin/frame/test_merge.py @@ -1158,8 +1158,8 @@ def test_merge_validate_negative(lvalues, rvalues, validate): left.merge(right, left_on="A", right_on="B", validate=validate) -@sql_count_checker(query_count=4, join_count=4) -def test_merge_timedelta(): +@sql_count_checker(query_count=1, join_count=1) +def test_merge_timedelta_on(): left_df = native_pd.DataFrame( {"lkey": ["foo", "bar", "baz", "foo"], "value": [1, 2, 3, 5]} ).astype({"value": "timedelta64[ns]"}) @@ -1176,36 +1176,42 @@ def test_merge_timedelta(): ), ) - left_df = native_pd.DataFrame({"a": ["foo", "bar"], "b": [1, 2]}).astype( - {"b": "timedelta64[ns]"} - ) - right_df = native_pd.DataFrame({"a": ["foo", "baz"], "c": [3, 4]}).astype( - {"c": "timedelta64[ns]"} - ) - eval_snowpark_pandas_result( - pd.DataFrame(left_df), - left_df, - lambda df: df.merge( - pd.DataFrame(right_df) if isinstance(df, pd.DataFrame) else right_df, - how="inner", - on="a", - ), - ) - eval_snowpark_pandas_result( - pd.DataFrame(left_df), - left_df, - lambda df: df.merge( - pd.DataFrame(right_df) if isinstance(df, pd.DataFrame) else right_df, - how="right", - on="a", - ), - ) - eval_snowpark_pandas_result( - pd.DataFrame(left_df), - left_df, - lambda df: df.merge( - pd.DataFrame(right_df) if isinstance(df, pd.DataFrame) else right_df, - how="cross", - ), - ) +@pytest.mark.parametrize( + "kwargs", + [ + {"how": "inner", "on": "a"}, + {"how": "right", "on": "a"}, + {"how": "right", "on": "b"}, + {"how": "left", "on": "c"}, + {"how": "cross"}, + ], +) +def test_merge_timedelta_how(kwargs): + left_df = native_pd.DataFrame( + {"a": ["foo", "bar"], "b": [1, 2], "c": [3, 5]} + ).astype({"b": "timedelta64[ns]"}) + right_df = native_pd.DataFrame( + {"a": ["foo", "baz"], "b": [1, 3], "c": [3, 4]} + ).astype({"b": "timedelta64[ns]", "c": "timedelta64[ns]"}) + count = 1 + expect_exception = False + if "c" == kwargs.get("on", None): # merge timedelta with int exception + expect_exception = True + count = 0 + + with SqlCounter(query_count=count, join_count=count): + eval_snowpark_pandas_result( + pd.DataFrame(left_df), + left_df, + lambda df: df.merge( + pd.DataFrame(right_df) if isinstance(df, pd.DataFrame) else right_df, + **kwargs, + ), + expect_exception=expect_exception, + expect_exception_match="You are trying to merge on LongType and TimedeltaType columns for key 'c'. If you " + "wish to proceed you should use pd.concat", + expect_exception_type=ValueError, + assert_exception_equal=False, # pandas exception: You are trying to merge on int64 and timedelta64[ns] + # columns for key 'c'. If you wish to proceed you should use pd.concat + ) From e79b62913c8cfb04ff59743949c95feae07e21f8 Mon Sep 17 00:00:00 2001 From: azhan Date: Tue, 27 Aug 2024 17:45:29 -0700 Subject: [PATCH 4/6] fix tests --- .../snowpark/modin/plugin/_internal/snowpark_pandas_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/snowpark_pandas_types.py b/src/snowflake/snowpark/modin/plugin/_internal/snowpark_pandas_types.py index a8806d7c90d..551a53b804c 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/snowpark_pandas_types.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/snowpark_pandas_types.py @@ -118,7 +118,7 @@ class SnowparkPandasColumn(NamedTuple): snowpark_pandas_type: Optional[SnowparkPandasType] -class TimedeltaType(SnowparkPandasType): +class TimedeltaType(SnowparkPandasType, LongType): """ Timedelta represents the difference between two times. From dc619af411a24e64b3bae1532c64e29f1304cfe6 Mon Sep 17 00:00:00 2001 From: azhan Date: Tue, 27 Aug 2024 17:56:13 -0700 Subject: [PATCH 5/6] improve timedeltaType equals --- .../snowpark/modin/plugin/_internal/binary_op_utils.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py index a0ca357c59b..6d79de24ffb 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/binary_op_utils.py @@ -512,10 +512,8 @@ def are_equal_types(type1: DataType, type2: DataType) -> bool: Returns: True if given types are equal, False otherwise. """ - if isinstance(type1, TimedeltaType) and not isinstance(type2, TimedeltaType): - return False - if isinstance(type2, TimedeltaType) and not isinstance(type1, TimedeltaType): - return False + if isinstance(type1, TimedeltaType) or isinstance(type2, TimedeltaType): + return type1 == type2 if isinstance(type1, _IntegralType) and isinstance(type2, _IntegralType): return True if isinstance(type1, _FractionalType) and isinstance(type2, _FractionalType): From 40994a3e2dcad4d28d2414d68e803a77669c3ca4 Mon Sep 17 00:00:00 2001 From: azhan Date: Wed, 28 Aug 2024 09:20:37 -0700 Subject: [PATCH 6/6] resolve comments and fix tests --- src/snowflake/snowpark/modin/plugin/_internal/join_utils.py | 2 +- .../modin/plugin/compiler/snowflake_query_compiler.py | 2 +- tests/integ/modin/frame/test_isin.py | 2 +- tests/integ/modin/frame/test_iterrows.py | 2 +- tests/integ/modin/series/test_shift.py | 6 +----- 5 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py index be36b004ed4..846f3c64079 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/join_utils.py @@ -173,7 +173,7 @@ def join( ), f"Invalid join type: {how}. Allowed values are {get_args(JoinTypeLit)}" def assert_snowpark_pandas_types_match() -> None: - """If Snowpark pandas types does not match, then a ValueError will be raised.""" + """If Snowpark pandas types do not match, then a ValueError will be raised.""" left_types = [ left.snowflake_quoted_identifier_to_snowpark_pandas_type.get(id, None) for id in left_on diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index a0efa4ef601..e13c77f8ec3 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -1499,7 +1499,7 @@ def _shift_values_axis_0( row_position_quoted_identifier = frame.row_position_snowflake_quoted_identifier fill_value_dtype = infer_object_type(fill_value) - fill_value = pandas_lit(fill_value) if fill_value is not None else None + fill_value = None if pd.isna(fill_value) else pandas_lit(fill_value) def shift_expression_and_type( quoted_identifier: str, dtype: DataType diff --git a/tests/integ/modin/frame/test_isin.py b/tests/integ/modin/frame/test_isin.py index cc6113c7466..cd560a5715a 100644 --- a/tests/integ/modin/frame/test_isin.py +++ b/tests/integ/modin/frame/test_isin.py @@ -260,7 +260,7 @@ def test_isin_dataframe_values_type_negative(): ) def test_isin_timedelta(values): native_df = native_pd.DataFrame({"a": [1, 2, 3], "b": [None, 4, 2]}).astype( - "timedelta64[ns]" + {"b": "timedelta64[ns]"} ) snow_df = pd.DataFrame(native_df) diff --git a/tests/integ/modin/frame/test_iterrows.py b/tests/integ/modin/frame/test_iterrows.py index 843848c6a5c..fc415b2daf5 100644 --- a/tests/integ/modin/frame/test_iterrows.py +++ b/tests/integ/modin/frame/test_iterrows.py @@ -53,7 +53,7 @@ def assert_iterators_equal(snowpark_iterator, native_iterator): ), # empty df native_pd.DataFrame([]), - native_pd.DataFrame({"ts": native_pd.timedelta_range(10, periods=10)}), + native_pd.DataFrame({"ts": native_pd.timedelta_range(10, periods=4)}), ], ) def test_df_iterrows(native_df): diff --git a/tests/integ/modin/series/test_shift.py b/tests/integ/modin/series/test_shift.py index 7f27c4d313b..f5d4169026e 100644 --- a/tests/integ/modin/series/test_shift.py +++ b/tests/integ/modin/series/test_shift.py @@ -46,11 +46,7 @@ def test_series_with_values_shift(series, periods, fill_value): lambda s: s.shift( periods=periods, fill_value=pd.Timedelta(fill_value) - if isinstance( - s, native_pd.Series - ) # pandas does not support fill int to timedelta - and s.dtype == "timedelta64[ns]" - and fill_value is not no_default + if s.dtype == "timedelta64[ns]" and fill_value is not no_default else fill_value, ), )