From 983d65c11f6c24e841e0beaea9f913e0f8725920 Mon Sep 17 00:00:00 2001 From: Naresh Kumar <113932371+sfc-gh-nkumar@users.noreply.github.com> Date: Tue, 25 Jun 2024 15:04:44 -0700 Subject: [PATCH] SNOW-1453559: Add support for Series.case_when (#1800) --- CHANGELOG.md | 1 + docs/source/modin/series.rst | 1 + .../modin/supported/series_supported.rst | 2 +- src/snowflake/snowpark/modin/pandas/series.py | 1 - .../compiler/snowflake_query_compiler.py | 132 +++++++++++++++++- .../modin/plugin/docstrings/series.py | 41 ++++++ tests/integ/modin/series/test_case_when.py | 115 +++++++++++++++ 7 files changed, 290 insertions(+), 3 deletions(-) create mode 100644 tests/integ/modin/series/test_case_when.py diff --git a/CHANGELOG.md b/CHANGELOG.md index b7e3d88ecf2..0d34d53844a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -69,6 +69,7 @@ - Added support for `replace` and `frac > 1` in `DataFrame.sample` and `Series.sample`. - Added support for `Series.at`, `Series.iat`, `DataFrame.at`, and `DataFrame.iat`. - Added support for `Series.dt.isocalendar`. +- Added support for `Series.case_when` except when condition or replacement is callable. #### Bug Fixes diff --git a/docs/source/modin/series.rst b/docs/source/modin/series.rst index 57d13e66b18..2bf450e9447 100644 --- a/docs/source/modin/series.rst +++ b/docs/source/modin/series.rst @@ -151,6 +151,7 @@ Series .. autosummary:: :toctree: pandas_api/ + Series.case_when Series.drop Series.drop_duplicates Series.get diff --git a/docs/source/modin/supported/series_supported.rst b/docs/source/modin/supported/series_supported.rst index dffba6bb9f1..f3c088ec485 100644 --- a/docs/source/modin/supported/series_supported.rst +++ b/docs/source/modin/supported/series_supported.rst @@ -123,7 +123,7 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``bool`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``case_when`` | N | | | +| ``case_when`` | P | | ``N`` if condition or replacement is a callable. | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``clip`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/src/snowflake/snowpark/modin/pandas/series.py b/src/snowflake/snowpark/modin/pandas/series.py index 39e092eed92..e7fd0aa8fce 100644 --- a/src/snowflake/snowpark/modin/pandas/series.py +++ b/src/snowflake/snowpark/modin/pandas/series.py @@ -1108,7 +1108,6 @@ def factorize( use_na_sentinel=use_na_sentinel, ) - @series_not_implemented() def case_when(self, caselist) -> Series: # noqa: PR01, RT01, D200 """ Replace values where the conditions are True. diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 5ad9dd8b3bd..13b862b4282 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -10,7 +10,7 @@ import uuid from collections.abc import Hashable, Iterable, Mapping, Sequence from datetime import timedelta, tzinfo -from typing import Any, Callable, Literal, Optional, Union, get_args +from typing import Any, Callable, List, Literal, Optional, Union, get_args import numpy as np import numpy.typing as npt @@ -8041,6 +8041,136 @@ def getitem_row_array( return SnowflakeQueryCompiler(new_frame) + def case_when(self, caselist: List[tuple]) -> "SnowflakeQueryCompiler": + """ + Replace values where the conditions are True. + + Args: + caselist: A list of tuples of conditions and expected replacements + Takes the form: (condition0, replacement0), (condition1, replacement1), … . + condition should be a 1-D boolean array-like object or a callable. + replacement should be a 1-D array-like object, a scalar or a callable. + + Returns: + New QueryCompiler with replacements. + """ + # Validate caselist. Errors raised are same as native pandas. + if not isinstance(caselist, list): + # modin frotnend always passes a list, but we still keep this check to guard + # against any breaking changes in frontend layer. + raise TypeError( + f"The caselist argument should be a list; instead got {type(caselist)}" + ) + if not caselist: + raise ValueError( + "provide at least one boolean condition, with a corresponding replacement." + ) + + # Validate entries in caselist. Errors raised are same as native pandas. + for num, entry in enumerate(caselist): + if not isinstance(entry, tuple): + # modin frotnend always passes a tuple, but we still eep this check to + # guard against any breaking changes in frontend layer. + raise TypeError( + f"Argument {num} must be a tuple; instead got {type(entry)}." + ) + if len(entry) != 2: + raise ValueError( + f"Argument {num} must have length 2; " + "a condition and replacement; " + f"instead got length {len(entry)}." + ) + + orig_frame = self._modin_frame + joined_frame = self._modin_frame + case_expr: Optional[CaseExpr] = None + for cond, replacement in caselist: + if isinstance(cond, SnowflakeQueryCompiler): + joined_frame, _ = join_utils.align_on_index( + joined_frame, cond._modin_frame, "left" + ) + elif is_list_like(cond): + cond_frame = self.from_pandas( + pandas.DataFrame(cond) + )._modin_frame.ensure_row_position_column() + joined_frame = joined_frame.ensure_row_position_column() + joined_frame, _ = join_utils.join( + joined_frame, + cond_frame, + how="left", + left_on=[joined_frame.row_position_snowflake_quoted_identifier], + right_on=[cond_frame.row_position_snowflake_quoted_identifier], + ) + elif callable(cond): + # TODO SNOW-1489503: Add support for callable + ErrorMessage.not_implemented( + "Snowpark pandas method Series.case_when doesn't yet support callable as condition" + ) + else: + raise TypeError( + f"condition must be a Series or 1-D array-like object; instead got {type(cond)}" + ) + + # if indices are misaligned treat the condition as True + cond_expr = coalesce( + col(joined_frame.data_column_snowflake_quoted_identifiers[-1]), + pandas_lit(True), + ) + if isinstance(replacement, SnowflakeQueryCompiler): + joined_frame, _ = join_utils.align_on_index( + joined_frame, replacement._modin_frame, "left" + ) + value = col(joined_frame.data_column_snowflake_quoted_identifiers[-1]) + elif is_scalar(replacement): + value = pandas_lit(replacement) + elif is_list_like(replacement): + repl_frame = self.from_pandas( + pandas.DataFrame(replacement) + )._modin_frame.ensure_row_position_column() + joined_frame = joined_frame.ensure_row_position_column() + joined_frame, _ = join_utils.join( + joined_frame, + repl_frame, + how="left", + left_on=[joined_frame.row_position_snowflake_quoted_identifier], + right_on=[repl_frame.row_position_snowflake_quoted_identifier], + ) + value = col(joined_frame.data_column_snowflake_quoted_identifiers[-1]) + elif callable(replacement): + # TODO SNOW-1489503: Add support for callable + ErrorMessage.not_implemented( + "Snowpark pandas method Series.case_when doesn't yet support callable as replacement" + ) + else: + raise TypeError( + f"replacement must be a Series, 1-D array-like object or scalar; instead got {type(replacement)}" + ) + + case_expr = ( + when(cond_expr, value) + if case_expr is None + else case_expr.when(cond_expr, value) + ) + orig_col = col(joined_frame.data_column_snowflake_quoted_identifiers[0]) + case_expr = orig_col if case_expr is None else case_expr.otherwise(orig_col) + ( + joined_frame, + _, + ) = joined_frame.update_snowflake_quoted_identifiers_with_expressions( + {joined_frame.data_column_snowflake_quoted_identifiers[0]: case_expr} + ) + new_frame = InternalFrame.create( + ordered_dataframe=joined_frame.ordered_dataframe, + index_column_pandas_labels=orig_frame.index_column_pandas_labels, + index_column_snowflake_quoted_identifiers=joined_frame.index_column_snowflake_quoted_identifiers, + data_column_pandas_index_names=orig_frame.data_column_pandas_index_names, + data_column_pandas_labels=orig_frame.data_column_pandas_labels, + data_column_snowflake_quoted_identifiers=joined_frame.data_column_snowflake_quoted_identifiers[ + :1 + ], + ) + return SnowflakeQueryCompiler(new_frame) + def mask( self, cond: "SnowflakeQueryCompiler", diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/series.py b/src/snowflake/snowpark/modin/plugin/docstrings/series.py index 49bc747dc22..23f73bb637d 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/series.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/series.py @@ -1115,6 +1115,47 @@ def factorize(): def case_when(): """ Replace values where the conditions are True. + + Parameters + ---------- + caselist : A list of tuples of conditions and expected replacements + Takes the form: ``(condition0, replacement0)``, + ``(condition1, replacement1)``, ... . + ``condition`` should be a 1-D boolean array-like object + or a callable. If ``condition`` is a callable, + it is computed on the Series + and should return a boolean Series or array. + The callable must not change the input Series + (though pandas doesn`t check it). ``replacement`` should be a + 1-D array-like object, a scalar or a callable. + If ``replacement`` is a callable, it is computed on the Series + and should return a scalar or Series. The callable + must not change the input Series + (though pandas doesn`t check it). + + .. versionadded:: 2.2.0 + + Returns + ------- + Series + + See Also + -------- + Series.mask : Replace values where the condition is True. + + Examples + -------- + >>> c = pd.Series([6, 7, 8, 9], name='c') + >>> a = pd.Series([0, 0, 1, 2]) + >>> b = pd.Series([0, 3, 4, 5]) + + >>> c.case_when(caselist=[(a.gt(0), a), # condition, replacement + ... (b.gt(0), b)]) + 0 6 + 1 3 + 2 1 + 3 2 + Name: c, dtype: int64 """ def fillna(): diff --git a/tests/integ/modin/series/test_case_when.py b/tests/integ/modin/series/test_case_when.py new file mode 100644 index 00000000000..edddd3ef9c1 --- /dev/null +++ b/tests/integ/modin/series/test_case_when.py @@ -0,0 +1,115 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +from typing import List + +import modin.pandas as pd +import numpy as np +import pandas as native_pd +import pytest +from pandas.api.types import is_scalar + +import snowflake.snowpark.modin.plugin # noqa: F401 +from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker +from tests.integ.modin.utils import assert_series_equal, eval_snowpark_pandas_result + + +def _verify_case_when(series: native_pd.Series, caselist: List[tuple]) -> None: + native_res = series.case_when(caselist) + caselist = [ + ( + pd.Series(cond) if isinstance(cond, native_pd.Series) else cond, + pd.Series(repl) if isinstance(repl, native_pd.Series) else repl, + ) + for cond, repl in caselist + ] + snow_res = pd.Series(series).case_when(caselist) + assert_series_equal(snow_res, native_res) + + +@pytest.mark.parametrize( + "repl", [native_pd.Series([11, 12, 13, 14]), [11, 12, 13, 14], 99] +) +@pytest.mark.parametrize( + "cond", [native_pd.Series([True, False, True, False]), [True, False, True, False]] +) +def test_case_when(cond, repl): + with SqlCounter(query_count=1, join_count=1 if is_scalar(repl) else 2): + series = native_pd.Series([1, 2, 3, 4]) + _verify_case_when(series, [(cond, repl)]) + + +@sql_count_checker(query_count=1, join_count=1) +def test_case_when_misaligned_index(): + series = native_pd.Series([1, 2, 3, 4, 5, 6]) + cond = native_pd.Series([True, False, True, False, True], index=[0, 1, 2, 6, 7]) + _verify_case_when(series, [(cond, 99)]) + + +@sql_count_checker(query_count=1, join_count=2) +def test_case_when_mulitple_cases(): + series = native_pd.Series([1, 2, 3, 4, 5, 6]) + cond1 = native_pd.Series([True, False, True, False, True]) + cond2 = native_pd.Series([False, True, False, False, True]) + caselist = [(cond1, 98), (cond2, 99)] + _verify_case_when(series, caselist) + + +@pytest.mark.parametrize("caselist", [[], [()], [(97, 98, 99)]]) +@sql_count_checker(query_count=0) +def test_case_when_invalid_caselist(caselist): + series = native_pd.Series([1, 2, 3, 4, 5, 6]) + if not caselist: + error_msg = ( + "provide at least one boolean condition, with a corresponding replacement" + ) + else: + error_msg = f"Argument 0 must have length 2; a condition and replacement; instead got length {len(caselist[0])}." + + eval_snowpark_pandas_result( + pd.Series(series), + series, + lambda s: s.case_when(caselist), + expect_exception=True, + expect_exception_type=ValueError, + expect_exception_match=error_msg, + ) + + +@sql_count_checker(query_count=0) +def test_case_when_invalid_condition_type(): + series = native_pd.Series([1, 2, 3, 4, 5, 6]) + error_msg = ( + "condition must be a Series or 1-D array-like object; instead got " + ) + # Native pandas raises ValueError('Failed to apply condition0 and replacement0.') + # Snowpark pandas raise more helpful error message. + with pytest.raises(TypeError, match=error_msg): + pd.Series(series).case_when([("xyz", 99)]) + + +@sql_count_checker(query_count=0) +def test_case_when_invalid_replacement_type(): + series = native_pd.Series([1, 2, 3, 4, 5, 6]) + error_msg = "replacement must be a Series, 1-D array-like object or scalar; instead got " + # Native pandas raises ValueError('Failed to apply condition0 and replacement0.') + # Snowpark pandas raise more helpful error message. + with pytest.raises(TypeError, match=error_msg): + pd.Series(series).case_when([(pd.Series([True, False]), np.array(2))]) + + +@sql_count_checker(query_count=0) +def test_case_when_callable_condition_not_implemented_error(): + series = native_pd.Series([1, 2, 3, 4, 5, 6]) + error_msg = "Snowpark pandas method Series.case_when doesn't yet support callable as condition" + with pytest.raises(NotImplementedError, match=error_msg): + pd.Series(series).case_when([(lambda x: x > 3, 99)]) + + +@sql_count_checker(query_count=0) +def test_case_when_callable_replacement_not_implemented_error(): + series = native_pd.Series([1, 2, 3, 4, 5, 6]) + error_msg = "Snowpark pandas method Series.case_when doesn't yet support callable as replacement" + with pytest.raises(NotImplementedError, match=error_msg): + pd.Series(series).case_when([(pd.Series([True, False]), lambda x: x > 3)])