diff --git a/CHANGELOG.md b/CHANGELOG.md index 67271108fba..3460c106bf4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -84,6 +84,7 @@ - Added support for `Index.is_monotonic_increasing` and `Index.is_monotonic_decreasing`. - Added support for `pd.crosstab`. - Added support for `pd.bdate_range` and included business frequency support (B, BME, BMS, BQE, BQS, BYE, BYS) for both `pd.date_range` and `pd.bdate_range`. +- Added support for lazy `Index` objects as `labels` in `DataFrame.reindex` and `Series.reindex`. #### Improvements @@ -97,6 +98,7 @@ - Fixed AssertionError in tree of binary operations. - Fixed bug in `Series.dt.isocalendar` using a named Series - Fixed `inplace` argument for Series objects derived from DataFrame columns. +- Fixed a bug where `Series.reindex` and `DataFrame.reindex` did not update the result index's name correctly. #### Behavior Change diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index a6f7b62b58c..831bbcec8f5 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -2260,7 +2260,7 @@ def any( def reindex( self, axis: int, - labels: Union[pandas.Index, list[Any]], + labels: Union[pandas.Index, "pd.Index", list[Any]], **kwargs: dict[str, Any], ) -> "SnowflakeQueryCompiler": """ @@ -2468,7 +2468,7 @@ def _add_columns_for_monotonicity_checks( def _reindex_axis_0( self, - labels: Union[pandas.Index, list[Any]], + labels: Union[pandas.Index, "pd.Index", list[Any]], **kwargs: dict[str, Any], ) -> "SnowflakeQueryCompiler": """ @@ -2494,7 +2494,13 @@ def _reindex_axis_0( """ self._raise_not_implemented_error_for_timedelta() - new_index_qc = pd.Series(labels)._query_compiler + if isinstance(labels, native_pd.Index): + labels = pd.Index(labels) + if isinstance(labels, pd.Index): + new_index_qc = labels.to_series()._query_compiler + else: + new_index_qc = pd.Series(labels)._query_compiler + new_index_modin_frame = new_index_qc._modin_frame modin_frame = self._modin_frame method = kwargs.get("method", None) @@ -2583,7 +2589,7 @@ def _reindex_axis_0( data_column_pandas_labels=data_column_pandas_labels, data_column_snowflake_quoted_identifiers=data_column_snowflake_quoted_identifiers, data_column_pandas_index_names=modin_frame.data_column_pandas_index_names, - index_column_pandas_labels=modin_frame.index_column_pandas_labels, + index_column_pandas_labels=new_index_modin_frame.index_column_pandas_labels, index_column_snowflake_quoted_identifiers=result_frame_column_mapper.map_left_quoted_identifiers( new_index_modin_frame.data_column_snowflake_quoted_identifiers ), diff --git a/tests/integ/modin/frame/test_reindex.py b/tests/integ/modin/frame/test_reindex.py index 692fd66471f..de1aacd786e 100644 --- a/tests/integ/modin/frame/test_reindex.py +++ b/tests/integ/modin/frame/test_reindex.py @@ -569,3 +569,46 @@ def test_reindex_multiindex_negative(axis): snow_df.reindex(index=[1, 2, 3]) else: snow_df.T.reindex(columns=[1, 2, 3]) + + +@sql_count_checker(query_count=1, join_count=1) +def test_reindex_with_index_name(): + native_df = native_pd.DataFrame( + [[0, 1, 2], [0, 0, 1], [1, 0, 0]], + index=list("ABC"), + ) + snow_df = pd.DataFrame(native_df) + index_with_name = native_pd.Index(list("CAB"), name="weewoo") + assert_snowpark_pandas_equals_to_pandas_without_dtypecheck( + snow_df.reindex(index=index_with_name), native_df.reindex(index=index_with_name) + ) + + +@sql_count_checker(query_count=1, join_count=1) +def test_reindex_with_index_name_and_df_index_name(): + native_df = native_pd.DataFrame( + {"X": [1, 2, 3], "Y": [8, 7, 3], "Z": [3, 4, 5]}, + index=native_pd.Index(list("ABC"), name="AAAAA"), + ) + snow_df = pd.DataFrame(native_df) + index_with_name = native_pd.Index(list("CAB"), name="weewoo") + assert_snowpark_pandas_equals_to_pandas_without_dtypecheck( + snow_df.reindex(index=index_with_name), native_df.reindex(index=index_with_name) + ) + + +@sql_count_checker(query_count=1, join_count=1) +def test_reindex_with_lazy_index(): + native_df = native_pd.DataFrame( + [[1, np.nan, 3], [np.nan, 5, np.nan], [7, 8, np.nan]], index=list("XYZ") + ) + snow_df = pd.DataFrame(native_df) + native_idx = native_pd.Index(list("CAB")) + lazy_idx = pd.Index(native_idx) + eval_snowpark_pandas_result( + snow_df, + native_df, + lambda df: df.reindex( + index=native_idx if isinstance(df, native_pd.DataFrame) else lazy_idx + ), + ) diff --git a/tests/integ/modin/index/test_reindex.py b/tests/integ/modin/index/test_reindex.py index 39a15322ca0..c33b7465461 100644 --- a/tests/integ/modin/index/test_reindex.py +++ b/tests/integ/modin/index/test_reindex.py @@ -128,9 +128,7 @@ def test_ordered_index_unordered_new_index(): @pytest.mark.parametrize("limit", [None, 1, 2, 100]) @pytest.mark.parametrize("method", ["bfill", "backfill", "pad", "ffill"]) def test_datetime_with_fill(limit, method): - query_count = 2 - join_count = 2 - with SqlCounter(query_count=query_count, join_count=join_count): + with SqlCounter(query_count=2 if limit is None else 3, join_count=2): native_date_index = native_pd.date_range("1/1/2010", periods=6, freq="D") snow_date_index = pd.date_range("1/1/2010", periods=6, freq="D") assert_reindex_result_equal( diff --git a/tests/integ/modin/series/test_reindex.py b/tests/integ/modin/series/test_reindex.py index 7c2bbba906e..0c9aa353d0a 100644 --- a/tests/integ/modin/series/test_reindex.py +++ b/tests/integ/modin/series/test_reindex.py @@ -376,3 +376,42 @@ def test_reindex_multiindex_negative(): match="Snowpark pandas doesn't support `reindex` with MultiIndex", ): snow_series.reindex(index=[1, 2, 3]) + + +@sql_count_checker(query_count=1, join_count=1) +def test_reindex_with_index_name(): + native_series = native_pd.Series([0, 1, 2], index=list("ABC"), name="test") + snow_series = pd.Series(native_series) + index_with_name = native_pd.Index(list("CAB"), name="weewoo") + assert_snowpark_pandas_equals_to_pandas_without_dtypecheck( + snow_series.reindex(index=index_with_name), + native_series.reindex(index=index_with_name), + ) + + +@sql_count_checker(query_count=1, join_count=1) +def test_reindex_with_index_name_and_series_index_name(): + native_series = native_pd.Series( + [0, 1, 2], index=native_pd.Index(list("ABC"), name="AAAAA"), name="test" + ) + snow_series = pd.Series(native_series) + index_with_name = native_pd.Index(list("CAB"), name="weewoo") + assert_snowpark_pandas_equals_to_pandas_without_dtypecheck( + snow_series.reindex(index=index_with_name), + native_series.reindex(index=index_with_name), + ) + + +@sql_count_checker(query_count=1, join_count=1) +def test_reindex_with_lazy_index(): + native_series = native_pd.Series([0, 1, 2], index=list("ABC")) + snow_series = pd.Series(native_series) + native_idx = native_pd.Index(list("CAB")) + lazy_idx = pd.Index(native_idx) + eval_snowpark_pandas_result( + snow_series, + native_series, + lambda series: series.reindex( + index=native_idx if isinstance(series, native_pd.Series) else lazy_idx + ), + )