From 25c1006f4df475befc59bc57be652be2ffdfefcf Mon Sep 17 00:00:00 2001
From: Jonathan Shi <149419494+sfc-gh-joshi@users.noreply.github.com>
Date: Fri, 30 Aug 2024 12:01:00 -0700
Subject: [PATCH] SNOW-1119855: Remove modin/pandas/base.py (2/2) (#2167)
1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.
Fixes SNOW-1119855
2. Fill out the following pre-review checklist:
- [ ] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
- [ ] I am adding new logging messages
- [ ] I am adding a new telemetry message
- [ ] I am adding new credentials
- [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.
3. Please describe how your code solves the related issue.
This followup to #2059 removes our vendored copy of `base.py` altogether
from the codebase. Many methods are still overridden in
`snowflake/snowpark/modin/plugin/extensions/base_overrides.py`, with the
reason for overriding each given inline in the comments. These methods
are as follows:
list
- agg/aggregate
- _agg_helper
- various aggregations (count, max, min, mean, median, std, sum,
var)
- _binary_op
- _dropna
- fillna
- isin
- quantile
- _to_series_list
- shift
- skew
- resample
- expanding
- rolling
- indexer properties (iloc, loc, iat, at)
- __getitem__
- sort_values
- where
- mask
- to_csv
- sample
- pct_change
- astype
- drop
- __len__
- set_axis
- describe
- diff
- tail
- idxmax
- idxmin
- unary operators (abs, __invert__, __neg__)
- rename_axis
- __array__ufunc__
- reindex
- _get_index
- _set_index
Some of these differences can be upstreamed fairly easily, and we will
work to upstream them once the updated modin build process for Snowflake
becomes clearer. Some methods will require significantly more work to
reconcile.
---
.../snowpark/modin/pandas/__init__.py | 79 +-
.../snowpark/modin/pandas/dataframe.py | 6 +-
.../snowpark/modin/pandas/general.py | 2 +-
.../snowpark/modin/pandas/indexing.py | 2 +-
src/snowflake/snowpark/modin/pandas/series.py | 6 +-
src/snowflake/snowpark/modin/pandas/utils.py | 3 +-
.../snowpark/modin/plugin/__init__.py | 26 +-
.../modin/plugin/_internal/telemetry.py | 83 +-
.../compiler/snowflake_query_compiler.py | 38 +-
.../modin/plugin/docstrings/dataframe.py | 4 +-
.../modin/plugin/docstrings/series.py | 4 +-
.../plugin/extensions/base_extensions.py | 46 +
.../modin/plugin/extensions/base_overrides.py | 1996 ++++++++++++++++-
.../snowpark/modin/plugin/extensions/index.py | 2 +-
.../plugin/extensions/series_overrides.py | 1 +
.../modin/plugin/utils/frontend_constants.py | 27 +
.../modin/plugin/utils/numpy_to_pandas.py | 3 +-
tests/integ/modin/test_telemetry.py | 23 +-
tests/unit/modin/modin/test_envvars.py | 48 +-
19 files changed, 2298 insertions(+), 101 deletions(-)
create mode 100644 src/snowflake/snowpark/modin/plugin/extensions/base_extensions.py
create mode 100644 src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py
diff --git a/src/snowflake/snowpark/modin/pandas/__init__.py b/src/snowflake/snowpark/modin/pandas/__init__.py
index c4eb07d9589..b51a47b64b3 100644
--- a/src/snowflake/snowpark/modin/pandas/__init__.py
+++ b/src/snowflake/snowpark/modin/pandas/__init__.py
@@ -85,10 +85,16 @@
timedelta_range,
)
+import modin.pandas
+
# TODO: SNOW-851745 make sure add all Snowpark pandas API general functions
from modin.pandas import plotting # type: ignore[import]
-from snowflake.snowpark.modin.pandas.dataframe import DataFrame
+from snowflake.snowpark.modin.pandas.api.extensions import (
+ register_dataframe_accessor,
+ register_series_accessor,
+)
+from snowflake.snowpark.modin.pandas.dataframe import _DATAFRAME_EXTENSIONS_, DataFrame
from snowflake.snowpark.modin.pandas.general import (
concat,
crosstab,
@@ -140,15 +146,15 @@
read_xml,
to_pickle,
)
-from snowflake.snowpark.modin.pandas.series import Series
+from snowflake.snowpark.modin.pandas.series import _SERIES_EXTENSIONS_, Series
from snowflake.snowpark.modin.plugin._internal.session import SnowpandasSessionHolder
+from snowflake.snowpark.modin.plugin._internal.telemetry import (
+ try_add_telemetry_to_attribute,
+)
# The extensions assigned to this module
_PD_EXTENSIONS_: dict = {}
-# base needs to be re-exported in order to properly override docstrings for BasePandasDataset
-# moving this import higher prevents sphinx from building documentation (??)
-from snowflake.snowpark.modin.pandas import base # isort: skip # noqa: E402,F401
import snowflake.snowpark.modin.plugin.extensions.pd_extensions as pd_extensions # isort: skip # noqa: E402,F401
import snowflake.snowpark.modin.plugin.extensions.pd_overrides # isort: skip # noqa: E402,F401
@@ -157,12 +163,71 @@
DatetimeIndex,
TimedeltaIndex,
)
+
+# this must occur before overrides are applied
+_attrs_defined_on_modin_base = set(dir(modin.pandas.base.BasePandasDataset))
+_attrs_defined_on_series = set(
+ dir(Series)
+) # TODO: SNOW-1063347 revisit when series.py is removed
+_attrs_defined_on_dataframe = set(
+ dir(DataFrame)
+) # TODO: SNOW-1063346 revisit when dataframe.py is removed
+
+# base overrides occur before subclass overrides in case subclasses override a base method
+import snowflake.snowpark.modin.plugin.extensions.base_extensions # isort: skip # noqa: E402,F401
import snowflake.snowpark.modin.plugin.extensions.base_overrides # isort: skip # noqa: E402,F401
import snowflake.snowpark.modin.plugin.extensions.dataframe_extensions # isort: skip # noqa: E402,F401
import snowflake.snowpark.modin.plugin.extensions.dataframe_overrides # isort: skip # noqa: E402,F401
import snowflake.snowpark.modin.plugin.extensions.series_extensions # isort: skip # noqa: E402,F401
import snowflake.snowpark.modin.plugin.extensions.series_overrides # isort: skip # noqa: E402,F401
+# For any method defined on Series/DF, add telemetry to it if it meets all of the following conditions:
+# 1. The method was defined directly on upstream BasePandasDataset (_attrs_defined_on_modin_base)
+# 2. The method is not overridden by a child class (this will change)
+# 3. The method is not overridden by an extensions module
+# 4. The method name does not start with an _
+#
+# TODO: SNOW-1063347
+# Since we still use the vendored version of Series and the overrides for the top-level
+# namespace haven't been performed yet, we need to set properties on the vendored version
+_base_telemetry_added_attrs = set()
+
+_series_ext = _SERIES_EXTENSIONS_.copy()
+for attr_name in dir(Series):
+ if (
+ attr_name in _attrs_defined_on_modin_base
+ and attr_name in _attrs_defined_on_series
+ and attr_name not in _series_ext
+ and not attr_name.startswith("_")
+ ):
+ register_series_accessor(attr_name)(
+ try_add_telemetry_to_attribute(attr_name, getattr(Series, attr_name))
+ )
+ _base_telemetry_added_attrs.add(attr_name)
+
+# TODO: SNOW-1063346
+# Since we still use the vendored version of DataFrame and the overrides for the top-level
+# namespace haven't been performed yet, we need to set properties on the vendored version
+_dataframe_ext = _DATAFRAME_EXTENSIONS_.copy()
+for attr_name in dir(DataFrame):
+ if (
+ attr_name in _attrs_defined_on_modin_base
+ and attr_name in _attrs_defined_on_dataframe
+ and attr_name not in _dataframe_ext
+ and not attr_name.startswith("_")
+ ):
+ # If telemetry was already added via Series, register the override but don't re-wrap
+ # the method in the telemetry annotation. If we don't do this check, we will end up
+ # double-reporting telemetry on some methods.
+ original_attr = getattr(DataFrame, attr_name)
+ new_attr = (
+ original_attr
+ if attr_name in _base_telemetry_added_attrs
+ else try_add_telemetry_to_attribute(attr_name, original_attr)
+ )
+ register_dataframe_accessor(attr_name)(new_attr)
+ _base_telemetry_added_attrs.add(attr_name)
+
def __getattr__(name: str) -> Any:
"""
@@ -220,7 +285,6 @@ def __getattr__(name: str) -> Any:
"date_range",
"Index",
"MultiIndex",
- "Series",
"bdate_range",
"period_range",
"DatetimeIndex",
@@ -318,8 +382,7 @@ def __getattr__(name: str) -> Any:
# Manually re-export the members of the pd_extensions namespace, which are not declared in __all__.
_EXTENSION_ATTRS = ["read_snowflake", "to_snowflake", "to_snowpark", "to_pandas"]
# We also need to re-export native_pd.offsets, since modin.pandas doesn't re-export it.
-# snowflake.snowpark.pandas.base also needs to be re-exported to make docstring overrides for BasePandasDataset work.
-_ADDITIONAL_ATTRS = ["offsets", "base"]
+_ADDITIONAL_ATTRS = ["offsets"]
# This code should eventually be moved into the `snowflake.snowpark.modin.plugin` module instead.
# Currently, trying to do so would result in incorrect results because `snowflake.snowpark.modin.pandas`
diff --git a/src/snowflake/snowpark/modin/pandas/dataframe.py b/src/snowflake/snowpark/modin/pandas/dataframe.py
index a7d53813779..b42ad5a04c7 100644
--- a/src/snowflake/snowpark/modin/pandas/dataframe.py
+++ b/src/snowflake/snowpark/modin/pandas/dataframe.py
@@ -37,6 +37,7 @@
import numpy as np
import pandas
from modin.pandas.accessor import CachedAccessor, SparseFrameAccessor
+from modin.pandas.base import BasePandasDataset
# from . import _update_engine
from modin.pandas.iterator import PartitionIterator
@@ -73,7 +74,6 @@
from pandas.util._validators import validate_bool_kwarg
from snowflake.snowpark.modin import pandas as pd
-from snowflake.snowpark.modin.pandas.base import _ATTRS_NO_LOOKUP, BasePandasDataset
from snowflake.snowpark.modin.pandas.groupby import (
DataFrameGroupBy,
validate_groupby_args,
@@ -91,12 +91,14 @@
replace_external_data_keys_with_empty_pandas_series,
replace_external_data_keys_with_query_compiler,
)
+from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta
from snowflake.snowpark.modin.plugin._internal.utils import is_repr_truncated
from snowflake.snowpark.modin.plugin._typing import DropKeep, ListLike
from snowflake.snowpark.modin.plugin.utils.error_message import (
ErrorMessage,
dataframe_not_implemented,
)
+from snowflake.snowpark.modin.plugin.utils.frontend_constants import _ATTRS_NO_LOOKUP
from snowflake.snowpark.modin.plugin.utils.warning_message import (
SET_DATAFRAME_ATTRIBUTE_WARNING,
WarningMessage,
@@ -136,7 +138,7 @@
],
apilink="pandas.DataFrame",
)
-class DataFrame(BasePandasDataset):
+class DataFrame(BasePandasDataset, metaclass=TelemetryMeta):
_pandas_class = pandas.DataFrame
def __init__(
diff --git a/src/snowflake/snowpark/modin/pandas/general.py b/src/snowflake/snowpark/modin/pandas/general.py
index f14e14840bb..07f0617d612 100644
--- a/src/snowflake/snowpark/modin/pandas/general.py
+++ b/src/snowflake/snowpark/modin/pandas/general.py
@@ -30,6 +30,7 @@
import numpy as np
import pandas
import pandas.core.common as common
+from modin.pandas.base import BasePandasDataset
from pandas import IntervalIndex, NaT, Timedelta, Timestamp
from pandas._libs import NaTType, lib
from pandas._libs.tslibs import to_offset
@@ -61,7 +62,6 @@
# add this line to make doctests runnable
from snowflake.snowpark.modin import pandas as pd # noqa: F401
-from snowflake.snowpark.modin.pandas.base import BasePandasDataset
from snowflake.snowpark.modin.pandas.dataframe import DataFrame
from snowflake.snowpark.modin.pandas.series import Series
from snowflake.snowpark.modin.pandas.utils import (
diff --git a/src/snowflake/snowpark/modin/pandas/indexing.py b/src/snowflake/snowpark/modin/pandas/indexing.py
index 0ac62f504ce..c83e3fe41c4 100644
--- a/src/snowflake/snowpark/modin/pandas/indexing.py
+++ b/src/snowflake/snowpark/modin/pandas/indexing.py
@@ -43,6 +43,7 @@
import numpy as np
import pandas
+from modin.pandas.base import BasePandasDataset
from pandas._libs.tslibs import Resolution, parsing
from pandas._typing import AnyArrayLike, Scalar
from pandas.api.types import is_bool, is_list_like
@@ -58,7 +59,6 @@
import snowflake.snowpark.modin.pandas as pd
import snowflake.snowpark.modin.pandas.utils as frontend_utils
-from snowflake.snowpark.modin.pandas.base import BasePandasDataset
from snowflake.snowpark.modin.pandas.dataframe import DataFrame
from snowflake.snowpark.modin.pandas.series import (
SERIES_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE_ERROR_MESSAGE,
diff --git a/src/snowflake/snowpark/modin/pandas/series.py b/src/snowflake/snowpark/modin/pandas/series.py
index 1ce3ecfc997..6e1b93437a8 100644
--- a/src/snowflake/snowpark/modin/pandas/series.py
+++ b/src/snowflake/snowpark/modin/pandas/series.py
@@ -31,6 +31,7 @@
import numpy.typing as npt
import pandas
from modin.pandas.accessor import CachedAccessor, SparseAccessor
+from modin.pandas.base import BasePandasDataset
from modin.pandas.iterator import PartitionIterator
from pandas._libs.lib import NoDefault, is_integer, no_default
from pandas._typing import (
@@ -51,17 +52,18 @@
from pandas.core.series import _coerce_method
from pandas.util._validators import validate_bool_kwarg
-from snowflake.snowpark.modin.pandas.base import _ATTRS_NO_LOOKUP, BasePandasDataset
from snowflake.snowpark.modin.pandas.utils import (
from_pandas,
is_scalar,
try_convert_index_to_native,
)
+from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta
from snowflake.snowpark.modin.plugin._typing import DropKeep, ListLike
from snowflake.snowpark.modin.plugin.utils.error_message import (
ErrorMessage,
series_not_implemented,
)
+from snowflake.snowpark.modin.plugin.utils.frontend_constants import _ATTRS_NO_LOOKUP
from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage
from snowflake.snowpark.modin.utils import (
MODIN_UNNAMED_SERIES_LABEL,
@@ -108,7 +110,7 @@
],
apilink="pandas.Series",
)
-class Series(BasePandasDataset):
+class Series(BasePandasDataset, metaclass=TelemetryMeta):
_pandas_class = pandas.Series
__array_priority__ = pandas.Series.__array_priority__
diff --git a/src/snowflake/snowpark/modin/pandas/utils.py b/src/snowflake/snowpark/modin/pandas/utils.py
index f971e0ff964..32702c8b1a4 100644
--- a/src/snowflake/snowpark/modin/pandas/utils.py
+++ b/src/snowflake/snowpark/modin/pandas/utils.py
@@ -170,10 +170,9 @@ def is_scalar(obj):
bool
True if given object is scalar and False otherwise.
"""
+ from modin.pandas.base import BasePandasDataset
from pandas.api.types import is_scalar as pandas_is_scalar
- from .base import BasePandasDataset
-
return not isinstance(obj, BasePandasDataset) and pandas_is_scalar(obj)
diff --git a/src/snowflake/snowpark/modin/plugin/__init__.py b/src/snowflake/snowpark/modin/plugin/__init__.py
index a76b9fe1613..c4172f26696 100644
--- a/src/snowflake/snowpark/modin/plugin/__init__.py
+++ b/src/snowflake/snowpark/modin/plugin/__init__.py
@@ -63,15 +63,23 @@
import modin.utils # type: ignore[import] # isort: skip # noqa: E402
import modin.pandas.series_utils # type: ignore[import] # isort: skip # noqa: E402
-modin.utils._inherit_docstrings(
- docstrings.series_utils.StringMethods,
- overwrite_existing=True,
-)(modin.pandas.series_utils.StringMethods)
-
-modin.utils._inherit_docstrings(
- docstrings.series_utils.CombinedDatetimelikeProperties,
- overwrite_existing=True,
-)(modin.pandas.series_utils.DatetimeProperties)
+# TODO: SNOW-1643979 pull in fixes for
+# https://github.com/modin-project/modin/issues/7113 and https://github.com/modin-project/modin/issues/7134
+# Upstream Modin has issues with certain docstring generation edge cases, so we should use our version instead
+_inherit_docstrings = snowflake.snowpark.modin.utils._inherit_docstrings
+
+inherit_modules = [
+ (docstrings.base.BasePandasDataset, modin.pandas.base.BasePandasDataset),
+ (docstrings.series_utils.StringMethods, modin.pandas.series_utils.StringMethods),
+ (
+ docstrings.series_utils.CombinedDatetimelikeProperties,
+ modin.pandas.series_utils.DatetimeProperties,
+ ),
+]
+
+for (doc_module, target_object) in inherit_modules:
+ _inherit_docstrings(doc_module, overwrite_existing=True)(target_object)
+
# Don't warn the user about our internal usage of private preview pivot
# features. The user should have already been warned that Snowpark pandas
diff --git a/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py b/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py
index 0a022b0d588..8057cf93885 100644
--- a/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py
+++ b/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py
@@ -495,6 +495,49 @@ def wrap(*args, **kwargs): # type: ignore
}
+def try_add_telemetry_to_attribute(attr_name: str, attr_value: Any) -> Any:
+ """
+ Attempts to add telemetry to an attribute.
+
+ If the attribute is callable with name in TELEMETRY_PRIVATE_METHODS, or is a callable that
+ starts with an underscore, the original attribute will be returned as-is. Otherwise, a version
+ of the method/property annotated with Snowpark pandas telemetry is returned.
+ """
+ if callable(attr_value) and (
+ not attr_name.startswith("_") or (attr_name in TELEMETRY_PRIVATE_METHODS)
+ ):
+ return snowpark_pandas_telemetry_method_decorator(attr_value)
+ elif isinstance(attr_value, property):
+ # wrap on getter and setter
+ return property(
+ snowpark_pandas_telemetry_method_decorator(
+ cast(
+ # add a cast because mypy doesn't recognize that
+ # non-None fget and __get__ are both callable
+ # arguments to snowpark_pandas_telemetry_method_decorator.
+ Callable,
+ attr_value.__get__ # pragma: no cover: we don't encounter this case in pandas or modin because every property has an fget method.
+ if attr_value.fget is None
+ else attr_value.fget,
+ ),
+ property_name=attr_name,
+ property_method_type=PropertyMethodType.FGET,
+ ),
+ snowpark_pandas_telemetry_method_decorator(
+ attr_value.__set__ if attr_value.fset is None else attr_value.fset,
+ property_name=attr_name,
+ property_method_type=PropertyMethodType.FSET,
+ ),
+ snowpark_pandas_telemetry_method_decorator(
+ attr_value.__delete__ if attr_value.fdel is None else attr_value.fdel,
+ property_name=attr_name,
+ property_method_type=PropertyMethodType.FDEL,
+ ),
+ doc=attr_value.__doc__,
+ )
+ return attr_value
+
+
class TelemetryMeta(type):
def __new__(
cls, name: str, bases: tuple, attrs: dict[str, Any]
@@ -536,43 +579,5 @@ def __new__(
The modified class with decorated methods.
"""
for attr_name, attr_value in attrs.items():
- if callable(attr_value) and (
- not attr_name.startswith("_")
- or (attr_name in TELEMETRY_PRIVATE_METHODS)
- ):
- attrs[attr_name] = snowpark_pandas_telemetry_method_decorator(
- attr_value
- )
- elif isinstance(attr_value, property):
- # wrap on getter and setter
- attrs[attr_name] = property(
- snowpark_pandas_telemetry_method_decorator(
- cast(
- # add a cast because mypy doesn't recognize that
- # non-None fget and __get__ are both callable
- # arguments to snowpark_pandas_telemetry_method_decorator.
- Callable,
- attr_value.__get__ # pragma: no cover: we don't encounter this case in pandas or modin because every property has an fget method.
- if attr_value.fget is None
- else attr_value.fget,
- ),
- property_name=attr_name,
- property_method_type=PropertyMethodType.FGET,
- ),
- snowpark_pandas_telemetry_method_decorator(
- attr_value.__set__
- if attr_value.fset is None
- else attr_value.fset,
- property_name=attr_name,
- property_method_type=PropertyMethodType.FSET,
- ),
- snowpark_pandas_telemetry_method_decorator(
- attr_value.__delete__
- if attr_value.fdel is None
- else attr_value.fdel,
- property_name=attr_name,
- property_method_type=PropertyMethodType.FDEL,
- ),
- doc=attr_value.__doc__,
- )
+ attrs[attr_name] = try_add_telemetry_to_attribute(attr_name, attr_value)
return type.__new__(cls, name, bases, attrs)
diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py
index bbebbec1783..50ce5e71310 100644
--- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py
+++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py
@@ -388,6 +388,8 @@ class SnowflakeQueryCompiler(BaseQueryCompiler):
this class is best explained by looking at https://github.com/modin-project/modin/blob/a8be482e644519f2823668210cec5cf1564deb7e/modin/experimental/core/storage_formats/hdk/query_compiler.py
"""
+ lazy_execution = True
+
def __init__(self, frame: InternalFrame) -> None:
"""this stores internally a local pandas object (refactor this)"""
assert frame is not None and isinstance(
@@ -767,6 +769,7 @@ def execute(self) -> None:
def to_numpy(
self,
dtype: Optional[npt.DTypeLike] = None,
+ copy: Optional[bool] = False,
na_value: object = lib.no_default,
**kwargs: Any,
) -> np.ndarray:
@@ -774,6 +777,12 @@ def to_numpy(
# i.e., for something like df.values internally to_numpy().flatten() is called
# with flatten being another query compiler call into the numpy frontend layer.
# here it's overwritten to actually perform numpy conversion, i.e. return an actual numpy object
+ if copy:
+ WarningMessage.ignored_argument(
+ operation="to_numpy",
+ argument="copy",
+ message="copy is ignored in Snowflake backend",
+ )
return self.to_pandas().to_numpy(dtype=dtype, na_value=na_value, **kwargs)
def repartition(self, axis: Any = None) -> "SnowflakeQueryCompiler":
@@ -1400,17 +1409,6 @@ def cache_result(self) -> "SnowflakeQueryCompiler":
"""
return SnowflakeQueryCompiler(self._modin_frame.persist_to_temporary_table())
- @property
- def columns(self) -> native_pd.Index:
- """
- Get pandas column labels.
-
- Returns:
- an index containing all pandas column labels
- """
- # TODO SNOW-837664: add more tests for df.columns
- return self._modin_frame.data_columns_index
-
@snowpark_pandas_type_immutable_check
def set_columns(self, new_pandas_labels: Axes) -> "SnowflakeQueryCompiler":
"""
@@ -1465,6 +1463,12 @@ def set_columns(self, new_pandas_labels: Axes) -> "SnowflakeQueryCompiler":
)
return SnowflakeQueryCompiler(new_internal_frame)
+ # TODO SNOW-837664: add more tests for df.columns
+ def get_columns(self) -> native_pd.Index:
+ return self._modin_frame.data_columns_index
+
+ columns: native_pd.Index = property(get_columns, set_columns)
+
def _shift_values(
self, periods: int, axis: Union[Literal[0], Literal[1]], fill_value: Hashable
) -> "SnowflakeQueryCompiler":
@@ -2807,6 +2811,8 @@ def reset_index(
Returns:
A new SnowflakeQueryCompiler instance with updated index.
"""
+ if allow_duplicates is no_default:
+ allow_duplicates = False
# These levels will be moved from index columns to data columns
levels_to_be_reset = self._modin_frame.parse_levels_to_integer_levels(
level, allow_duplicates=False
@@ -3007,9 +3013,11 @@ def first_last_valid_index(
def sort_index(
self,
+ *,
axis: int,
level: Optional[list[Union[str, int]]],
ascending: Union[bool, list[bool]],
+ inplace: bool = False,
kind: SortKind,
na_position: NaPosition,
sort_remaining: bool,
@@ -3025,6 +3033,8 @@ def sort_index(
level: If not None, sort on values in specified index level(s).
ascending: A list of bools to represent ascending vs descending sort. Defaults to True.
When the index is a MultiIndex the sort direction can be controlled for each level individually.
+ inplace: Whether or not the sort occurs in-place. This argument is ignored and only provided
+ for compatibility with Modin.
kind: Choice of sorting algorithm. Perform stable sort if 'stable'. Defaults to unstable sort.
Snowpark pandas ignores choice of sorting algorithm except 'stable'.
na_position: Puts NaNs at the beginning if 'first'; 'last' puts NaNs at the end. Defaults to 'last'
@@ -10859,6 +10869,12 @@ def is_multiindex(self, *, axis: int = 0) -> bool:
"""
return self._modin_frame.is_multiindex(axis=axis)
+ def abs(self) -> "SnowflakeQueryCompiler":
+ return self.unary_op("abs")
+
+ def negative(self) -> "SnowflakeQueryCompiler":
+ return self.unary_op("__neg__")
+
def unary_op(self, op: str) -> "SnowflakeQueryCompiler":
"""
Applies a unary operation `op` on each element of the `SnowflakeQueryCompiler`.
diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py b/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py
index 047b4592068..f0c02aa0e65 100644
--- a/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py
+++ b/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py
@@ -13,6 +13,8 @@
_shared_docs,
)
+from .base import BasePandasDataset
+
_doc_binary_op_kwargs = {"returns": "BasePandasDataset", "left": "BasePandasDataset"}
@@ -49,7 +51,7 @@
}
-class DataFrame:
+class DataFrame(BasePandasDataset):
"""
Snowpark pandas representation of ``pandas.DataFrame`` with a lazily-evaluated relational dataset.
diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/series.py b/src/snowflake/snowpark/modin/plugin/docstrings/series.py
index 6e48a7e57f3..4878c82635a 100644
--- a/src/snowflake/snowpark/modin/plugin/docstrings/series.py
+++ b/src/snowflake/snowpark/modin/plugin/docstrings/series.py
@@ -15,6 +15,8 @@
)
from snowflake.snowpark.modin.utils import _create_operator_docstring
+from .base import BasePandasDataset
+
_shared_doc_kwargs = {
"axes": "index",
"klass": "Series",
@@ -35,7 +37,7 @@
}
-class Series:
+class Series(BasePandasDataset):
"""
Snowpark pandas representation of `pandas.Series` with a lazily-evaluated relational dataset.
diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_extensions.py b/src/snowflake/snowpark/modin/plugin/extensions/base_extensions.py
new file mode 100644
index 00000000000..496136d736e
--- /dev/null
+++ b/src/snowflake/snowpark/modin/plugin/extensions/base_extensions.py
@@ -0,0 +1,46 @@
+#
+# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
+#
+
+"""
+File containing BasePandasDataset APIs defined in Snowpark pandas but not the Modin API layer.
+"""
+
+from snowflake.snowpark.modin.plugin._internal.telemetry import (
+ snowpark_pandas_telemetry_method_decorator,
+)
+
+from .base_overrides import register_base_override
+
+
+@register_base_override("__array_function__")
+@snowpark_pandas_telemetry_method_decorator
+def __array_function__(self, func: callable, types: tuple, args: tuple, kwargs: dict):
+ """
+ Apply the `func` to the `BasePandasDataset`.
+
+ Parameters
+ ----------
+ func : np.func
+ The NumPy func to apply.
+ types : tuple
+ The types of the args.
+ args : tuple
+ The args to the func.
+ kwargs : dict
+ Additional keyword arguments.
+
+ Returns
+ -------
+ BasePandasDataset
+ The result of the ufunc applied to the `BasePandasDataset`.
+ """
+ from snowflake.snowpark.modin.plugin.utils.numpy_to_pandas import (
+ numpy_to_pandas_func_map,
+ )
+
+ if func.__name__ in numpy_to_pandas_func_map:
+ return numpy_to_pandas_func_map[func.__name__](*args, **kwargs)
+ else:
+ # per NEP18 we raise NotImplementedError so that numpy can intercept
+ return NotImplemented # pragma: no cover
diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py
index 332df757787..abbcb9bc762 100644
--- a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py
+++ b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py
@@ -6,31 +6,123 @@
Methods defined on BasePandasDataset that are overridden in Snowpark pandas. Adding a method to this file
should be done with discretion, and only when relevant changes cannot be made to the query compiler or
upstream frontend to accommodate Snowpark pandas.
+
+If you must override a method in this file, please add a comment describing why it must be overridden,
+and if possible, whether this can be reconciled with upstream Modin.
"""
from __future__ import annotations
import pickle as pkl
-from typing import Any
+import warnings
+from collections.abc import Sequence
+from typing import Any, Callable, Hashable, Literal, Mapping, get_args
+import modin.pandas as pd
import numpy as np
+import numpy.typing as npt
import pandas
from modin.pandas.base import BasePandasDataset
-from pandas._libs.lib import no_default
+from pandas._libs import lib
+from pandas._libs.lib import NoDefault, is_bool, no_default
from pandas._typing import (
+ AggFuncType,
+ AnyArrayLike,
+ Axes,
Axis,
CompressionOptions,
+ FillnaOptions,
+ IgnoreRaise,
+ IndexKeyFunc,
+ IndexLabel,
+ Level,
+ NaPosition,
+ RandomState,
+ Scalar,
StorageOptions,
TimedeltaConvertibleTypes,
+ TimestampConvertibleTypes,
+)
+from pandas.core.common import apply_if_callable
+from pandas.core.dtypes.common import (
+ is_dict_like,
+ is_dtype_equal,
+ is_list_like,
+ is_numeric_dtype,
+ pandas_dtype,
+)
+from pandas.core.dtypes.inference import is_integer
+from pandas.core.methods.describe import _refine_percentiles
+from pandas.errors import SpecificationError
+from pandas.util._validators import (
+ validate_ascending,
+ validate_bool_kwarg,
+ validate_percentile,
)
+import snowflake.snowpark.modin.pandas as spd
from snowflake.snowpark.modin.pandas.api.extensions import (
register_dataframe_accessor,
register_series_accessor,
)
+from snowflake.snowpark.modin.pandas.utils import (
+ ensure_index,
+ extract_validate_and_try_convert_named_aggs_from_kwargs,
+ get_as_shape_compatible_dataframe_or_series,
+ is_scalar,
+ raise_if_native_pandas_objects,
+ validate_and_try_convert_agg_func_arg_func_to_str,
+)
from snowflake.snowpark.modin.plugin._internal.telemetry import (
snowpark_pandas_telemetry_method_decorator,
+ try_add_telemetry_to_attribute,
+)
+from snowflake.snowpark.modin.plugin._typing import ListLike
+from snowflake.snowpark.modin.plugin.utils.error_message import (
+ ErrorMessage,
+ base_not_implemented,
)
-from snowflake.snowpark.modin.plugin.utils.error_message import base_not_implemented
+from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage
+from snowflake.snowpark.modin.utils import validate_int_kwarg
+
+
+def register_base_override(method_name: str):
+ """
+ Decorator function to override a method on BasePandasDataset. Since Modin does not provide a mechanism
+ for directly overriding methods on BasePandasDataset, we mock this by performing the override on
+ DataFrame and Series, and manually performing a `setattr` on the base class. These steps are necessary
+ to allow both the docstring extension and method dispatch to work properly.
+
+ Methods annotated here also are automatically instrumented with Snowpark pandas telemetry.
+ """
+
+ def decorator(base_method: Any):
+ base_method = try_add_telemetry_to_attribute(method_name, base_method)
+ parent_method = getattr(BasePandasDataset, method_name, None)
+ if isinstance(parent_method, property):
+ parent_method = parent_method.fget
+ # If the method was not defined on Series/DataFrame and instead inherited from the superclass
+ # we need to override it as well because the MRO was already determined or something?
+ # TODO: SNOW-1063347
+ # Since we still use the vendored version of Series and the overrides for the top-level
+ # namespace haven't been performed yet, we need to set properties on the vendored version
+ series_method = getattr(spd.series.Series, method_name, None)
+ if isinstance(series_method, property):
+ series_method = series_method.fget
+ if series_method is None or series_method is parent_method:
+ register_series_accessor(method_name)(base_method)
+ # TODO: SNOW-1063346
+ # Since we still use the vendored version of DataFrame and the overrides for the top-level
+ # namespace haven't been performed yet, we need to set properties on the vendored version
+ df_method = getattr(spd.dataframe.DataFrame, method_name, None)
+ if isinstance(df_method, property):
+ df_method = df_method.fget
+ if df_method is None or df_method is parent_method:
+ register_dataframe_accessor(method_name)(base_method)
+ # Replace base method
+ setattr(BasePandasDataset, method_name, base_method)
+ return base_method
+
+ return decorator
def register_base_not_implemented():
@@ -303,3 +395,1901 @@ def truncate(
@register_base_not_implemented()
def __finalize__(self, other, method=None, **kwargs):
pass # pragma: no cover
+
+
+# === OVERRIDDEN METHODS ===
+# The below methods have their frontend implementations overridden compared to the version present
+# in base.py. This is usually for one of the following reasons:
+# 1. The underlying QC interface used differs from that of modin. Notably, this applies to aggregate
+# and binary operations; further work is needed to refactor either our implementation or upstream
+# modin's implementation.
+# 2. Modin performs extra validation queries that perform extra SQL queries. Some of these are already
+# fixed on main; see https://github.com/modin-project/modin/issues/7340 for details.
+# 3. Upstream Modin defaults to pandas for some edge cases. Defaulting to pandas at the query compiler
+# layer is acceptable because we can force the method to raise NotImplementedError, but if a method
+# defaults at the frontend, Modin raises a warning and performs the operation by coercing the
+# dataset to a native pandas object. Removing these is tracked by
+# https://github.com/modin-project/modin/issues/7104
+# 4. Snowpark pandas uses different default arguments from modin. This occurs if some parameters are
+# only partially supported (like `numeric_only=True` for `skew`), but this behavior should likewise
+# be revisited.
+
+# `aggregate` for axis=1 is performed as a call to `BasePandasDataset.apply` in upstream Modin,
+# which is unacceptable for Snowpark pandas. Upstream Modin should be changed to allow the query
+# compiler or a different layer to control dispatch.
+@register_base_override("aggregate")
+def aggregate(
+ self, func: AggFuncType = None, axis: Axis | None = 0, *args: Any, **kwargs: Any
+):
+ """
+ Aggregate using one or more operations over the specified axis.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ from snowflake.snowpark.modin.pandas import Series
+
+ origin_axis = axis
+ axis = self._get_axis_number(axis)
+
+ if axis == 1 and isinstance(self, Series):
+ raise ValueError(f"No axis named {origin_axis} for object type Series")
+
+ if len(self._query_compiler.columns) == 0:
+ # native pandas raise error with message "no result", here we raise a more readable error.
+ raise ValueError("No column to aggregate on.")
+
+ # If we are using named kwargs, then we do not clear the kwargs (need them in the QC for processing
+ # order, as well as formatting error messages.)
+ uses_named_kwargs = False
+ # If aggregate is called on a Series, named aggregations can be passed in via a dictionary
+ # to func.
+ if func is None or (is_dict_like(func) and not self._is_dataframe):
+ if axis == 1:
+ raise ValueError(
+ "`func` must not be `None` when `axis=1`. Named aggregations are not supported with `axis=1`."
+ )
+ if func is not None:
+ # If named aggregations are passed in via a dictionary to func, then we
+ # ignore the kwargs.
+ if any(is_dict_like(value) for value in func.values()):
+ # We can only get to this codepath if self is a Series, and func is a dictionary.
+ # In this case, if any of the values of func are themselves dictionaries, we must raise
+ # a Specification Error, as that is what pandas does.
+ raise SpecificationError("nested renamer is not supported")
+ kwargs = func
+ func = extract_validate_and_try_convert_named_aggs_from_kwargs(
+ self, allow_duplication=False, axis=axis, **kwargs
+ )
+ uses_named_kwargs = True
+ else:
+ func = validate_and_try_convert_agg_func_arg_func_to_str(
+ agg_func=func,
+ obj=self,
+ allow_duplication=False,
+ axis=axis,
+ )
+
+ # This is to stay consistent with pandas result format, when the func is single
+ # aggregation function in format of callable or str, reduce the result dimension to
+ # convert dataframe to series, or convert series to scalar.
+ # Note: When named aggregations are used, the result is not reduced, even if there
+ # is only a single function.
+ # needs_reduce_dimension cannot be True if we are using named aggregations, since
+ # the values for func in that case are either NamedTuples (AggFuncWithLabels) or
+ # lists of NamedTuples, both of which are list like.
+ need_reduce_dimension = (
+ (callable(func) or isinstance(func, str))
+ # A Series should be returned when a single scalar string/function aggregation function, or a
+ # dict of scalar string/functions is specified. In all other cases (including if the function
+ # is a 1-element list), the result is a DataFrame.
+ #
+ # The examples below have axis=1, but the same logic is applied for axis=0.
+ # >>> df = pd.DataFrame({"a": [0, 1], "b": [2, 3]})
+ #
+ # single aggregation: return Series
+ # >>> df.agg("max", axis=1)
+ # 0 2
+ # 1 3
+ # dtype: int64
+ #
+ # list of aggregations: return DF
+ # >>> df.agg(["max"], axis=1)
+ # max
+ # 0 2
+ # 1 3
+ #
+ # dict where all aggregations are strings: return Series
+ # >>> df.agg({1: "max", 0: "min"}, axis=1)
+ # 1 3
+ # 0 0
+ # dtype: int64
+ #
+ # dict where one element is a list: return DF
+ # >>> df.agg({1: "max", 0: ["min"]}, axis=1)
+ # max min
+ # 1 3.0 NaN
+ # 0 NaN 0.0
+ or (
+ is_dict_like(func)
+ and all(not is_list_like(value) for value in func.values())
+ )
+ )
+
+ # If func is a dict, pandas will not respect kwargs for each aggregation function, and
+ # we should drop them before passing the to the query compiler.
+ #
+ # >>> native_pd.DataFrame({"a": [0, 1], "b": [np.nan, 0]}).agg("max", skipna=False, axis=1)
+ # 0 NaN
+ # 1 1.0
+ # dtype: float64
+ # >>> native_pd.DataFrame({"a": [0, 1], "b": [np.nan, 0]}).agg(["max"], skipna=False, axis=1)
+ # max
+ # 0 0.0
+ # 1 1.0
+ # >>> pd.DataFrame([[np.nan], [0]]).aggregate("count", skipna=True, axis=0)
+ # 0 1
+ # dtype: int8
+ # >>> pd.DataFrame([[np.nan], [0]]).count(skipna=True, axis=0)
+ # TypeError: got an unexpected keyword argument 'skipna'
+ if is_dict_like(func) and not uses_named_kwargs:
+ kwargs.clear()
+
+ result = self.__constructor__(
+ query_compiler=self._query_compiler.agg(
+ func=func,
+ axis=axis,
+ args=args,
+ kwargs=kwargs,
+ )
+ )
+
+ if need_reduce_dimension:
+ if self._is_dataframe:
+ result = Series(query_compiler=result._query_compiler)
+
+ if isinstance(result, Series):
+ # When func is just "quantile" with a scalar q, result has quantile value as name
+ q = kwargs.get("q", 0.5)
+ if func == "quantile" and is_scalar(q):
+ result.name = q
+ else:
+ result.name = None
+
+ # handle case for single scalar (same as result._reduce_dimension())
+ if isinstance(self, Series):
+ return result.to_pandas().squeeze()
+
+ return result
+
+
+# `agg` is an alias of `aggregate`.
+agg = aggregate
+register_base_override("agg")(agg)
+
+
+# `_agg_helper` is not defined in modin, and used by Snowpark pandas to do extra validation.
+@register_base_override("_agg_helper")
+def _agg_helper(
+ self,
+ func: str,
+ skipna: bool = True,
+ axis: int | None | NoDefault = no_default,
+ numeric_only: bool = False,
+ **kwargs: Any,
+):
+ if not self._is_dataframe and numeric_only and not is_numeric_dtype(self.dtype):
+ # Series aggregations on non-numeric data do not support numeric_only:
+ # https://github.com/pandas-dev/pandas/blob/cece8c6579854f6b39b143e22c11cac56502c4fd/pandas/core/series.py#L6358
+ raise TypeError(
+ f"Series.{func} does not allow numeric_only=True with non-numeric dtypes."
+ )
+ axis = self._get_axis_number(axis)
+ numeric_only = validate_bool_kwarg(numeric_only, "numeric_only", none_allowed=True)
+ skipna = validate_bool_kwarg(skipna, "skipna", none_allowed=False)
+ agg_kwargs: dict[str, Any] = {
+ "numeric_only": numeric_only,
+ "skipna": skipna,
+ }
+ agg_kwargs.update(kwargs)
+ return self.aggregate(func=func, axis=axis, **agg_kwargs)
+
+
+# See _agg_helper
+@register_base_override("count")
+def count(
+ self,
+ axis: Axis | None = 0,
+ numeric_only: bool = False,
+):
+ """
+ Count non-NA cells for `BasePandasDataset`.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ return self._agg_helper(
+ func="count",
+ axis=axis,
+ numeric_only=numeric_only,
+ )
+
+
+# See _agg_helper
+@register_base_override("max")
+def max(
+ self,
+ axis: Axis | None = 0,
+ skipna: bool = True,
+ numeric_only: bool = False,
+ **kwargs: Any,
+):
+ """
+ Return the maximum of the values over the requested axis.
+ """
+ return self._agg_helper(
+ func="max",
+ axis=axis,
+ skipna=skipna,
+ numeric_only=numeric_only,
+ **kwargs,
+ )
+
+
+# See _agg_helper
+@register_base_override("min")
+def min(
+ self,
+ axis: Axis | None | NoDefault = no_default,
+ skipna: bool = True,
+ numeric_only: bool = False,
+ **kwargs,
+):
+ """
+ Return the minimum of the values over the requested axis.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ return self._agg_helper(
+ func="min",
+ axis=axis,
+ skipna=skipna,
+ numeric_only=numeric_only,
+ **kwargs,
+ )
+
+
+# See _agg_helper
+@register_base_override("mean")
+def mean(
+ self,
+ axis: Axis | None | NoDefault = no_default,
+ skipna: bool = True,
+ numeric_only: bool = False,
+ **kwargs: Any,
+):
+ """
+ Return the mean of the values over the requested axis.
+ """
+ return self._agg_helper(
+ func="mean",
+ axis=axis,
+ skipna=skipna,
+ numeric_only=numeric_only,
+ **kwargs,
+ )
+
+
+# See _agg_helper
+@register_base_override("median")
+def median(
+ self,
+ axis: Axis | None | NoDefault = no_default,
+ skipna: bool = True,
+ numeric_only: bool = False,
+ **kwargs: Any,
+):
+ """
+ Return the mean of the values over the requested axis.
+ """
+ return self._agg_helper(
+ func="median",
+ axis=axis,
+ skipna=skipna,
+ numeric_only=numeric_only,
+ **kwargs,
+ )
+
+
+# See _agg_helper
+@register_base_override("std")
+def std(
+ self,
+ axis: Axis | None = None,
+ skipna: bool = True,
+ ddof: int = 1,
+ numeric_only: bool = False,
+ **kwargs,
+):
+ """
+ Return sample standard deviation over requested axis.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ kwargs.update({"ddof": ddof})
+ return self._agg_helper(
+ func="std",
+ axis=axis,
+ skipna=skipna,
+ numeric_only=numeric_only,
+ **kwargs,
+ )
+
+
+# See _agg_helper
+@register_base_override("sum")
+def sum(
+ self,
+ axis: Axis | None = None,
+ skipna: bool = True,
+ numeric_only: bool = False,
+ min_count: int = 0,
+ **kwargs: Any,
+):
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ min_count = validate_int_kwarg(min_count, "min_count")
+ kwargs.update({"min_count": min_count})
+ return self._agg_helper(
+ func="sum",
+ axis=axis,
+ skipna=skipna,
+ numeric_only=numeric_only,
+ **kwargs,
+ )
+
+
+# See _agg_helper
+@register_base_override("var")
+def var(
+ self,
+ axis: Axis | None = None,
+ skipna: bool = True,
+ ddof: int = 1,
+ numeric_only: bool = False,
+ **kwargs: Any,
+):
+ """
+ Return unbiased variance over requested axis.
+ """
+ kwargs.update({"ddof": ddof})
+ return self._agg_helper(
+ func="var",
+ axis=axis,
+ skipna=skipna,
+ numeric_only=numeric_only,
+ **kwargs,
+ )
+
+
+# Modin does not provide `MultiIndex` support and will default to pandas when `level` is specified,
+# and allows binary ops against native pandas objects that Snowpark pandas prohibits.
+@register_base_override("_binary_op")
+def _binary_op(
+ self,
+ op: str,
+ other: BasePandasDataset,
+ axis: Axis = None,
+ level: Level | None = None,
+ fill_value: float | None = None,
+ **kwargs: Any,
+):
+ """
+ Do binary operation between two datasets.
+
+ Parameters
+ ----------
+ op : str
+ Name of binary operation.
+ other : modin.pandas.BasePandasDataset
+ Second operand of binary operation.
+ axis: Whether to compare by the index (0 or ‘index’) or columns. (1 or ‘columns’).
+ level: Broadcast across a level, matching Index values on the passed MultiIndex level.
+ fill_value: Fill existing missing (NaN) values, and any new element needed for
+ successful DataFrame alignment, with this value before computation.
+ If data in both corresponding DataFrame locations is missing the result will be missing.
+ only arithmetic binary operation has this parameter (e.g., add() has, but eq() doesn't have).
+
+ kwargs can contain the following parameters passed in at the frontend:
+ func: Only used for `combine` method. Function that takes two series as inputs and
+ return a Series or a scalar. Used to merge the two dataframes column by columns.
+
+ Returns
+ -------
+ modin.pandas.BasePandasDataset
+ Result of binary operation.
+ """
+ # In upstream modin, _axis indicates the operator will use the default axis
+ if kwargs.pop("_axis", None) is None:
+ if axis is not None:
+ axis = self._get_axis_number(axis)
+ else:
+ axis = 1
+ else:
+ axis = 0
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ raise_if_native_pandas_objects(other)
+ axis = self._get_axis_number(axis)
+ squeeze_self = isinstance(self, pd.Series)
+
+ # pandas itself will ignore the axis argument when using Series..
+ # Per default, it is set to axis=0. However, for the case of a Series interacting with
+ # a DataFrame the behavior is axis=1. Manually check here for this case and adjust the axis.
+
+ is_lhs_series_and_rhs_dataframe = (
+ True
+ if isinstance(self, pd.Series) and isinstance(other, pd.DataFrame)
+ else False
+ )
+
+ new_query_compiler = self._query_compiler.binary_op(
+ op=op,
+ other=other,
+ axis=1 if is_lhs_series_and_rhs_dataframe else axis,
+ level=level,
+ fill_value=fill_value,
+ squeeze_self=squeeze_self,
+ **kwargs,
+ )
+
+ from snowflake.snowpark.modin.pandas.dataframe import DataFrame
+
+ # Modin Bug: https://github.com/modin-project/modin/issues/7236
+ # For a Series interacting with a DataFrame, always return a DataFrame
+ return (
+ DataFrame(query_compiler=new_query_compiler)
+ if is_lhs_series_and_rhs_dataframe
+ else self._create_or_update_from_compiler(new_query_compiler)
+ )
+
+
+# Current Modin does not use _dropna and instead defines `dropna` directly, but Snowpark pandas
+# Series/DF still do. Snowpark pandas still needs to add support for the `ignore_index` parameter
+# (added in pandas 2.0), and should be able to refactor to remove this override.
+@register_base_override("_dropna")
+def _dropna(
+ self,
+ axis: Axis = 0,
+ how: str | NoDefault = no_default,
+ thresh: int | NoDefault = no_default,
+ subset: IndexLabel = None,
+ inplace: bool = False,
+):
+ inplace = validate_bool_kwarg(inplace, "inplace")
+
+ if is_list_like(axis):
+ raise TypeError("supplying multiple axes to axis is no longer supported.")
+
+ axis = self._get_axis_number(axis)
+
+ if (how is not no_default) and (thresh is not no_default):
+ raise TypeError(
+ "You cannot set both the how and thresh arguments at the same time."
+ )
+
+ if how is no_default:
+ how = "any"
+ if how not in ["any", "all"]:
+ raise ValueError("invalid how option: %s" % how)
+ if subset is not None:
+ if axis == 1:
+ indices = self.index.get_indexer_for(subset)
+ check = indices == -1
+ if check.any():
+ raise KeyError(list(np.compress(check, subset)))
+ else:
+ indices = self.columns.get_indexer_for(subset)
+ check = indices == -1
+ if check.any():
+ raise KeyError(list(np.compress(check, subset)))
+
+ new_query_compiler = self._query_compiler.dropna(
+ axis=axis,
+ how=how,
+ thresh=thresh,
+ subset=subset,
+ )
+ return self._create_or_update_from_compiler(new_query_compiler, inplace)
+
+
+# Snowpark pandas uses `self_is_series` instead of `squeeze_self` and `squeeze_value` to determine
+# the shape of `self` and `value`. Further work is needed to reconcile these two approaches.
+@register_base_override("fillna")
+def fillna(
+ self,
+ self_is_series,
+ value: Hashable | Mapping | pd.Series | pd.DataFrame = None,
+ method: FillnaOptions | None = None,
+ axis: Axis | None = None,
+ inplace: bool = False,
+ limit: int | None = None,
+ downcast: dict | None = None,
+):
+ """
+ Fill NA/NaN values using the specified method.
+
+ Parameters
+ ----------
+ self_is_series : bool
+ If True then self contains a Series object, if False then self contains
+ a DataFrame object.
+ value : scalar, dict, Series, or DataFrame, default: None
+ Value to use to fill holes (e.g. 0), alternately a
+ dict/Series/DataFrame of values specifying which value to use for
+ each index (for a Series) or column (for a DataFrame). Values not
+ in the dict/Series/DataFrame will not be filled. This value cannot
+ be a list.
+ method : {'backfill', 'bfill', 'pad', 'ffill', None}, default: None
+ Method to use for filling holes in reindexed Series
+ pad / ffill: propagate last valid observation forward to next valid
+ backfill / bfill: use next valid observation to fill gap.
+ axis : {None, 0, 1}, default: None
+ Axis along which to fill missing values.
+ inplace : bool, default: False
+ If True, fill in-place. Note: this will modify any
+ other views on this object (e.g., a no-copy slice for a column in a
+ DataFrame).
+ limit : int, default: None
+ If method is specified, this is the maximum number of consecutive
+ NaN values to forward/backward fill. In other words, if there is
+ a gap with more than this number of consecutive NaNs, it will only
+ be partially filled. If method is not specified, this is the
+ maximum number of entries along the entire axis where NaNs will be
+ filled. Must be greater than 0 if not None.
+ downcast : dict, default: None
+ A dict of item->dtype of what to downcast if possible,
+ or the string 'infer' which will try to downcast to an appropriate
+ equal type (e.g. float64 to int64 if possible).
+
+ Returns
+ -------
+ Series, DataFrame or None
+ Object with missing values filled or None if ``inplace=True``.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ raise_if_native_pandas_objects(value)
+ inplace = validate_bool_kwarg(inplace, "inplace")
+ axis = self._get_axis_number(axis)
+ if isinstance(value, (list, tuple)):
+ raise TypeError(
+ '"value" parameter must be a scalar or dict, but '
+ + f'you passed a "{type(value).__name__}"'
+ )
+ if value is None and method is None:
+ # same as pandas
+ raise ValueError("Must specify a fill 'value' or 'method'.")
+ if value is not None and method is not None:
+ raise ValueError("Cannot specify both 'value' and 'method'.")
+ if method is not None and method not in ["backfill", "bfill", "pad", "ffill"]:
+ expecting = "pad (ffill) or backfill (bfill)"
+ msg = "Invalid fill method. Expecting {expecting}. Got {method}".format(
+ expecting=expecting, method=method
+ )
+ raise ValueError(msg)
+ if limit is not None:
+ if not isinstance(limit, int):
+ raise ValueError("Limit must be an integer")
+ elif limit <= 0:
+ raise ValueError("Limit must be greater than 0")
+
+ new_query_compiler = self._query_compiler.fillna(
+ self_is_series=self_is_series,
+ value=value,
+ method=method,
+ axis=axis,
+ limit=limit,
+ downcast=downcast,
+ )
+ return self._create_or_update_from_compiler(new_query_compiler, inplace)
+
+
+# Snowpark pandas passes the query compiler object from a BasePandasDataset, which Modin does not do.
+@register_base_override("isin")
+def isin(
+ self, values: BasePandasDataset | ListLike | dict[Hashable, ListLike]
+) -> BasePandasDataset: # noqa: PR01, RT01, D200
+ """
+ Whether elements in `BasePandasDataset` are contained in `values`.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+
+ # Pass as query compiler if values is BasePandasDataset.
+ if isinstance(values, BasePandasDataset):
+ values = values._query_compiler
+
+ # Convert non-dict values to List if values is neither List[Any] nor np.ndarray. SnowflakeQueryCompiler
+ # expects for the non-lazy case, where values is not a BasePandasDataset, the data to be materialized
+ # as list or numpy array. Because numpy may perform implicit type conversions, use here list to be more general.
+ elif not isinstance(values, dict) and (
+ not isinstance(values, list) or not isinstance(values, np.ndarray)
+ ):
+ values = list(values)
+
+ return self.__constructor__(query_compiler=self._query_compiler.isin(values=values))
+
+
+# Snowpark pandas uses the single `quantiles_along_axis0` query compiler method, while upstream
+# Modin splits this into `quantile_for_single_value` and `quantile_for_list_of_values` calls.
+# It should be possible to merge those two functions upstream and reconcile the implementations.
+@register_base_override("quantile")
+def quantile(
+ self,
+ q: Scalar | ListLike = 0.5,
+ axis: Axis = 0,
+ numeric_only: bool = False,
+ interpolation: Literal[
+ "linear", "lower", "higher", "midpoint", "nearest"
+ ] = "linear",
+ method: Literal["single", "table"] = "single",
+) -> float | BasePandasDataset:
+ """
+ Return values at the given quantile over requested axis.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ axis = self._get_axis_number(axis)
+
+ # TODO
+ # - SNOW-1008361: support axis=1
+ # - SNOW-1008367: support when q is Snowpandas DF/Series (need to require QC interface to accept QC q values)
+ # - SNOW-1003587: support datetime/timedelta columns
+
+ if axis == 1 or interpolation not in ["linear", "nearest"] or method != "single":
+ ErrorMessage.not_implemented(
+ f"quantile function with parameters axis={axis}, interpolation={interpolation}, method={method} not supported"
+ )
+
+ if not numeric_only:
+ # If not numeric_only and columns, then check all columns are either
+ # numeric, timestamp, or timedelta
+ # Check if dtype is numeric, timedelta ("m"), or datetime ("M")
+ if not axis and not all(
+ is_numeric_dtype(t) or lib.is_np_dtype(t, "mM") for t in self._get_dtypes()
+ ):
+ raise TypeError("can't multiply sequence by non-int of type 'float'")
+ # If over rows, then make sure that all dtypes are equal for not
+ # numeric_only
+ elif axis:
+ for i in range(1, len(self._get_dtypes())):
+ pre_dtype = self._get_dtypes()[i - 1]
+ curr_dtype = self._get_dtypes()[i]
+ if not is_dtype_equal(pre_dtype, curr_dtype):
+ raise TypeError(
+ "Cannot compare type '{}' with type '{}'".format(
+ pre_dtype, curr_dtype
+ )
+ )
+ else:
+ # Normally pandas returns this near the end of the quantile, but we
+ # can't afford the overhead of running the entire operation before
+ # we error.
+ if not any(is_numeric_dtype(t) for t in self._get_dtypes()):
+ raise ValueError("need at least one array to concatenate")
+
+ # check that all qs are between 0 and 1
+ validate_percentile(q)
+ axis = self._get_axis_number(axis)
+ query_compiler = self._query_compiler.quantiles_along_axis0(
+ q=q if is_list_like(q) else [q],
+ numeric_only=numeric_only,
+ interpolation=interpolation,
+ method=method,
+ )
+ if is_list_like(q):
+ return self.__constructor__(query_compiler=query_compiler)
+ else:
+ # result is either a scalar or Series
+ result = self._reduce_dimension(query_compiler.transpose_single_row())
+ if isinstance(result, BasePandasDataset):
+ result.name = q
+ return result
+
+
+# Current Modin does not define this method. Snowpark pandas currently only uses it in
+# `DataFrame.set_index`. Modin does not support MultiIndex, or have its own lazy index class,
+# so we may need to keep this method for the foreseeable future.
+@register_base_override("_to_series_list")
+def _to_series_list(self, index: pd.Index) -> list[pd.Series]:
+ """
+ Convert index to a list of series
+ Args:
+ index: can be single or multi index
+
+ Returns:
+ the list of series
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ if isinstance(index, pd.MultiIndex):
+ return [
+ pd.Series(index.get_level_values(level)) for level in range(index.nlevels)
+ ]
+ elif isinstance(index, pd.Index):
+ return [pd.Series(index)]
+ else:
+ raise Exception("invalid index: " + str(index))
+
+
+# Upstream modin defaults to pandas when `suffix` is provided.
+@register_base_override("shift")
+def shift(
+ self,
+ periods: int | Sequence[int] = 1,
+ freq=None,
+ axis: Axis = 0,
+ fill_value: Hashable = no_default,
+ suffix: str | None = None,
+) -> BasePandasDataset:
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ if periods == 0 and freq is None:
+ # Check obvious case first, freq manipulates the index even for periods == 0 so check for it in addition.
+ return self.copy()
+
+ # pandas compatible ValueError for freq='infer'
+ # TODO: Test as part of SNOW-1023324.
+ if freq == "infer": # pragma: no cover
+ if not hasattr(self, "freq") and not hasattr( # pragma: no cover
+ self, "inferred_freq" # pragma: no cover
+ ): # pragma: no cover
+ raise ValueError() # pragma: no cover
+
+ axis = self._get_axis_number(axis)
+
+ if fill_value == no_default:
+ fill_value = None
+
+ new_query_compiler = self._query_compiler.shift(
+ periods, freq, axis, fill_value, suffix
+ )
+ return self._create_or_update_from_compiler(new_query_compiler, False)
+
+
+# Snowpark pandas supports only `numeric_only=True`, which is not the default value of the argument,
+# so we have this overridden. We should revisit this behavior.
+@register_base_override("skew")
+def skew(
+ self,
+ axis: Axis | None | NoDefault = no_default,
+ skipna: bool = True,
+ numeric_only=True,
+ **kwargs,
+): # noqa: PR01, RT01, D200
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ """
+ Return unbiased skew over requested axis.
+ """
+ return self._stat_operation("skew", axis, skipna, numeric_only, **kwargs)
+
+
+@register_base_override("resample")
+def resample(
+ self,
+ rule,
+ axis: Axis = lib.no_default,
+ closed: str | None = None,
+ label: str | None = None,
+ convention: str = "start",
+ kind: str | None = None,
+ on: Level = None,
+ level: Level = None,
+ origin: str | TimestampConvertibleTypes = "start_day",
+ offset: TimedeltaConvertibleTypes | None = None,
+ group_keys=no_default,
+): # noqa: PR01, RT01, D200
+ """
+ Resample time-series data.
+ """
+ from snowflake.snowpark.modin.pandas.resample import Resampler
+
+ if axis is not lib.no_default: # pragma: no cover
+ axis = self._get_axis_number(axis)
+ if axis == 1:
+ warnings.warn(
+ "DataFrame.resample with axis=1 is deprecated. Do "
+ + "`frame.T.resample(...)` without axis instead.",
+ FutureWarning,
+ stacklevel=1,
+ )
+ else:
+ warnings.warn(
+ f"The 'axis' keyword in {type(self).__name__}.resample is "
+ + "deprecated and will be removed in a future version.",
+ FutureWarning,
+ stacklevel=1,
+ )
+ else:
+ axis = 0
+
+ return Resampler(
+ dataframe=self,
+ rule=rule,
+ axis=axis,
+ closed=closed,
+ label=label,
+ convention=convention,
+ kind=kind,
+ on=on,
+ level=level,
+ origin=origin,
+ offset=offset,
+ group_keys=group_keys,
+ )
+
+
+# Snowpark pandas needs to return a custom Expanding window object. We cannot use the
+# extensions module for this at the moment because modin performs a relative import of
+# `from .window import Expanding`.
+@register_base_override("expanding")
+def expanding(self, min_periods=1, axis=0, method="single"): # noqa: PR01, RT01, D200
+ """
+ Provide expanding window calculations.
+ """
+ from snowflake.snowpark.modin.pandas.window import Expanding
+
+ if axis is not lib.no_default:
+ axis = self._get_axis_number(axis)
+ name = "expanding"
+ if axis == 1:
+ warnings.warn(
+ f"Support for axis=1 in {type(self).__name__}.{name} is "
+ + "deprecated and will be removed in a future version. "
+ + f"Use obj.T.{name}(...) instead",
+ FutureWarning,
+ stacklevel=1,
+ )
+ else:
+ warnings.warn(
+ f"The 'axis' keyword in {type(self).__name__}.{name} is "
+ + "deprecated and will be removed in a future version. "
+ + "Call the method without the axis keyword instead.",
+ FutureWarning,
+ stacklevel=1,
+ )
+ else:
+ axis = 0
+
+ return Expanding(
+ self,
+ min_periods=min_periods,
+ axis=axis,
+ method=method,
+ )
+
+
+# Same as Expanding: Snowpark pandas needs to return a custmo Window object.
+@register_base_override("rolling")
+def rolling(
+ self,
+ window,
+ min_periods: int | None = None,
+ center: bool = False,
+ win_type: str | None = None,
+ on: str | None = None,
+ axis: Axis = lib.no_default,
+ closed: str | None = None,
+ step: int | None = None,
+ method: str = "single",
+): # noqa: PR01, RT01, D200
+ """
+ Provide rolling window calculations.
+ """
+ if axis is not lib.no_default:
+ axis = self._get_axis_number(axis)
+ name = "rolling"
+ if axis == 1:
+ warnings.warn(
+ f"Support for axis=1 in {type(self).__name__}.{name} is "
+ + "deprecated and will be removed in a future version. "
+ + f"Use obj.T.{name}(...) instead",
+ FutureWarning,
+ stacklevel=1,
+ )
+ else: # pragma: no cover
+ warnings.warn(
+ f"The 'axis' keyword in {type(self).__name__}.{name} is "
+ + "deprecated and will be removed in a future version. "
+ + "Call the method without the axis keyword instead.",
+ FutureWarning,
+ stacklevel=1,
+ )
+ else:
+ axis = 0
+
+ if win_type is not None:
+ from snowflake.snowpark.modin.pandas.window import Window
+
+ return Window(
+ self,
+ window=window,
+ min_periods=min_periods,
+ center=center,
+ win_type=win_type,
+ on=on,
+ axis=axis,
+ closed=closed,
+ step=step,
+ method=method,
+ )
+ from snowflake.snowpark.modin.pandas.window import Rolling
+
+ return Rolling(
+ self,
+ window=window,
+ min_periods=min_periods,
+ center=center,
+ win_type=win_type,
+ on=on,
+ axis=axis,
+ closed=closed,
+ step=step,
+ method=method,
+ )
+
+
+# Snowpark pandas uses a custom indexer object for all indexing methods.
+@register_base_override("iloc")
+@property
+def iloc(self):
+ """
+ Purely integer-location based indexing for selection by position.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ # TODO: SNOW-930028 enable all skipped doctests
+ from snowflake.snowpark.modin.pandas.indexing import _iLocIndexer
+
+ return _iLocIndexer(self)
+
+
+# Snowpark pandas uses a custom indexer object for all indexing methods.
+@register_base_override("loc")
+@property
+def loc(self):
+ """
+ Get a group of rows and columns by label(s) or a boolean array.
+ """
+ # TODO: SNOW-935444 fix doctest where index key has name
+ # TODO: SNOW-933782 fix multiindex transpose bug, e.g., Name: (cobra, mark ii) => Name: ('cobra', 'mark ii')
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ from snowflake.snowpark.modin.pandas.indexing import _LocIndexer
+
+ return _LocIndexer(self)
+
+
+# Snowpark pandas uses a custom indexer object for all indexing methods.
+@register_base_override("iat")
+@property
+def iat(self, axis=None): # noqa: PR01, RT01, D200
+ """
+ Get a single value for a row/column pair by integer position.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ from snowflake.snowpark.modin.pandas.indexing import _iAtIndexer
+
+ return _iAtIndexer(self)
+
+
+# Snowpark pandas uses a custom indexer object for all indexing methods.
+@register_base_override("at")
+@property
+def at(self, axis=None): # noqa: PR01, RT01, D200
+ """
+ Get a single value for a row/column label pair.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ from snowflake.snowpark.modin.pandas.indexing import _AtIndexer
+
+ return _AtIndexer(self)
+
+
+# Snowpark pandas performs different dispatch logic; some changes may need to be upstreamed
+# to fix edge case indexing behaviors.
+@register_base_override("__getitem__")
+def __getitem__(self, key):
+ """
+ Retrieve dataset according to `key`.
+
+ Parameters
+ ----------
+ key : callable, scalar, slice, str or tuple
+ The global row index to retrieve data from.
+
+ Returns
+ -------
+ BasePandasDataset
+ Located dataset.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ key = apply_if_callable(key, self)
+ # If a slice is passed in, use .iloc[key].
+ if isinstance(key, slice):
+ if (is_integer(key.start) or key.start is None) and (
+ is_integer(key.stop) or key.stop is None
+ ):
+ return self.iloc[key]
+ else:
+ return self.loc[key]
+
+ # If the object calling getitem is a Series, only use .loc[key] to filter index.
+ if isinstance(self, pd.Series):
+ return self.loc[key]
+
+ # Sometimes the result of a callable is a DataFrame (e.g. df[df > 0]) - use where.
+ elif isinstance(key, pd.DataFrame):
+ return self.where(cond=key)
+
+ # If the object is a boolean list-like object, use .loc[key] to filter index.
+ # The if statement is structured this way to avoid calling dtype and reduce query count.
+ if isinstance(key, pd.Series):
+ if pandas.api.types.is_bool_dtype(key.dtype):
+ return self.loc[key]
+ elif is_list_like(key):
+ if hasattr(key, "dtype"):
+ if pandas.api.types.is_bool_dtype(key.dtype):
+ return self.loc[key]
+ if (all(is_bool(k) for k in key)) and len(key) > 0:
+ return self.loc[key]
+
+ # In all other cases, use .loc[:, key] to filter columns.
+ return self.loc[:, key]
+
+
+# Snowpark pandas does extra argument validation, which may need to be upstreamed.
+@register_base_override("sort_values")
+def sort_values(
+ self,
+ by,
+ axis=0,
+ ascending=True,
+ inplace: bool = False,
+ kind="quicksort",
+ na_position="last",
+ ignore_index: bool = False,
+ key: IndexKeyFunc | None = None,
+): # noqa: PR01, RT01, D200
+ """
+ Sort by the values along either axis.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ axis = self._get_axis_number(axis)
+ inplace = validate_bool_kwarg(inplace, "inplace")
+ ascending = validate_ascending(ascending)
+ if axis == 0:
+ # If any column is None raise KeyError (same a native pandas).
+ if by is None or (isinstance(by, list) and None in by):
+ # Same error message as native pandas.
+ raise KeyError(None)
+ if not isinstance(by, list):
+ by = [by]
+
+ # Convert 'ascending' to sequence if needed.
+ if not isinstance(ascending, Sequence):
+ ascending = [ascending] * len(by)
+ if len(by) != len(ascending):
+ # Same error message as native pandas.
+ raise ValueError(
+ f"Length of ascending ({len(ascending)})"
+ f" != length of by ({len(by)})"
+ )
+
+ columns = self._query_compiler.columns.values.tolist()
+ index_names = self._query_compiler.get_index_names()
+ for by_col in by:
+ col_count = columns.count(by_col)
+ index_count = index_names.count(by_col)
+ if col_count == 0 and index_count == 0:
+ # Same error message as native pandas.
+ raise KeyError(by_col)
+ if col_count and index_count:
+ # Same error message as native pandas.
+ raise ValueError(
+ f"'{by_col}' is both an index level and a column label, which is ambiguous."
+ )
+ if col_count > 1:
+ # Same error message as native pandas.
+ raise ValueError(f"The column label '{by_col}' is not unique.")
+
+ if na_position not in get_args(NaPosition):
+ # Same error message as native pandas for invalid 'na_position' value.
+ raise ValueError(f"invalid na_position: {na_position}")
+ result = self._query_compiler.sort_rows_by_column_values(
+ by,
+ ascending=ascending,
+ kind=kind,
+ na_position=na_position,
+ ignore_index=ignore_index,
+ key=key,
+ )
+ else:
+ result = self._query_compiler.sort_columns_by_row_values(
+ by,
+ ascending=ascending,
+ kind=kind,
+ na_position=na_position,
+ ignore_index=ignore_index,
+ key=key,
+ )
+ return self._create_or_update_from_compiler(result, inplace)
+
+
+# Modin does not define `where` on BasePandasDataset, and defaults to pandas at the frontend
+# layer for Series.
+@register_base_override("where")
+def where(
+ self,
+ cond: BasePandasDataset | Callable | AnyArrayLike,
+ other: BasePandasDataset | Callable | Scalar | None = np.nan,
+ inplace: bool = False,
+ axis: Axis | None = None,
+ level: Level | None = None,
+):
+ """
+ Replace values where the condition is False.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ # TODO: SNOW-985670: Refactor `where` and `mask`
+ # will move pre-processing to QC layer.
+ inplace = validate_bool_kwarg(inplace, "inplace")
+ if cond is None:
+ raise ValueError("Array conditional must be same shape as self")
+
+ cond = apply_if_callable(cond, self)
+
+ if isinstance(cond, Callable):
+ raise NotImplementedError("Do not support callable for 'cond' parameter.")
+
+ from snowflake.snowpark.modin.pandas import Series
+
+ if isinstance(cond, Series):
+ cond._query_compiler._shape_hint = "column"
+ if isinstance(self, Series):
+ self._query_compiler._shape_hint = "column"
+ if isinstance(other, Series):
+ other._query_compiler._shape_hint = "column"
+
+ if not isinstance(cond, BasePandasDataset):
+ cond = get_as_shape_compatible_dataframe_or_series(cond, self)
+ cond._query_compiler._shape_hint = "array"
+
+ if other is not None:
+ other = apply_if_callable(other, self)
+
+ if isinstance(other, np.ndarray):
+ other = get_as_shape_compatible_dataframe_or_series(
+ other,
+ self,
+ shape_mismatch_message="other must be the same shape as self when an ndarray",
+ )
+ other._query_compiler._shape_hint = "array"
+
+ if isinstance(other, BasePandasDataset):
+ other = other._query_compiler
+
+ query_compiler = self._query_compiler.where(
+ cond._query_compiler,
+ other,
+ axis,
+ level,
+ )
+
+ return self._create_or_update_from_compiler(query_compiler, inplace)
+
+
+# Snowpark pandas performs extra argument validation, some of which should be pushed down
+# to the QC layer.
+@register_base_override("mask")
+def mask(
+ self,
+ cond: BasePandasDataset | Callable | AnyArrayLike,
+ other: BasePandasDataset | Callable | Scalar | None = np.nan,
+ inplace: bool = False,
+ axis: Axis | None = None,
+ level: Level | None = None,
+):
+ """
+ Replace values where the condition is True.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ # TODO: https://snowflakecomputing.atlassian.net/browse/SNOW-985670
+ # will move pre-processing to QC layer.
+ inplace = validate_bool_kwarg(inplace, "inplace")
+ if cond is None:
+ raise ValueError("Array conditional must be same shape as self")
+
+ cond = apply_if_callable(cond, self)
+
+ if isinstance(cond, Callable):
+ raise NotImplementedError("Do not support callable for 'cond' parameter.")
+
+ from snowflake.snowpark.modin.pandas import Series
+
+ if isinstance(cond, Series):
+ cond._query_compiler._shape_hint = "column"
+ if isinstance(self, Series):
+ self._query_compiler._shape_hint = "column"
+ if isinstance(other, Series):
+ other._query_compiler._shape_hint = "column"
+
+ if not isinstance(cond, BasePandasDataset):
+ cond = get_as_shape_compatible_dataframe_or_series(cond, self)
+ cond._query_compiler._shape_hint = "array"
+
+ if other is not None:
+ other = apply_if_callable(other, self)
+
+ if isinstance(other, np.ndarray):
+ other = get_as_shape_compatible_dataframe_or_series(
+ other,
+ self,
+ shape_mismatch_message="other must be the same shape as self when an ndarray",
+ )
+ other._query_compiler._shape_hint = "array"
+
+ if isinstance(other, BasePandasDataset):
+ other = other._query_compiler
+
+ query_compiler = self._query_compiler.mask(
+ cond._query_compiler,
+ other,
+ axis,
+ level,
+ )
+
+ return self._create_or_update_from_compiler(query_compiler, inplace)
+
+
+# Snowpark pandas uses a custom I/O dispatcher class.
+@register_base_override("to_csv")
+def to_csv(
+ self,
+ path_or_buf=None,
+ sep=",",
+ na_rep="",
+ float_format=None,
+ columns=None,
+ header=True,
+ index=True,
+ index_label=None,
+ mode="w",
+ encoding=None,
+ compression="infer",
+ quoting=None,
+ quotechar='"',
+ lineterminator=None,
+ chunksize=None,
+ date_format=None,
+ doublequote=True,
+ escapechar=None,
+ decimal=".",
+ errors: str = "strict",
+ storage_options: StorageOptions = None,
+): # pragma: no cover
+ from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import (
+ FactoryDispatcher,
+ )
+
+ return FactoryDispatcher.to_csv(
+ self._query_compiler,
+ path_or_buf=path_or_buf,
+ sep=sep,
+ na_rep=na_rep,
+ float_format=float_format,
+ columns=columns,
+ header=header,
+ index=index,
+ index_label=index_label,
+ mode=mode,
+ encoding=encoding,
+ compression=compression,
+ quoting=quoting,
+ quotechar=quotechar,
+ lineterminator=lineterminator,
+ chunksize=chunksize,
+ date_format=date_format,
+ doublequote=doublequote,
+ escapechar=escapechar,
+ decimal=decimal,
+ errors=errors,
+ storage_options=storage_options,
+ )
+
+
+# Modin performs extra argument validation and defaults to pandas for some edge cases.
+@register_base_override("sample")
+def sample(
+ self,
+ n: int | None = None,
+ frac: float | None = None,
+ replace: bool = False,
+ weights: str | np.ndarray | None = None,
+ random_state: RandomState | None = None,
+ axis: Axis | None = None,
+ ignore_index: bool = False,
+):
+ """
+ Return a random sample of items from an axis of object.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ if self._get_axis_number(axis):
+ if weights is not None and isinstance(weights, str):
+ raise ValueError(
+ "Strings can only be passed to weights when sampling from rows on a DataFrame"
+ )
+ else:
+ if n is None and frac is None:
+ n = 1
+ elif n is not None and frac is not None:
+ raise ValueError("Please enter a value for `frac` OR `n`, not both")
+ else:
+ if n is not None:
+ if n < 0:
+ raise ValueError(
+ "A negative number of rows requested. Please provide `n` >= 0."
+ )
+ if n % 1 != 0:
+ raise ValueError("Only integers accepted as `n` values")
+ else:
+ if frac < 0:
+ raise ValueError(
+ "A negative number of rows requested. Please provide `frac` >= 0."
+ )
+
+ query_compiler = self._query_compiler.sample(
+ n, frac, replace, weights, random_state, axis, ignore_index
+ )
+ return self.__constructor__(query_compiler=query_compiler)
+
+
+# Modin performs an extra query calling self.isna() to raise a warning when fill_method is unspecified.
+@register_base_override("pct_change")
+def pct_change(
+ self, periods=1, fill_method=no_default, limit=no_default, freq=None, **kwargs
+): # noqa: PR01, RT01, D200
+ """
+ Percentage change between the current and a prior element.
+ """
+ if fill_method not in (lib.no_default, None) or limit is not lib.no_default:
+ warnings.warn(
+ "The 'fill_method' keyword being not None and the 'limit' keyword in "
+ + f"{type(self).__name__}.pct_change are deprecated and will be removed "
+ + "in a future version. Either fill in any non-leading NA values prior "
+ + "to calling pct_change or specify 'fill_method=None' to not fill NA "
+ + "values.",
+ FutureWarning,
+ stacklevel=1,
+ )
+ if fill_method is lib.no_default:
+ warnings.warn(
+ f"The default fill_method='pad' in {type(self).__name__}.pct_change is "
+ + "deprecated and will be removed in a future version. Either fill in any "
+ + "non-leading NA values prior to calling pct_change or specify 'fill_method=None' "
+ + "to not fill NA values.",
+ FutureWarning,
+ stacklevel=1,
+ )
+ fill_method = "pad"
+
+ if limit is lib.no_default:
+ limit = None
+
+ if "axis" in kwargs:
+ kwargs["axis"] = self._get_axis_number(kwargs["axis"])
+
+ # Attempting to match pandas error behavior here
+ if not isinstance(periods, int):
+ raise TypeError(f"periods must be an int. got {type(periods)} instead")
+
+ # Attempting to match pandas error behavior here
+ for dtype in self._get_dtypes():
+ if not is_numeric_dtype(dtype):
+ raise TypeError(
+ f"cannot perform pct_change on non-numeric column with dtype {dtype}"
+ )
+
+ return self.__constructor__(
+ query_compiler=self._query_compiler.pct_change(
+ periods=periods,
+ fill_method=fill_method,
+ limit=limit,
+ freq=freq,
+ **kwargs,
+ )
+ )
+
+
+# Snowpark pandas has different `copy` behavior, and some different behavior with native series arguments.
+@register_base_override("astype")
+def astype(
+ self,
+ dtype: str | type | pd.Series | dict[str, type],
+ copy: bool = True,
+ errors: Literal["raise", "ignore"] = "raise",
+) -> pd.DataFrame | pd.Series:
+ """
+ Cast a Modin object to a specified dtype `dtype`.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ # dtype can be a series, a dict, or a scalar. If it's series or scalar,
+ # convert it to a dict before passing it to the query compiler.
+ raise_if_native_pandas_objects(dtype)
+ from snowflake.snowpark.modin.pandas import Series
+
+ if isinstance(dtype, Series):
+ dtype = dtype.to_pandas()
+ if not dtype.index.is_unique:
+ raise ValueError(
+ "The new Series of types must have a unique index, i.e. "
+ + "it must be one-to-one mapping from column names to "
+ + " their new dtypes."
+ )
+ dtype = dtype.to_dict()
+ # If we got a series or dict originally, dtype is a dict now. Its keys
+ # must be column names.
+ if isinstance(dtype, dict):
+ # Avoid materializing columns. The query compiler will handle errors where
+ # dtype dict includes keys that are not in columns.
+ col_dtypes = dtype
+ for col_name in col_dtypes:
+ if col_name not in self._query_compiler.columns:
+ raise KeyError(
+ "Only a column name can be used for the key in a dtype mappings argument. "
+ f"'{col_name}' not found in columns."
+ )
+ else:
+ # Assume that the dtype is a scalar.
+ col_dtypes = {column: dtype for column in self._query_compiler.columns}
+
+ # ensure values are pandas dtypes
+ col_dtypes = {k: pandas_dtype(v) for k, v in col_dtypes.items()}
+ new_query_compiler = self._query_compiler.astype(col_dtypes, errors=errors)
+ return self._create_or_update_from_compiler(new_query_compiler, not copy)
+
+
+# Modin defaults to pandsa when `level` is specified, and has some extra axis validation that
+# is guarded in newer versions.
+@register_base_override("drop")
+def drop(
+ self,
+ labels: IndexLabel = None,
+ axis: Axis = 0,
+ index: IndexLabel = None,
+ columns: IndexLabel = None,
+ level: Level = None,
+ inplace: bool = False,
+ errors: IgnoreRaise = "raise",
+) -> BasePandasDataset | None:
+ """
+ Drop specified labels from `BasePandasDataset`.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ inplace = validate_bool_kwarg(inplace, "inplace")
+ if labels is not None:
+ if index is not None or columns is not None:
+ raise ValueError("Cannot specify both 'labels' and 'index'/'columns'")
+ axes = {self._get_axis_number(axis): labels}
+ elif index is not None or columns is not None:
+ axes = {0: index, 1: columns}
+ else:
+ raise ValueError(
+ "Need to specify at least one of 'labels', 'index' or 'columns'"
+ )
+
+ for axis, labels in axes.items():
+ if labels is not None:
+ if level is not None and not self._query_compiler.has_multiindex(axis=axis):
+ # Same error as native pandas.
+ raise AssertionError("axis must be a MultiIndex")
+ # According to pandas documentation, a tuple will be used as a single
+ # label and not treated as a list-like.
+ if not is_list_like(labels) or isinstance(labels, tuple):
+ axes[axis] = [labels]
+
+ new_query_compiler = self._query_compiler.drop(
+ index=axes.get(0), columns=axes.get(1), level=level, errors=errors
+ )
+ return self._create_or_update_from_compiler(new_query_compiler, inplace)
+
+
+# Modin calls len(self.index) instead of a direct query compiler method.
+@register_base_override("__len__")
+def __len__(self) -> int:
+ """
+ Return length of info axis.
+
+ Returns
+ -------
+ int
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ return self._query_compiler.get_axis_len(axis=0)
+
+
+# Snowpark pandas ignores `copy`.
+@register_base_override("set_axis")
+def set_axis(
+ self,
+ labels: IndexLabel,
+ *,
+ axis: Axis = 0,
+ copy: bool | NoDefault = no_default,
+):
+ """
+ Assign desired index to given axis.
+ """
+ # Behavior based on copy:
+ # -----------------------------------
+ # - In native pandas, copy determines whether to create a copy of the data (not DataFrame).
+ # - We cannot emulate the native pandas' copy behavior in Snowpark since a copy of only data
+ # cannot be created -- you can only copy the whole object (DataFrame/Series).
+ #
+ # Snowpark behavior:
+ # ------------------
+ # - copy is kept for compatibility with native pandas but is ignored. The user is warned that copy is unused.
+ # Warn user that copy does not do anything.
+ if copy is not no_default:
+ WarningMessage.single_warning(
+ message=f"{type(self).__name__}.set_axis 'copy' keyword is unused and is ignored."
+ )
+ if labels is None:
+ raise TypeError("None is not a valid value for the parameter 'labels'.")
+
+ # Determine whether to update self or a copy and perform update.
+ obj = self.copy()
+ setattr(obj, axis, labels)
+ return obj
+
+
+# Modin has different behavior for empty dataframes and some slightly different length validation.
+@register_base_override("describe")
+def describe(
+ self,
+ percentiles: ListLike | None = None,
+ include: ListLike | Literal["all"] | None = None,
+ exclude: ListLike | None = None,
+) -> BasePandasDataset:
+ """
+ Generate descriptive statistics.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ percentiles = _refine_percentiles(percentiles)
+ data = self
+ if self._is_dataframe:
+ # Upstream modin lacks this check because it defaults to pandas for describing empty dataframes
+ if len(self.columns) == 0:
+ raise ValueError("Cannot describe a DataFrame without columns")
+
+ # include/exclude are ignored for Series
+ if (include is None) and (exclude is None):
+ # when some numerics are found, keep only numerics
+ default_include: list[npt.DTypeLike] = [np.number]
+ default_include.append("datetime")
+ data = self.select_dtypes(include=default_include)
+ if len(data.columns) == 0:
+ data = self
+ elif include == "all":
+ if exclude is not None:
+ raise ValueError("exclude must be None when include is 'all'")
+ data = self
+ else:
+ data = self.select_dtypes(
+ include=include,
+ exclude=exclude,
+ )
+ # Upstream modin uses data.empty, but that incurs an extra row count query
+ if self._is_dataframe and len(data.columns) == 0:
+ # Match pandas error from concatenating empty list of series descriptions.
+ raise ValueError("No objects to concatenate")
+
+ return self.__constructor__(
+ query_compiler=data._query_compiler.describe(percentiles=percentiles)
+ )
+
+
+# Modin does type validation on self that Snowpark pandas defers to SQL.
+@register_base_override("diff")
+def diff(self, periods: int = 1, axis: Axis = 0):
+ """
+ First discrete difference of element.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ # We must only accept integer (or float values that are whole numbers)
+ # for periods.
+ int_periods = validate_int_kwarg(periods, "periods", float_allowed=True)
+ axis = self._get_axis_number(axis)
+ return self.__constructor__(
+ query_compiler=self._query_compiler.diff(axis=axis, periods=int_periods)
+ )
+
+
+# Modin does an unnecessary len call when n == 0.
+@register_base_override("tail")
+def tail(self, n: int = 5):
+ if n == 0:
+ return self.iloc[0:0]
+ return self.iloc[-n:]
+
+
+# Snowpark pandas does extra argument validation (which should probably be deferred to SQL instead).
+@register_base_override("idxmax")
+def idxmax(self, axis=0, skipna=True, numeric_only=False): # noqa: PR01, RT01, D200
+ """
+ Return index of first occurrence of maximum over requested axis.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ dtypes = self._get_dtypes()
+ if (
+ axis == 1
+ and not numeric_only
+ and any(not is_numeric_dtype(d) for d in dtypes)
+ and len(set(dtypes)) > 1
+ ):
+ # For numeric_only=False, if we have any non-numeric dtype, e.g.
+ # a string type, we need every other column to be of the same type.
+ # We can't compare two objects of different non-numeric types, e.g.
+ # a string and a timestamp.
+ # If we have only numeric data, we can compare columns even if they
+ # different types, e.g. we can compare an int column to a float
+ # column.
+ raise TypeError("'>' not supported for these dtypes")
+ axis = self._get_axis_number(axis)
+ return self._reduce_dimension(
+ self._query_compiler.idxmax(axis=axis, skipna=skipna, numeric_only=numeric_only)
+ )
+
+
+# Snowpark pandas does extra argument validation (which should probably be deferred to SQL instead).
+@register_base_override("idxmin")
+def idxmin(self, axis=0, skipna=True, numeric_only=False): # noqa: PR01, RT01, D200
+ """
+ Return index of first occurrence of minimum over requested axis.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ dtypes = self._get_dtypes()
+ if (
+ axis == 1
+ and not numeric_only
+ and any(not is_numeric_dtype(d) for d in dtypes)
+ and len(set(dtypes)) > 1
+ ):
+ # For numeric_only=False, if we have any non-numeric dtype, e.g.
+ # a string type, we need every other column to be of the same type.
+ # We can't compare two objects of different non-numeric types, e.g.
+ # a string and a timestamp.
+ # If we have only numeric data, we can compare columns even if they
+ # different types, e.g. we can compare an int column to a float
+ # column.
+ raise TypeError("'<' not supported for these dtypes")
+ axis = self._get_axis_number(axis)
+ return self._reduce_dimension(
+ self._query_compiler.idxmin(axis=axis, skipna=skipna, numeric_only=numeric_only)
+ )
+
+
+# Modin does dtype validation on unary ops that Snowpark pandas does not.
+@register_base_override("__abs__")
+def abs(self): # noqa: RT01, D200
+ """
+ Return a `BasePandasDataset` with absolute numeric value of each element.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ return self.__constructor__(query_compiler=self._query_compiler.abs())
+
+
+# Modin does dtype validation on unary ops that Snowpark pandas does not.
+@register_base_override("__invert__")
+def __invert__(self):
+ """
+ Apply bitwise inverse to each element of the `BasePandasDataset`.
+
+ Returns
+ -------
+ BasePandasDataset
+ New BasePandasDataset containing bitwise inverse to each value.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ return self.__constructor__(query_compiler=self._query_compiler.invert())
+
+
+# Modin does dtype validation on unary ops that Snowpark pandas does not.
+@register_base_override("__neg__")
+def __neg__(self):
+ """
+ Change the sign for every value of self.
+
+ Returns
+ -------
+ BasePandasDataset
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ return self.__constructor__(query_compiler=self._query_compiler.negative())
+
+
+# Modin needs to add a check for mapper is not None, which changes query counts in test_concat.py
+# if not present.
+@register_base_override("rename_axis")
+def rename_axis(
+ self,
+ mapper=lib.no_default,
+ *,
+ index=lib.no_default,
+ columns=lib.no_default,
+ axis=0,
+ copy=None,
+ inplace=False,
+): # noqa: PR01, RT01, D200
+ """
+ Set the name of the axis for the index or columns.
+ """
+ axes = {"index": index, "columns": columns}
+
+ if copy is None:
+ copy = True
+
+ if axis is not None:
+ axis = self._get_axis_number(axis)
+
+ inplace = validate_bool_kwarg(inplace, "inplace")
+
+ if mapper is not lib.no_default and mapper is not None:
+ # Use v0.23 behavior if a scalar or list
+ non_mapper = is_scalar(mapper) or (
+ is_list_like(mapper) and not is_dict_like(mapper)
+ )
+ if non_mapper:
+ return self._set_axis_name(mapper, axis=axis, inplace=inplace)
+ else:
+ raise ValueError("Use `.rename` to alter labels with a mapper.")
+ else:
+ # Use new behavior. Means that index and/or columns is specified
+ result = self if inplace else self.copy(deep=copy)
+
+ for axis in range(self.ndim):
+ v = axes.get(pandas.DataFrame._get_axis_name(axis))
+ if v is lib.no_default:
+ continue
+ non_mapper = is_scalar(v) or (is_list_like(v) and not is_dict_like(v))
+ if non_mapper:
+ newnames = v
+ else:
+
+ def _get_rename_function(mapper):
+ if isinstance(mapper, (dict, BasePandasDataset)):
+
+ def f(x):
+ if x in mapper:
+ return mapper[x]
+ else:
+ return x
+
+ else:
+ f = mapper
+
+ return f
+
+ f = _get_rename_function(v)
+ curnames = self.index.names if axis == 0 else self.columns.names
+ newnames = [f(name) for name in curnames]
+ result._set_axis_name(newnames, axis=axis, inplace=True)
+ if not inplace:
+ return result
+
+
+# Snowpark pandas has custom dispatch logic for ufuncs, while modin defaults to pandas.
+@register_base_override("__array_ufunc__")
+def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
+ """
+ Apply the `ufunc` to the `BasePandasDataset`.
+
+ Parameters
+ ----------
+ ufunc : np.ufunc
+ The NumPy ufunc to apply.
+ method : str
+ The method to apply.
+ *inputs : tuple
+ The inputs to the ufunc.
+ **kwargs : dict
+ Additional keyword arguments.
+
+ Returns
+ -------
+ BasePandasDataset
+ The result of the ufunc applied to the `BasePandasDataset`.
+ """
+ # Use pandas version of ufunc if it exists
+ if method != "__call__":
+ # Return sentinel value NotImplemented
+ return NotImplemented # pragma: no cover
+ from snowflake.snowpark.modin.plugin.utils.numpy_to_pandas import (
+ numpy_to_pandas_universal_func_map,
+ )
+
+ if ufunc.__name__ in numpy_to_pandas_universal_func_map:
+ ufunc = numpy_to_pandas_universal_func_map[ufunc.__name__]
+ return ufunc(self, inputs[1:], kwargs)
+ # return the sentinel NotImplemented if we do not support this function
+ return NotImplemented # pragma: no cover
+
+
+# Snowpark pandas does extra argument validation.
+@register_base_override("reindex")
+def reindex(
+ self,
+ index=None,
+ columns=None,
+ copy=True,
+ **kwargs,
+): # noqa: PR01, RT01, D200
+ """
+ Conform `BasePandasDataset` to new index with optional filling logic.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ if kwargs.get("limit", None) is not None and kwargs.get("method", None) is None:
+ raise ValueError(
+ "limit argument only valid if doing pad, backfill or nearest reindexing"
+ )
+ new_query_compiler = None
+ if index is not None:
+ if not isinstance(index, pandas.Index) or not index.equals(self.index):
+ new_query_compiler = self._query_compiler.reindex(
+ axis=0, labels=index, **kwargs
+ )
+ if new_query_compiler is None:
+ new_query_compiler = self._query_compiler
+ final_query_compiler = None
+ if columns is not None:
+ if not isinstance(index, pandas.Index) or not columns.equals(self.columns):
+ final_query_compiler = new_query_compiler.reindex(
+ axis=1, labels=columns, **kwargs
+ )
+ if final_query_compiler is None:
+ final_query_compiler = new_query_compiler
+ return self._create_or_update_from_compiler(
+ final_query_compiler, inplace=False if copy is None else not copy
+ )
+
+
+# No direct override annotation; used as part of `property`.
+# Snowpark pandas may return a custom lazy index object.
+def _get_index(self):
+ """
+ Get the index for this DataFrame.
+
+ Returns
+ -------
+ pandas.Index
+ The union of all indexes across the partitions.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ from snowflake.snowpark.modin.plugin.extensions.index import Index
+
+ if self._query_compiler.is_multiindex():
+ # Lazy multiindex is not supported
+ return self._query_compiler.index
+
+ idx = Index(query_compiler=self._query_compiler)
+ idx._set_parent(self)
+ return idx
+
+
+# No direct override annotation; used as part of `property`.
+# Snowpark pandas may return a custom lazy index object.
+def _set_index(self, new_index: Axes) -> None:
+ """
+ Set the index for this DataFrame.
+
+ Parameters
+ ----------
+ new_index : pandas.Index
+ The new index to set this.
+ """
+ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
+ self._update_inplace(
+ new_query_compiler=self._query_compiler.set_index(
+ [s._query_compiler for s in self._to_series_list(ensure_index(new_index))]
+ )
+ )
+
+
+# Snowpark pandas may return a custom lazy index object.
+register_base_override("index")(property(_get_index, _set_index))
diff --git a/src/snowflake/snowpark/modin/plugin/extensions/index.py b/src/snowflake/snowpark/modin/plugin/extensions/index.py
index 95fcf684924..808489b8917 100644
--- a/src/snowflake/snowpark/modin/plugin/extensions/index.py
+++ b/src/snowflake/snowpark/modin/plugin/extensions/index.py
@@ -29,6 +29,7 @@
import modin
import numpy as np
import pandas as native_pd
+from modin.pandas.base import BasePandasDataset
from pandas import get_option
from pandas._libs import lib
from pandas._libs.lib import is_list_like, is_scalar
@@ -48,7 +49,6 @@
from pandas.core.dtypes.inference import is_hashable
from snowflake.snowpark.modin.pandas import DataFrame, Series
-from snowflake.snowpark.modin.pandas.base import BasePandasDataset
from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native
from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta
from snowflake.snowpark.modin.plugin._internal.timestamp_utils import DateTimeOrigin
diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py
index 0afea30e29a..a33d7702203 100644
--- a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py
+++ b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py
@@ -161,6 +161,7 @@ def plot(
@register_series_accessor("transform")
+@snowpark_pandas_telemetry_method_decorator
@series_not_implemented()
def transform(self, func, axis=0, *args, **kwargs): # noqa: PR01, RT01, D200
pass # pragma: no cover
diff --git a/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py b/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py
new file mode 100644
index 00000000000..f2b28e8bfc1
--- /dev/null
+++ b/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py
@@ -0,0 +1,27 @@
+#
+# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
+#
+
+# Do not look up certain attributes in columns or index, as they're used for some
+# special purposes, like serving remote context
+# TODO: SNOW-1643986 examine whether to update upstream modin to follow this
+_ATTRS_NO_LOOKUP = {
+ "____id_pack__",
+ "__name__",
+ "_cache",
+ "_ipython_canary_method_should_not_exist_",
+ "_ipython_display_",
+ "_repr_html_",
+ "_repr_javascript_",
+ "_repr_jpeg_",
+ "_repr_json_",
+ "_repr_latex_",
+ "_repr_markdown_",
+ "_repr_mimebundle_",
+ "_repr_pdf_",
+ "_repr_png_",
+ "_repr_svg_",
+ "__array_struct__",
+ "__array_interface__",
+ "_typ",
+}
diff --git a/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py b/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py
index 3da545c64b6..f673bf157bf 100644
--- a/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py
+++ b/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py
@@ -3,8 +3,9 @@
#
from typing import Any, Optional, Union
+from modin.pandas.base import BasePandasDataset
+
import snowflake.snowpark.modin.pandas as pd
-from snowflake.snowpark.modin.pandas.base import BasePandasDataset
from snowflake.snowpark.modin.pandas.utils import is_scalar
from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage
diff --git a/tests/integ/modin/test_telemetry.py b/tests/integ/modin/test_telemetry.py
index 9c24c6b6853..06fbc71eec7 100644
--- a/tests/integ/modin/test_telemetry.py
+++ b/tests/integ/modin/test_telemetry.py
@@ -398,7 +398,7 @@ def test_telemetry_getitem_setitem():
s = df["a"]
assert len(df._query_compiler.snowpark_pandas_api_calls) == 0
assert s._query_compiler.snowpark_pandas_api_calls == [
- {"name": "DataFrame.BasePandasDataset.__getitem__"}
+ {"name": "DataFrame.__getitem__"}
]
df["a"] = 0
df["b"] = 0
@@ -412,12 +412,12 @@ def test_telemetry_getitem_setitem():
# the telemetry log from the connector to validate
_ = s[0]
data = _extract_snowpark_pandas_telemetry_log_data(
- expected_func_name="Series.BasePandasDataset.__getitem__",
+ expected_func_name="Series.__getitem__",
session=s._query_compiler._modin_frame.ordered_dataframe.session,
)
assert data["api_calls"] == [
- {"name": "DataFrame.BasePandasDataset.__getitem__"},
- {"name": "Series.BasePandasDataset.__getitem__"},
+ {"name": "DataFrame.__getitem__"},
+ {"name": "Series.__getitem__"},
]
@@ -547,3 +547,18 @@ def test_telemetry_repr():
{"name": "Series.property.name_set"},
{"name": "Series.Series.__repr__"},
]
+
+
+@sql_count_checker(query_count=0)
+def test_telemetry_copy():
+ # copy() is defined in upstream modin's BasePandasDataset class, and not overridden by any
+ # child class or the extensions module.
+ s = pd.Series([1, 2, 3, 4])
+ copied = s.copy()
+ assert s._query_compiler.snowpark_pandas_api_calls == [
+ {"name": "Series.property.name_set"}
+ ]
+ assert copied._query_compiler.snowpark_pandas_api_calls == [
+ {"name": "Series.property.name_set"},
+ {"name": "Series.BasePandasDataset.copy"},
+ ]
diff --git a/tests/unit/modin/modin/test_envvars.py b/tests/unit/modin/modin/test_envvars.py
index 4f3540a63bf..7c5e3a40bb0 100644
--- a/tests/unit/modin/modin/test_envvars.py
+++ b/tests/unit/modin/modin/test_envvars.py
@@ -90,6 +90,32 @@ def test_custom_help(make_custom_envvar):
assert "custom var" in make_custom_envvar.get_help()
+def _init_doc_module():
+ # Put the docs_module on the path
+ sys.path.append(f"{os.path.dirname(__file__)}")
+ # We use base.py from upstream modin, so we need to initialize its doc module
+ # However, since using the environment variable causes an importlib.reload call,
+ # we need to manually call _inherit_docstrings (https://github.com/modin-project/modin/issues/7138)
+ from .docs_module import classes
+
+ # As a workaround for upstream modin bugs, we use our own _inherit_docstrings instead of the upstream
+ # function. We accordingly need to clear the docstring dictionary in testing because
+ # we manually called the annotation on initializing snowflake.snowpark.modin.pandas.
+ # snowflake.snowpark.modin.utils._attributes_with_docstrings_replaced.clear()
+ # TODO: once modin 0.31.0 is available, use the actual modin DocModule class
+ snowflake.snowpark.modin.utils._inherit_docstrings(
+ classes.BasePandasDataset,
+ overwrite_existing=True,
+ )(pd.base.BasePandasDataset)
+ DocModule.put("docs_module")
+
+
+DOC_OVERRIDE_XFAIL_REASON = (
+ "test docstring overrides currently cannot override real docstring overrides until "
+ "modin 0.31.0 is available"
+)
+
+
class TestDocModule:
"""
Test using a module to replace default docstrings.
@@ -99,11 +125,9 @@ class TestDocModule:
which we need to fix in upstream modin.
"""
+ @pytest.mark.xfail(strict=True, reason=DOC_OVERRIDE_XFAIL_REASON)
def test_overrides(self):
- # Put the docs_module on the path
- sys.path.append(f"{os.path.dirname(__file__)}")
- DocModule.put("docs_module")
-
+ _init_doc_module()
# Test for override
# TODO(https://github.com/modin-project/modin/issues/7134): Upstream
# the BasePandasDataset tests to modin.
@@ -144,11 +168,7 @@ def test_overrides(self):
def test_not_redefining_classes_modin_issue_7138(self):
original_dataframe_class = pd.DataFrame
-
- # Put the docs_module on the path
- sys.path.append(f"{os.path.dirname(__file__)}")
- DocModule.put("docs_module")
-
+ _init_doc_module()
# Test for override
assert (
pd.DataFrame.apply.__doc__
@@ -157,22 +177,20 @@ def test_not_redefining_classes_modin_issue_7138(self):
assert pd.DataFrame is original_dataframe_class
+ @pytest.mark.xfail(strict=True, reason=DOC_OVERRIDE_XFAIL_REASON)
def test_base_docstring_override_with_no_dataframe_or_series_class_modin_issue_7113(
self,
):
# TODO(https://github.com/modin-project/modin/issues/7113): Upstream
# this test case to Modin. This test case tests scenario 1 from issue 7113.
- sys.path.append(f"{os.path.dirname(__file__)}")
- DocModule.put("docs_module_with_just_base")
+ _init_doc_module()
assert pd.base.BasePandasDataset.astype.__doc__ == (
"This is a test of the documentation module for BasePandasDataSet.astype."
)
+ @pytest.mark.xfail(strict=True, reason=DOC_OVERRIDE_XFAIL_REASON)
def test_base_property_not_overridden_in_either_subclass_modin_issue_7113(self):
- # Put the docs_module on the path
- sys.path.append(f"{os.path.dirname(__file__)}")
- DocModule.put("docs_module")
-
+ _init_doc_module()
assert (
pd.base.BasePandasDataset.loc.__doc__
== "This is a test of the documentation module for BasePandasDataset.loc."