Skip to content

Commit

Permalink
SNOW-1570506 Fix Series.argmax and Series.argmin, add tests (#2023)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-vbudati authored Aug 3, 2024
1 parent 74ff08d commit 0e511e5
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
- Added support for `Index.all` and `Index.any`.
- Added support for `Series.dt.is_year_start` and `Series.dt.is_year_end`.
- Added support for `Series.dt.is_quarter_start` and `Series.dt.is_quarter_end`.
- Added support for `Series.argmax` and `Series.argmin`.

#### Improvements
- Removed the public preview warning message upon importing Snowpark pandas.
Expand Down
4 changes: 2 additions & 2 deletions docs/source/modin/supported/series_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ Methods
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``apply`` | P | ``convert_dtype`` is ignored | ``N`` if ``func`` is not callable. |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``argmax`` | N | | |
| ``argmax`` | P | | ``N`` if the Series has a MultiIndex index. |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``argmin`` | N | | |
| ``argmin`` | P | | ``N`` if the Series has a MultiIndex index. |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``argsort`` | N | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
Expand Down
24 changes: 18 additions & 6 deletions src/snowflake/snowpark/modin/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,25 +781,37 @@ def apply(

return self.__constructor__(query_compiler=new_query_compiler)

@series_not_implemented()
def argmax(self, axis=None, skipna=True, *args, **kwargs): # noqa: PR01, RT01, D200
"""
Return int position of the largest value in the Series.
"""
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
result = self.idxmax(axis=axis, skipna=skipna, *args, **kwargs)
if np.isnan(result) or result is pandas.NA:
if self._query_compiler.has_multiindex():
# The index is a MultiIndex, current logic does not support this.
ErrorMessage.not_implemented(
"Series.argmax is not yet supported when the index is a MultiIndex."
)
result = self.reset_index(drop=True).idxmax(
axis=axis, skipna=skipna, *args, **kwargs
)
if not is_integer(result): # if result is None, return -1
result = -1
return result

@series_not_implemented()
def argmin(self, axis=None, skipna=True, *args, **kwargs): # noqa: PR01, RT01, D200
"""
Return int position of the smallest value in the Series.
"""
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
result = self.idxmin(axis=axis, skipna=skipna, *args, **kwargs)
if np.isnan(result) or result is pandas.NA:
if self._query_compiler.has_multiindex():
# The index is a MultiIndex, current logic does not support this.
ErrorMessage.not_implemented(
"Series.argmin is not yet supported when the index is a MultiIndex."
)
result = self.reset_index(drop=True).idxmin(
axis=axis, skipna=skipna, *args, **kwargs
)
if not is_integer(result): # if result is None, return -1
result = -1
return result

Expand Down
52 changes: 52 additions & 0 deletions tests/integ/modin/series/test_argmax_argmin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

import modin.pandas as pd
import pandas as native_pd
import pytest

import snowflake.snowpark.modin.plugin # noqa: F401
from tests.integ.modin.sql_counter import sql_count_checker


@sql_count_checker(query_count=1)
@pytest.mark.parametrize(
"data, index",
[
([1, None, 4, 3, 4], ["A", "B", "C", "D", "E"]),
([4, None, 1, 3, 4, 1], ["A", "B", "C", "D", "E", "F"]),
([4, None, 1, 3, 4, 1], [None, "B", "C", "D", "E", "F"]),
([1, 10, 4, 3, 4], ["E", "D", "C", "A", "B"]),
],
)
@pytest.mark.parametrize("func", ["argmax", "argmin"])
@pytest.mark.parametrize(
"skipna",
[True, False],
)
def test_argmax_argmin_series(data, index, func, skipna):
native_series = native_pd.Series(data=data, index=index)
snow_series = pd.Series(native_series)

native_output = native_series.__getattribute__(func)(skipna=skipna)
snow_output = snow_series.__getattribute__(func)(skipna=skipna)
assert snow_output == native_output


@pytest.mark.parametrize("func", ["argmax", "argmin"])
@pytest.mark.parametrize("skipna", [True, False])
@sql_count_checker(query_count=0)
def test_series_argmax_argmin_with_multiindex_negative(
multiindex_native_int_series, func, skipna
):
"""
Test Series.argmax and Series.argmin with a MultiIndex Series.
"""
native_series = multiindex_native_int_series
snow_series = pd.Series(native_series)
with pytest.raises(
NotImplementedError,
match=f"Series.{func} is not yet supported when the index is a MultiIndex.",
):
snow_series.__getattribute__(func)(skipna=skipna)
4 changes: 1 addition & 3 deletions tests/integ/modin/series/test_idxmax_idxmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,8 @@ def test_series_idxmax_idxmin_with_multiindex(
multiindex_native_int_series, func, skipna
):
"""
Test DataFrameGroupBy.idxmax and DataFrameGroupBy.idxmin with a MultiIndex DataFrame.
Here, the MultiIndex DataFrames are grouped by `level` and not `by`.
Test Series.idxmax and Series.idxmin with a MultiIndex Series.
"""
# Create MultiIndex DataFrames.
native_series = multiindex_native_int_series
snow_series = pd.Series(native_series)
with pytest.raises(
Expand Down
2 changes: 0 additions & 2 deletions tests/integ/modin/test_unimplemented.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,6 @@ def test_unsupported_str_methods(func, func_name, caplog) -> None:
lambda idx: idx.is_monotonic_decreasing(),
lambda idx: idx.nbytes(),
lambda idx: idx.memory_usage(),
lambda idx: idx.argmin(),
lambda idx: idx.argmax(),
lambda idx: idx.delete(),
lambda idx: idx.drop_duplicates(),
lambda idx: idx.factorize(),
Expand Down

0 comments on commit 0e511e5

Please sign in to comment.