Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1642293 Add support for lazy index labels in reindex and fix reindex name bug #2175

Merged
merged 18 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
- Added support for `Index.is_boolean`, `Index.is_integer`, `Index.is_floating`, `Index.is_numeric`, and `Index.is_object`.
- Added support for `DatetimeIndex.round`, `DatetimeIndex.floor` and `DatetimeIndex.ceil`.
- Added support for `Series.dt.days_in_month` and `Series.dt.daysinmonth`.
- Added support for lazy `Index` objects as `labels` in `DataFrame.reindex` and `Series.reindex`.

#### Improvements

Expand All @@ -86,6 +87,7 @@

- Stopped ignoring nanoseconds in `pd.Timedelta` scalars.
- Fixed AssertionError in tree of binary operations.
- Fixed a bug where `Series.reindex` and `DataFrame.reindex` did not update the result index's name correctly.

#### Behavior Change

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2237,7 +2237,7 @@ def any(
def reindex(
self,
axis: int,
labels: Union[pandas.Index, list[Any]],
labels: Union[pandas.Index, "pd.Index", list[Any]],
sfc-gh-azhan marked this conversation as resolved.
Show resolved Hide resolved
**kwargs: dict[str, Any],
) -> "SnowflakeQueryCompiler":
"""
Expand Down Expand Up @@ -2347,7 +2347,7 @@ def _add_columns_for_monotonicity_checks(

def _reindex_axis_0(
self,
labels: Union[pandas.Index, list[Any]],
labels: Union[pandas.Index, "pd.Index", list[Any]],
**kwargs: dict[str, Any],
) -> "SnowflakeQueryCompiler":
"""
Expand All @@ -2373,7 +2373,13 @@ def _reindex_axis_0(
"""
self._raise_not_implemented_error_for_timedelta()

new_index_qc = pd.Series(labels)._query_compiler
if isinstance(labels, native_pd.Index):
labels = pd.Index(labels)
if isinstance(labels, pd.Index):
new_index_qc = labels.to_series()._query_compiler
else:
new_index_qc = pd.Series(labels)._query_compiler

new_index_modin_frame = new_index_qc._modin_frame
modin_frame = self._modin_frame
method = kwargs.get("method", None)
Expand Down Expand Up @@ -2462,7 +2468,7 @@ def _reindex_axis_0(
data_column_pandas_labels=data_column_pandas_labels,
data_column_snowflake_quoted_identifiers=data_column_snowflake_quoted_identifiers,
data_column_pandas_index_names=modin_frame.data_column_pandas_index_names,
index_column_pandas_labels=modin_frame.index_column_pandas_labels,
index_column_pandas_labels=new_index_modin_frame.index_column_pandas_labels,
index_column_snowflake_quoted_identifiers=result_frame_column_mapper.map_left_quoted_identifiers(
new_index_modin_frame.data_column_snowflake_quoted_identifiers
),
Expand Down
43 changes: 43 additions & 0 deletions tests/integ/modin/frame/test_reindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,3 +569,46 @@ def test_reindex_multiindex_negative(axis):
snow_df.reindex(index=[1, 2, 3])
else:
snow_df.T.reindex(columns=[1, 2, 3])


@sql_count_checker(query_count=1, join_count=1)
def test_reindex_with_index_name():
native_df = native_pd.DataFrame(
[[0, 1, 2], [0, 0, 1], [1, 0, 0]],
index=list("ABC"),
)
snow_df = pd.DataFrame(native_df)
index_with_name = native_pd.Index(list("CAB"), name="weewoo")
assert_snowpark_pandas_equals_to_pandas_without_dtypecheck(
snow_df.reindex(index=index_with_name), native_df.reindex(index=index_with_name)
)


@sql_count_checker(query_count=1, join_count=1)
def test_reindex_with_index_name_and_df_index_name():
native_df = native_pd.DataFrame(
{"X": [1, 2, 3], "Y": [8, 7, 3], "Z": [3, 4, 5]},
index=native_pd.Index(list("ABC"), name="AAAAA"),
)
snow_df = pd.DataFrame(native_df)
index_with_name = native_pd.Index(list("CAB"), name="weewoo")
assert_snowpark_pandas_equals_to_pandas_without_dtypecheck(
snow_df.reindex(index=index_with_name), native_df.reindex(index=index_with_name)
)


@sql_count_checker(query_count=1, join_count=1)
def test_reindex_with_lazy_index():
native_df = native_pd.DataFrame(
[[1, np.nan, 3], [np.nan, 5, np.nan], [7, 8, np.nan]], index=list("XYZ")
)
snow_df = pd.DataFrame(native_df)
native_idx = native_pd.Index(list("CAB"))
lazy_idx = pd.Index(native_idx)
eval_snowpark_pandas_result(
snow_df,
native_df,
lambda df: df.reindex(
index=native_idx if isinstance(df, native_pd.DataFrame) else lazy_idx
),
)
4 changes: 1 addition & 3 deletions tests/integ/modin/index/test_reindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,7 @@ def test_ordered_index_unordered_new_index():
@pytest.mark.parametrize("limit", [None, 1, 2, 100])
@pytest.mark.parametrize("method", ["bfill", "backfill", "pad", "ffill"])
def test_datetime_with_fill(limit, method):
query_count = 2
join_count = 2
with SqlCounter(query_count=query_count, join_count=join_count):
with SqlCounter(query_count=2 if limit is None else 3, join_count=2):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a todo for fixing limit once @sfc-gh-rdurrani 's is monotonic PR is done

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

query count is still the same, so no changes required.

native_date_index = native_pd.date_range("1/1/2010", periods=6, freq="D")
snow_date_index = pd.date_range("1/1/2010", periods=6, freq="D")
assert_reindex_result_equal(
Expand Down
39 changes: 39 additions & 0 deletions tests/integ/modin/series/test_reindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,42 @@ def test_reindex_multiindex_negative():
match="Snowpark pandas doesn't support `reindex` with MultiIndex",
):
snow_series.reindex(index=[1, 2, 3])


@sql_count_checker(query_count=1, join_count=1)
def test_reindex_with_index_name():
native_series = native_pd.Series([0, 1, 2], index=list("ABC"), name="test")
snow_series = pd.Series(native_series)
index_with_name = native_pd.Index(list("CAB"), name="weewoo")
assert_snowpark_pandas_equals_to_pandas_without_dtypecheck(
snow_series.reindex(index=index_with_name),
native_series.reindex(index=index_with_name),
)


@sql_count_checker(query_count=1, join_count=1)
def test_reindex_with_index_name_and_series_index_name():
native_series = native_pd.Series(
[0, 1, 2], index=native_pd.Index(list("ABC"), name="AAAAA"), name="test"
)
snow_series = pd.Series(native_series)
index_with_name = native_pd.Index(list("CAB"), name="weewoo")
assert_snowpark_pandas_equals_to_pandas_without_dtypecheck(
snow_series.reindex(index=index_with_name),
native_series.reindex(index=index_with_name),
)


@sql_count_checker(query_count=1, join_count=1)
def test_reindex_with_lazy_index():
native_series = native_pd.Series([0, 1, 2], index=list("ABC"))
snow_series = pd.Series(native_series)
native_idx = native_pd.Index(list("CAB"))
lazy_idx = pd.Index(native_idx)
eval_snowpark_pandas_result(
snow_series,
native_series,
lambda series: series.reindex(
index=native_idx if isinstance(series, native_pd.Series) else lazy_idx
),
)
Loading