Skip to content

Commit

Permalink
SNOW-1649780: Bug fix Series.sort_values fails when name overlaps wit…
Browse files Browse the repository at this point in the history
…h index name
  • Loading branch information
sfc-gh-nkumar committed Sep 13, 2024
1 parent 2866998 commit f2f19b5
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
36 changes: 16 additions & 20 deletions src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Literal,
Mapping,
Sequence,
get_args,
)

import modin.pandas as pd
Expand All @@ -36,6 +37,7 @@
IndexKeyFunc,
IndexLabel,
Level,
NaPosition,
Renamer,
Scalar,
)
Expand All @@ -46,7 +48,7 @@
)
from pandas.core.common import apply_if_callable, is_bool_indexer
from pandas.core.dtypes.common import is_bool_dtype, is_dict_like, is_list_like
from pandas.util._validators import validate_bool_kwarg
from pandas.util._validators import validate_ascending, validate_bool_kwarg

from snowflake.snowpark.modin import pandas as spd # noqa: F401
from snowflake.snowpark.modin.pandas.api.extensions import register_series_accessor
Expand Down Expand Up @@ -1521,33 +1523,27 @@ def sort_values(
Sort by the values.
"""
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
from snowflake.snowpark.modin.pandas.dataframe import DataFrame

if is_list_like(ascending) and len(ascending) != 1:
raise ValueError(f"Length of ascending ({len(ascending)}) must be 1 for Series")

if axis is not None:
# Validate `axis`
self._get_axis_number(axis)

# When we convert to a DataFrame, the name is automatically converted to 0 if it
# is None, so we do this to avoid a KeyError.
by = self.name if self.name is not None else 0
result = (
DataFrame(self.copy())
.sort_values(
by=by,
ascending=ascending,
inplace=False,
kind=kind,
na_position=na_position,
ignore_index=ignore_index,
key=key,
)
.squeeze(axis=1)
# Validate inplace, ascending and na_position.
inplace = validate_bool_kwarg(inplace, "inplace")
ascending = validate_ascending(ascending)
if na_position not in get_args(NaPosition):
# Same error message as native pandas for invalid 'na_position' value.
raise ValueError(f"invalid na_position: {na_position}")

# Convert 'ascending' to sequence if needed.
if not isinstance(ascending, Sequence):
ascending = [ascending]
result = self._query_compiler.sort_rows_by_column_values(
[self.name], ascending, kind, na_position, ignore_index, key
)
result.name = self.name
return self._create_or_update_from_compiler(result._query_compiler, inplace=inplace)
return self._create_or_update_from_compiler(result, inplace=inplace)


# TODO: SNOW-1063346
Expand Down
10 changes: 10 additions & 0 deletions tests/integ/modin/series/test_sort_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,13 @@ def test_sort_values_repeat(snow_series):
snow_series.to_pandas(),
lambda s: s.sort_values().sort_values(ascending=False),
)


@sql_count_checker(query_count=1)
def test_sort_values_shared_name_with_index():
# Bug fix: SNOW-1649780
native_series = native_pd.Series(
[1], name="X", index=native_pd.Index([2], name="X")
)
snow_series = pd.Series(native_series)
eval_snowpark_pandas_result(snow_series, native_series, lambda s: s.sort_values())

0 comments on commit f2f19b5

Please sign in to comment.