Skip to content

Commit

Permalink
SNOW-1649780: Bug fix Series.sort_values fails when name overlaps wit…
Browse files Browse the repository at this point in the history
…h index name
  • Loading branch information
sfc-gh-nkumar committed Sep 16, 2024
1 parent 0ee3033 commit dd0bbee
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 22 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
- Added support for some cases of aggregating `Timedelta` columns on `axis=0` with `agg` or `aggregate`.
- Added support for `by`, `left_by`, and `right_by` for `pd.merge_asof`.

#### Bug Fixes

- Fixed a bug where `Series.sort_values` failed if series name overlapped with index column name.

## 1.22.1 (2024-09-11)
This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for detailed release content.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3255,6 +3255,7 @@ def sort_rows_by_column_values(
ignore_index: bool,
key: Optional[IndexKeyFunc] = None,
include_indexer: bool = False,
include_index: bool = True,
) -> "SnowflakeQueryCompiler":
"""
Reorder the rows based on the lexicographic order of the given columns.
Expand All @@ -3270,6 +3271,7 @@ def sort_rows_by_column_values(
key: Apply the key function to the values before sorting.
include_indexer: If True, add a data column with the original row numbers in the same order as
the index, i.e., add an indexer column. This is used with Index.sort_values.
include_index: If True, include index columns in the sort.

Returns:
A new SnowflakeQueryCompiler instance after applying the sort.
Expand Down Expand Up @@ -3297,7 +3299,7 @@ def sort_rows_by_column_values(

matched_identifiers = (
self._modin_frame.get_snowflake_quoted_identifiers_group_by_pandas_labels(
columns
columns, include_index
)
)

Expand Down
43 changes: 22 additions & 21 deletions src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from __future__ import annotations

from typing import IO, Any, Callable, Hashable, Literal, Mapping, Sequence
from typing import IO, Any, Callable, Hashable, Literal, Mapping, Sequence, get_args

import modin.pandas as pd
import numpy as np
Expand All @@ -27,6 +27,7 @@
IndexKeyFunc,
IndexLabel,
Level,
NaPosition,
Renamer,
Scalar,
)
Expand All @@ -37,7 +38,7 @@
)
from pandas.core.common import apply_if_callable, is_bool_indexer
from pandas.core.dtypes.common import is_bool_dtype, is_dict_like, is_list_like
from pandas.util._validators import validate_bool_kwarg
from pandas.util._validators import validate_ascending, validate_bool_kwarg

from snowflake.snowpark.modin import pandas as spd # noqa: F401
from snowflake.snowpark.modin.pandas.api.extensions import register_series_accessor
Expand Down Expand Up @@ -1519,33 +1520,33 @@ def sort_values(
Sort by the values.
"""
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
from modin.pandas.dataframe import DataFrame

if is_list_like(ascending) and len(ascending) != 1:
raise ValueError(f"Length of ascending ({len(ascending)}) must be 1 for Series")

if axis is not None:
# Validate `axis`
self._get_axis_number(axis)

# When we convert to a DataFrame, the name is automatically converted to 0 if it
# is None, so we do this to avoid a KeyError.
by = self.name if self.name is not None else 0
result = (
DataFrame(self.copy())
.sort_values(
by=by,
ascending=ascending,
inplace=False,
kind=kind,
na_position=na_position,
ignore_index=ignore_index,
key=key,
)
.squeeze(axis=1)
# Validate inplace, ascending and na_position.
inplace = validate_bool_kwarg(inplace, "inplace")
ascending = validate_ascending(ascending)
if na_position not in get_args(NaPosition):
# Same error message as native pandas for invalid 'na_position' value.
raise ValueError(f"invalid na_position: {na_position}")

# Convert 'ascending' to sequence if needed.
if not isinstance(ascending, Sequence):
ascending = [ascending]
result = self._query_compiler.sort_rows_by_column_values(
self._query_compiler.columns,
ascending,
kind,
na_position,
ignore_index,
key,
include_index=False,
)
result.name = self.name
return self._create_or_update_from_compiler(result._query_compiler, inplace=inplace)
return self._create_or_update_from_compiler(result, inplace=inplace)


# Upstream Modin defaults at the frontend layer.
Expand Down
10 changes: 10 additions & 0 deletions tests/integ/modin/series/test_sort_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,13 @@ def test_sort_values_repeat(snow_series):
snow_series.to_pandas(),
lambda s: s.sort_values().sort_values(ascending=False),
)


@sql_count_checker(query_count=1)
def test_sort_values_shared_name_with_index():
# Bug fix: SNOW-1649780
native_series = native_pd.Series(
[1, 3], name="X", index=native_pd.Index([2, 1], name="X")
)
snow_series = pd.Series(native_series)
eval_snowpark_pandas_result(snow_series, native_series, lambda s: s.sort_values())

0 comments on commit dd0bbee

Please sign in to comment.