diff --git a/CHANGELOG.md b/CHANGELOG.md index aed3cb3dab2..f594f31d2e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,7 @@ - Added support for `Index.equals`. - Added support for `Index.value_counts`. - Added support for `Series.dt.day_name` and `Series.dt.month_name`. +- Added support for indexing on Index, e.g., `df.index[:10]`. #### Improvements - Removed the public preview warning message upon importing Snowpark pandas. diff --git a/src/snowflake/snowpark/modin/pandas/indexing.py b/src/snowflake/snowpark/modin/pandas/indexing.py index 1dc2b7c23dd..89460b35340 100644 --- a/src/snowflake/snowpark/modin/pandas/indexing.py +++ b/src/snowflake/snowpark/modin/pandas/indexing.py @@ -1132,6 +1132,9 @@ def __getitem__( # Convert all scalar, list-like, and indexer row_loc to a Series object to get a query compiler object. if is_scalar(row_loc): row_loc = pd.Series([row_loc]) + elif isinstance(row_loc, pd.Index): + # Convert index row_loc to series + row_loc = row_loc.to_series().reset_index(drop=True) elif is_list_like(row_loc): if hasattr(row_loc, "dtype"): dtype = row_loc.dtype diff --git a/src/snowflake/snowpark/modin/plugin/extensions/index.py b/src/snowflake/snowpark/modin/plugin/extensions/index.py index 0c0457b4679..592c6d53742 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/index.py @@ -2354,19 +2354,18 @@ def __len__(self) -> int: @is_lazy_check def __getitem__(self, key: Any) -> np.ndarray | None | Index: """ - Override numpy.ndarray's __getitem__ method to work as desired. - - This function adds lists and Series as valid boolean indexers - (ndarrays only supports ndarray with dtype=bool). - - If resulting ndim != 1, plain ndarray is returned instead of - corresponding `Index` subclass. + Reuse series iloc to implement getitem for index. """ - WarningMessage.index_to_pandas_warning("__getitem__") - item = self.to_pandas().__getitem__(key=key) - if isinstance(item, native_pd.Index): - return Index(item, convert_to_lazy=self.is_lazy) - return item + try: + res = self.to_series().iloc[key] + if isinstance(res, Series): + res = res.index + return res + except IndexError as ie: + raise IndexError( + "only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or " + "boolean arrays are valid indices" + ) from ie @is_lazy_check def __setitem__(self, key: Any, value: Any) -> None: diff --git a/tests/integ/modin/frame/test_iloc.py b/tests/integ/modin/frame/test_iloc.py index 5e43397b0bc..5df42ec98ec 100644 --- a/tests/integ/modin/frame/test_iloc.py +++ b/tests/integ/modin/frame/test_iloc.py @@ -714,9 +714,8 @@ def iloc_helper(df): # One extra query for index conversion to series to set item query_count = ( - 2 if "index" in key_type or (key_type == "series" and axis == "col") else 1 + 2 if ("index" in key_type or key_type == "series") and axis == "col" else 1 ) - expected_join_count = 0 if axis == "row": if key == [] and key_type in ["list", "ndarray"]: expected_join_count = 2 @@ -952,7 +951,7 @@ def iloc_helper(df): # one extra query for index conversion to series to set item query_count = ( - 2 if "index" in key_type or (key_type == "series" and axis == "col") else 1 + 2 if ("index" in key_type or key_type == "series") and axis == "col" else 1 ) join_count = 2 if axis == "row" else 0 @@ -1300,15 +1299,7 @@ def test_df_iloc_get_non_numeric_key_negative( if isinstance(key, native_pd.Index): key = pd.Index(key) - # 2 extra queries for repr - # 1 extra query to convert index to series if row case - with SqlCounter( - query_count=3 - if isinstance(key, pd.Index) and axis == "row" - else 2 - if isinstance(key, pd.Index) - else 0 - ): + with SqlCounter(query_count=2 if isinstance(key, pd.Index) else 0): # General case fails with TypeError. error_msg = re.escape(f".iloc requires numeric indexers, got {key}") with pytest.raises(IndexError, match=error_msg): diff --git a/tests/integ/modin/frame/test_set_index.py b/tests/integ/modin/frame/test_set_index.py index fdfd3de001d..6bd6d754278 100644 --- a/tests/integ/modin/frame/test_set_index.py +++ b/tests/integ/modin/frame/test_set_index.py @@ -351,8 +351,8 @@ def test_set_index_pass_multiindex(drop, append, native_df): @pytest.mark.parametrize( "keys, expected_query_count", [ - (["a"], 4), - ([[1, 6, 6]], 6), + (["a"], 3), + ([[1, 6, 6]], 5), ], ) def test_set_index_verify_integrity_negative(native_df, keys, expected_query_count): diff --git a/tests/integ/modin/index/test_indexing.py b/tests/integ/modin/index/test_indexing.py new file mode 100644 index 00000000000..3fc28270623 --- /dev/null +++ b/tests/integ/modin/index/test_indexing.py @@ -0,0 +1,91 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# +import re + +import modin.pandas as pd +import numpy as np +import pandas as native_pd +import pytest + +import snowflake.snowpark.modin.plugin # noqa: F401 +from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker +from tests.integ.modin.utils import assert_index_equal + + +@pytest.mark.parametrize( + "index", + [ + native_pd.Index([1, 2, 3, 4, 5, 6, 7]), + ], +) +@pytest.mark.parametrize( + "key", + [ + 1, + slice(1, None, None), + slice(None, None, -2), + [0, 2, 4], + [True, False, False, True, False, False, True], + ..., + ], +) +def test_index_indexing(index, key): + if isinstance(key, slice) or key is ...: + join_count = 0 # because slice key uses filter not join + elif isinstance(key, list) and isinstance(key[0], bool): + join_count = 1 # because need to join key + else: + join_count = 2 # because need to join key and squeeze + with SqlCounter(query_count=1, join_count=join_count): + if isinstance(key, (slice, list)) or key is ...: + assert_index_equal(pd.Index(index)[key], index[key]) + else: + assert pd.Index(index)[key] == index[key] + + +@pytest.mark.parametrize( + "index", + [ + native_pd.Index([1, 2, 3, 4, 5, 6, 7]), + ], +) +@pytest.mark.parametrize( + "key", + [ + np.array([1, 3, 5]), + native_pd.Index([0, 1]), + native_pd.Series([0, 1]), + ], +) +@sql_count_checker(query_count=1, join_count=2) +def test_index_indexing_other_list_like_key(index, key): + if isinstance(key, native_pd.Index): + key1 = pd.Index(key) + elif isinstance(key, native_pd.Series): + key1 = pd.Series(key) + else: + key1 = key + assert_index_equal(pd.Index(index)[key1], index[key]) + + +@pytest.mark.parametrize( + "index", + [ + native_pd.Index([1, 2, 3, 4, 5, 6, 7]), + ], +) +@pytest.mark.parametrize( + "key", + ["1", ["1"]], +) +@sql_count_checker(query_count=0) +def test_index_indexing_negative(index, key): + ie = ( + "only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid " + "indices" + ) + with pytest.raises(IndexError, match=re.escape(ie)): + index[key] + with pytest.raises(IndexError, match=re.escape(ie)): + pd.Index(index)[key] diff --git a/tests/integ/modin/series/test_iloc.py b/tests/integ/modin/series/test_iloc.py index 051f657a0b6..ac02f368dd8 100644 --- a/tests/integ/modin/series/test_iloc.py +++ b/tests/integ/modin/series/test_iloc.py @@ -198,8 +198,7 @@ def iloc_helper(ser): expected_join_count = 1 - # 1 extra query for converting to series - expected_query_count = 2 if "index" in key_type else 1 + expected_query_count = 1 if key == [] and key_type in ["list", "ndarray"]: expected_join_count += 1 @@ -289,12 +288,9 @@ def iloc_helper(ser): # Index objects have dtype object when empty return - # 1 extra query for converting to series - query_count = 2 if "index" in key_type else 1 - default_index_int_series = pd.Series(default_index_native_int_series) # test ser with default index - with SqlCounter(query_count=query_count, join_count=2): + with SqlCounter(query_count=1, join_count=2): eval_snowpark_pandas_result( default_index_int_series, default_index_native_int_series, @@ -306,7 +302,7 @@ def iloc_helper(ser): native_int_series_with_non_default_index ) # test ser with non default index - with SqlCounter(query_count=query_count, join_count=2): + with SqlCounter(query_count=1, join_count=2): eval_snowpark_pandas_result( int_series_with_non_default_index, native_int_series_with_non_default_index, @@ -316,7 +312,7 @@ def iloc_helper(ser): # test ser with MultiIndex # Index dtype is different between Snowpark and native pandas if key produces empty df. int_series_with_multiindex = pd.Series(multiindex_native_int_series) - with SqlCounter(query_count=query_count, join_count=2): + with SqlCounter(query_count=1, join_count=2): eval_snowpark_pandas_result( int_series_with_multiindex, multiindex_native_int_series, @@ -482,8 +478,8 @@ def test_series_iloc_get_non_numeric_key_negative(key, default_index_native_int_ key = pd.Index(key) snowpark_index_int_series = pd.Series(default_index_native_int_series) error_msg = re.escape(f".iloc requires numeric indexers, got {key}") - # 2 extra queries for repr - with SqlCounter(query_count=2 if isinstance(key, pd.Index) else 0): + # 1 extra queries for repr + with SqlCounter(query_count=1 if isinstance(key, pd.Index) else 0): with pytest.raises(IndexError, match=error_msg): _ = snowpark_index_int_series.iloc[key]