From dd0bbee73076e9e8d1efd68ee8fddde9ff6db834 Mon Sep 17 00:00:00 2001 From: Naresh Kumar Date: Fri, 13 Sep 2024 11:40:46 -0700 Subject: [PATCH] SNOW-1649780: Bug fix Series.sort_values fails when name overlaps with index name --- CHANGELOG.md | 3 ++ .../compiler/snowflake_query_compiler.py | 4 +- .../plugin/extensions/series_overrides.py | 43 ++++++++++--------- tests/integ/modin/series/test_sort_values.py | 10 +++++ 4 files changed, 38 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0bd719dcb8c..4faeb22c8aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,9 @@ - Added support for some cases of aggregating `Timedelta` columns on `axis=0` with `agg` or `aggregate`. - Added support for `by`, `left_by`, and `right_by` for `pd.merge_asof`. +#### Bug Fixes + +- Fixed a bug where `Series.sort_values` failed if series name overlapped with index column name. ## 1.22.1 (2024-09-11) This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for detailed release content. diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index e971b15b6d6..8b0f75e4421 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -3255,6 +3255,7 @@ def sort_rows_by_column_values( ignore_index: bool, key: Optional[IndexKeyFunc] = None, include_indexer: bool = False, + include_index: bool = True, ) -> "SnowflakeQueryCompiler": """ Reorder the rows based on the lexicographic order of the given columns. @@ -3270,6 +3271,7 @@ def sort_rows_by_column_values( key: Apply the key function to the values before sorting. include_indexer: If True, add a data column with the original row numbers in the same order as the index, i.e., add an indexer column. This is used with Index.sort_values. + include_index: If True, include index columns in the sort. Returns: A new SnowflakeQueryCompiler instance after applying the sort. @@ -3297,7 +3299,7 @@ def sort_rows_by_column_values( matched_identifiers = ( self._modin_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels( - columns + columns, include_index ) ) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py index b104c223e26..a0c4fcdf571 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py @@ -9,7 +9,7 @@ from __future__ import annotations -from typing import IO, Any, Callable, Hashable, Literal, Mapping, Sequence +from typing import IO, Any, Callable, Hashable, Literal, Mapping, Sequence, get_args import modin.pandas as pd import numpy as np @@ -27,6 +27,7 @@ IndexKeyFunc, IndexLabel, Level, + NaPosition, Renamer, Scalar, ) @@ -37,7 +38,7 @@ ) from pandas.core.common import apply_if_callable, is_bool_indexer from pandas.core.dtypes.common import is_bool_dtype, is_dict_like, is_list_like -from pandas.util._validators import validate_bool_kwarg +from pandas.util._validators import validate_ascending, validate_bool_kwarg from snowflake.snowpark.modin import pandas as spd # noqa: F401 from snowflake.snowpark.modin.pandas.api.extensions import register_series_accessor @@ -1519,7 +1520,6 @@ def sort_values( Sort by the values. """ # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions - from modin.pandas.dataframe import DataFrame if is_list_like(ascending) and len(ascending) != 1: raise ValueError(f"Length of ascending ({len(ascending)}) must be 1 for Series") @@ -1527,25 +1527,26 @@ def sort_values( if axis is not None: # Validate `axis` self._get_axis_number(axis) - - # When we convert to a DataFrame, the name is automatically converted to 0 if it - # is None, so we do this to avoid a KeyError. - by = self.name if self.name is not None else 0 - result = ( - DataFrame(self.copy()) - .sort_values( - by=by, - ascending=ascending, - inplace=False, - kind=kind, - na_position=na_position, - ignore_index=ignore_index, - key=key, - ) - .squeeze(axis=1) + # Validate inplace, ascending and na_position. + inplace = validate_bool_kwarg(inplace, "inplace") + ascending = validate_ascending(ascending) + if na_position not in get_args(NaPosition): + # Same error message as native pandas for invalid 'na_position' value. + raise ValueError(f"invalid na_position: {na_position}") + + # Convert 'ascending' to sequence if needed. + if not isinstance(ascending, Sequence): + ascending = [ascending] + result = self._query_compiler.sort_rows_by_column_values( + self._query_compiler.columns, + ascending, + kind, + na_position, + ignore_index, + key, + include_index=False, ) - result.name = self.name - return self._create_or_update_from_compiler(result._query_compiler, inplace=inplace) + return self._create_or_update_from_compiler(result, inplace=inplace) # Upstream Modin defaults at the frontend layer. diff --git a/tests/integ/modin/series/test_sort_values.py b/tests/integ/modin/series/test_sort_values.py index e966409dfc9..d5342f5a15e 100644 --- a/tests/integ/modin/series/test_sort_values.py +++ b/tests/integ/modin/series/test_sort_values.py @@ -178,3 +178,13 @@ def test_sort_values_repeat(snow_series): snow_series.to_pandas(), lambda s: s.sort_values().sort_values(ascending=False), ) + + +@sql_count_checker(query_count=1) +def test_sort_values_shared_name_with_index(): + # Bug fix: SNOW-1649780 + native_series = native_pd.Series( + [1, 3], name="X", index=native_pd.Index([2, 1], name="X") + ) + snow_series = pd.Series(native_series) + eval_snowpark_pandas_result(snow_series, native_series, lambda s: s.sort_values())