-
Notifications
You must be signed in to change notification settings - Fork 120
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
SNOW-1722641: Add support for Series.between (#2775)
<!--- Please answer these questions before creating your pull request. Thanks! ---> 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. <!--- In this section, please add a Snowflake Jira issue number. Note that if a corresponding GitHub issue exists, you should still include the Snowflake Jira issue number. For example, for GitHub issue #1400, you should add "SNOW-1335071" here. ---> Fixes SNOW-1722641 2. Fill out the following pre-review checklist: - [x] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. - [x] I acknowledge that I have ensured my changes to be thread-safe. Follow the link for more information: [Thread-safe Developer Guidelines](https://github.com/snowflakedb/snowpark-python/blob/main/CONTRIBUTING.md#thread-safe-development) 3. Please describe how your code solves the related issue. Implements `Series.between`. The frontend implementation uses [modin's implementation](https://github.com/modin-project/modin/blob/1c4d173d3b2c44a1c1b5d5516552c7717b26de32/modin/pandas/series.py#L795), which passes the `modin.pandas.Series` object to the [native pandas method](https://github.com/pandas-dev/pandas/blob/9fe33bcbca79e098f9ba8ffd9fcf95440b95032b/pandas/core/series.py#L5362-L5380), which directly uses comparison operators to implement the method.
- Loading branch information
1 parent
a0ac492
commit be0ae7b
Showing
6 changed files
with
202 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# | ||
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. | ||
# | ||
|
||
import datetime as dt | ||
|
||
import numpy as np | ||
import pandas as native_pd | ||
import pytest | ||
|
||
import modin.pandas as pd | ||
import snowflake.snowpark.modin.plugin # noqa: F401 | ||
from snowflake.snowpark.exceptions import SnowparkSQLException | ||
from tests.integ.modin.utils import ( | ||
eval_snowpark_pandas_result, | ||
create_test_series, | ||
assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, | ||
) | ||
from tests.integ.utils.sql_counter import sql_count_checker | ||
|
||
|
||
@sql_count_checker(query_count=1) | ||
def test_series_between_default_inclusive(): | ||
eval_snowpark_pandas_result( | ||
*create_test_series(list(range(0, 10))), lambda ser: ser.between(2, 8) | ||
) | ||
|
||
|
||
# tuples of (data, low, high) | ||
BETWEEN_TEST_ARGUMENTS = [ | ||
([0.8, 1.1, 0.9, 1.2], 0.9, 1.1), | ||
([0.8, -1.1, 0.9, 1.2], -1, 1.14), | ||
# strings are compared lexicographically | ||
(["quick", "brown", "fox", "Quick", "Brown", "Fox"], "Zeta", "alpha"), | ||
(["quick", "brown", "fox", "Quick", "Brown", "Fox"], "Delta", "kappa"), | ||
( | ||
[ | ||
dt.datetime(2024, 10, 11, 17, 5), | ||
dt.datetime(2020, 1, 2, 2, 40), | ||
dt.datetime(1998, 7, 7, 12, 33), | ||
], | ||
dt.datetime(2019, 1, 1, 0, 0), | ||
dt.datetime(2021, 1, 1, 0, 0), | ||
), | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("data, low, high", BETWEEN_TEST_ARGUMENTS) | ||
@pytest.mark.parametrize("inclusive", ["both", "neither", "left", "right"]) | ||
@sql_count_checker(query_count=1) | ||
def test_series_between(data, low, high, inclusive): | ||
eval_snowpark_pandas_result( | ||
*create_test_series(data), lambda ser: ser.between(low, high, inclusive) | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"data, low, high", | ||
[ | ||
([0.8, 1.1, 0.9, 1.2, np.nan], np.nan, 1.1), | ||
([0.8, 1.1, 0.9, 1.2, np.nan], np.nan, np.nan), | ||
([0.8, 1.1, 0.9, 1.2, np.nan], -1, 1.1), | ||
([None, "", "aa", "aaaa"], "", "aaa"), | ||
([None, "", "aa", "aaaa"], None, "aaa"), | ||
([None, "", "aa", "aaaa"], None, None), | ||
], | ||
) | ||
@pytest.mark.parametrize("inclusive", ["both", "neither", "left", "right"]) | ||
@sql_count_checker(query_count=1) | ||
def test_series_between_with_nulls(data, low, high, inclusive): | ||
# In Snowflake SQL, comparisons with NULL will always result in a NULL value in the output. | ||
# Any comparison with NULL will return NULL, though the conjunction FALSE AND NULL will | ||
# short-circuit and return FALSE. | ||
eval_snowpark_pandas_result( | ||
*create_test_series(data), | ||
lambda ser: ser.between(low, high, inclusive).astype(bool), | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("data, low, high", BETWEEN_TEST_ARGUMENTS) | ||
@sql_count_checker(query_count=1) | ||
def test_series_between_flip_left_right(data, low, high): | ||
# When left/right are out of order, comparisons are still performed (high >= low is not enforced) | ||
eval_snowpark_pandas_result( | ||
*create_test_series(data), lambda ser: ser.between(high, low) | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"data, low, high", | ||
[ | ||
([1, 2, 3, 4], [0, 1, 2, 3], [1.1, 1.9, 3.1, 3.9]), | ||
(["a", "b", "aa", "aaa"], ["", "a", "ccc", "aaaa"], ["c", "bb", "aaa", "d"]), | ||
], | ||
) | ||
@pytest.mark.parametrize("inclusive", ["both", "neither", "left", "right"]) | ||
@sql_count_checker(query_count=1, join_count=3) | ||
def test_series_between_series(data, low, high, inclusive): | ||
eval_snowpark_pandas_result( | ||
*create_test_series(data), | ||
lambda ser: ser.between( | ||
pd.Series(high) if isinstance(ser, pd.Series) else native_pd.Series(high), | ||
pd.Series(low) if isinstance(ser, pd.Series) else native_pd.Series(low), | ||
inclusive, | ||
), | ||
) | ||
|
||
|
||
@sql_count_checker(query_count=1, join_count=3) | ||
def test_series_between_series_different_dimensions(): | ||
# When attempting to compare with low/high of different lengths, Snowflake will leave NULLs | ||
# but pandas will raise an error. | ||
data = [1.1] | ||
low = [1, 2] | ||
high = [1, 2, 3] | ||
assert_snowpark_pandas_equals_to_pandas_without_dtypecheck( | ||
pd.Series(data).between(low, high), | ||
native_pd.Series([False]), | ||
) | ||
with pytest.raises( | ||
ValueError, match="Can only compare identically-labeled Series objects" | ||
): | ||
native_pd.Series(data).between(native_pd.Series(low), native_pd.Series(high)) | ||
|
||
|
||
@sql_count_checker(query_count=0) | ||
def test_series_between_invalid_comparison(): | ||
with pytest.raises( | ||
TypeError, match="Invalid comparison between dtype=int64 and str" | ||
): | ||
native_pd.Series([1]).between("a", "b") | ||
with pytest.raises( | ||
SnowparkSQLException, match="Numeric value 'a' is not recognized" | ||
): | ||
pd.Series([1]).between("a", "b").to_pandas() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters