From 6ad66c4590545a0f5b9b248eac6dd82140352b6b Mon Sep 17 00:00:00 2001 From: azhan Date: Tue, 3 Sep 2024 17:34:24 -0700 Subject: [PATCH] SNOW-1649528 Fix monotonic bug --- .../compiler/snowflake_query_compiler.py | 30 ++++++++++++++----- tests/integ/modin/index/test_monotonic.py | 4 +-- tests/integ/modin/series/test_monotonic.py | 14 +++++++-- 3 files changed, 36 insertions(+), 12 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 848c5e438b3..a709751b816 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -2344,7 +2344,7 @@ def _check_monotonic(self, increasing: bool) -> "SnowflakeQueryCompiler": monotonic_decreasing_snowflake_quoted_identifier ) new_modin_frame = new_qc._modin_frame - return SnowflakeQueryCompiler( + qc = SnowflakeQueryCompiler( InternalFrame.create( ordered_dataframe=new_modin_frame.ordered_dataframe.limit( n=1, sort=False @@ -2358,6 +2358,8 @@ def _check_monotonic(self, increasing: bool) -> "SnowflakeQueryCompiler": index_column_types=None, ) ) + # use agg all to handle empty case + return qc.agg(func="all", args=(), kwargs={}, axis=0) def _add_columns_for_monotonicity_checks( self, col_to_check: str, columns_to_add: Optional[str] = None @@ -2371,7 +2373,7 @@ def _add_columns_for_monotonicity_checks( col_to_check : str The Snowflake quoted identifier for the column whose monotonicity to check. columns_to_add : str, optional - Whether or not to add all columns, and if not, which columns to add. + Whether to add all columns, and if not, which columns to add. Returns ------- @@ -2402,9 +2404,15 @@ def _add_columns_for_monotonicity_checks( if columns_to_add in [None, "decreasing"]: modin_frame = modin_frame.append_column( "_is_monotonic_decreasing", - coalesce( - min_(col(col_to_check) <= col(lag_col_snowflake_quoted_id)).over(), - pandas_lit(False), + iff( + count("*").over() <= 1, + pandas_lit(True), + coalesce( + min_( + col(col_to_check) <= col(lag_col_snowflake_quoted_id) + ).over(), + pandas_lit(False), + ), ), ) monotonic_decreasing_snowflake_quoted_id = ( @@ -2413,9 +2421,15 @@ def _add_columns_for_monotonicity_checks( if columns_to_add in [None, "increasing"]: modin_frame = modin_frame.append_column( "_is_monotonic_increasing", - coalesce( - min_(col(col_to_check) >= col(lag_col_snowflake_quoted_id)).over(), - pandas_lit(False), + iff( + count("*").over() <= 1, + pandas_lit(True), + coalesce( + min_( + col(col_to_check) >= col(lag_col_snowflake_quoted_id) + ).over(), + pandas_lit(False), + ), ), ) monotonic_increasing_snowflake_quoted_id = ( diff --git a/tests/integ/modin/index/test_monotonic.py b/tests/integ/modin/index/test_monotonic.py index 5a15e4eb021..05d560c700b 100644 --- a/tests/integ/modin/index/test_monotonic.py +++ b/tests/integ/modin/index/test_monotonic.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize( - "values", [[1, 2, 3], [3, 2, 1], [1, 3, 2], [1, 2, 2], [1, np.NaN, 3]] + "values", [[], [1], [3, 2], [1, 3, 2], [1, 2, 2], [1, np.NaN, 3]] ) @sql_count_checker(query_count=1) def test_monotonic_increasing_numbers(values): @@ -23,7 +23,7 @@ def test_monotonic_increasing_numbers(values): @pytest.mark.parametrize( - "values", [[3, 2, 1], [1, 2, 3], [3, 1, 2], [2, 2, 1], [3, np.NaN, 1]] + "values", [[], [3], [1, 2], [3, 1, 2], [2, 2, 1], [3, np.NaN, 1]] ) @sql_count_checker(query_count=1) def test_monotonic_decreasing_numbers(values): diff --git a/tests/integ/modin/series/test_monotonic.py b/tests/integ/modin/series/test_monotonic.py index 8726b9d9bd8..5812250b84a 100644 --- a/tests/integ/modin/series/test_monotonic.py +++ b/tests/integ/modin/series/test_monotonic.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize( - "values", [[1, 2, 3], [3, 2, 1], [1, 3, 2], [1, 2, 2], [1, np.NaN, 3]] + "values", [[], [1], [3, 2], [1, 3, 2], [1, 2, 2], [1, np.NaN, 3]] ) @sql_count_checker(query_count=1) def test_monotonic_increasing_numbers(values): @@ -23,7 +23,7 @@ def test_monotonic_increasing_numbers(values): @pytest.mark.parametrize( - "values", [[3, 2, 1], [1, 2, 3], [3, 1, 2], [2, 2, 1], [3, np.NaN, 1]] + "values", [[], [3], [1, 2], [3, 1, 2], [2, 2, 1], [3, np.NaN, 1]] ) @sql_count_checker(query_count=1) def test_monotonic_decreasing_numbers(values): @@ -95,3 +95,13 @@ def test_monotonic_decreasing_dates(values): pd.Series(values).is_monotonic_decreasing == native_pd.Series(values).is_monotonic_decreasing ) + + +@sql_count_checker(query_count=2) +def test_monotonic_type_mismatch(): + # Snowpark pandas may have different behavior when the column type is variant. pandas always returns False while + # Snowflake engine does implicit casting (“coercion”) and then check monotonic + assert not native_pd.Series([0, "a"]).is_monotonic_increasing + assert pd.Series([0, "a"]).is_monotonic_increasing + assert not native_pd.Series(["a", 0]).is_monotonic_decreasing + assert pd.Series(["a", 0]).is_monotonic_decreasing