diff --git a/CHANGELOG.md b/CHANGELOG.md index afb4515f69d..3a424de7548 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ - Added partial support for `DataFrame.pct_change` and `Series.pct_change` without the `freq` and `limit` parameters. - Added support for `Series.str.get`. - Added support for `Series.dt.dayofweek`, `Series.dt.day_of_week`, `Series.dt.dayofyear`, and `Series.dt.day_of_year`. +- Added support for `Series.str.__getitem__` (`Series.str[...]`). #### Bug Fixes diff --git a/docs/source/modin/supported/series_str_supported.rst b/docs/source/modin/supported/series_str_supported.rst index 79732758b08..3dc2cc23bd5 100644 --- a/docs/source/modin/supported/series_str_supported.rst +++ b/docs/source/modin/supported/series_str_supported.rst @@ -14,6 +14,9 @@ the method in the left column. | StringMethods | Snowpark implemented? (Y/N/P/D) | Notes for current implementation | | (Series.str) | | | +-----------------------------+---------------------------------+----------------------------------------------------+ +| ``__getitem__`` | P | ``N`` if the `key` parameter is set to a non-int | +| | | scalar value. | ++-----------------------------+---------------------------------+----------------------------------------------------+ | ``capitalize`` | Y | | +-----------------------------+---------------------------------+----------------------------------------------------+ | ``casefold`` | N | | diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 714648d949b..a02bf4e108c 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -6,6 +6,7 @@ import json import logging import re +import typing from collections.abc import Hashable, Iterable, Mapping, Sequence from datetime import timedelta, tzinfo from typing import Any, Callable, Literal, Optional, Union, get_args @@ -12190,6 +12191,34 @@ def _get_regex_params(self, flags: int = 0) -> str: params = params + "s" return params + def str___getitem__(self, key: Union[Scalar, slice]) -> "SnowflakeQueryCompiler": + """ + Retrieve character(s) or substring(s) from each element in the Series or Index according to `key`. + + Parameters + ---------- + key : scalar or slice + Index to retrieve data from. + + Returns + ------- + SnowflakeQueryCompiler representing result of the string operation. + """ + if not is_scalar(key) and not isinstance(key, slice): + # Follow pandas behavior; all values will be None. + key = None + if is_scalar(key): + if key is not None and not isinstance(key, int): + ErrorMessage.not_implemented( + "Snowpark pandas string indexing doesn't yet support non-numeric keys" + ) + return self.str_get(typing.cast(int, key)) + else: + assert isinstance(key, slice), "key is expected to be slice here" + if key.step == 0: + raise ValueError("slice step cannot be zero") + return self.str_slice(key.start, key.stop, key.step) + def str_center(self, width: int, fillchar: str = " ") -> None: ErrorMessage.method_not_implemented_error("center", "Series.str") diff --git a/tests/integ/modin/series/test_str_accessor.py b/tests/integ/modin/series/test_str_accessor.py index 26749695a0b..f9e624e9099 100644 --- a/tests/integ/modin/series/test_str_accessor.py +++ b/tests/integ/modin/series/test_str_accessor.py @@ -182,6 +182,62 @@ def test_str_get_neg(): snow_ser.str.get(i="a") +@pytest.mark.parametrize( + "key", + [ + None, + [1, 2], + (1, 2), + {1: "a", 2: "b"}, + -100, + -2, + -1, + 0, + 1, + 2, + 100, + slice(None, None, None), + slice(0, -1, 1), + slice(-1, 0, -1), + slice(0, -1, 2), + slice(-1, 0, -2), + slice(-100, 100, 2), + slice(100, -100, -2), + ], +) +@sql_count_checker(query_count=1) +def test_str___getitem__(key): + native_ser = native_pd.Series(TEST_DATA) + snow_ser = pd.Series(native_ser) + eval_snowpark_pandas_result( + snow_ser, + native_ser, + lambda ser: ser.str[key], + ) + + +@sql_count_checker(query_count=0) +def test_str___getitem___zero_step(): + native_ser = native_pd.Series(TEST_DATA) + snow_ser = pd.Series(native_ser) + with pytest.raises( + ValueError, + match="slice step cannot be zero", + ): + snow_ser.str[slice(None, None, 0)] + + +@sql_count_checker(query_count=0) +def test_str___getitem___string_key(): + native_ser = native_pd.Series(TEST_DATA) + snow_ser = pd.Series(native_ser) + with pytest.raises( + NotImplementedError, + match="Snowpark pandas string indexing doesn't yet support non-numeric keys", + ): + snow_ser.str["a"] + + @pytest.mark.parametrize("start", [None, -100, -2, -1, 0, 1, 2, 100]) @pytest.mark.parametrize("stop", [None, -100, -2, -1, 0, 1, 2, 100]) @pytest.mark.parametrize("step", [None, -100, -2, -1, 1, 2, 100])