diff --git a/CHANGELOG.md b/CHANGELOG.md index c1a667898e0..d201ed4c301 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -56,6 +56,7 @@ - Added support for `Index.all` and `Index.any`. - Added support for `Series.dt.is_year_start` and `Series.dt.is_year_end`. - Added support for `Series.dt.is_quarter_start` and `Series.dt.is_quarter_end`. +- Added support for `Series.argmax` and `Series.argmin`. #### Improvements - Removed the public preview warning message upon importing Snowpark pandas. diff --git a/docs/source/modin/supported/series_supported.rst b/docs/source/modin/supported/series_supported.rst index 36d0b29a0a4..ea78a3a0e68 100644 --- a/docs/source/modin/supported/series_supported.rst +++ b/docs/source/modin/supported/series_supported.rst @@ -94,9 +94,9 @@ Methods +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``apply`` | P | ``convert_dtype`` is ignored | ``N`` if ``func`` is not callable. | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``argmax`` | N | | | +| ``argmax`` | P | | ``N`` if the Series has a MultiIndex index. | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ -| ``argmin`` | N | | | +| ``argmin`` | P | | ``N`` if the Series has a MultiIndex index. | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ | ``argsort`` | N | | | +-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+ diff --git a/src/snowflake/snowpark/modin/pandas/series.py b/src/snowflake/snowpark/modin/pandas/series.py index daa1c51c488..12dc9d10972 100644 --- a/src/snowflake/snowpark/modin/pandas/series.py +++ b/src/snowflake/snowpark/modin/pandas/series.py @@ -781,25 +781,37 @@ def apply( return self.__constructor__(query_compiler=new_query_compiler) - @series_not_implemented() def argmax(self, axis=None, skipna=True, *args, **kwargs): # noqa: PR01, RT01, D200 """ Return int position of the largest value in the Series. """ # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions - result = self.idxmax(axis=axis, skipna=skipna, *args, **kwargs) - if np.isnan(result) or result is pandas.NA: + if self._query_compiler.has_multiindex(): + # The index is a MultiIndex, current logic does not support this. + ErrorMessage.not_implemented( + "Series.argmax is not yet supported when the index is a MultiIndex." + ) + result = self.reset_index(drop=True).idxmax( + axis=axis, skipna=skipna, *args, **kwargs + ) + if not is_integer(result): # if result is None, return -1 result = -1 return result - @series_not_implemented() def argmin(self, axis=None, skipna=True, *args, **kwargs): # noqa: PR01, RT01, D200 """ Return int position of the smallest value in the Series. """ # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions - result = self.idxmin(axis=axis, skipna=skipna, *args, **kwargs) - if np.isnan(result) or result is pandas.NA: + if self._query_compiler.has_multiindex(): + # The index is a MultiIndex, current logic does not support this. + ErrorMessage.not_implemented( + "Series.argmin is not yet supported when the index is a MultiIndex." + ) + result = self.reset_index(drop=True).idxmin( + axis=axis, skipna=skipna, *args, **kwargs + ) + if not is_integer(result): # if result is None, return -1 result = -1 return result diff --git a/tests/integ/modin/series/test_argmax_argmin.py b/tests/integ/modin/series/test_argmax_argmin.py new file mode 100644 index 00000000000..607b36a27f3 --- /dev/null +++ b/tests/integ/modin/series/test_argmax_argmin.py @@ -0,0 +1,52 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +import modin.pandas as pd +import pandas as native_pd +import pytest + +import snowflake.snowpark.modin.plugin # noqa: F401 +from tests.integ.modin.sql_counter import sql_count_checker + + +@sql_count_checker(query_count=1) +@pytest.mark.parametrize( + "data, index", + [ + ([1, None, 4, 3, 4], ["A", "B", "C", "D", "E"]), + ([4, None, 1, 3, 4, 1], ["A", "B", "C", "D", "E", "F"]), + ([4, None, 1, 3, 4, 1], [None, "B", "C", "D", "E", "F"]), + ([1, 10, 4, 3, 4], ["E", "D", "C", "A", "B"]), + ], +) +@pytest.mark.parametrize("func", ["argmax", "argmin"]) +@pytest.mark.parametrize( + "skipna", + [True, False], +) +def test_argmax_argmin_series(data, index, func, skipna): + native_series = native_pd.Series(data=data, index=index) + snow_series = pd.Series(native_series) + + native_output = native_series.__getattribute__(func)(skipna=skipna) + snow_output = snow_series.__getattribute__(func)(skipna=skipna) + assert snow_output == native_output + + +@pytest.mark.parametrize("func", ["argmax", "argmin"]) +@pytest.mark.parametrize("skipna", [True, False]) +@sql_count_checker(query_count=0) +def test_series_argmax_argmin_with_multiindex_negative( + multiindex_native_int_series, func, skipna +): + """ + Test Series.argmax and Series.argmin with a MultiIndex Series. + """ + native_series = multiindex_native_int_series + snow_series = pd.Series(native_series) + with pytest.raises( + NotImplementedError, + match=f"Series.{func} is not yet supported when the index is a MultiIndex.", + ): + snow_series.__getattribute__(func)(skipna=skipna) diff --git a/tests/integ/modin/series/test_idxmax_idxmin.py b/tests/integ/modin/series/test_idxmax_idxmin.py index ff7f74b3369..ea536240a42 100644 --- a/tests/integ/modin/series/test_idxmax_idxmin.py +++ b/tests/integ/modin/series/test_idxmax_idxmin.py @@ -48,10 +48,8 @@ def test_series_idxmax_idxmin_with_multiindex( multiindex_native_int_series, func, skipna ): """ - Test DataFrameGroupBy.idxmax and DataFrameGroupBy.idxmin with a MultiIndex DataFrame. - Here, the MultiIndex DataFrames are grouped by `level` and not `by`. + Test Series.idxmax and Series.idxmin with a MultiIndex Series. """ - # Create MultiIndex DataFrames. native_series = multiindex_native_int_series snow_series = pd.Series(native_series) with pytest.raises( diff --git a/tests/integ/modin/test_unimplemented.py b/tests/integ/modin/test_unimplemented.py index 1e784440b5d..5bfd2bbd05e 100644 --- a/tests/integ/modin/test_unimplemented.py +++ b/tests/integ/modin/test_unimplemented.py @@ -154,8 +154,6 @@ def test_unsupported_str_methods(func, func_name, caplog) -> None: lambda idx: idx.is_monotonic_decreasing(), lambda idx: idx.nbytes(), lambda idx: idx.memory_usage(), - lambda idx: idx.argmin(), - lambda idx: idx.argmax(), lambda idx: idx.delete(), lambda idx: idx.drop_duplicates(), lambda idx: idx.factorize(),