diff --git a/src/snowflake/snowpark/modin/pandas/__init__.py b/src/snowflake/snowpark/modin/pandas/__init__.py index 497dea43697..d1c31c48ef1 100644 --- a/src/snowflake/snowpark/modin/pandas/__init__.py +++ b/src/snowflake/snowpark/modin/pandas/__init__.py @@ -183,8 +183,11 @@ # dt and str accessors raise AttributeErrors that get caught by Modin __getitem__. Whitelist # them in _ATTRS_NO_LOOKUP here to avoid this. +# In upstream Modin, we should change __getitem__ to perform a direct getitem call rather than +# calling self.index[]. modin.pandas.base._ATTRS_NO_LOOKUP.add("dt") modin.pandas.base._ATTRS_NO_LOOKUP.add("str") +modin.pandas.base._ATTRS_NO_LOOKUP.add("columns") modin.pandas.base._ATTRS_NO_LOOKUP.update(_ATTRS_NO_LOOKUP) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/index.py b/src/snowflake/snowpark/modin/plugin/extensions/index.py index 2682fd2b985..df1c1fa9e7e 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/index.py @@ -262,7 +262,7 @@ def __getattr__(self, key: str) -> Any: def _binary_ops(self, method: str, other: Any) -> Index: if isinstance(other, Index): other = other.to_series().reset_index(drop=True) - series = self.to_series().reset_index(drop=True).__getattr__(method)(other) + series = getattr(self.to_series().reset_index(drop=True), method)(other) qc = series._query_compiler qc = qc.set_index_from_columns(qc.columns, include_index=False) # Use base constructor to ensure that the correct type is returned. @@ -272,7 +272,7 @@ def _binary_ops(self, method: str, other: Any) -> Index: def _unary_ops(self, method: str) -> Index: return self.__constructor__( - self.to_series().reset_index(drop=True).__getattr__(method)() + getattr(self.to_series().reset_index(drop=True), method)() ) def __add__(self, other: Any) -> Index: diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py index abe3b19d0df..97e1b1383b9 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py @@ -908,6 +908,27 @@ def argmin(self, axis=None, skipna=True, *args, **kwargs): # noqa: PR01, RT01, return result +# Modin uses the same implementation as Snowpark pandas starting form 0.31.0. +# Until then, upstream Modin does not convert arguments in the caselist into query compilers. +@register_series_accessor("case_when") +@snowpark_pandas_telemetry_method_decorator +def case_when(self, caselist) -> Series: # noqa: PR01, RT01, D200 + """ + Replace values where the conditions are True. + """ + modin_type = type(self) + caselist = [ + tuple( + data._query_compiler if isinstance(data, modin_type) else data + for data in case_tuple + ) + for case_tuple in caselist + ] + return self.__constructor__( + query_compiler=self._query_compiler.case_when(caselist=caselist) + ) + + # Snowpark pandas does not respect `ignore_index`, and upstream Modin does not respect `how`. @register_series_accessor("dropna") @snowpark_pandas_telemetry_method_decorator diff --git a/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py b/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py index f673bf157bf..7740d9cd2f9 100644 --- a/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py +++ b/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py @@ -3,9 +3,9 @@ # from typing import Any, Optional, Union +import modin.pandas as pd from modin.pandas.base import BasePandasDataset -import snowflake.snowpark.modin.pandas as pd from snowflake.snowpark.modin.pandas.utils import is_scalar from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage