Skip to content

Commit

Permalink
SNOW-1453559: Add support for Series.case_when (#1800)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-nkumar authored Jun 25, 2024
1 parent 4453759 commit 983d65c
Show file tree
Hide file tree
Showing 7 changed files with 290 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
- Added support for `replace` and `frac > 1` in `DataFrame.sample` and `Series.sample`.
- Added support for `Series.at`, `Series.iat`, `DataFrame.at`, and `DataFrame.iat`.
- Added support for `Series.dt.isocalendar`.
- Added support for `Series.case_when` except when condition or replacement is callable.

#### Bug Fixes

Expand Down
1 change: 1 addition & 0 deletions docs/source/modin/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ Series
.. autosummary::
:toctree: pandas_api/

Series.case_when
Series.drop
Series.drop_duplicates
Series.get
Expand Down
2 changes: 1 addition & 1 deletion docs/source/modin/supported/series_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ Methods
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``bool`` | N | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``case_when`` | N | | |
| ``case_when`` | P | | ``N`` if condition or replacement is a callable. |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``clip`` | N | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
Expand Down
1 change: 0 additions & 1 deletion src/snowflake/snowpark/modin/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,6 @@ def factorize(
use_na_sentinel=use_na_sentinel,
)

@series_not_implemented()
def case_when(self, caselist) -> Series: # noqa: PR01, RT01, D200
"""
Replace values where the conditions are True.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import uuid
from collections.abc import Hashable, Iterable, Mapping, Sequence
from datetime import timedelta, tzinfo
from typing import Any, Callable, Literal, Optional, Union, get_args
from typing import Any, Callable, List, Literal, Optional, Union, get_args

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -8041,6 +8041,136 @@ def getitem_row_array(

return SnowflakeQueryCompiler(new_frame)

def case_when(self, caselist: List[tuple]) -> "SnowflakeQueryCompiler":
"""
Replace values where the conditions are True.

Args:
caselist: A list of tuples of conditions and expected replacements
Takes the form: (condition0, replacement0), (condition1, replacement1), … .
condition should be a 1-D boolean array-like object or a callable.
replacement should be a 1-D array-like object, a scalar or a callable.

Returns:
New QueryCompiler with replacements.
"""
# Validate caselist. Errors raised are same as native pandas.
if not isinstance(caselist, list):
# modin frotnend always passes a list, but we still keep this check to guard
# against any breaking changes in frontend layer.
raise TypeError(
f"The caselist argument should be a list; instead got {type(caselist)}"
)
if not caselist:
raise ValueError(
"provide at least one boolean condition, with a corresponding replacement."
)

# Validate entries in caselist. Errors raised are same as native pandas.
for num, entry in enumerate(caselist):
if not isinstance(entry, tuple):
# modin frotnend always passes a tuple, but we still eep this check to
# guard against any breaking changes in frontend layer.
raise TypeError(
f"Argument {num} must be a tuple; instead got {type(entry)}."
)
if len(entry) != 2:
raise ValueError(
f"Argument {num} must have length 2; "
"a condition and replacement; "
f"instead got length {len(entry)}."
)

orig_frame = self._modin_frame
joined_frame = self._modin_frame
case_expr: Optional[CaseExpr] = None
for cond, replacement in caselist:
if isinstance(cond, SnowflakeQueryCompiler):
joined_frame, _ = join_utils.align_on_index(
joined_frame, cond._modin_frame, "left"
)
elif is_list_like(cond):
cond_frame = self.from_pandas(
pandas.DataFrame(cond)
)._modin_frame.ensure_row_position_column()
joined_frame = joined_frame.ensure_row_position_column()
joined_frame, _ = join_utils.join(
joined_frame,
cond_frame,
how="left",
left_on=[joined_frame.row_position_snowflake_quoted_identifier],
right_on=[cond_frame.row_position_snowflake_quoted_identifier],
)
elif callable(cond):
# TODO SNOW-1489503: Add support for callable
ErrorMessage.not_implemented(
"Snowpark pandas method Series.case_when doesn't yet support callable as condition"
)
else:
raise TypeError(
f"condition must be a Series or 1-D array-like object; instead got {type(cond)}"
)

# if indices are misaligned treat the condition as True
cond_expr = coalesce(
col(joined_frame.data_column_snowflake_quoted_identifiers[-1]),
pandas_lit(True),
)
if isinstance(replacement, SnowflakeQueryCompiler):
joined_frame, _ = join_utils.align_on_index(
joined_frame, replacement._modin_frame, "left"
)
value = col(joined_frame.data_column_snowflake_quoted_identifiers[-1])
elif is_scalar(replacement):
value = pandas_lit(replacement)
elif is_list_like(replacement):
repl_frame = self.from_pandas(
pandas.DataFrame(replacement)
)._modin_frame.ensure_row_position_column()
joined_frame = joined_frame.ensure_row_position_column()
joined_frame, _ = join_utils.join(
joined_frame,
repl_frame,
how="left",
left_on=[joined_frame.row_position_snowflake_quoted_identifier],
right_on=[repl_frame.row_position_snowflake_quoted_identifier],
)
value = col(joined_frame.data_column_snowflake_quoted_identifiers[-1])
elif callable(replacement):
# TODO SNOW-1489503: Add support for callable
ErrorMessage.not_implemented(
"Snowpark pandas method Series.case_when doesn't yet support callable as replacement"
)
else:
raise TypeError(
f"replacement must be a Series, 1-D array-like object or scalar; instead got {type(replacement)}"
)

case_expr = (
when(cond_expr, value)
if case_expr is None
else case_expr.when(cond_expr, value)
)
orig_col = col(joined_frame.data_column_snowflake_quoted_identifiers[0])
case_expr = orig_col if case_expr is None else case_expr.otherwise(orig_col)
(
joined_frame,
_,
) = joined_frame.update_snowflake_quoted_identifiers_with_expressions(
{joined_frame.data_column_snowflake_quoted_identifiers[0]: case_expr}
)
new_frame = InternalFrame.create(
ordered_dataframe=joined_frame.ordered_dataframe,
index_column_pandas_labels=orig_frame.index_column_pandas_labels,
index_column_snowflake_quoted_identifiers=joined_frame.index_column_snowflake_quoted_identifiers,
data_column_pandas_index_names=orig_frame.data_column_pandas_index_names,
data_column_pandas_labels=orig_frame.data_column_pandas_labels,
data_column_snowflake_quoted_identifiers=joined_frame.data_column_snowflake_quoted_identifiers[
:1
],
)
return SnowflakeQueryCompiler(new_frame)

def mask(
self,
cond: "SnowflakeQueryCompiler",
Expand Down
41 changes: 41 additions & 0 deletions src/snowflake/snowpark/modin/plugin/docstrings/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,47 @@ def factorize():
def case_when():
"""
Replace values where the conditions are True.
Parameters
----------
caselist : A list of tuples of conditions and expected replacements
Takes the form: ``(condition0, replacement0)``,
``(condition1, replacement1)``, ... .
``condition`` should be a 1-D boolean array-like object
or a callable. If ``condition`` is a callable,
it is computed on the Series
and should return a boolean Series or array.
The callable must not change the input Series
(though pandas doesn`t check it). ``replacement`` should be a
1-D array-like object, a scalar or a callable.
If ``replacement`` is a callable, it is computed on the Series
and should return a scalar or Series. The callable
must not change the input Series
(though pandas doesn`t check it).
.. versionadded:: 2.2.0
Returns
-------
Series
See Also
--------
Series.mask : Replace values where the condition is True.
Examples
--------
>>> c = pd.Series([6, 7, 8, 9], name='c')
>>> a = pd.Series([0, 0, 1, 2])
>>> b = pd.Series([0, 3, 4, 5])
>>> c.case_when(caselist=[(a.gt(0), a), # condition, replacement
... (b.gt(0), b)])
0 6
1 3
2 1
3 2
Name: c, dtype: int64
"""

def fillna():
Expand Down
115 changes: 115 additions & 0 deletions tests/integ/modin/series/test_case_when.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

from typing import List

import modin.pandas as pd
import numpy as np
import pandas as native_pd
import pytest
from pandas.api.types import is_scalar

import snowflake.snowpark.modin.plugin # noqa: F401
from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker
from tests.integ.modin.utils import assert_series_equal, eval_snowpark_pandas_result


def _verify_case_when(series: native_pd.Series, caselist: List[tuple]) -> None:
native_res = series.case_when(caselist)
caselist = [
(
pd.Series(cond) if isinstance(cond, native_pd.Series) else cond,
pd.Series(repl) if isinstance(repl, native_pd.Series) else repl,
)
for cond, repl in caselist
]
snow_res = pd.Series(series).case_when(caselist)
assert_series_equal(snow_res, native_res)


@pytest.mark.parametrize(
"repl", [native_pd.Series([11, 12, 13, 14]), [11, 12, 13, 14], 99]
)
@pytest.mark.parametrize(
"cond", [native_pd.Series([True, False, True, False]), [True, False, True, False]]
)
def test_case_when(cond, repl):
with SqlCounter(query_count=1, join_count=1 if is_scalar(repl) else 2):
series = native_pd.Series([1, 2, 3, 4])
_verify_case_when(series, [(cond, repl)])


@sql_count_checker(query_count=1, join_count=1)
def test_case_when_misaligned_index():
series = native_pd.Series([1, 2, 3, 4, 5, 6])
cond = native_pd.Series([True, False, True, False, True], index=[0, 1, 2, 6, 7])
_verify_case_when(series, [(cond, 99)])


@sql_count_checker(query_count=1, join_count=2)
def test_case_when_mulitple_cases():
series = native_pd.Series([1, 2, 3, 4, 5, 6])
cond1 = native_pd.Series([True, False, True, False, True])
cond2 = native_pd.Series([False, True, False, False, True])
caselist = [(cond1, 98), (cond2, 99)]
_verify_case_when(series, caselist)


@pytest.mark.parametrize("caselist", [[], [()], [(97, 98, 99)]])
@sql_count_checker(query_count=0)
def test_case_when_invalid_caselist(caselist):
series = native_pd.Series([1, 2, 3, 4, 5, 6])
if not caselist:
error_msg = (
"provide at least one boolean condition, with a corresponding replacement"
)
else:
error_msg = f"Argument 0 must have length 2; a condition and replacement; instead got length {len(caselist[0])}."

eval_snowpark_pandas_result(
pd.Series(series),
series,
lambda s: s.case_when(caselist),
expect_exception=True,
expect_exception_type=ValueError,
expect_exception_match=error_msg,
)


@sql_count_checker(query_count=0)
def test_case_when_invalid_condition_type():
series = native_pd.Series([1, 2, 3, 4, 5, 6])
error_msg = (
"condition must be a Series or 1-D array-like object; instead got <class 'str'>"
)
# Native pandas raises ValueError('Failed to apply condition0 and replacement0.')
# Snowpark pandas raise more helpful error message.
with pytest.raises(TypeError, match=error_msg):
pd.Series(series).case_when([("xyz", 99)])


@sql_count_checker(query_count=0)
def test_case_when_invalid_replacement_type():
series = native_pd.Series([1, 2, 3, 4, 5, 6])
error_msg = "replacement must be a Series, 1-D array-like object or scalar; instead got <class 'numpy.ndarray'>"
# Native pandas raises ValueError('Failed to apply condition0 and replacement0.')
# Snowpark pandas raise more helpful error message.
with pytest.raises(TypeError, match=error_msg):
pd.Series(series).case_when([(pd.Series([True, False]), np.array(2))])


@sql_count_checker(query_count=0)
def test_case_when_callable_condition_not_implemented_error():
series = native_pd.Series([1, 2, 3, 4, 5, 6])
error_msg = "Snowpark pandas method Series.case_when doesn't yet support callable as condition"
with pytest.raises(NotImplementedError, match=error_msg):
pd.Series(series).case_when([(lambda x: x > 3, 99)])


@sql_count_checker(query_count=0)
def test_case_when_callable_replacement_not_implemented_error():
series = native_pd.Series([1, 2, 3, 4, 5, 6])
error_msg = "Snowpark pandas method Series.case_when doesn't yet support callable as replacement"
with pytest.raises(NotImplementedError, match=error_msg):
pd.Series(series).case_when([(pd.Series([True, False]), lambda x: x > 3)])

0 comments on commit 983d65c

Please sign in to comment.