From 84aab65cb3e6d29b1a2968126f012ba25ab5a87c Mon Sep 17 00:00:00 2001 From: Naresh Kumar Date: Fri, 13 Sep 2024 11:40:46 -0700 Subject: [PATCH] SNOW-1649780: Bug fix Series.sort_values fails when name overlaps with index name --- CHANGELOG.md | 3 ++ .../plugin/extensions/series_overrides.py | 36 +++++++++---------- tests/integ/modin/series/test_sort_values.py | 10 ++++++ 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e0589d4a358..d9e8bc2b085 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ - Added support for `TimedeltaIndex.mean` method. +#### Bug Fixes + +- Fixed a bug where `Series.sort_values` failed if series name overlapped with index column name. ## 1.22.1 (2024-09-11) This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for detailed release content. diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py index 5011defa685..dc047a989c3 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py @@ -18,6 +18,7 @@ Literal, Mapping, Sequence, + get_args, ) import modin.pandas as pd @@ -36,6 +37,7 @@ IndexKeyFunc, IndexLabel, Level, + NaPosition, Renamer, Scalar, ) @@ -46,7 +48,7 @@ ) from pandas.core.common import apply_if_callable, is_bool_indexer from pandas.core.dtypes.common import is_bool_dtype, is_dict_like, is_list_like -from pandas.util._validators import validate_bool_kwarg +from pandas.util._validators import validate_ascending, validate_bool_kwarg from snowflake.snowpark.modin import pandas as spd # noqa: F401 from snowflake.snowpark.modin.pandas.api.extensions import register_series_accessor @@ -1521,7 +1523,6 @@ def sort_values( Sort by the values. """ # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions - from snowflake.snowpark.modin.pandas.dataframe import DataFrame if is_list_like(ascending) and len(ascending) != 1: raise ValueError(f"Length of ascending ({len(ascending)}) must be 1 for Series") @@ -1529,25 +1530,20 @@ def sort_values( if axis is not None: # Validate `axis` self._get_axis_number(axis) - - # When we convert to a DataFrame, the name is automatically converted to 0 if it - # is None, so we do this to avoid a KeyError. - by = self.name if self.name is not None else 0 - result = ( - DataFrame(self.copy()) - .sort_values( - by=by, - ascending=ascending, - inplace=False, - kind=kind, - na_position=na_position, - ignore_index=ignore_index, - key=key, - ) - .squeeze(axis=1) + # Validate inplace, ascending and na_position. + inplace = validate_bool_kwarg(inplace, "inplace") + ascending = validate_ascending(ascending) + if na_position not in get_args(NaPosition): + # Same error message as native pandas for invalid 'na_position' value. + raise ValueError(f"invalid na_position: {na_position}") + + # Convert 'ascending' to sequence if needed. + if not isinstance(ascending, Sequence): + ascending = [ascending] + result = self._query_compiler.sort_rows_by_column_values( + [self.name], ascending, kind, na_position, ignore_index, key ) - result.name = self.name - return self._create_or_update_from_compiler(result._query_compiler, inplace=inplace) + return self._create_or_update_from_compiler(result, inplace=inplace) # TODO: SNOW-1063346 diff --git a/tests/integ/modin/series/test_sort_values.py b/tests/integ/modin/series/test_sort_values.py index e966409dfc9..b49de842b9d 100644 --- a/tests/integ/modin/series/test_sort_values.py +++ b/tests/integ/modin/series/test_sort_values.py @@ -178,3 +178,13 @@ def test_sort_values_repeat(snow_series): snow_series.to_pandas(), lambda s: s.sort_values().sort_values(ascending=False), ) + + +@sql_count_checker(query_count=1) +def test_sort_values_shared_name_with_index(): + # Bug fix: SNOW-1649780 + native_series = native_pd.Series( + [1], name="X", index=native_pd.Index([2], name="X") + ) + snow_series = pd.Series(native_series) + eval_snowpark_pandas_result(snow_series, native_series, lambda s: s.sort_values())