From 9f16cedcd5b7aef0a9f01799d688a41c7c746dcf Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Thu, 27 Jun 2024 13:36:48 -0700 Subject: [PATCH 01/14] remove base.py --- .../snowpark/modin/pandas/__init__.py | 9 +- .../snowpark/modin/pandas/dataframe.py | 7 +- .../snowpark/modin/pandas/general.py | 2 +- .../snowpark/modin/pandas/indexing.py | 2 +- src/snowflake/snowpark/modin/pandas/series.py | 29 +- src/snowflake/snowpark/modin/pandas/utils.py | 3 +- .../snowpark/modin/plugin/__init__.py | 13 +- .../compiler/snowflake_query_compiler.py | 38 +- .../plugin/extensions/base_not_implemented.py | 414 ++++ .../modin/plugin/extensions/base_overrides.py | 1998 ++++++++++++++++- .../plugin/extensions/dataframe_extensions.py | 33 + .../plugin/extensions/series_extensions.py | 33 + .../plugin/extensions/series_overrides.py | 1 + .../modin/plugin/utils/numpy_to_pandas.py | 3 +- tests/integ/modin/test_telemetry.py | 8 +- tests/unit/modin/modin/test_envvars.py | 48 +- 16 files changed, 2590 insertions(+), 51 deletions(-) create mode 100644 src/snowflake/snowpark/modin/plugin/extensions/base_not_implemented.py diff --git a/src/snowflake/snowpark/modin/pandas/__init__.py b/src/snowflake/snowpark/modin/pandas/__init__.py index c4eb07d9589..02d8c950cce 100644 --- a/src/snowflake/snowpark/modin/pandas/__init__.py +++ b/src/snowflake/snowpark/modin/pandas/__init__.py @@ -146,10 +146,9 @@ # The extensions assigned to this module _PD_EXTENSIONS_: dict = {} -# base needs to be re-exported in order to properly override docstrings for BasePandasDataset -# moving this import higher prevents sphinx from building documentation (??) -from snowflake.snowpark.modin.pandas import base # isort: skip # noqa: E402,F401 +import snowflake.snowpark.modin.plugin.extensions.base_overrides # isort: skip # noqa: E402,F401 +import snowflake.snowpark.modin.plugin.extensions.base_not_implemented # isort: skip # noqa: E402,F401 import snowflake.snowpark.modin.plugin.extensions.pd_extensions as pd_extensions # isort: skip # noqa: E402,F401 import snowflake.snowpark.modin.plugin.extensions.pd_overrides # isort: skip # noqa: E402,F401 from snowflake.snowpark.modin.plugin.extensions.pd_overrides import ( # isort: skip # noqa: E402,F401 @@ -220,7 +219,6 @@ def __getattr__(name: str) -> Any: "date_range", "Index", "MultiIndex", - "Series", "bdate_range", "period_range", "DatetimeIndex", @@ -318,8 +316,7 @@ def __getattr__(name: str) -> Any: # Manually re-export the members of the pd_extensions namespace, which are not declared in __all__. _EXTENSION_ATTRS = ["read_snowflake", "to_snowflake", "to_snowpark", "to_pandas"] # We also need to re-export native_pd.offsets, since modin.pandas doesn't re-export it. -# snowflake.snowpark.pandas.base also needs to be re-exported to make docstring overrides for BasePandasDataset work. -_ADDITIONAL_ATTRS = ["offsets", "base"] +_ADDITIONAL_ATTRS = ["offsets"] # This code should eventually be moved into the `snowflake.snowpark.modin.plugin` module instead. # Currently, trying to do so would result in incorrect results because `snowflake.snowpark.modin.pandas` diff --git a/src/snowflake/snowpark/modin/pandas/dataframe.py b/src/snowflake/snowpark/modin/pandas/dataframe.py index a7d53813779..a199d73e999 100644 --- a/src/snowflake/snowpark/modin/pandas/dataframe.py +++ b/src/snowflake/snowpark/modin/pandas/dataframe.py @@ -37,6 +37,7 @@ import numpy as np import pandas from modin.pandas.accessor import CachedAccessor, SparseFrameAccessor +from modin.pandas.base import BasePandasDataset # from . import _update_engine from modin.pandas.iterator import PartitionIterator @@ -73,12 +74,11 @@ from pandas.util._validators import validate_bool_kwarg from snowflake.snowpark.modin import pandas as pd -from snowflake.snowpark.modin.pandas.base import _ATTRS_NO_LOOKUP, BasePandasDataset from snowflake.snowpark.modin.pandas.groupby import ( DataFrameGroupBy, validate_groupby_args, ) -from snowflake.snowpark.modin.pandas.series import Series +from snowflake.snowpark.modin.pandas.series import _ATTRS_NO_LOOKUP, Series from snowflake.snowpark.modin.pandas.snow_partition_iterator import ( SnowparkPandasRowPartitionIterator, ) @@ -91,6 +91,7 @@ replace_external_data_keys_with_empty_pandas_series, replace_external_data_keys_with_query_compiler, ) +from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta from snowflake.snowpark.modin.plugin._internal.utils import is_repr_truncated from snowflake.snowpark.modin.plugin._typing import DropKeep, ListLike from snowflake.snowpark.modin.plugin.utils.error_message import ( @@ -136,7 +137,7 @@ ], apilink="pandas.DataFrame", ) -class DataFrame(BasePandasDataset): +class DataFrame(BasePandasDataset, metaclass=TelemetryMeta): _pandas_class = pandas.DataFrame def __init__( diff --git a/src/snowflake/snowpark/modin/pandas/general.py b/src/snowflake/snowpark/modin/pandas/general.py index f14e14840bb..07f0617d612 100644 --- a/src/snowflake/snowpark/modin/pandas/general.py +++ b/src/snowflake/snowpark/modin/pandas/general.py @@ -30,6 +30,7 @@ import numpy as np import pandas import pandas.core.common as common +from modin.pandas.base import BasePandasDataset from pandas import IntervalIndex, NaT, Timedelta, Timestamp from pandas._libs import NaTType, lib from pandas._libs.tslibs import to_offset @@ -61,7 +62,6 @@ # add this line to make doctests runnable from snowflake.snowpark.modin import pandas as pd # noqa: F401 -from snowflake.snowpark.modin.pandas.base import BasePandasDataset from snowflake.snowpark.modin.pandas.dataframe import DataFrame from snowflake.snowpark.modin.pandas.series import Series from snowflake.snowpark.modin.pandas.utils import ( diff --git a/src/snowflake/snowpark/modin/pandas/indexing.py b/src/snowflake/snowpark/modin/pandas/indexing.py index 0ac62f504ce..c83e3fe41c4 100644 --- a/src/snowflake/snowpark/modin/pandas/indexing.py +++ b/src/snowflake/snowpark/modin/pandas/indexing.py @@ -43,6 +43,7 @@ import numpy as np import pandas +from modin.pandas.base import BasePandasDataset from pandas._libs.tslibs import Resolution, parsing from pandas._typing import AnyArrayLike, Scalar from pandas.api.types import is_bool, is_list_like @@ -58,7 +59,6 @@ import snowflake.snowpark.modin.pandas as pd import snowflake.snowpark.modin.pandas.utils as frontend_utils -from snowflake.snowpark.modin.pandas.base import BasePandasDataset from snowflake.snowpark.modin.pandas.dataframe import DataFrame from snowflake.snowpark.modin.pandas.series import ( SERIES_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE_ERROR_MESSAGE, diff --git a/src/snowflake/snowpark/modin/pandas/series.py b/src/snowflake/snowpark/modin/pandas/series.py index 1ce3ecfc997..59f00ef2574 100644 --- a/src/snowflake/snowpark/modin/pandas/series.py +++ b/src/snowflake/snowpark/modin/pandas/series.py @@ -31,6 +31,7 @@ import numpy.typing as npt import pandas from modin.pandas.accessor import CachedAccessor, SparseAccessor +from modin.pandas.base import BasePandasDataset from modin.pandas.iterator import PartitionIterator from pandas._libs.lib import NoDefault, is_integer, no_default from pandas._typing import ( @@ -51,12 +52,12 @@ from pandas.core.series import _coerce_method from pandas.util._validators import validate_bool_kwarg -from snowflake.snowpark.modin.pandas.base import _ATTRS_NO_LOOKUP, BasePandasDataset from snowflake.snowpark.modin.pandas.utils import ( from_pandas, is_scalar, try_convert_index_to_native, ) +from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta from snowflake.snowpark.modin.plugin._typing import DropKeep, ListLike from snowflake.snowpark.modin.plugin.utils.error_message import ( ErrorMessage, @@ -96,6 +97,30 @@ _SERIES_EXTENSIONS_ = {} +# Do not look up certain attributes in columns or index, as they're used for some +# special purposes, like serving remote context +_ATTRS_NO_LOOKUP = { + "____id_pack__", + "__name__", + "_cache", + "_ipython_canary_method_should_not_exist_", + "_ipython_display_", + "_repr_html_", + "_repr_javascript_", + "_repr_jpeg_", + "_repr_json_", + "_repr_latex_", + "_repr_markdown_", + "_repr_mimebundle_", + "_repr_pdf_", + "_repr_png_", + "_repr_svg_", + "__array_struct__", + "__array_interface__", + "_typ", +} + + @_inherit_docstrings( pandas.Series, excluded=[ @@ -108,7 +133,7 @@ ], apilink="pandas.Series", ) -class Series(BasePandasDataset): +class Series(BasePandasDataset, metaclass=TelemetryMeta): _pandas_class = pandas.Series __array_priority__ = pandas.Series.__array_priority__ diff --git a/src/snowflake/snowpark/modin/pandas/utils.py b/src/snowflake/snowpark/modin/pandas/utils.py index f971e0ff964..32702c8b1a4 100644 --- a/src/snowflake/snowpark/modin/pandas/utils.py +++ b/src/snowflake/snowpark/modin/pandas/utils.py @@ -170,10 +170,9 @@ def is_scalar(obj): bool True if given object is scalar and False otherwise. """ + from modin.pandas.base import BasePandasDataset from pandas.api.types import is_scalar as pandas_is_scalar - from .base import BasePandasDataset - return not isinstance(obj, BasePandasDataset) and pandas_is_scalar(obj) diff --git a/src/snowflake/snowpark/modin/plugin/__init__.py b/src/snowflake/snowpark/modin/plugin/__init__.py index a76b9fe1613..b46d11f7f8e 100644 --- a/src/snowflake/snowpark/modin/plugin/__init__.py +++ b/src/snowflake/snowpark/modin/plugin/__init__.py @@ -63,12 +63,21 @@ import modin.utils # type: ignore[import] # isort: skip # noqa: E402 import modin.pandas.series_utils # type: ignore[import] # isort: skip # noqa: E402 -modin.utils._inherit_docstrings( +# TODO: https://github.com/modin-project/modin/issues/7113 and https://github.com/modin-project/modin/issues/7134 +# Upstream Modin has issues with certain docstring generation edge cases, so we should use our version instead +_inherit_docstrings = snowflake.snowpark.modin.utils._inherit_docstrings + +_inherit_docstrings( + docstrings.base.BasePandasDataset, + overwrite_existing=True, +)(modin.pandas.base.BasePandasDataset) + +_inherit_docstrings( docstrings.series_utils.StringMethods, overwrite_existing=True, )(modin.pandas.series_utils.StringMethods) -modin.utils._inherit_docstrings( +_inherit_docstrings( docstrings.series_utils.CombinedDatetimelikeProperties, overwrite_existing=True, )(modin.pandas.series_utils.DatetimeProperties) diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index bbebbec1783..8e1abce8d1e 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -388,6 +388,8 @@ class SnowflakeQueryCompiler(BaseQueryCompiler): this class is best explained by looking at https://github.com/modin-project/modin/blob/a8be482e644519f2823668210cec5cf1564deb7e/modin/experimental/core/storage_formats/hdk/query_compiler.py """ + lazy_execution = True + def __init__(self, frame: InternalFrame) -> None: """this stores internally a local pandas object (refactor this)""" assert frame is not None and isinstance( @@ -767,6 +769,7 @@ def execute(self) -> None: def to_numpy( self, dtype: Optional[npt.DTypeLike] = None, + copy: Optional[bool] = False, na_value: object = lib.no_default, **kwargs: Any, ) -> np.ndarray: @@ -774,6 +777,12 @@ def to_numpy( # i.e., for something like df.values internally to_numpy().flatten() is called # with flatten being another query compiler call into the numpy frontend layer. # here it's overwritten to actually perform numpy conversion, i.e. return an actual numpy object + if copy: + WarningMessage.ignored_argument( + operation="to_numpy", + argument="copy", + message="copy is ignored in Snowflake backend", + ) return self.to_pandas().to_numpy(dtype=dtype, na_value=na_value, **kwargs) def repartition(self, axis: Any = None) -> "SnowflakeQueryCompiler": @@ -1400,17 +1409,6 @@ def cache_result(self) -> "SnowflakeQueryCompiler": """ return SnowflakeQueryCompiler(self._modin_frame.persist_to_temporary_table()) - @property - def columns(self) -> native_pd.Index: - """ - Get pandas column labels. - - Returns: - an index containing all pandas column labels - """ - # TODO SNOW-837664: add more tests for df.columns - return self._modin_frame.data_columns_index - @snowpark_pandas_type_immutable_check def set_columns(self, new_pandas_labels: Axes) -> "SnowflakeQueryCompiler": """ @@ -1465,6 +1463,12 @@ def set_columns(self, new_pandas_labels: Axes) -> "SnowflakeQueryCompiler": ) return SnowflakeQueryCompiler(new_internal_frame) + # TODO SNOW-837664: add more tests for df.columns + columns: native_pd.Index = property( + lambda self: self._modin_frame.data_columns_index, + lambda self, labels: self.set_columns(labels), + ) + def _shift_values( self, periods: int, axis: Union[Literal[0], Literal[1]], fill_value: Hashable ) -> "SnowflakeQueryCompiler": @@ -2807,6 +2811,8 @@ def reset_index( Returns: A new SnowflakeQueryCompiler instance with updated index. """ + if allow_duplicates is no_default: + allow_duplicates = False # These levels will be moved from index columns to data columns levels_to_be_reset = self._modin_frame.parse_levels_to_integer_levels( level, allow_duplicates=False @@ -3007,9 +3013,11 @@ def first_last_valid_index( def sort_index( self, + *, axis: int, level: Optional[list[Union[str, int]]], ascending: Union[bool, list[bool]], + inplace: bool = False, kind: SortKind, na_position: NaPosition, sort_remaining: bool, @@ -3025,6 +3033,8 @@ def sort_index( level: If not None, sort on values in specified index level(s). ascending: A list of bools to represent ascending vs descending sort. Defaults to True. When the index is a MultiIndex the sort direction can be controlled for each level individually. + inplace: Whether or not the sort occurs in-place. This argument is ignored and only provided + for compatibility with Modin. kind: Choice of sorting algorithm. Perform stable sort if 'stable'. Defaults to unstable sort. Snowpark pandas ignores choice of sorting algorithm except 'stable'. na_position: Puts NaNs at the beginning if 'first'; 'last' puts NaNs at the end. Defaults to 'last' @@ -10859,6 +10869,12 @@ def is_multiindex(self, *, axis: int = 0) -> bool: """ return self._modin_frame.is_multiindex(axis=axis) + def abs(self) -> "SnowflakeQueryCompiler": + return self.unary_op("abs") + + def negative(self) -> "SnowflakeQueryCompiler": + return self.unary_op("__neg__") + def unary_op(self, op: str) -> "SnowflakeQueryCompiler": """ Applies a unary operation `op` on each element of the `SnowflakeQueryCompiler`. diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_not_implemented.py b/src/snowflake/snowpark/modin/plugin/extensions/base_not_implemented.py new file mode 100644 index 00000000000..aaf8b86494d --- /dev/null +++ b/src/snowflake/snowpark/modin/plugin/extensions/base_not_implemented.py @@ -0,0 +1,414 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +""" +The functions in this file are not implemented in Snowpark pandas. In the future, they +should raise NotImplementedError at the query compiler layer, but doing so requires a longer-term +effort. + +We currently test unsupported APIs under tests/unit/modin/test_unsupported.py, which does not initialize +a session. As such, many frontend methods have additional query compiler API calls that would have to +be mocked before the NotImplementedError can appropriately be raised. +""" +from __future__ import annotations + +import pickle as pkl +from typing import Any + +import numpy as np +import pandas +from modin.pandas.base import BasePandasDataset +from pandas._libs import lib +from pandas._libs.lib import no_default +from pandas._typing import ( + Axis, + CompressionOptions, + StorageOptions, + TimedeltaConvertibleTypes, +) + +from snowflake.snowpark.modin.pandas.api.extensions import ( + register_dataframe_accessor, + register_series_accessor, +) +from snowflake.snowpark.modin.plugin._internal.telemetry import ( + snowpark_pandas_telemetry_method_decorator, +) +from snowflake.snowpark.modin.plugin.utils.error_message import base_not_implemented + + +def register_base_not_implemented(): + def decorator(base_method: Any): + func = snowpark_pandas_telemetry_method_decorator( + base_not_implemented()(base_method) + ) + register_series_accessor(base_method.__name__)(func) + register_dataframe_accessor(base_method.__name__)(func) + return func + + return decorator + + +@register_base_not_implemented() +def align( + self, + other, + join="outer", + axis=None, + level=None, + copy=None, + fill_value=None, + method=lib.no_default, + limit=lib.no_default, + fill_axis=lib.no_default, + broadcast_axis=lib.no_default, +): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def asof(self, where, subset=None): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def at_time(self, time, asof=False, axis=None): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def between_time( + self: BasePandasDataset, + start_time, + end_time, + inclusive: str | None = None, + axis=None, +): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def bool(self): # noqa: RT01, D200 + pass + + +@register_base_not_implemented() +def clip( + self, lower=None, upper=None, axis=None, inplace=False, *args, **kwargs +): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def combine(self, other, func, fill_value=None, **kwargs): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def combine_first(self, other): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def droplevel(self, level, axis=0): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def explode(self, column, ignore_index: bool = False): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def ewm( + self, + com: float | None = None, + span: float | None = None, + halflife: float | TimedeltaConvertibleTypes | None = None, + alpha: float | None = None, + min_periods: int | None = 0, + adjust: bool = True, + ignore_na: bool = False, + axis: Axis = 0, + times: str | np.ndarray | BasePandasDataset | None = None, + method: str = "single", +) -> pandas.core.window.ewm.ExponentialMovingWindow: # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def filter( + self, items=None, like=None, regex=None, axis=None +): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def infer_objects( + self, copy: bool | None = None +) -> BasePandasDataset: # pragma: no cover # noqa: RT01, D200 + pass + + +@register_base_not_implemented() +def kurt(self, axis=no_default, skipna=True, numeric_only=False, **kwargs): + pass + + +@register_base_not_implemented() +def kurtosis(self, axis=no_default, skipna=True, numeric_only=False, **kwargs): + pass + + +@register_base_not_implemented() +def mode(self, axis=0, numeric_only=False, dropna=True): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def pipe(self, func, *args, **kwargs): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def pop(self, item): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def reindex_like( + self, other, method=None, copy=True, limit=None, tolerance=None +): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def reorder_levels(self, order, axis=0): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def sem( + self, + axis: Axis | None = None, + skipna: bool = True, + ddof: int = 1, + numeric_only=False, + **kwargs, +): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def set_flags( + self, *, copy: bool = False, allows_duplicate_labels: bool | None = None +): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def swapaxes(self, axis1, axis2, copy=True): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def swaplevel(self, i=-2, j=-1, axis=0): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def to_clipboard( + self, excel=True, sep=None, **kwargs +): # pragma: no cover # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def to_excel( + self, + excel_writer, + sheet_name="Sheet1", + na_rep="", + float_format=None, + columns=None, + header=True, + index=True, + index_label=None, + startrow=0, + startcol=0, + engine=None, + merge_cells=True, + encoding=no_default, + inf_rep="inf", + verbose=no_default, + freeze_panes=None, + storage_options: StorageOptions = None, +): # pragma: no cover # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def to_hdf( + self, path_or_buf, key, format="table", **kwargs +): # pragma: no cover # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def to_json( + self, + path_or_buf=None, + orient=None, + date_format=None, + double_precision=10, + force_ascii=True, + date_unit="ms", + default_handler=None, + lines=False, + compression="infer", + index=True, + indent=None, + storage_options: StorageOptions = None, +): # pragma: no cover # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def to_latex( + self, + buf=None, + columns=None, + col_space=None, + header=True, + index=True, + na_rep="NaN", + formatters=None, + float_format=None, + sparsify=None, + index_names=True, + bold_rows=False, + column_format=None, + longtable=None, + escape=None, + encoding=None, + decimal=".", + multicolumn=None, + multicolumn_format=None, + multirow=None, + caption=None, + label=None, + position=None, +): # pragma: no cover # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def to_markdown( + self, + buf=None, + mode: str = "wt", + index: bool = True, + storage_options: StorageOptions = None, + **kwargs, +): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def to_pickle( + self, + path, + compression: CompressionOptions = "infer", + protocol: int = pkl.HIGHEST_PROTOCOL, + storage_options: StorageOptions = None, +): # pragma: no cover # noqa: PR01, D200 + pass + + +@register_base_not_implemented() +def to_string( + self, + buf=None, + columns=None, + col_space=None, + header=True, + index=True, + na_rep="NaN", + formatters=None, + float_format=None, + sparsify=None, + index_names=True, + justify=None, + max_rows=None, + min_rows=None, + max_cols=None, + show_dimensions=False, + decimal=".", + line_width=None, + max_colwidth=None, + encoding=None, +): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def to_sql( + self, + name, + con, + schema=None, + if_exists="fail", + index=True, + index_label=None, + chunksize=None, + dtype=None, + method=None, +): # noqa: PR01, D200 + pass + + +@register_base_not_implemented() +def to_timestamp( + self, freq=None, how="start", axis=0, copy=True +): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def to_xarray(self): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def truncate( + self, before=None, after=None, axis=None, copy=True +): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def tz_convert(self, tz, axis=0, level=None, copy=True): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def tz_localize( + self, tz, axis=0, level=None, copy=True, ambiguous="raise", nonexistent="raise" +): # noqa: PR01, RT01, D200 + pass + + +@register_base_not_implemented() +def __array_wrap__(self, result, context=None): + pass + + +@register_base_not_implemented() +def __finalize__(self, other, method=None, **kwargs): + pass + + +@register_base_not_implemented() +def __sizeof__(self): + pass diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py index 332df757787..afde17548a8 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py @@ -10,27 +10,145 @@ from __future__ import annotations import pickle as pkl -from typing import Any +import warnings +from collections.abc import Sequence +from typing import Any, Callable, Hashable, Literal, Mapping, cast, get_args +import modin.pandas as pd import numpy as np +import numpy.typing as npt import pandas from modin.pandas.base import BasePandasDataset -from pandas._libs.lib import no_default +from pandas._libs import lib +from pandas._libs.lib import NoDefault, is_bool, no_default from pandas._typing import ( + AggFuncType, + AnyArrayLike, + Axes, Axis, CompressionOptions, + FillnaOptions, + IgnoreRaise, + IndexKeyFunc, + IndexLabel, + Level, + NaPosition, + RandomState, + Scalar, StorageOptions, TimedeltaConvertibleTypes, + TimestampConvertibleTypes, +) +from pandas.core.common import apply_if_callable +from pandas.core.dtypes.common import ( + is_dict_like, + is_dtype_equal, + is_list_like, + is_numeric_dtype, + pandas_dtype, +) +from pandas.core.dtypes.inference import is_integer +from pandas.core.methods.describe import _refine_percentiles +from pandas.errors import SpecificationError +from pandas.util._validators import ( + validate_ascending, + validate_bool_kwarg, + validate_percentile, ) +import snowflake.snowpark.modin.pandas as spd from snowflake.snowpark.modin.pandas.api.extensions import ( register_dataframe_accessor, register_series_accessor, ) +from snowflake.snowpark.modin.pandas.utils import ( + ensure_index, + extract_validate_and_try_convert_named_aggs_from_kwargs, + get_as_shape_compatible_dataframe_or_series, + is_scalar, + raise_if_native_pandas_objects, + validate_and_try_convert_agg_func_arg_func_to_str, +) from snowflake.snowpark.modin.plugin._internal.telemetry import ( + TELEMETRY_PRIVATE_METHODS, + PropertyMethodType, snowpark_pandas_telemetry_method_decorator, ) -from snowflake.snowpark.modin.plugin.utils.error_message import base_not_implemented +from snowflake.snowpark.modin.plugin._typing import ListLike +from snowflake.snowpark.modin.plugin.utils.error_message import ( + ErrorMessage, + base_not_implemented, +) +from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage +from snowflake.snowpark.modin.utils import validate_int_kwarg + + +def register_base_override(method_name: str): + def decorator(base_method: Any): + if callable(base_method) and ( + not method_name.startswith("_") + or (method_name in TELEMETRY_PRIVATE_METHODS) + ): + base_method = snowpark_pandas_telemetry_method_decorator(base_method) + elif isinstance(base_method, property): + base_method = property( + snowpark_pandas_telemetry_method_decorator( + cast( + # add a cast because mypy doesn't recognize that + # non-None fget and __get__ are both callable + # arguments to snowpark_pandas_telemetry_method_decorator. + Callable, + base_method.fget, # all properties defined in this file have an fget + ), + property_name=method_name, + property_method_type=PropertyMethodType.FGET, + ), + snowpark_pandas_telemetry_method_decorator( + ( + base_method.__set__ + if base_method.fset is None + else base_method.fset + ), + property_name=method_name, + property_method_type=PropertyMethodType.FSET, + ), + snowpark_pandas_telemetry_method_decorator( + ( + base_method.__delete__ + if base_method.fdel is None + else base_method.fdel + ), + property_name=method_name, + property_method_type=PropertyMethodType.FDEL, + ), + doc=base_method.__doc__, + ) + parent_method = getattr(BasePandasDataset, method_name, None) + if isinstance(parent_method, property): + parent_method = parent_method.fget + # If the method was not defined on Series/DataFrame and instead inherited from the superclass + # we need to override it as well because the MRO was already determined or something? + # TODO: SNOW-1063347 + # Since we still use the vendored version of Series and the overrides for the top-level + # namespace haven't been performed yet, we need to set properties on the vendored version + series_method = getattr(spd.series.Series, method_name, None) + if isinstance(series_method, property): + series_method = series_method.fget + if series_method is None or series_method is parent_method: + register_series_accessor(method_name)(base_method) + # TODO: SNOW-1063346 + # Since we still use the vendored version of DataFrame and the overrides for the top-level + # namespace haven't been performed yet, we need to set properties on the vendored version + df_method = getattr(spd.dataframe.DataFrame, method_name, None) + if isinstance(df_method, property): + df_method = df_method.fget + if df_method is None or df_method is parent_method: + register_dataframe_accessor(method_name)(base_method) + # Replace base method + setattr(BasePandasDataset, method_name, base_method) + return base_method + + return decorator def register_base_not_implemented(): @@ -303,3 +421,1877 @@ def truncate( @register_base_not_implemented() def __finalize__(self, other, method=None, **kwargs): pass # pragma: no cover + + +# === OVERRIDDEN METHODS === + + +@register_base_override("aggregate") +def aggregate( + self, func: AggFuncType = None, axis: Axis | None = 0, *args: Any, **kwargs: Any +): + """ + Aggregate using one or more operations over the specified axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + from snowflake.snowpark.modin.pandas import Series + + origin_axis = axis + axis = self._get_axis_number(axis) + + if axis == 1 and isinstance(self, Series): + raise ValueError(f"No axis named {origin_axis} for object type Series") + + if len(self._query_compiler.columns) == 0: + # native pandas raise error with message "no result", here we raise a more readable error. + raise ValueError("No column to aggregate on.") + + # If we are using named kwargs, then we do not clear the kwargs (need them in the QC for processing + # order, as well as formatting error messages.) + uses_named_kwargs = False + # If aggregate is called on a Series, named aggregations can be passed in via a dictionary + # to func. + if func is None or (is_dict_like(func) and not self._is_dataframe): + if axis == 1: + raise ValueError( + "`func` must not be `None` when `axis=1`. Named aggregations are not supported with `axis=1`." + ) + if func is not None: + # If named aggregations are passed in via a dictionary to func, then we + # ignore the kwargs. + if any(is_dict_like(value) for value in func.values()): + # We can only get to this codepath if self is a Series, and func is a dictionary. + # In this case, if any of the values of func are themselves dictionaries, we must raise + # a Specification Error, as that is what pandas does. + raise SpecificationError("nested renamer is not supported") + kwargs = func + func = extract_validate_and_try_convert_named_aggs_from_kwargs( + self, allow_duplication=False, axis=axis, **kwargs + ) + uses_named_kwargs = True + else: + func = validate_and_try_convert_agg_func_arg_func_to_str( + agg_func=func, + obj=self, + allow_duplication=False, + axis=axis, + ) + + # This is to stay consistent with pandas result format, when the func is single + # aggregation function in format of callable or str, reduce the result dimension to + # convert dataframe to series, or convert series to scalar. + # Note: When named aggregations are used, the result is not reduced, even if there + # is only a single function. + # needs_reduce_dimension cannot be True if we are using named aggregations, since + # the values for func in that case are either NamedTuples (AggFuncWithLabels) or + # lists of NamedTuples, both of which are list like. + need_reduce_dimension = ( + (callable(func) or isinstance(func, str)) + # A Series should be returned when a single scalar string/function aggregation function, or a + # dict of scalar string/functions is specified. In all other cases (including if the function + # is a 1-element list), the result is a DataFrame. + # + # The examples below have axis=1, but the same logic is applied for axis=0. + # >>> df = pd.DataFrame({"a": [0, 1], "b": [2, 3]}) + # + # single aggregation: return Series + # >>> df.agg("max", axis=1) + # 0 2 + # 1 3 + # dtype: int64 + # + # list of aggregations: return DF + # >>> df.agg(["max"], axis=1) + # max + # 0 2 + # 1 3 + # + # dict where all aggregations are strings: return Series + # >>> df.agg({1: "max", 0: "min"}, axis=1) + # 1 3 + # 0 0 + # dtype: int64 + # + # dict where one element is a list: return DF + # >>> df.agg({1: "max", 0: ["min"]}, axis=1) + # max min + # 1 3.0 NaN + # 0 NaN 0.0 + or ( + is_dict_like(func) + and all(not is_list_like(value) for value in func.values()) + ) + ) + + # If func is a dict, pandas will not respect kwargs for each aggregation function, and + # we should drop them before passing the to the query compiler. + # + # >>> native_pd.DataFrame({"a": [0, 1], "b": [np.nan, 0]}).agg("max", skipna=False, axis=1) + # 0 NaN + # 1 1.0 + # dtype: float64 + # >>> native_pd.DataFrame({"a": [0, 1], "b": [np.nan, 0]}).agg(["max"], skipna=False, axis=1) + # max + # 0 0.0 + # 1 1.0 + # >>> pd.DataFrame([[np.nan], [0]]).aggregate("count", skipna=True, axis=0) + # 0 1 + # dtype: int8 + # >>> pd.DataFrame([[np.nan], [0]]).count(skipna=True, axis=0) + # TypeError: got an unexpected keyword argument 'skipna' + if is_dict_like(func) and not uses_named_kwargs: + kwargs.clear() + + result = self.__constructor__( + query_compiler=self._query_compiler.agg( + func=func, + axis=axis, + args=args, + kwargs=kwargs, + ) + ) + + if need_reduce_dimension: + if self._is_dataframe: + result = Series(query_compiler=result._query_compiler) + + if isinstance(result, Series): + # When func is just "quantile" with a scalar q, result has quantile value as name + q = kwargs.get("q", 0.5) + if func == "quantile" and is_scalar(q): + result.name = q + else: + result.name = None + + # handle case for single scalar (same as result._reduce_dimension()) + if isinstance(self, Series): + return result.to_pandas().squeeze() + + return result + + +agg = aggregate +register_base_override("agg")(agg) + + +@register_base_override("_binary_op") +def _binary_op( + self, + op: str, + other: BasePandasDataset, + axis: Axis, + level: Level | None = None, + fill_value: float | None = None, + **kwargs: Any, +): + """ + Do binary operation between two datasets. + + Parameters + ---------- + op : str + Name of binary operation. + other : modin.pandas.BasePandasDataset + Second operand of binary operation. + axis: Whether to compare by the index (0 or ‘index’) or columns. (1 or ‘columns’). + level: Broadcast across a level, matching Index values on the passed MultiIndex level. + fill_value: Fill existing missing (NaN) values, and any new element needed for + successful DataFrame alignment, with this value before computation. + If data in both corresponding DataFrame locations is missing the result will be missing. + only arithmetic binary operation has this parameter (e.g., add() has, but eq() doesn't have). + + kwargs can contain the following parameters passed in at the frontend: + func: Only used for `combine` method. Function that takes two series as inputs and + return a Series or a scalar. Used to merge the two dataframes column by columns. + + Returns + ------- + modin.pandas.BasePandasDataset + Result of binary operation. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + raise_if_native_pandas_objects(other) + axis = self._get_axis_number(axis) + squeeze_self = isinstance(self, pd.Series) + + # pandas itself will ignore the axis argument when using Series.. + # Per default, it is set to axis=0. However, for the case of a Series interacting with + # a DataFrame the behavior is axis=1. Manually check here for this case and adjust the axis. + + is_lhs_series_and_rhs_dataframe = ( + True + if isinstance(self, pd.Series) and isinstance(other, pd.DataFrame) + else False + ) + + new_query_compiler = self._query_compiler.binary_op( + op=op, + other=other, + axis=1 if is_lhs_series_and_rhs_dataframe else axis, + level=level, + fill_value=fill_value, + squeeze_self=squeeze_self, + **kwargs, + ) + + from snowflake.snowpark.modin.pandas.dataframe import DataFrame + + # Modin Bug: https://github.com/modin-project/modin/issues/7236 + # For a Series interacting with a DataFrame, always return a DataFrame + return ( + DataFrame(query_compiler=new_query_compiler) + if is_lhs_series_and_rhs_dataframe + else self._create_or_update_from_compiler(new_query_compiler) + ) + + +@register_base_override("_dropna") +def _dropna( + self, + axis: Axis = 0, + how: str | NoDefault = no_default, + thresh: int | NoDefault = no_default, + subset: IndexLabel = None, + inplace: bool = False, +): + inplace = validate_bool_kwarg(inplace, "inplace") + + if is_list_like(axis): + raise TypeError("supplying multiple axes to axis is no longer supported.") + + axis = self._get_axis_number(axis) + + if (how is not no_default) and (thresh is not no_default): + raise TypeError( + "You cannot set both the how and thresh arguments at the same time." + ) + + if how is no_default: + how = "any" + if how not in ["any", "all"]: + raise ValueError("invalid how option: %s" % how) + if subset is not None: + if axis == 1: + indices = self.index.get_indexer_for(subset) + check = indices == -1 + if check.any(): + raise KeyError(list(np.compress(check, subset))) + else: + indices = self.columns.get_indexer_for(subset) + check = indices == -1 + if check.any(): + raise KeyError(list(np.compress(check, subset))) + + new_query_compiler = self._query_compiler.dropna( + axis=axis, + how=how, + thresh=thresh, + subset=subset, + ) + return self._create_or_update_from_compiler(new_query_compiler, inplace) + + +@register_base_override("fillna") +def fillna( + self, + self_is_series, + value: Hashable | Mapping | pd.Series | pd.DataFrame = None, + method: FillnaOptions | None = None, + axis: Axis | None = None, + inplace: bool = False, + limit: int | None = None, + downcast: dict | None = None, +): + """ + Fill NA/NaN values using the specified method. + + Parameters + ---------- + self_is_series : bool + If True then self contains a Series object, if False then self contains + a DataFrame object. + value : scalar, dict, Series, or DataFrame, default: None + Value to use to fill holes (e.g. 0), alternately a + dict/Series/DataFrame of values specifying which value to use for + each index (for a Series) or column (for a DataFrame). Values not + in the dict/Series/DataFrame will not be filled. This value cannot + be a list. + method : {'backfill', 'bfill', 'pad', 'ffill', None}, default: None + Method to use for filling holes in reindexed Series + pad / ffill: propagate last valid observation forward to next valid + backfill / bfill: use next valid observation to fill gap. + axis : {None, 0, 1}, default: None + Axis along which to fill missing values. + inplace : bool, default: False + If True, fill in-place. Note: this will modify any + other views on this object (e.g., a no-copy slice for a column in a + DataFrame). + limit : int, default: None + If method is specified, this is the maximum number of consecutive + NaN values to forward/backward fill. In other words, if there is + a gap with more than this number of consecutive NaNs, it will only + be partially filled. If method is not specified, this is the + maximum number of entries along the entire axis where NaNs will be + filled. Must be greater than 0 if not None. + downcast : dict, default: None + A dict of item->dtype of what to downcast if possible, + or the string 'infer' which will try to downcast to an appropriate + equal type (e.g. float64 to int64 if possible). + + Returns + ------- + Series, DataFrame or None + Object with missing values filled or None if ``inplace=True``. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + raise_if_native_pandas_objects(value) + inplace = validate_bool_kwarg(inplace, "inplace") + axis = self._get_axis_number(axis) + if isinstance(value, (list, tuple)): + raise TypeError( + '"value" parameter must be a scalar or dict, but ' + + f'you passed a "{type(value).__name__}"' + ) + if value is None and method is None: + # same as pandas + raise ValueError("Must specify a fill 'value' or 'method'.") + if value is not None and method is not None: + raise ValueError("Cannot specify both 'value' and 'method'.") + if method is not None and method not in ["backfill", "bfill", "pad", "ffill"]: + expecting = "pad (ffill) or backfill (bfill)" + msg = "Invalid fill method. Expecting {expecting}. Got {method}".format( + expecting=expecting, method=method + ) + raise ValueError(msg) + if limit is not None: + if not isinstance(limit, int): + raise ValueError("Limit must be an integer") + elif limit <= 0: + raise ValueError("Limit must be greater than 0") + + new_query_compiler = self._query_compiler.fillna( + self_is_series=self_is_series, + value=value, + method=method, + axis=axis, + limit=limit, + downcast=downcast, + ) + return self._create_or_update_from_compiler(new_query_compiler, inplace) + + +@register_base_override("isin") +def isin( + self, values: BasePandasDataset | ListLike | dict[Hashable, ListLike] +) -> BasePandasDataset: # noqa: PR01, RT01, D200 + """ + Whether elements in `BasePandasDataset` are contained in `values`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + + # Pass as query compiler if values is BasePandasDataset. + if isinstance(values, BasePandasDataset): + values = values._query_compiler + + # Convert non-dict values to List if values is neither List[Any] nor np.ndarray. SnowflakeQueryCompiler + # expects for the non-lazy case, where values is not a BasePandasDataset, the data to be materialized + # as list or numpy array. Because numpy may perform implicit type conversions, use here list to be more general. + elif not isinstance(values, dict) and ( + not isinstance(values, list) or not isinstance(values, np.ndarray) + ): + values = list(values) + + return self.__constructor__(query_compiler=self._query_compiler.isin(values=values)) + + +@register_base_override("quantile") +def quantile( + self, + q: Scalar | ListLike = 0.5, + axis: Axis = 0, + numeric_only: bool = False, + interpolation: Literal[ + "linear", "lower", "higher", "midpoint", "nearest" + ] = "linear", + method: Literal["single", "table"] = "single", +) -> float | BasePandasDataset: + """ + Return values at the given quantile over requested axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + axis = self._get_axis_number(axis) + + # TODO + # - SNOW-1008361: support axis=1 + # - SNOW-1008367: support when q is Snowpandas DF/Series (need to require QC interface to accept QC q values) + # - SNOW-1003587: support datetime/timedelta columns + + if axis == 1 or interpolation not in ["linear", "nearest"] or method != "single": + ErrorMessage.not_implemented( + f"quantile function with parameters axis={axis}, interpolation={interpolation}, method={method} not supported" + ) + + if not numeric_only: + # If not numeric_only and columns, then check all columns are either + # numeric, timestamp, or timedelta + # Check if dtype is numeric, timedelta ("m"), or datetime ("M") + if not axis and not all( + is_numeric_dtype(t) or lib.is_np_dtype(t, "mM") for t in self._get_dtypes() + ): + raise TypeError("can't multiply sequence by non-int of type 'float'") + # If over rows, then make sure that all dtypes are equal for not + # numeric_only + elif axis: + for i in range(1, len(self._get_dtypes())): + pre_dtype = self._get_dtypes()[i - 1] + curr_dtype = self._get_dtypes()[i] + if not is_dtype_equal(pre_dtype, curr_dtype): + raise TypeError( + "Cannot compare type '{}' with type '{}'".format( + pre_dtype, curr_dtype + ) + ) + else: + # Normally pandas returns this near the end of the quantile, but we + # can't afford the overhead of running the entire operation before + # we error. + if not any(is_numeric_dtype(t) for t in self._get_dtypes()): + raise ValueError("need at least one array to concatenate") + + # check that all qs are between 0 and 1 + validate_percentile(q) + axis = self._get_axis_number(axis) + query_compiler = self._query_compiler.quantiles_along_axis0( + q=q if is_list_like(q) else [q], + numeric_only=numeric_only, + interpolation=interpolation, + method=method, + ) + if is_list_like(q): + return self.__constructor__(query_compiler=query_compiler) + else: + # result is either a scalar or Series + result = self._reduce_dimension(query_compiler.transpose_single_row()) + if isinstance(result, BasePandasDataset): + result.name = q + return result + + +@register_base_override("_to_series_list") +def _to_series_list(self, index: pd.Index) -> list[pd.Series]: + """ + Convert index to a list of series + Args: + index: can be single or multi index + + Returns: + the list of series + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + if isinstance(index, pd.MultiIndex): + return [ + pd.Series(index.get_level_values(level)) for level in range(index.nlevels) + ] + elif isinstance(index, pd.Index): + return [pd.Series(index)] + else: + raise Exception("invalid index: " + str(index)) + + +@register_base_override("_set_index") +def _set_index(self, new_index: Axes) -> None: + """ + Set the index for this DataFrame. + + Parameters + ---------- + new_index : pandas.Index + The new index to set this. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + self._update_inplace( + new_query_compiler=self._query_compiler.set_index( + [s._query_compiler for s in self._to_series_list(ensure_index(new_index))] + ) + ) + + +index = property(lambda self: self._query_compiler.index, _set_index) +register_base_override("index")(index) + + +@register_base_override("shift") +def shift( + self, + periods: int | Sequence[int] = 1, + freq=None, + axis: Axis = 0, + fill_value: Hashable = no_default, + suffix: str | None = None, +) -> BasePandasDataset: + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + if periods == 0 and freq is None: + # Check obvious case first, freq manipulates the index even for periods == 0 so check for it in addition. + return self.copy() + + # pandas compatible ValueError for freq='infer' + # TODO: Test as part of SNOW-1023324. + if freq == "infer": # pragma: no cover + if not hasattr(self, "freq") and not hasattr( # pragma: no cover + self, "inferred_freq" # pragma: no cover + ): # pragma: no cover + raise ValueError() # pragma: no cover + + axis = self._get_axis_number(axis) + + if fill_value == no_default: + fill_value = None + + new_query_compiler = self._query_compiler.shift( + periods, freq, axis, fill_value, suffix + ) + return self._create_or_update_from_compiler(new_query_compiler, False) + + +@register_base_override("skew") +def skew( + self, + axis: Axis | None | NoDefault = no_default, + skipna: bool = True, + numeric_only=True, + **kwargs, +): # noqa: PR01, RT01, D200 + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + """ + Return unbiased skew over requested axis. + """ + return self._stat_operation("skew", axis, skipna, numeric_only, **kwargs) + + +@register_base_override("_agg_helper") +def _agg_helper( + self, + func: str, + skipna: bool = True, + axis: int | None | NoDefault = no_default, + numeric_only: bool = False, + **kwargs: Any, +): + if not self._is_dataframe and numeric_only and not is_numeric_dtype(self.dtype): + # Series aggregations on non-numeric data do not support numeric_only: + # https://github.com/pandas-dev/pandas/blob/cece8c6579854f6b39b143e22c11cac56502c4fd/pandas/core/series.py#L6358 + raise TypeError( + f"Series.{func} does not allow numeric_only=True with non-numeric dtypes." + ) + axis = self._get_axis_number(axis) + numeric_only = validate_bool_kwarg(numeric_only, "numeric_only", none_allowed=True) + skipna = validate_bool_kwarg(skipna, "skipna", none_allowed=False) + agg_kwargs: dict[str, Any] = { + "numeric_only": numeric_only, + "skipna": skipna, + } + agg_kwargs.update(kwargs) + return self.aggregate(func=func, axis=axis, **agg_kwargs) + + +@register_base_override("count") +def count( + self, + axis: Axis | None = 0, + numeric_only: bool = False, +): + """ + Count non-NA cells for `BasePandasDataset`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._agg_helper( + func="count", + axis=axis, + numeric_only=numeric_only, + ) + + +@register_base_override("max") +def max( + self, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs: Any, +): + """ + Return the maximum of the values over the requested axis. + """ + return self._agg_helper( + func="max", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +@register_base_override("min") +def min( + self, + axis: Axis | None | NoDefault = no_default, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, +): + """ + Return the minimum of the values over the requested axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._agg_helper( + func="min", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +@register_base_override("mean") +def mean( + self, + axis: Axis | None | NoDefault = no_default, + skipna: bool = True, + numeric_only: bool = False, + **kwargs: Any, +): + """ + Return the mean of the values over the requested axis. + """ + return self._agg_helper( + func="mean", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +@register_base_override("median") +def median( + self, + axis: Axis | None | NoDefault = no_default, + skipna: bool = True, + numeric_only: bool = False, + **kwargs: Any, +): + """ + Return the mean of the values over the requested axis. + """ + return self._agg_helper( + func="median", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +@register_base_override("std") +def std( + self, + axis: Axis | None = None, + skipna: bool = True, + ddof: int = 1, + numeric_only: bool = False, + **kwargs, +): + """ + Return sample standard deviation over requested axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + kwargs.update({"ddof": ddof}) + return self._agg_helper( + func="std", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +@register_base_override("sum") +def sum( + self, + axis: Axis | None = None, + skipna: bool = True, + numeric_only: bool = False, + min_count: int = 0, + **kwargs: Any, +): + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + min_count = validate_int_kwarg(min_count, "min_count") + kwargs.update({"min_count": min_count}) + return self._agg_helper( + func="sum", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +@register_base_override("var") +def var( + self, + axis: Axis | None = None, + skipna: bool = True, + ddof: int = 1, + numeric_only: bool = False, + **kwargs: Any, +): + """ + Return unbiased variance over requested axis. + """ + kwargs.update({"ddof": ddof}) + return self._agg_helper( + func="var", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +@register_base_override("resample") +def resample( + self, + rule, + axis: Axis = lib.no_default, + closed: str | None = None, + label: str | None = None, + convention: str = "start", + kind: str | None = None, + on: Level = None, + level: Level = None, + origin: str | TimestampConvertibleTypes = "start_day", + offset: TimedeltaConvertibleTypes | None = None, + group_keys=no_default, +): # noqa: PR01, RT01, D200 + """ + Resample time-series data. + """ + from snowflake.snowpark.modin.pandas.resample import Resampler + + if axis is not lib.no_default: # pragma: no cover + axis = self._get_axis_number(axis) + if axis == 1: + warnings.warn( + "DataFrame.resample with axis=1 is deprecated. Do " + + "`frame.T.resample(...)` without axis instead.", + FutureWarning, + stacklevel=1, + ) + else: + warnings.warn( + f"The 'axis' keyword in {type(self).__name__}.resample is " + + "deprecated and will be removed in a future version.", + FutureWarning, + stacklevel=1, + ) + else: + axis = 0 + + return Resampler( + dataframe=self, + rule=rule, + axis=axis, + closed=closed, + label=label, + convention=convention, + kind=kind, + on=on, + level=level, + origin=origin, + offset=offset, + group_keys=group_keys, + ) + + +@register_base_override("expanding") +def expanding(self, min_periods=1, axis=0, method="single"): # noqa: PR01, RT01, D200 + """ + Provide expanding window calculations. + """ + from snowflake.snowpark.modin.pandas.window import Expanding + + if axis is not lib.no_default: + axis = self._get_axis_number(axis) + name = "expanding" + if axis == 1: + warnings.warn( + f"Support for axis=1 in {type(self).__name__}.{name} is " + + "deprecated and will be removed in a future version. " + + f"Use obj.T.{name}(...) instead", + FutureWarning, + stacklevel=1, + ) + else: + warnings.warn( + f"The 'axis' keyword in {type(self).__name__}.{name} is " + + "deprecated and will be removed in a future version. " + + "Call the method without the axis keyword instead.", + FutureWarning, + stacklevel=1, + ) + else: + axis = 0 + + return Expanding( + self, + min_periods=min_periods, + axis=axis, + method=method, + ) + + +@register_base_override("rolling") +def rolling( + self, + window, + min_periods: int | None = None, + center: bool = False, + win_type: str | None = None, + on: str | None = None, + axis: Axis = lib.no_default, + closed: str | None = None, + step: int | None = None, + method: str = "single", +): # noqa: PR01, RT01, D200 + """ + Provide rolling window calculations. + """ + if axis is not lib.no_default: + axis = self._get_axis_number(axis) + name = "rolling" + if axis == 1: + warnings.warn( + f"Support for axis=1 in {type(self).__name__}.{name} is " + + "deprecated and will be removed in a future version. " + + f"Use obj.T.{name}(...) instead", + FutureWarning, + stacklevel=1, + ) + else: # pragma: no cover + warnings.warn( + f"The 'axis' keyword in {type(self).__name__}.{name} is " + + "deprecated and will be removed in a future version. " + + "Call the method without the axis keyword instead.", + FutureWarning, + stacklevel=1, + ) + else: + axis = 0 + + if win_type is not None: + from snowflake.snowpark.modin.pandas.window import Window + + return Window( + self, + window=window, + min_periods=min_periods, + center=center, + win_type=win_type, + on=on, + axis=axis, + closed=closed, + step=step, + method=method, + ) + from snowflake.snowpark.modin.pandas.window import Rolling + + return Rolling( + self, + window=window, + min_periods=min_periods, + center=center, + win_type=win_type, + on=on, + axis=axis, + closed=closed, + step=step, + method=method, + ) + + +@register_base_override("iloc") +@property +def iloc(self): + """ + Purely integer-location based indexing for selection by position. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # TODO: SNOW-930028 enable all skipped doctests + from snowflake.snowpark.modin.pandas.indexing import _iLocIndexer + + return _iLocIndexer(self) + + +@register_base_override("loc") +@property +def loc(self): + """ + Get a group of rows and columns by label(s) or a boolean array. + """ + # TODO: SNOW-935444 fix doctest where index key has name + # TODO: SNOW-933782 fix multiindex transpose bug, e.g., Name: (cobra, mark ii) => Name: ('cobra', 'mark ii') + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + from snowflake.snowpark.modin.pandas.indexing import _LocIndexer + + return _LocIndexer(self) + + +@register_base_override("iat") +@property +def iat(self, axis=None): # noqa: PR01, RT01, D200 + """ + Get a single value for a row/column pair by integer position. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + from snowflake.snowpark.modin.pandas.indexing import _iAtIndexer + + return _iAtIndexer(self) + + +@register_base_override("at") +@property +def at(self, axis=None): # noqa: PR01, RT01, D200 + """ + Get a single value for a row/column label pair. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + from snowflake.snowpark.modin.pandas.indexing import _AtIndexer + + return _AtIndexer(self) + + +@register_base_override("__getitem__") +def __getitem__(self, key): + """ + Retrieve dataset according to `key`. + + Parameters + ---------- + key : callable, scalar, slice, str or tuple + The global row index to retrieve data from. + + Returns + ------- + BasePandasDataset + Located dataset. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + key = apply_if_callable(key, self) + # If a slice is passed in, use .iloc[key]. + if isinstance(key, slice): + if (is_integer(key.start) or key.start is None) and ( + is_integer(key.stop) or key.stop is None + ): + return self.iloc[key] + else: + return self.loc[key] + + # If the object calling getitem is a Series, only use .loc[key] to filter index. + if isinstance(self, pd.Series): + return self.loc[key] + + # Sometimes the result of a callable is a DataFrame (e.g. df[df > 0]) - use where. + elif isinstance(key, pd.DataFrame): + return self.where(cond=key) + + # If the object is a boolean list-like object, use .loc[key] to filter index. + # The if statement is structured this way to avoid calling dtype and reduce query count. + if isinstance(key, pd.Series): + if key.dtype == bool: + return self.loc[key] + elif is_list_like(key): + if hasattr(key, "dtype"): + if key.dtype == bool: + return self.loc[key] + if (all(is_bool(k) for k in key)) and len(key) > 0: + return self.loc[key] + + # In all other cases, use .loc[:, key] to filter columns. + return self.loc[:, key] + + +@register_base_override("sort_values") +def sort_values( + self, + by, + axis=0, + ascending=True, + inplace: bool = False, + kind="quicksort", + na_position="last", + ignore_index: bool = False, + key: IndexKeyFunc | None = None, +): # noqa: PR01, RT01, D200 + """ + Sort by the values along either axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + axis = self._get_axis_number(axis) + inplace = validate_bool_kwarg(inplace, "inplace") + ascending = validate_ascending(ascending) + if axis == 0: + # If any column is None raise KeyError (same a native pandas). + if by is None or (isinstance(by, list) and None in by): + # Same error message as native pandas. + raise KeyError(None) + if not isinstance(by, list): + by = [by] + + # Convert 'ascending' to sequence if needed. + if not isinstance(ascending, Sequence): + ascending = [ascending] * len(by) + if len(by) != len(ascending): + # Same error message as native pandas. + raise ValueError( + f"Length of ascending ({len(ascending)})" + f" != length of by ({len(by)})" + ) + + columns = self._query_compiler.columns.values.tolist() + index_names = self._query_compiler.get_index_names() + for by_col in by: + col_count = columns.count(by_col) + index_count = index_names.count(by_col) + if col_count == 0 and index_count == 0: + # Same error message as native pandas. + raise KeyError(by_col) + if col_count and index_count: + # Same error message as native pandas. + raise ValueError( + f"'{by_col}' is both an index level and a column label, which is ambiguous." + ) + if col_count > 1: + # Same error message as native pandas. + raise ValueError(f"The column label '{by_col}' is not unique.") + + if na_position not in get_args(NaPosition): + # Same error message as native pandas for invalid 'na_position' value. + raise ValueError(f"invalid na_position: {na_position}") + result = self._query_compiler.sort_rows_by_column_values( + by, + ascending=ascending, + kind=kind, + na_position=na_position, + ignore_index=ignore_index, + key=key, + ) + else: + result = self._query_compiler.sort_columns_by_row_values( + by, + ascending=ascending, + kind=kind, + na_position=na_position, + ignore_index=ignore_index, + key=key, + ) + return self._create_or_update_from_compiler(result, inplace) + + +@register_base_override("where") +def where( + self, + cond: BasePandasDataset | Callable | AnyArrayLike, + other: BasePandasDataset | Callable | Scalar | None = np.nan, + inplace: bool = False, + axis: Axis | None = None, + level: Level | None = None, +): + """ + Replace values where the condition is False. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # TODO: SNOW-985670: Refactor `where` and `mask` + # will move pre-processing to QC layer. + inplace = validate_bool_kwarg(inplace, "inplace") + if cond is None: + raise ValueError("Array conditional must be same shape as self") + + cond = apply_if_callable(cond, self) + + if isinstance(cond, Callable): + raise NotImplementedError("Do not support callable for 'cond' parameter.") + + from snowflake.snowpark.modin.pandas import Series + + if isinstance(cond, Series): + cond._query_compiler._shape_hint = "column" + if isinstance(self, Series): + self._query_compiler._shape_hint = "column" + if isinstance(other, Series): + other._query_compiler._shape_hint = "column" + + if not isinstance(cond, BasePandasDataset): + cond = get_as_shape_compatible_dataframe_or_series(cond, self) + cond._query_compiler._shape_hint = "array" + + if other is not None: + other = apply_if_callable(other, self) + + if isinstance(other, np.ndarray): + other = get_as_shape_compatible_dataframe_or_series( + other, + self, + shape_mismatch_message="other must be the same shape as self when an ndarray", + ) + other._query_compiler._shape_hint = "array" + + if isinstance(other, BasePandasDataset): + other = other._query_compiler + + query_compiler = self._query_compiler.where( + cond._query_compiler, + other, + axis, + level, + ) + + return self._create_or_update_from_compiler(query_compiler, inplace) + + +@register_base_override("to_csv") +def to_csv( + self, + path_or_buf=None, + sep=",", + na_rep=",", + float_format=None, + columns=None, + header=True, + index=True, + index_label=None, + mode="w", + encoding=None, + compression="infer", + quoting=None, + quotechar='"', + lineterminator=None, + chunksize=None, + date_format=None, + doublequote=True, + escapechar=None, + decimal=".", + errors: str = "strict", + storage_options: StorageOptions = None, +): # pragma: no cover + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + return FactoryDispatcher.to_csv( + self._query_compiler, + path_or_buf=path_or_buf, + sep=sep, + na_rep=na_rep, + float_format=float_format, + columns=columns, + header=header, + index=index, + index_label=index_label, + mode=mode, + encoding=encoding, + compression=compression, + quoting=quoting, + quotechar=quotechar, + lineterminator=lineterminator, + chunksize=chunksize, + date_format=date_format, + doublequote=doublequote, + escapechar=escapechar, + decimal=decimal, + errors=errors, + storage_options=storage_options, + ) + + +@register_base_override("mask") +def mask( + self, + cond: BasePandasDataset | Callable | AnyArrayLike, + other: BasePandasDataset | Callable | Scalar | None = np.nan, + inplace: bool = False, + axis: Axis | None = None, + level: Level | None = None, +): + """ + Replace values where the condition is True. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # TODO: https://snowflakecomputing.atlassian.net/browse/SNOW-985670 + # will move pre-processing to QC layer. + inplace = validate_bool_kwarg(inplace, "inplace") + if cond is None: + raise ValueError("Array conditional must be same shape as self") + + cond = apply_if_callable(cond, self) + + if isinstance(cond, Callable): + raise NotImplementedError("Do not support callable for 'cond' parameter.") + + from snowflake.snowpark.modin.pandas import Series + + if isinstance(cond, Series): + cond._query_compiler._shape_hint = "column" + if isinstance(self, Series): + self._query_compiler._shape_hint = "column" + if isinstance(other, Series): + other._query_compiler._shape_hint = "column" + + if not isinstance(cond, BasePandasDataset): + cond = get_as_shape_compatible_dataframe_or_series(cond, self) + cond._query_compiler._shape_hint = "array" + + if other is not None: + other = apply_if_callable(other, self) + + if isinstance(other, np.ndarray): + other = get_as_shape_compatible_dataframe_or_series( + other, + self, + shape_mismatch_message="other must be the same shape as self when an ndarray", + ) + other._query_compiler._shape_hint = "array" + + if isinstance(other, BasePandasDataset): + other = other._query_compiler + + query_compiler = self._query_compiler.mask( + cond._query_compiler, + other, + axis, + level, + ) + + return self._create_or_update_from_compiler(query_compiler, inplace) + + +@register_base_override("sample") +def sample( + self, + n: int | None = None, + frac: float | None = None, + replace: bool = False, + weights: str | np.ndarray | None = None, + random_state: RandomState | None = None, + axis: Axis | None = None, + ignore_index: bool = False, +): + """ + Return a random sample of items from an axis of object. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + if self._get_axis_number(axis): + if weights is not None and isinstance(weights, str): + raise ValueError( + "Strings can only be passed to weights when sampling from rows on a DataFrame" + ) + else: + if n is None and frac is None: + n = 1 + elif n is not None and frac is not None: + raise ValueError("Please enter a value for `frac` OR `n`, not both") + else: + if n is not None: + if n < 0: + raise ValueError( + "A negative number of rows requested. Please provide `n` >= 0." + ) + if n % 1 != 0: + raise ValueError("Only integers accepted as `n` values") + else: + if frac < 0: + raise ValueError( + "A negative number of rows requested. Please provide `frac` >= 0." + ) + + query_compiler = self._query_compiler.sample( + n, frac, replace, weights, random_state, axis, ignore_index + ) + return self.__constructor__(query_compiler=query_compiler) + + +@register_base_override("pct_change") +def pct_change( + self, periods=1, fill_method=no_default, limit=no_default, freq=None, **kwargs +): # noqa: PR01, RT01, D200 + """ + Percentage change between the current and a prior element. + """ + if fill_method not in (lib.no_default, None) or limit is not lib.no_default: + warnings.warn( + "The 'fill_method' keyword being not None and the 'limit' keyword in " + + f"{type(self).__name__}.pct_change are deprecated and will be removed " + + "in a future version. Either fill in any non-leading NA values prior " + + "to calling pct_change or specify 'fill_method=None' to not fill NA " + + "values.", + FutureWarning, + stacklevel=1, + ) + if fill_method is lib.no_default: + warnings.warn( + f"The default fill_method='pad' in {type(self).__name__}.pct_change is " + + "deprecated and will be removed in a future version. Either fill in any " + + "non-leading NA values prior to calling pct_change or specify 'fill_method=None' " + + "to not fill NA values.", + FutureWarning, + stacklevel=1, + ) + fill_method = "pad" + + if limit is lib.no_default: + limit = None + + if "axis" in kwargs: + kwargs["axis"] = self._get_axis_number(kwargs["axis"]) + + # Attempting to match pandas error behavior here + if not isinstance(periods, int): + raise TypeError(f"periods must be an int. got {type(periods)} instead") + + # Attempting to match pandas error behavior here + for dtype in self._get_dtypes(): + if not is_numeric_dtype(dtype): + raise TypeError( + f"cannot perform pct_change on non-numeric column with dtype {dtype}" + ) + + return self.__constructor__( + query_compiler=self._query_compiler.pct_change( + periods=periods, + fill_method=fill_method, + limit=limit, + freq=freq, + **kwargs, + ) + ) + + +@register_base_override("astype") +def astype( + self, + dtype: str | type | pd.Series | dict[str, type], + copy: bool = True, + errors: Literal["raise", "ignore"] = "raise", +) -> pd.DataFrame | pd.Series: + """ + Cast a Modin object to a specified dtype `dtype`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # dtype can be a series, a dict, or a scalar. If it's series or scalar, + # convert it to a dict before passing it to the query compiler. + raise_if_native_pandas_objects(dtype) + from snowflake.snowpark.modin.pandas import Series + + if isinstance(dtype, Series): + dtype = dtype.to_pandas() + if not dtype.index.is_unique: + raise ValueError( + "The new Series of types must have a unique index, i.e. " + + "it must be one-to-one mapping from column names to " + + " their new dtypes." + ) + dtype = dtype.to_dict() + # If we got a series or dict originally, dtype is a dict now. Its keys + # must be column names. + if isinstance(dtype, dict): + # Avoid materializing columns. The query compiler will handle errors where + # dtype dict includes keys that are not in columns. + col_dtypes = dtype + for col_name in col_dtypes: + if col_name not in self._query_compiler.columns: + raise KeyError( + "Only a column name can be used for the key in a dtype mappings argument. " + f"'{col_name}' not found in columns." + ) + else: + # Assume that the dtype is a scalar. + col_dtypes = {column: dtype for column in self._query_compiler.columns} + + # ensure values are pandas dtypes + col_dtypes = {k: pandas_dtype(v) for k, v in col_dtypes.items()} + new_query_compiler = self._query_compiler.astype(col_dtypes, errors=errors) + return self._create_or_update_from_compiler(new_query_compiler, not copy) + + +@register_base_override("drop") +def drop( + self, + labels: IndexLabel = None, + axis: Axis = 0, + index: IndexLabel = None, + columns: IndexLabel = None, + level: Level = None, + inplace: bool = False, + errors: IgnoreRaise = "raise", +) -> BasePandasDataset | None: + """ + Drop specified labels from `BasePandasDataset`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + inplace = validate_bool_kwarg(inplace, "inplace") + if labels is not None: + if index is not None or columns is not None: + raise ValueError("Cannot specify both 'labels' and 'index'/'columns'") + axes = {self._get_axis_number(axis): labels} + elif index is not None or columns is not None: + axes = {0: index, 1: columns} + else: + raise ValueError( + "Need to specify at least one of 'labels', 'index' or 'columns'" + ) + + for axis, labels in axes.items(): + if labels is not None: + if level is not None and not self._query_compiler.has_multiindex(axis=axis): + # Same error as native pandas. + raise AssertionError("axis must be a MultiIndex") + # According to pandas documentation, a tuple will be used as a single + # label and not treated as a list-like. + if not is_list_like(labels) or isinstance(labels, tuple): + axes[axis] = [labels] + + new_query_compiler = self._query_compiler.drop( + index=axes.get(0), columns=axes.get(1), level=level, errors=errors + ) + return self._create_or_update_from_compiler(new_query_compiler, inplace) + + +@register_base_override("__len__") +def __len__(self) -> int: + """ + Return length of info axis. + + Returns + ------- + int + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._query_compiler.get_axis_len(axis=0) + + +@register_base_override("set_axis") +def set_axis( + self, + labels: IndexLabel, + *, + axis: Axis = 0, + copy: bool | NoDefault = no_default, +): + """ + Assign desired index to given axis. + """ + # Behavior based on copy: + # ----------------------------------- + # - In native pandas, copy determines whether to create a copy of the data (not DataFrame). + # - We cannot emulate the native pandas' copy behavior in Snowpark since a copy of only data + # cannot be created -- you can only copy the whole object (DataFrame/Series). + # + # Snowpark behavior: + # ------------------ + # - copy is kept for compatibility with native pandas but is ignored. The user is warned that copy is unused. + # Warn user that copy does not do anything. + if copy is not no_default: + WarningMessage.single_warning( + message=f"{type(self).__name__}.set_axis 'copy' keyword is unused and is ignored." + ) + if labels is None: + raise TypeError("None is not a valid value for the parameter 'labels'.") + + # Determine whether to update self or a copy and perform update. + obj = self.copy() + setattr(obj, axis, labels) + return obj + + +@register_base_override("describe") +def describe( + self, + percentiles: ListLike | None = None, + include: ListLike | Literal["all"] | None = None, + exclude: ListLike | None = None, +) -> BasePandasDataset: + """ + Generate descriptive statistics. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + percentiles = _refine_percentiles(percentiles) + data = self + if self._is_dataframe: + # Upstream modin lacks this check because it defaults to pandas for describing empty dataframes + if len(self.columns) == 0: + raise ValueError("Cannot describe a DataFrame without columns") + + # include/exclude are ignored for Series + if (include is None) and (exclude is None): + # when some numerics are found, keep only numerics + default_include: list[npt.DTypeLike] = [np.number] + default_include.append("datetime") + data = self.select_dtypes(include=default_include) + if len(data.columns) == 0: + data = self + elif include == "all": + if exclude is not None: + raise ValueError("exclude must be None when include is 'all'") + data = self + else: + data = self.select_dtypes( + include=include, + exclude=exclude, + ) + # Upstream modin uses data.empty, but that incurs an extra row count query + if self._is_dataframe and len(data.columns) == 0: + # Match pandas error from concatenating empty list of series descriptions. + raise ValueError("No objects to concatenate") + + return self.__constructor__( + query_compiler=data._query_compiler.describe(percentiles=percentiles) + ) + + +@register_base_override("diff") +def diff(self, periods: int = 1, axis: Axis = 0): + """ + First discrete difference of element. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # We must only accept integer (or float values that are whole numbers) + # for periods. + int_periods = validate_int_kwarg(periods, "periods", float_allowed=True) + axis = self._get_axis_number(axis) + return self.__constructor__( + query_compiler=self._query_compiler.diff(axis=axis, periods=int_periods) + ) + + +@register_base_override("tail") +def tail(self, n: int = 5): + if n == 0: + return self.iloc[0:0] + return self.iloc[-n:] + + +@register_base_override("idxmax") +def idxmax(self, axis=0, skipna=True, numeric_only=False): # noqa: PR01, RT01, D200 + """ + Return index of first occurrence of maximum over requested axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + dtypes = self._get_dtypes() + if ( + axis == 1 + and not numeric_only + and any(not is_numeric_dtype(d) for d in dtypes) + and len(set(dtypes)) > 1 + ): + # For numeric_only=False, if we have any non-numeric dtype, e.g. + # a string type, we need every other column to be of the same type. + # We can't compare two objects of different non-numeric types, e.g. + # a string and a timestamp. + # If we have only numeric data, we can compare columns even if they + # different types, e.g. we can compare an int column to a float + # column. + raise TypeError("'>' not supported for these dtypes") + axis = self._get_axis_number(axis) + return self._reduce_dimension( + self._query_compiler.idxmax(axis=axis, skipna=skipna, numeric_only=numeric_only) + ) + + +@register_base_override("idxmin") +def idxmin(self, axis=0, skipna=True, numeric_only=False): # noqa: PR01, RT01, D200 + """ + Return index of first occurrence of minimum over requested axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + dtypes = self._get_dtypes() + if ( + axis == 1 + and not numeric_only + and any(not is_numeric_dtype(d) for d in dtypes) + and len(set(dtypes)) > 1 + ): + # For numeric_only=False, if we have any non-numeric dtype, e.g. + # a string type, we need every other column to be of the same type. + # We can't compare two objects of different non-numeric types, e.g. + # a string and a timestamp. + # If we have only numeric data, we can compare columns even if they + # different types, e.g. we can compare an int column to a float + # column. + raise TypeError("'<' not supported for these dtypes") + axis = self._get_axis_number(axis) + return self._reduce_dimension( + self._query_compiler.idxmin(axis=axis, skipna=skipna, numeric_only=numeric_only) + ) + + +@register_base_override("__abs__") +def abs(self): # noqa: RT01, D200 + """ + Return a `BasePandasDataset` with absolute numeric value of each element. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.__constructor__(query_compiler=self._query_compiler.unary_op("abs")) + + +@register_base_override("__invert__") +def __invert__(self): + """ + Apply bitwise inverse to each element of the `BasePandasDataset`. + + Returns + ------- + BasePandasDataset + New BasePandasDataset containing bitwise inverse to each value. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.__constructor__(query_compiler=self._query_compiler.invert()) + + +@register_base_override("__neg__") +def __neg__(self): + """ + Change the sign for every value of self. + + Returns + ------- + BasePandasDataset + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.__constructor__(query_compiler=self._query_compiler.unary_op("__neg__")) + + +@register_base_override("rename_axis") +def rename_axis( + self, + mapper=lib.no_default, + *, + index=lib.no_default, + columns=lib.no_default, + axis=0, + copy=None, + inplace=False, +): # noqa: PR01, RT01, D200 + """ + Set the name of the axis for the index or columns. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + axes = {"index": index, "columns": columns} + + if copy is None: + copy = True + + if axis is not None: + axis = self._get_axis_number(axis) + + inplace = validate_bool_kwarg(inplace, "inplace") + + if mapper is not lib.no_default and mapper is not None: + # Use v0.23 behavior if a scalar or list + non_mapper = is_scalar(mapper) or ( + is_list_like(mapper) and not is_dict_like(mapper) + ) + if non_mapper: + return self._set_axis_name(mapper, axis=axis, inplace=inplace) + else: + raise ValueError("Use `.rename` to alter labels with a mapper.") + else: + # Use new behavior. Means that index and/or columns is specified + result = self if inplace else self.copy(deep=copy) + + for axis in range(self.ndim): + v = axes.get(pandas.DataFrame._get_axis_name(axis)) + if v is lib.no_default: + continue + non_mapper = is_scalar(v) or (is_list_like(v) and not is_dict_like(v)) + if non_mapper: + newnames = v + else: + + def _get_rename_function(mapper): + if isinstance(mapper, (dict, BasePandasDataset)): + + def f(x): + if x in mapper: + return mapper[x] + else: + return x + + else: + f = mapper + + return f + + f = _get_rename_function(v) + curnames = self.index.names if axis == 0 else self.columns.names + newnames = [f(name) for name in curnames] + result._set_axis_name(newnames, axis=axis, inplace=True) + if not inplace: + return result + + +@register_base_override("__array_ufunc__") +def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): + """ + Apply the `ufunc` to the `BasePandasDataset`. + + Parameters + ---------- + ufunc : np.ufunc + The NumPy ufunc to apply. + method : str + The method to apply. + *inputs : tuple + The inputs to the ufunc. + **kwargs : dict + Additional keyword arguments. + + Returns + ------- + BasePandasDataset + The result of the ufunc applied to the `BasePandasDataset`. + """ + # Use pandas version of ufunc if it exists + if method != "__call__": + # Return sentinel value NotImplemented + return NotImplemented + from snowflake.snowpark.modin.plugin.utils.numpy_to_pandas import ( + numpy_to_pandas_universal_func_map, + ) + + if ufunc.__name__ in numpy_to_pandas_universal_func_map: + ufunc = numpy_to_pandas_universal_func_map[ufunc.__name__] + return ufunc(self, inputs[1:], kwargs) + # return the sentinel NotImplemented if we do not support this function + return NotImplemented + + +@register_base_override("reindex") +def reindex( + self, + index=None, + columns=None, + copy=True, + **kwargs, +): # noqa: PR01, RT01, D200 + """ + Conform `BasePandasDataset` to new index with optional filling logic. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + if kwargs.get("limit", None) is not None and kwargs.get("method", None) is None: + raise ValueError( + "limit argument only valid if doing pad, backfill or nearest reindexing" + ) + new_query_compiler = None + if index is not None: + if not isinstance(index, pandas.Index) or not index.equals(self.index): + new_query_compiler = self._query_compiler.reindex( + axis=0, labels=index, **kwargs + ) + if new_query_compiler is None: + new_query_compiler = self._query_compiler + final_query_compiler = None + if columns is not None: + if not isinstance(index, pandas.Index) or not columns.equals(self.columns): + final_query_compiler = new_query_compiler.reindex( + axis=1, labels=columns, **kwargs + ) + if final_query_compiler is None: + final_query_compiler = new_query_compiler + return self._create_or_update_from_compiler( + final_query_compiler, inplace=False if copy is None else not copy + ) + + +@register_base_override("all") +def all(self, axis=0, bool_only=None, skipna=True, **kwargs): + """ + Return whether all elements are True, potentially over an axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + validate_bool_kwarg(skipna, "skipna", none_allowed=False) + if axis is not None: + axis = self._get_axis_number(axis) + if bool_only and axis == 0: + if hasattr(self, "dtype"): + ErrorMessage.not_implemented( + "{}.{} does not implement numeric_only.".format( + type(self).__name__, "all" + ) + ) # pragma: no cover + data_for_compute = self[self.columns[self.dtypes == np.bool_]] + return data_for_compute.all( + axis=axis, bool_only=False, skipna=skipna, **kwargs + ) + result = self._reduce_dimension( + self._query_compiler.all( + axis=axis, bool_only=bool_only, skipna=skipna, **kwargs + ) + ) + else: + if bool_only: + raise ValueError(f"Axis must be 0 or 1 (got {axis})") + # Reduce to a scalar if axis is None. + result = self._reduce_dimension( + # FIXME: Judging by pandas docs `**kwargs` serves only compatibility + # purpose and does not affect the result, we shouldn't pass them to the query compiler. + self._query_compiler.all( + axis=0, + bool_only=bool_only, + skipna=skipna, + **kwargs, + ) + ) + if isinstance(result, BasePandasDataset): + return result.all(axis=axis, bool_only=bool_only, skipna=skipna, **kwargs) + return True if result is None else result + + +@register_base_override("any") +def any(self, axis=0, bool_only=None, skipna=True, **kwargs): + """ + Return whether any element is True, potentially over an axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + validate_bool_kwarg(skipna, "skipna", none_allowed=False) + if axis is not None: + axis = self._get_axis_number(axis) + if bool_only and axis == 0: + if hasattr(self, "dtype"): + ErrorMessage.not_implemented( + "{}.{} does not implement numeric_only.".format( + type(self).__name__, "all" + ) + ) # pragma: no cover + data_for_compute = self[self.columns[self.dtypes == np.bool_]] + return data_for_compute.any( + axis=axis, bool_only=False, skipna=skipna, **kwargs + ) + result = self._reduce_dimension( + self._query_compiler.any( + axis=axis, bool_only=bool_only, skipna=skipna, **kwargs + ) + ) + else: + if bool_only: + raise ValueError(f"Axis must be 0 or 1 (got {axis})") + # Reduce to a scalar if axis is None. + result = self._reduce_dimension( + self._query_compiler.any( + axis=0, + bool_only=bool_only, + skipna=skipna, + **kwargs, + ) + ) + if isinstance(result, BasePandasDataset): + return result.any(axis=axis, bool_only=bool_only, skipna=skipna, **kwargs) + return False if result is None else result diff --git a/src/snowflake/snowpark/modin/plugin/extensions/dataframe_extensions.py b/src/snowflake/snowpark/modin/plugin/extensions/dataframe_extensions.py index a2d4710bf66..b167c924452 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/dataframe_extensions.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/dataframe_extensions.py @@ -254,3 +254,36 @@ def cache_result(self, inplace: bool = False) -> Optional[pd.DataFrame]: self._update_inplace(new_qc) else: return pd.DataFrame(query_compiler=new_qc) + + +@register_dataframe_accessor("__array_function__") +@snowpark_pandas_telemetry_method_decorator +def __array_function__(self, func: callable, types: tuple, args: tuple, kwargs: dict): + """ + Apply the `func` to the `BasePandasDataset`. + + Parameters + ---------- + func : np.func + The NumPy func to apply. + types : tuple + The types of the args. + args : tuple + The args to the func. + kwargs : dict + Additional keyword arguments. + + Returns + ------- + BasePandasDataset + The result of the ufunc applied to the `BasePandasDataset`. + """ + from snowflake.snowpark.modin.plugin.utils.numpy_to_pandas import ( + numpy_to_pandas_func_map, + ) + + if func.__name__ in numpy_to_pandas_func_map: + return numpy_to_pandas_func_map[func.__name__](*args, **kwargs) + else: + # per NEP18 we raise NotImplementedError so that numpy can intercept + return NotImplemented # pragma: no cover diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py b/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py index f5e27a44e80..729b6c3bb0a 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py @@ -218,3 +218,36 @@ def cache_result(self, inplace: bool = False) -> Optional[pd.Series]: self._update_inplace(new_qc) else: return pd.Series(query_compiler=new_qc) + + +@register_series_accessor("__array_function__") +@snowpark_pandas_telemetry_method_decorator +def __array_function__(self, func: callable, types: tuple, args: tuple, kwargs: dict): + """ + Apply the `func` to the `BasePandasDataset`. + + Parameters + ---------- + func : np.func + The NumPy func to apply. + types : tuple + The types of the args. + args : tuple + The args to the func. + kwargs : dict + Additional keyword arguments. + + Returns + ------- + BasePandasDataset + The result of the ufunc applied to the `BasePandasDataset`. + """ + from snowflake.snowpark.modin.plugin.utils.numpy_to_pandas import ( + numpy_to_pandas_func_map, + ) + + if func.__name__ in numpy_to_pandas_func_map: + return numpy_to_pandas_func_map[func.__name__](*args, **kwargs) + else: + # per NEP18 we raise NotImplementedError so that numpy can intercept + return NotImplemented # pragma: no cover diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py index 0afea30e29a..a33d7702203 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py @@ -161,6 +161,7 @@ def plot( @register_series_accessor("transform") +@snowpark_pandas_telemetry_method_decorator @series_not_implemented() def transform(self, func, axis=0, *args, **kwargs): # noqa: PR01, RT01, D200 pass # pragma: no cover diff --git a/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py b/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py index 3da545c64b6..f673bf157bf 100644 --- a/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py +++ b/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py @@ -3,8 +3,9 @@ # from typing import Any, Optional, Union +from modin.pandas.base import BasePandasDataset + import snowflake.snowpark.modin.pandas as pd -from snowflake.snowpark.modin.pandas.base import BasePandasDataset from snowflake.snowpark.modin.pandas.utils import is_scalar from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage diff --git a/tests/integ/modin/test_telemetry.py b/tests/integ/modin/test_telemetry.py index 9c24c6b6853..9d3f632dd01 100644 --- a/tests/integ/modin/test_telemetry.py +++ b/tests/integ/modin/test_telemetry.py @@ -398,7 +398,7 @@ def test_telemetry_getitem_setitem(): s = df["a"] assert len(df._query_compiler.snowpark_pandas_api_calls) == 0 assert s._query_compiler.snowpark_pandas_api_calls == [ - {"name": "DataFrame.BasePandasDataset.__getitem__"} + {"name": "DataFrame.__getitem__"} ] df["a"] = 0 df["b"] = 0 @@ -412,12 +412,12 @@ def test_telemetry_getitem_setitem(): # the telemetry log from the connector to validate _ = s[0] data = _extract_snowpark_pandas_telemetry_log_data( - expected_func_name="Series.BasePandasDataset.__getitem__", + expected_func_name="Series.__getitem__", session=s._query_compiler._modin_frame.ordered_dataframe.session, ) assert data["api_calls"] == [ - {"name": "DataFrame.BasePandasDataset.__getitem__"}, - {"name": "Series.BasePandasDataset.__getitem__"}, + {"name": "DataFrame.__getitem__"}, + {"name": "Series.__getitem__"}, ] diff --git a/tests/unit/modin/modin/test_envvars.py b/tests/unit/modin/modin/test_envvars.py index 4f3540a63bf..734b710c809 100644 --- a/tests/unit/modin/modin/test_envvars.py +++ b/tests/unit/modin/modin/test_envvars.py @@ -90,6 +90,32 @@ def test_custom_help(make_custom_envvar): assert "custom var" in make_custom_envvar.get_help() +def _init_doc_module(): + # Put the docs_module on the path + sys.path.append(f"{os.path.dirname(__file__)}") + # We use base.py from upstream modin, so we need to initialize its doc module + # However, since using the environment variable causes an importlib.reload call, + # we need to manually call _inherit_docstrings (https://github.com/modin-project/modin/issues/7138) + from .docs_module import classes + + # As a workaround for upstream modin bugs, we use our own _inherit_docstrings instead of the upstream + # function. We accordingly need to clear the docstring dictionary in testing because + # we manually called the annotation on initializing snowflake.snowpark.modin.pandas. + # snowflake.snowpark.modin.utils._attributes_with_docstrings_replaced.clear() + # TODO: once SNOW-1473605 (modin 0.30.1) is available, use the actual modin DocModule class + snowflake.snowpark.modin.utils._inherit_docstrings( + classes.BasePandasDataset, + overwrite_existing=True, + )(pd.base.BasePandasDataset) + DocModule.put("docs_module") + + +DOC_OVERRIDE_XFAIL_REASON = ( + "test docstring overrides currently cannot override real docstring overrides until " + "modin 0.30.1 is available (SNOW-1473605)" +) + + class TestDocModule: """ Test using a module to replace default docstrings. @@ -99,11 +125,9 @@ class TestDocModule: which we need to fix in upstream modin. """ + @pytest.mark.xfail(strict=True, reason=DOC_OVERRIDE_XFAIL_REASON) def test_overrides(self): - # Put the docs_module on the path - sys.path.append(f"{os.path.dirname(__file__)}") - DocModule.put("docs_module") - + _init_doc_module() # Test for override # TODO(https://github.com/modin-project/modin/issues/7134): Upstream # the BasePandasDataset tests to modin. @@ -144,11 +168,7 @@ def test_overrides(self): def test_not_redefining_classes_modin_issue_7138(self): original_dataframe_class = pd.DataFrame - - # Put the docs_module on the path - sys.path.append(f"{os.path.dirname(__file__)}") - DocModule.put("docs_module") - + _init_doc_module() # Test for override assert ( pd.DataFrame.apply.__doc__ @@ -157,22 +177,20 @@ def test_not_redefining_classes_modin_issue_7138(self): assert pd.DataFrame is original_dataframe_class + @pytest.mark.xfail(strict=True, reason=DOC_OVERRIDE_XFAIL_REASON) def test_base_docstring_override_with_no_dataframe_or_series_class_modin_issue_7113( self, ): # TODO(https://github.com/modin-project/modin/issues/7113): Upstream # this test case to Modin. This test case tests scenario 1 from issue 7113. - sys.path.append(f"{os.path.dirname(__file__)}") - DocModule.put("docs_module_with_just_base") + _init_doc_module() assert pd.base.BasePandasDataset.astype.__doc__ == ( "This is a test of the documentation module for BasePandasDataSet.astype." ) + @pytest.mark.xfail(strict=True, reason=DOC_OVERRIDE_XFAIL_REASON) def test_base_property_not_overridden_in_either_subclass_modin_issue_7113(self): - # Put the docs_module on the path - sys.path.append(f"{os.path.dirname(__file__)}") - DocModule.put("docs_module") - + _init_doc_module() assert ( pd.base.BasePandasDataset.loc.__doc__ == "This is a test of the documentation module for BasePandasDataset.loc." From 340d5069e79105c70a46aec9462b65bc3b496fc1 Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Tue, 20 Aug 2024 14:09:58 -0700 Subject: [PATCH 02/14] remove base_not_implemented.py --- .../snowpark/modin/pandas/__init__.py | 2 - .../plugin/extensions/base_not_implemented.py | 414 ------------------ tests/unit/modin/modin/test_envvars.py | 2 +- 3 files changed, 1 insertion(+), 417 deletions(-) delete mode 100644 src/snowflake/snowpark/modin/plugin/extensions/base_not_implemented.py diff --git a/src/snowflake/snowpark/modin/pandas/__init__.py b/src/snowflake/snowpark/modin/pandas/__init__.py index 02d8c950cce..274d5b3763f 100644 --- a/src/snowflake/snowpark/modin/pandas/__init__.py +++ b/src/snowflake/snowpark/modin/pandas/__init__.py @@ -147,8 +147,6 @@ _PD_EXTENSIONS_: dict = {} -import snowflake.snowpark.modin.plugin.extensions.base_overrides # isort: skip # noqa: E402,F401 -import snowflake.snowpark.modin.plugin.extensions.base_not_implemented # isort: skip # noqa: E402,F401 import snowflake.snowpark.modin.plugin.extensions.pd_extensions as pd_extensions # isort: skip # noqa: E402,F401 import snowflake.snowpark.modin.plugin.extensions.pd_overrides # isort: skip # noqa: E402,F401 from snowflake.snowpark.modin.plugin.extensions.pd_overrides import ( # isort: skip # noqa: E402,F401 diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_not_implemented.py b/src/snowflake/snowpark/modin/plugin/extensions/base_not_implemented.py deleted file mode 100644 index aaf8b86494d..00000000000 --- a/src/snowflake/snowpark/modin/plugin/extensions/base_not_implemented.py +++ /dev/null @@ -1,414 +0,0 @@ -# -# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. -# - -""" -The functions in this file are not implemented in Snowpark pandas. In the future, they -should raise NotImplementedError at the query compiler layer, but doing so requires a longer-term -effort. - -We currently test unsupported APIs under tests/unit/modin/test_unsupported.py, which does not initialize -a session. As such, many frontend methods have additional query compiler API calls that would have to -be mocked before the NotImplementedError can appropriately be raised. -""" -from __future__ import annotations - -import pickle as pkl -from typing import Any - -import numpy as np -import pandas -from modin.pandas.base import BasePandasDataset -from pandas._libs import lib -from pandas._libs.lib import no_default -from pandas._typing import ( - Axis, - CompressionOptions, - StorageOptions, - TimedeltaConvertibleTypes, -) - -from snowflake.snowpark.modin.pandas.api.extensions import ( - register_dataframe_accessor, - register_series_accessor, -) -from snowflake.snowpark.modin.plugin._internal.telemetry import ( - snowpark_pandas_telemetry_method_decorator, -) -from snowflake.snowpark.modin.plugin.utils.error_message import base_not_implemented - - -def register_base_not_implemented(): - def decorator(base_method: Any): - func = snowpark_pandas_telemetry_method_decorator( - base_not_implemented()(base_method) - ) - register_series_accessor(base_method.__name__)(func) - register_dataframe_accessor(base_method.__name__)(func) - return func - - return decorator - - -@register_base_not_implemented() -def align( - self, - other, - join="outer", - axis=None, - level=None, - copy=None, - fill_value=None, - method=lib.no_default, - limit=lib.no_default, - fill_axis=lib.no_default, - broadcast_axis=lib.no_default, -): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def asof(self, where, subset=None): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def at_time(self, time, asof=False, axis=None): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def between_time( - self: BasePandasDataset, - start_time, - end_time, - inclusive: str | None = None, - axis=None, -): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def bool(self): # noqa: RT01, D200 - pass - - -@register_base_not_implemented() -def clip( - self, lower=None, upper=None, axis=None, inplace=False, *args, **kwargs -): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def combine(self, other, func, fill_value=None, **kwargs): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def combine_first(self, other): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def droplevel(self, level, axis=0): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def explode(self, column, ignore_index: bool = False): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def ewm( - self, - com: float | None = None, - span: float | None = None, - halflife: float | TimedeltaConvertibleTypes | None = None, - alpha: float | None = None, - min_periods: int | None = 0, - adjust: bool = True, - ignore_na: bool = False, - axis: Axis = 0, - times: str | np.ndarray | BasePandasDataset | None = None, - method: str = "single", -) -> pandas.core.window.ewm.ExponentialMovingWindow: # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def filter( - self, items=None, like=None, regex=None, axis=None -): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def infer_objects( - self, copy: bool | None = None -) -> BasePandasDataset: # pragma: no cover # noqa: RT01, D200 - pass - - -@register_base_not_implemented() -def kurt(self, axis=no_default, skipna=True, numeric_only=False, **kwargs): - pass - - -@register_base_not_implemented() -def kurtosis(self, axis=no_default, skipna=True, numeric_only=False, **kwargs): - pass - - -@register_base_not_implemented() -def mode(self, axis=0, numeric_only=False, dropna=True): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def pipe(self, func, *args, **kwargs): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def pop(self, item): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def reindex_like( - self, other, method=None, copy=True, limit=None, tolerance=None -): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def reorder_levels(self, order, axis=0): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def sem( - self, - axis: Axis | None = None, - skipna: bool = True, - ddof: int = 1, - numeric_only=False, - **kwargs, -): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def set_flags( - self, *, copy: bool = False, allows_duplicate_labels: bool | None = None -): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def swapaxes(self, axis1, axis2, copy=True): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def swaplevel(self, i=-2, j=-1, axis=0): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def to_clipboard( - self, excel=True, sep=None, **kwargs -): # pragma: no cover # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def to_excel( - self, - excel_writer, - sheet_name="Sheet1", - na_rep="", - float_format=None, - columns=None, - header=True, - index=True, - index_label=None, - startrow=0, - startcol=0, - engine=None, - merge_cells=True, - encoding=no_default, - inf_rep="inf", - verbose=no_default, - freeze_panes=None, - storage_options: StorageOptions = None, -): # pragma: no cover # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def to_hdf( - self, path_or_buf, key, format="table", **kwargs -): # pragma: no cover # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def to_json( - self, - path_or_buf=None, - orient=None, - date_format=None, - double_precision=10, - force_ascii=True, - date_unit="ms", - default_handler=None, - lines=False, - compression="infer", - index=True, - indent=None, - storage_options: StorageOptions = None, -): # pragma: no cover # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def to_latex( - self, - buf=None, - columns=None, - col_space=None, - header=True, - index=True, - na_rep="NaN", - formatters=None, - float_format=None, - sparsify=None, - index_names=True, - bold_rows=False, - column_format=None, - longtable=None, - escape=None, - encoding=None, - decimal=".", - multicolumn=None, - multicolumn_format=None, - multirow=None, - caption=None, - label=None, - position=None, -): # pragma: no cover # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def to_markdown( - self, - buf=None, - mode: str = "wt", - index: bool = True, - storage_options: StorageOptions = None, - **kwargs, -): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def to_pickle( - self, - path, - compression: CompressionOptions = "infer", - protocol: int = pkl.HIGHEST_PROTOCOL, - storage_options: StorageOptions = None, -): # pragma: no cover # noqa: PR01, D200 - pass - - -@register_base_not_implemented() -def to_string( - self, - buf=None, - columns=None, - col_space=None, - header=True, - index=True, - na_rep="NaN", - formatters=None, - float_format=None, - sparsify=None, - index_names=True, - justify=None, - max_rows=None, - min_rows=None, - max_cols=None, - show_dimensions=False, - decimal=".", - line_width=None, - max_colwidth=None, - encoding=None, -): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def to_sql( - self, - name, - con, - schema=None, - if_exists="fail", - index=True, - index_label=None, - chunksize=None, - dtype=None, - method=None, -): # noqa: PR01, D200 - pass - - -@register_base_not_implemented() -def to_timestamp( - self, freq=None, how="start", axis=0, copy=True -): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def to_xarray(self): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def truncate( - self, before=None, after=None, axis=None, copy=True -): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def tz_convert(self, tz, axis=0, level=None, copy=True): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def tz_localize( - self, tz, axis=0, level=None, copy=True, ambiguous="raise", nonexistent="raise" -): # noqa: PR01, RT01, D200 - pass - - -@register_base_not_implemented() -def __array_wrap__(self, result, context=None): - pass - - -@register_base_not_implemented() -def __finalize__(self, other, method=None, **kwargs): - pass - - -@register_base_not_implemented() -def __sizeof__(self): - pass diff --git a/tests/unit/modin/modin/test_envvars.py b/tests/unit/modin/modin/test_envvars.py index 734b710c809..a3f8fca324a 100644 --- a/tests/unit/modin/modin/test_envvars.py +++ b/tests/unit/modin/modin/test_envvars.py @@ -102,7 +102,7 @@ def _init_doc_module(): # function. We accordingly need to clear the docstring dictionary in testing because # we manually called the annotation on initializing snowflake.snowpark.modin.pandas. # snowflake.snowpark.modin.utils._attributes_with_docstrings_replaced.clear() - # TODO: once SNOW-1473605 (modin 0.30.1) is available, use the actual modin DocModule class + # TODO: once modin 0.31.0 is available, use the actual modin DocModule class snowflake.snowpark.modin.utils._inherit_docstrings( classes.BasePandasDataset, overwrite_existing=True, From 899d90d47e566c29c00d540bbb437a490edf56e5 Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Tue, 20 Aug 2024 15:40:53 -0700 Subject: [PATCH 03/14] fix any/all name mangling --- .../snowpark/modin/plugin/extensions/base_overrides.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py index afde17548a8..f4cb3ba8ad9 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py @@ -2212,7 +2212,8 @@ def reindex( @register_base_override("all") -def all(self, axis=0, bool_only=None, skipna=True, **kwargs): +# Renamed to _all to avoid conflict with builtin python function all (override still has the correct name) +def _all(self, axis=0, bool_only=None, skipna=True, **kwargs): """ Return whether all elements are True, potentially over an axis. """ @@ -2256,7 +2257,8 @@ def all(self, axis=0, bool_only=None, skipna=True, **kwargs): @register_base_override("any") -def any(self, axis=0, bool_only=None, skipna=True, **kwargs): +# Renamed to _any to avoid conflict with builtin python function any (override still has the correct name) +def _any(self, axis=0, bool_only=None, skipna=True, **kwargs): """ Return whether any element is True, potentially over an axis. """ From 3d9c5c3f3f9c6389399868a74c909df3f98244d3 Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Tue, 20 Aug 2024 16:56:14 -0700 Subject: [PATCH 04/14] actually remove base.py and fix to_csv na_rep --- .../snowpark/modin/plugin/extensions/base_overrides.py | 2 +- src/snowflake/snowpark/modin/plugin/extensions/index.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py index f4cb3ba8ad9..9acabef9317 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py @@ -1563,7 +1563,7 @@ def to_csv( self, path_or_buf=None, sep=",", - na_rep=",", + na_rep="", float_format=None, columns=None, header=True, diff --git a/src/snowflake/snowpark/modin/plugin/extensions/index.py b/src/snowflake/snowpark/modin/plugin/extensions/index.py index 95fcf684924..808489b8917 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/index.py @@ -29,6 +29,7 @@ import modin import numpy as np import pandas as native_pd +from modin.pandas.base import BasePandasDataset from pandas import get_option from pandas._libs import lib from pandas._libs.lib import is_list_like, is_scalar @@ -48,7 +49,6 @@ from pandas.core.dtypes.inference import is_hashable from snowflake.snowpark.modin.pandas import DataFrame, Series -from snowflake.snowpark.modin.pandas.base import BasePandasDataset from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta from snowflake.snowpark.modin.plugin._internal.timestamp_utils import DateTimeOrigin From 72f76c2ecb34f61e7d20de67e0db4a007f6037d9 Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Wed, 21 Aug 2024 16:14:21 -0700 Subject: [PATCH 05/14] fix index and binary op stuff --- .../modin/plugin/extensions/base_overrides.py | 73 +++++++++++++------ 1 file changed, 50 insertions(+), 23 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py index 9acabef9317..456ffde2dfa 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py @@ -579,7 +579,7 @@ def _binary_op( self, op: str, other: BasePandasDataset, - axis: Axis, + axis: Axis = None, level: Level | None = None, fill_value: float | None = None, **kwargs: Any, @@ -609,6 +609,14 @@ def _binary_op( modin.pandas.BasePandasDataset Result of binary operation. """ + # In upstream modin, _axis indicates the operator will use the default axis + if kwargs.pop("_axis", None) is None: + if axis is not None: + axis = self._get_axis_number(axis) + else: + axis = 1 + else: + axis = 0 # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset raise_if_native_pandas_objects(other) axis = self._get_axis_number(axis) @@ -898,28 +906,6 @@ def _to_series_list(self, index: pd.Index) -> list[pd.Series]: raise Exception("invalid index: " + str(index)) -@register_base_override("_set_index") -def _set_index(self, new_index: Axes) -> None: - """ - Set the index for this DataFrame. - - Parameters - ---------- - new_index : pandas.Index - The new index to set this. - """ - # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset - self._update_inplace( - new_query_compiler=self._query_compiler.set_index( - [s._query_compiler for s in self._to_series_list(ensure_index(new_index))] - ) - ) - - -index = property(lambda self: self._query_compiler.index, _set_index) -register_base_override("index")(index) - - @register_base_override("shift") def shift( self, @@ -2297,3 +2283,44 @@ def _any(self, axis=0, bool_only=None, skipna=True, **kwargs): if isinstance(result, BasePandasDataset): return result.any(axis=axis, bool_only=bool_only, skipna=skipna, **kwargs) return False if result is None else result + + +def _get_index(self): + """ + Get the index for this DataFrame. + + Returns + ------- + pandas.Index + The union of all indexes across the partitions. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + from snowflake.snowpark.modin.plugin.extensions.index import Index + + if self._query_compiler.is_multiindex(): + # Lazy multiindex is not supported + return self._query_compiler.index + + idx = Index(query_compiler=self._query_compiler) + idx._set_parent(self) + return idx + + +def _set_index(self, new_index: Axes) -> None: + """ + Set the index for this DataFrame. + + Parameters + ---------- + new_index : pandas.Index + The new index to set this. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + self._update_inplace( + new_query_compiler=self._query_compiler.set_index( + [s._query_compiler for s in self._to_series_list(ensure_index(new_index))] + ) + ) + + +register_base_override("index")(property(_get_index, _set_index)) From 87d75fa305b8cabd2e758aa6d1024e2243a2e861 Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Thu, 22 Aug 2024 14:52:46 -0700 Subject: [PATCH 06/14] workarounds for docstring inheritance --- .../snowpark/modin/plugin/__init__.py | 26 +++++++++---------- .../modin/plugin/docstrings/dataframe.py | 4 ++- .../modin/plugin/docstrings/series.py | 4 ++- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/__init__.py b/src/snowflake/snowpark/modin/plugin/__init__.py index b46d11f7f8e..90445892497 100644 --- a/src/snowflake/snowpark/modin/plugin/__init__.py +++ b/src/snowflake/snowpark/modin/plugin/__init__.py @@ -67,20 +67,18 @@ # Upstream Modin has issues with certain docstring generation edge cases, so we should use our version instead _inherit_docstrings = snowflake.snowpark.modin.utils._inherit_docstrings -_inherit_docstrings( - docstrings.base.BasePandasDataset, - overwrite_existing=True, -)(modin.pandas.base.BasePandasDataset) - -_inherit_docstrings( - docstrings.series_utils.StringMethods, - overwrite_existing=True, -)(modin.pandas.series_utils.StringMethods) - -_inherit_docstrings( - docstrings.series_utils.CombinedDatetimelikeProperties, - overwrite_existing=True, -)(modin.pandas.series_utils.DatetimeProperties) +inherit_modules = [ + (docstrings.base.BasePandasDataset, modin.pandas.base.BasePandasDataset), + (docstrings.series_utils.StringMethods, modin.pandas.series_utils.StringMethods), + ( + docstrings.series_utils.CombinedDatetimelikeProperties, + modin.pandas.series_utils.DatetimeProperties, + ), +] + +for (doc_module, target_object) in inherit_modules: + _inherit_docstrings(doc_module, overwrite_existing=True)(target_object) + # Don't warn the user about our internal usage of private preview pivot # features. The user should have already been warned that Snowpark pandas diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py b/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py index 047b4592068..f0c02aa0e65 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py @@ -13,6 +13,8 @@ _shared_docs, ) +from .base import BasePandasDataset + _doc_binary_op_kwargs = {"returns": "BasePandasDataset", "left": "BasePandasDataset"} @@ -49,7 +51,7 @@ } -class DataFrame: +class DataFrame(BasePandasDataset): """ Snowpark pandas representation of ``pandas.DataFrame`` with a lazily-evaluated relational dataset. diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/series.py b/src/snowflake/snowpark/modin/plugin/docstrings/series.py index 6e48a7e57f3..4878c82635a 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/series.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/series.py @@ -15,6 +15,8 @@ ) from snowflake.snowpark.modin.utils import _create_operator_docstring +from .base import BasePandasDataset + _shared_doc_kwargs = { "axes": "index", "klass": "Series", @@ -35,7 +37,7 @@ } -class Series: +class Series(BasePandasDataset): """ Snowpark pandas representation of `pandas.Series` with a lazily-evaluated relational dataset. From 880aa0573e8abf4a7952358106c8033677c65c6f Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Thu, 22 Aug 2024 16:10:54 -0700 Subject: [PATCH 07/14] fix bizarre bug where np.dtype('bool') != bool --- .../snowpark/modin/plugin/extensions/base_overrides.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py index 456ffde2dfa..36e55115dc5 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py @@ -1393,11 +1393,11 @@ def __getitem__(self, key): # If the object is a boolean list-like object, use .loc[key] to filter index. # The if statement is structured this way to avoid calling dtype and reduce query count. if isinstance(key, pd.Series): - if key.dtype == bool: + if pandas.api.types.is_bool_dtype(key.dtype): return self.loc[key] elif is_list_like(key): if hasattr(key, "dtype"): - if key.dtype == bool: + if pandas.api.types.is_bool_dtype(key.dtype): return self.loc[key] if (all(is_bool(k) for k in key)) and len(key) > 0: return self.loc[key] From 3be10dcb34b84741d532109d66689cbb5ce388c5 Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Mon, 26 Aug 2024 12:52:41 -0700 Subject: [PATCH 08/14] add comments on all overrides --- .../modin/plugin/extensions/base_overrides.py | 768 ++++++++---------- 1 file changed, 348 insertions(+), 420 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py index 36e55115dc5..50e4aea9c8b 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py @@ -6,6 +6,9 @@ Methods defined on BasePandasDataset that are overridden in Snowpark pandas. Adding a method to this file should be done with discretion, and only when relevant changes cannot be made to the query compiler or upstream frontend to accommodate Snowpark pandas. + +If you must override a method in this file, please add a comment describing why it must be overridden, +and if possible, whether this can be reconciled with upstream Modin. """ from __future__ import annotations @@ -424,8 +427,25 @@ def __finalize__(self, other, method=None, **kwargs): # === OVERRIDDEN METHODS === - - +# The below methods have their frontend implementations overridden compared to the version present +# in base.py. This is usually for one of the following reasons: +# 1. The underlying QC interface used differs from that of modin. Notably, this applies to aggregate +# and binary operations; further work is needed to refactor either our implementation or upstream +# modin's implementation. +# 2. Modin performs extra validation queries that perform extra SQL queries. Some of these are already +# fixed on main; see https://github.com/modin-project/modin/issues/7340 for details. +# 3. Upstream Modin defaults to pandas for some edge cases. Defaulting to pandas at the query compiler +# layer is acceptable because we can force the method to raise NotImplementedError, but if a method +# defaults at the frontend, Modin raises a warning and performs the operation by coercing the +# dataset to a native pandas object. Removing these is tracked by +# https://github.com/modin-project/modin/issues/7104 +# 4. Snowpark pandas uses different default arguments from modin. This occurs if some parameters are +# only partially supported (like `numeric_only=True` for `skew`), but this behavior should likewise +# be revisited. + +# `aggregate` for axis=1 is performed as a call to `BasePandasDataset.apply` in upstream Modin, +# which is unacceptable for Snowpark pandas. Upstream Modin should be changed to allow the query +# compiler or a different layer to control dispatch. @register_base_override("aggregate") def aggregate( self, func: AggFuncType = None, axis: Axis | None = 0, *args: Any, **kwargs: Any @@ -570,10 +590,212 @@ def aggregate( return result +# `agg` is an alias of `aggregate`. agg = aggregate register_base_override("agg")(agg) +# `_agg_helper` is not defined in modin, and used by Snowpark pandas to do extra validation. +@register_base_override("_agg_helper") +def _agg_helper( + self, + func: str, + skipna: bool = True, + axis: int | None | NoDefault = no_default, + numeric_only: bool = False, + **kwargs: Any, +): + if not self._is_dataframe and numeric_only and not is_numeric_dtype(self.dtype): + # Series aggregations on non-numeric data do not support numeric_only: + # https://github.com/pandas-dev/pandas/blob/cece8c6579854f6b39b143e22c11cac56502c4fd/pandas/core/series.py#L6358 + raise TypeError( + f"Series.{func} does not allow numeric_only=True with non-numeric dtypes." + ) + axis = self._get_axis_number(axis) + numeric_only = validate_bool_kwarg(numeric_only, "numeric_only", none_allowed=True) + skipna = validate_bool_kwarg(skipna, "skipna", none_allowed=False) + agg_kwargs: dict[str, Any] = { + "numeric_only": numeric_only, + "skipna": skipna, + } + agg_kwargs.update(kwargs) + return self.aggregate(func=func, axis=axis, **agg_kwargs) + + +# See _agg_helper +@register_base_override("count") +def count( + self, + axis: Axis | None = 0, + numeric_only: bool = False, +): + """ + Count non-NA cells for `BasePandasDataset`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._agg_helper( + func="count", + axis=axis, + numeric_only=numeric_only, + ) + + +# See _agg_helper +@register_base_override("max") +def max( + self, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs: Any, +): + """ + Return the maximum of the values over the requested axis. + """ + return self._agg_helper( + func="max", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +# See _agg_helper +@register_base_override("min") +def min( + self, + axis: Axis | None | NoDefault = no_default, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, +): + """ + Return the minimum of the values over the requested axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._agg_helper( + func="min", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +# See _agg_helper +@register_base_override("mean") +def mean( + self, + axis: Axis | None | NoDefault = no_default, + skipna: bool = True, + numeric_only: bool = False, + **kwargs: Any, +): + """ + Return the mean of the values over the requested axis. + """ + return self._agg_helper( + func="mean", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +# See _agg_helper +@register_base_override("median") +def median( + self, + axis: Axis | None | NoDefault = no_default, + skipna: bool = True, + numeric_only: bool = False, + **kwargs: Any, +): + """ + Return the mean of the values over the requested axis. + """ + return self._agg_helper( + func="median", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +# See _agg_helper +@register_base_override("std") +def std( + self, + axis: Axis | None = None, + skipna: bool = True, + ddof: int = 1, + numeric_only: bool = False, + **kwargs, +): + """ + Return sample standard deviation over requested axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + kwargs.update({"ddof": ddof}) + return self._agg_helper( + func="std", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +# See _agg_helper +@register_base_override("sum") +def sum( + self, + axis: Axis | None = None, + skipna: bool = True, + numeric_only: bool = False, + min_count: int = 0, + **kwargs: Any, +): + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + min_count = validate_int_kwarg(min_count, "min_count") + kwargs.update({"min_count": min_count}) + return self._agg_helper( + func="sum", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +# See _agg_helper +@register_base_override("var") +def var( + self, + axis: Axis | None = None, + skipna: bool = True, + ddof: int = 1, + numeric_only: bool = False, + **kwargs: Any, +): + """ + Return unbiased variance over requested axis. + """ + kwargs.update({"ddof": ddof}) + return self._agg_helper( + func="var", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +# Modin does not provide `MultiIndex` support and will default to pandas when `level` is specified, +# and allows binary ops against native pandas objects that Snowpark pandas prohibits. @register_base_override("_binary_op") def _binary_op( self, @@ -653,6 +875,9 @@ def _binary_op( ) +# Current Modin does not use _dropna and instead defines `dropna` directly, but Snowpark pandas +# Series/DF still do. Snowpark pandas still needs to add support for the `ignore_index` parameter +# (added in pandas 2.0), and should be able to refactor to remove this override. @register_base_override("_dropna") def _dropna( self, @@ -699,6 +924,8 @@ def _dropna( return self._create_or_update_from_compiler(new_query_compiler, inplace) +# Snowpark pandas uses `self_is_series` instead of `squeeze_self` and `squeeze_value` to determine +# the shape of `self` and `value`. Further work is needed to reconcile these two approaches. @register_base_override("fillna") def fillna( self, @@ -788,6 +1015,7 @@ def fillna( return self._create_or_update_from_compiler(new_query_compiler, inplace) +# Snowpark pandas passes the query compiler object from a BasePandasDataset, which Modin does not do. @register_base_override("isin") def isin( self, values: BasePandasDataset | ListLike | dict[Hashable, ListLike] @@ -812,6 +1040,9 @@ def isin( return self.__constructor__(query_compiler=self._query_compiler.isin(values=values)) +# Snowpark pandas uses the single `quantiles_along_axis0` query compiler method, while upstream +# Modin splits this into `quantile_for_single_value` and `quantile_for_list_of_values` calls. +# It should be possible to merge those two functions upstream and reconcile the implementations. @register_base_override("quantile") def quantile( self, @@ -885,6 +1116,9 @@ def quantile( return result +# Current Modin does not define this method. Snowpark pandas currently only uses it in +# `DataFrame.set_index`. Modin does not support MultiIndex, or have its own lazy index class, +# so we may need to keep this method for the foreseeable future. @register_base_override("_to_series_list") def _to_series_list(self, index: pd.Index) -> list[pd.Series]: """ @@ -906,6 +1140,7 @@ def _to_series_list(self, index: pd.Index) -> list[pd.Series]: raise Exception("invalid index: " + str(index)) +# Upstream modin defaults to pandas when `suffix` is provided. @register_base_override("shift") def shift( self, @@ -939,6 +1174,8 @@ def shift( return self._create_or_update_from_compiler(new_query_compiler, False) +# Snowpark pandas supports only `numeric_only=True`, which is not the default value of the argument, +# so we have this overridden. We should revisit this behavior. @register_base_override("skew") def skew( self, @@ -954,215 +1191,25 @@ def skew( return self._stat_operation("skew", axis, skipna, numeric_only, **kwargs) -@register_base_override("_agg_helper") -def _agg_helper( - self, - func: str, - skipna: bool = True, - axis: int | None | NoDefault = no_default, - numeric_only: bool = False, - **kwargs: Any, -): - if not self._is_dataframe and numeric_only and not is_numeric_dtype(self.dtype): - # Series aggregations on non-numeric data do not support numeric_only: - # https://github.com/pandas-dev/pandas/blob/cece8c6579854f6b39b143e22c11cac56502c4fd/pandas/core/series.py#L6358 - raise TypeError( - f"Series.{func} does not allow numeric_only=True with non-numeric dtypes." - ) - axis = self._get_axis_number(axis) - numeric_only = validate_bool_kwarg(numeric_only, "numeric_only", none_allowed=True) - skipna = validate_bool_kwarg(skipna, "skipna", none_allowed=False) - agg_kwargs: dict[str, Any] = { - "numeric_only": numeric_only, - "skipna": skipna, - } - agg_kwargs.update(kwargs) - return self.aggregate(func=func, axis=axis, **agg_kwargs) - - -@register_base_override("count") -def count( +@register_base_override("resample") +def resample( self, - axis: Axis | None = 0, - numeric_only: bool = False, -): + rule, + axis: Axis = lib.no_default, + closed: str | None = None, + label: str | None = None, + convention: str = "start", + kind: str | None = None, + on: Level = None, + level: Level = None, + origin: str | TimestampConvertibleTypes = "start_day", + offset: TimedeltaConvertibleTypes | None = None, + group_keys=no_default, +): # noqa: PR01, RT01, D200 """ - Count non-NA cells for `BasePandasDataset`. + Resample time-series data. """ - # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset - return self._agg_helper( - func="count", - axis=axis, - numeric_only=numeric_only, - ) - - -@register_base_override("max") -def max( - self, - axis: Axis | None = 0, - skipna: bool = True, - numeric_only: bool = False, - **kwargs: Any, -): - """ - Return the maximum of the values over the requested axis. - """ - return self._agg_helper( - func="max", - axis=axis, - skipna=skipna, - numeric_only=numeric_only, - **kwargs, - ) - - -@register_base_override("min") -def min( - self, - axis: Axis | None | NoDefault = no_default, - skipna: bool = True, - numeric_only: bool = False, - **kwargs, -): - """ - Return the minimum of the values over the requested axis. - """ - # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset - return self._agg_helper( - func="min", - axis=axis, - skipna=skipna, - numeric_only=numeric_only, - **kwargs, - ) - - -@register_base_override("mean") -def mean( - self, - axis: Axis | None | NoDefault = no_default, - skipna: bool = True, - numeric_only: bool = False, - **kwargs: Any, -): - """ - Return the mean of the values over the requested axis. - """ - return self._agg_helper( - func="mean", - axis=axis, - skipna=skipna, - numeric_only=numeric_only, - **kwargs, - ) - - -@register_base_override("median") -def median( - self, - axis: Axis | None | NoDefault = no_default, - skipna: bool = True, - numeric_only: bool = False, - **kwargs: Any, -): - """ - Return the mean of the values over the requested axis. - """ - return self._agg_helper( - func="median", - axis=axis, - skipna=skipna, - numeric_only=numeric_only, - **kwargs, - ) - - -@register_base_override("std") -def std( - self, - axis: Axis | None = None, - skipna: bool = True, - ddof: int = 1, - numeric_only: bool = False, - **kwargs, -): - """ - Return sample standard deviation over requested axis. - """ - # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset - kwargs.update({"ddof": ddof}) - return self._agg_helper( - func="std", - axis=axis, - skipna=skipna, - numeric_only=numeric_only, - **kwargs, - ) - - -@register_base_override("sum") -def sum( - self, - axis: Axis | None = None, - skipna: bool = True, - numeric_only: bool = False, - min_count: int = 0, - **kwargs: Any, -): - # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset - min_count = validate_int_kwarg(min_count, "min_count") - kwargs.update({"min_count": min_count}) - return self._agg_helper( - func="sum", - axis=axis, - skipna=skipna, - numeric_only=numeric_only, - **kwargs, - ) - - -@register_base_override("var") -def var( - self, - axis: Axis | None = None, - skipna: bool = True, - ddof: int = 1, - numeric_only: bool = False, - **kwargs: Any, -): - """ - Return unbiased variance over requested axis. - """ - kwargs.update({"ddof": ddof}) - return self._agg_helper( - func="var", - axis=axis, - skipna=skipna, - numeric_only=numeric_only, - **kwargs, - ) - - -@register_base_override("resample") -def resample( - self, - rule, - axis: Axis = lib.no_default, - closed: str | None = None, - label: str | None = None, - convention: str = "start", - kind: str | None = None, - on: Level = None, - level: Level = None, - origin: str | TimestampConvertibleTypes = "start_day", - offset: TimedeltaConvertibleTypes | None = None, - group_keys=no_default, -): # noqa: PR01, RT01, D200 - """ - Resample time-series data. - """ - from snowflake.snowpark.modin.pandas.resample import Resampler + from snowflake.snowpark.modin.pandas.resample import Resampler if axis is not lib.no_default: # pragma: no cover axis = self._get_axis_number(axis) @@ -1199,6 +1246,9 @@ def resample( ) +# Snowpark pandas needs to return a custom Expanding window object. We cannot use the +# extensions module for this at the moment because modin performs a relative import of +# `from .window import Expanding`. @register_base_override("expanding") def expanding(self, min_periods=1, axis=0, method="single"): # noqa: PR01, RT01, D200 """ @@ -1236,6 +1286,7 @@ def expanding(self, min_periods=1, axis=0, method="single"): # noqa: PR01, RT01 ) +# Same as Expanding: Snowpark pandas needs to return a custmo Window object. @register_base_override("rolling") def rolling( self, @@ -1305,6 +1356,7 @@ def rolling( ) +# Snowpark pandas uses a custom indexer object for all indexing methods. @register_base_override("iloc") @property def iloc(self): @@ -1318,6 +1370,7 @@ def iloc(self): return _iLocIndexer(self) +# Snowpark pandas uses a custom indexer object for all indexing methods. @register_base_override("loc") @property def loc(self): @@ -1332,6 +1385,7 @@ def loc(self): return _LocIndexer(self) +# Snowpark pandas uses a custom indexer object for all indexing methods. @register_base_override("iat") @property def iat(self, axis=None): # noqa: PR01, RT01, D200 @@ -1344,6 +1398,7 @@ def iat(self, axis=None): # noqa: PR01, RT01, D200 return _iAtIndexer(self) +# Snowpark pandas uses a custom indexer object for all indexing methods. @register_base_override("at") @property def at(self, axis=None): # noqa: PR01, RT01, D200 @@ -1356,6 +1411,8 @@ def at(self, axis=None): # noqa: PR01, RT01, D200 return _AtIndexer(self) +# Snowpark pandas performs different dispatch logic; some changes may need to be upstreamed +# to fix edge case indexing behaviors. @register_base_override("__getitem__") def __getitem__(self, key): """ @@ -1406,6 +1463,7 @@ def __getitem__(self, key): return self.loc[:, key] +# Snowpark pandas does extra argument validation, which may need to be upstreamed. @register_base_override("sort_values") def sort_values( self, @@ -1483,6 +1541,8 @@ def sort_values( return self._create_or_update_from_compiler(result, inplace) +# Modin does not define `where` on BasePandasDataset, and defaults to pandas at the frontend +# layer for Series. @register_base_override("where") def where( self, @@ -1544,61 +1604,8 @@ def where( return self._create_or_update_from_compiler(query_compiler, inplace) -@register_base_override("to_csv") -def to_csv( - self, - path_or_buf=None, - sep=",", - na_rep="", - float_format=None, - columns=None, - header=True, - index=True, - index_label=None, - mode="w", - encoding=None, - compression="infer", - quoting=None, - quotechar='"', - lineterminator=None, - chunksize=None, - date_format=None, - doublequote=True, - escapechar=None, - decimal=".", - errors: str = "strict", - storage_options: StorageOptions = None, -): # pragma: no cover - from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( - FactoryDispatcher, - ) - - return FactoryDispatcher.to_csv( - self._query_compiler, - path_or_buf=path_or_buf, - sep=sep, - na_rep=na_rep, - float_format=float_format, - columns=columns, - header=header, - index=index, - index_label=index_label, - mode=mode, - encoding=encoding, - compression=compression, - quoting=quoting, - quotechar=quotechar, - lineterminator=lineterminator, - chunksize=chunksize, - date_format=date_format, - doublequote=doublequote, - escapechar=escapechar, - decimal=decimal, - errors=errors, - storage_options=storage_options, - ) - - +# Snowpark pandas performs extra argument validation, some of which should be pushed down +# to the QC layer. @register_base_override("mask") def mask( self, @@ -1660,6 +1667,63 @@ def mask( return self._create_or_update_from_compiler(query_compiler, inplace) +# Snowpark pandas uses a custom I/O dispatcher class. +@register_base_override("to_csv") +def to_csv( + self, + path_or_buf=None, + sep=",", + na_rep="", + float_format=None, + columns=None, + header=True, + index=True, + index_label=None, + mode="w", + encoding=None, + compression="infer", + quoting=None, + quotechar='"', + lineterminator=None, + chunksize=None, + date_format=None, + doublequote=True, + escapechar=None, + decimal=".", + errors: str = "strict", + storage_options: StorageOptions = None, +): # pragma: no cover + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + return FactoryDispatcher.to_csv( + self._query_compiler, + path_or_buf=path_or_buf, + sep=sep, + na_rep=na_rep, + float_format=float_format, + columns=columns, + header=header, + index=index, + index_label=index_label, + mode=mode, + encoding=encoding, + compression=compression, + quoting=quoting, + quotechar=quotechar, + lineterminator=lineterminator, + chunksize=chunksize, + date_format=date_format, + doublequote=doublequote, + escapechar=escapechar, + decimal=decimal, + errors=errors, + storage_options=storage_options, + ) + + +# Modin performs extra argument validation and defaults to pandas for some edge cases. @register_base_override("sample") def sample( self, @@ -1705,6 +1769,7 @@ def sample( return self.__constructor__(query_compiler=query_compiler) +# Modin performs an extra query calling self.isna() to raise a warning when fill_method is unspecified. @register_base_override("pct_change") def pct_change( self, periods=1, fill_method=no_default, limit=no_default, freq=None, **kwargs @@ -1761,6 +1826,7 @@ def pct_change( ) +# Snowpark pandas has different `copy` behavior, and some different behavior with native series arguments. @register_base_override("astype") def astype( self, @@ -1808,6 +1874,8 @@ def astype( return self._create_or_update_from_compiler(new_query_compiler, not copy) +# Modin defaults to pandsa when `level` is specified, and has some extra axis validation that +# is guarded in newer versions. @register_base_override("drop") def drop( self, @@ -1851,6 +1919,7 @@ def drop( return self._create_or_update_from_compiler(new_query_compiler, inplace) +# Modin calls len(self.index) instead of a direct query compiler method. @register_base_override("__len__") def __len__(self) -> int: """ @@ -1864,6 +1933,7 @@ def __len__(self) -> int: return self._query_compiler.get_axis_len(axis=0) +# Snowpark pandas ignores `copy`. @register_base_override("set_axis") def set_axis( self, @@ -1898,6 +1968,7 @@ def set_axis( return obj +# Modin has different behavior for empty dataframes and some slightly different length validation. @register_base_override("describe") def describe( self, @@ -1943,6 +2014,7 @@ def describe( ) +# Modin does type validation on self that Snowpark pandas defers to SQL. @register_base_override("diff") def diff(self, periods: int = 1, axis: Axis = 0): """ @@ -1958,6 +2030,7 @@ def diff(self, periods: int = 1, axis: Axis = 0): ) +# Modin does an unnecessary len call when n == 0. @register_base_override("tail") def tail(self, n: int = 5): if n == 0: @@ -1965,6 +2038,7 @@ def tail(self, n: int = 5): return self.iloc[-n:] +# Snowpark pandas does extra argument validation (which should probably be deferred to SQL instead). @register_base_override("idxmax") def idxmax(self, axis=0, skipna=True, numeric_only=False): # noqa: PR01, RT01, D200 """ @@ -1992,6 +2066,7 @@ def idxmax(self, axis=0, skipna=True, numeric_only=False): # noqa: PR01, RT01, ) +# Snowpark pandas does extra argument validation (which should probably be deferred to SQL instead). @register_base_override("idxmin") def idxmin(self, axis=0, skipna=True, numeric_only=False): # noqa: PR01, RT01, D200 """ @@ -2019,6 +2094,7 @@ def idxmin(self, axis=0, skipna=True, numeric_only=False): # noqa: PR01, RT01, ) +# Modin does dtype validation on unary ops that Snowpark pandas does not. @register_base_override("__abs__") def abs(self): # noqa: RT01, D200 """ @@ -2028,6 +2104,7 @@ def abs(self): # noqa: RT01, D200 return self.__constructor__(query_compiler=self._query_compiler.unary_op("abs")) +# Modin does dtype validation on unary ops that Snowpark pandas does not. @register_base_override("__invert__") def __invert__(self): """ @@ -2042,6 +2119,7 @@ def __invert__(self): return self.__constructor__(query_compiler=self._query_compiler.invert()) +# Modin does dtype validation on unary ops that Snowpark pandas does not. @register_base_override("__neg__") def __neg__(self): """ @@ -2055,75 +2133,7 @@ def __neg__(self): return self.__constructor__(query_compiler=self._query_compiler.unary_op("__neg__")) -@register_base_override("rename_axis") -def rename_axis( - self, - mapper=lib.no_default, - *, - index=lib.no_default, - columns=lib.no_default, - axis=0, - copy=None, - inplace=False, -): # noqa: PR01, RT01, D200 - """ - Set the name of the axis for the index or columns. - """ - # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset - axes = {"index": index, "columns": columns} - - if copy is None: - copy = True - - if axis is not None: - axis = self._get_axis_number(axis) - - inplace = validate_bool_kwarg(inplace, "inplace") - - if mapper is not lib.no_default and mapper is not None: - # Use v0.23 behavior if a scalar or list - non_mapper = is_scalar(mapper) or ( - is_list_like(mapper) and not is_dict_like(mapper) - ) - if non_mapper: - return self._set_axis_name(mapper, axis=axis, inplace=inplace) - else: - raise ValueError("Use `.rename` to alter labels with a mapper.") - else: - # Use new behavior. Means that index and/or columns is specified - result = self if inplace else self.copy(deep=copy) - - for axis in range(self.ndim): - v = axes.get(pandas.DataFrame._get_axis_name(axis)) - if v is lib.no_default: - continue - non_mapper = is_scalar(v) or (is_list_like(v) and not is_dict_like(v)) - if non_mapper: - newnames = v - else: - - def _get_rename_function(mapper): - if isinstance(mapper, (dict, BasePandasDataset)): - - def f(x): - if x in mapper: - return mapper[x] - else: - return x - - else: - f = mapper - - return f - - f = _get_rename_function(v) - curnames = self.index.names if axis == 0 else self.columns.names - newnames = [f(name) for name in curnames] - result._set_axis_name(newnames, axis=axis, inplace=True) - if not inplace: - return result - - +# Snowpark pandas has custom dispatch logic for ufuncs, while modin defaults to pandas. @register_base_override("__array_ufunc__") def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): """ @@ -2160,6 +2170,7 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): return NotImplemented +# Snowpark pandas does extra argument validation. @register_base_override("reindex") def reindex( self, @@ -2197,94 +2208,8 @@ def reindex( ) -@register_base_override("all") -# Renamed to _all to avoid conflict with builtin python function all (override still has the correct name) -def _all(self, axis=0, bool_only=None, skipna=True, **kwargs): - """ - Return whether all elements are True, potentially over an axis. - """ - # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset - validate_bool_kwarg(skipna, "skipna", none_allowed=False) - if axis is not None: - axis = self._get_axis_number(axis) - if bool_only and axis == 0: - if hasattr(self, "dtype"): - ErrorMessage.not_implemented( - "{}.{} does not implement numeric_only.".format( - type(self).__name__, "all" - ) - ) # pragma: no cover - data_for_compute = self[self.columns[self.dtypes == np.bool_]] - return data_for_compute.all( - axis=axis, bool_only=False, skipna=skipna, **kwargs - ) - result = self._reduce_dimension( - self._query_compiler.all( - axis=axis, bool_only=bool_only, skipna=skipna, **kwargs - ) - ) - else: - if bool_only: - raise ValueError(f"Axis must be 0 or 1 (got {axis})") - # Reduce to a scalar if axis is None. - result = self._reduce_dimension( - # FIXME: Judging by pandas docs `**kwargs` serves only compatibility - # purpose and does not affect the result, we shouldn't pass them to the query compiler. - self._query_compiler.all( - axis=0, - bool_only=bool_only, - skipna=skipna, - **kwargs, - ) - ) - if isinstance(result, BasePandasDataset): - return result.all(axis=axis, bool_only=bool_only, skipna=skipna, **kwargs) - return True if result is None else result - - -@register_base_override("any") -# Renamed to _any to avoid conflict with builtin python function any (override still has the correct name) -def _any(self, axis=0, bool_only=None, skipna=True, **kwargs): - """ - Return whether any element is True, potentially over an axis. - """ - # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset - validate_bool_kwarg(skipna, "skipna", none_allowed=False) - if axis is not None: - axis = self._get_axis_number(axis) - if bool_only and axis == 0: - if hasattr(self, "dtype"): - ErrorMessage.not_implemented( - "{}.{} does not implement numeric_only.".format( - type(self).__name__, "all" - ) - ) # pragma: no cover - data_for_compute = self[self.columns[self.dtypes == np.bool_]] - return data_for_compute.any( - axis=axis, bool_only=False, skipna=skipna, **kwargs - ) - result = self._reduce_dimension( - self._query_compiler.any( - axis=axis, bool_only=bool_only, skipna=skipna, **kwargs - ) - ) - else: - if bool_only: - raise ValueError(f"Axis must be 0 or 1 (got {axis})") - # Reduce to a scalar if axis is None. - result = self._reduce_dimension( - self._query_compiler.any( - axis=0, - bool_only=bool_only, - skipna=skipna, - **kwargs, - ) - ) - if isinstance(result, BasePandasDataset): - return result.any(axis=axis, bool_only=bool_only, skipna=skipna, **kwargs) - return False if result is None else result - - +# No direct override annotation; used as part of `property`. +# Snowpark pandas may return a custom lazy index object. def _get_index(self): """ Get the index for this DataFrame. @@ -2306,6 +2231,8 @@ def _get_index(self): return idx +# No direct override annotation; used as part of `property`. +# Snowpark pandas may return a custom lazy index object. def _set_index(self, new_index: Axes) -> None: """ Set the index for this DataFrame. @@ -2323,4 +2250,5 @@ def _set_index(self, new_index: Axes) -> None: ) +# Snowpark pandas may return a custom lazy index object. register_base_override("index")(property(_get_index, _set_index)) From 2b587642aea6533d3d734fb590f456ed1595daa0 Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Mon, 26 Aug 2024 14:32:02 -0700 Subject: [PATCH 09/14] add back rename_axis --- .../modin/plugin/extensions/base_overrides.py | 70 +++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py index 50e4aea9c8b..5756bb2e7dc 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py @@ -2133,6 +2133,76 @@ def __neg__(self): return self.__constructor__(query_compiler=self._query_compiler.unary_op("__neg__")) +# Modin needs to add a check for mapper is not None, which changes query counts in test_concat.py +# if not present. +@register_base_override("rename_axis") +def rename_axis( + self, + mapper=lib.no_default, + *, + index=lib.no_default, + columns=lib.no_default, + axis=0, + copy=None, + inplace=False, +): # noqa: PR01, RT01, D200 + """ + Set the name of the axis for the index or columns. + """ + axes = {"index": index, "columns": columns} + + if copy is None: + copy = True + + if axis is not None: + axis = self._get_axis_number(axis) + + inplace = validate_bool_kwarg(inplace, "inplace") + + if mapper is not lib.no_default and mapper is not None: + # Use v0.23 behavior if a scalar or list + non_mapper = is_scalar(mapper) or ( + is_list_like(mapper) and not is_dict_like(mapper) + ) + if non_mapper: + return self._set_axis_name(mapper, axis=axis, inplace=inplace) + else: + raise ValueError("Use `.rename` to alter labels with a mapper.") + else: + # Use new behavior. Means that index and/or columns is specified + result = self if inplace else self.copy(deep=copy) + + for axis in range(self.ndim): + v = axes.get(pandas.DataFrame._get_axis_name(axis)) + if v is lib.no_default: + continue + non_mapper = is_scalar(v) or (is_list_like(v) and not is_dict_like(v)) + if non_mapper: + newnames = v + else: + + def _get_rename_function(mapper): + if isinstance(mapper, (dict, BasePandasDataset)): + + def f(x): + if x in mapper: + return mapper[x] + else: + return x + + else: + f = mapper + + return f + + f = _get_rename_function(v) + curnames = self.index.names if axis == 0 else self.columns.names + newnames = [f(name) for name in curnames] + result._set_axis_name(newnames, axis=axis, inplace=True) + if not inplace: + return result + + # Snowpark pandas has custom dispatch logic for ufuncs, while modin defaults to pandas. @register_base_override("__array_ufunc__") def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): From 3be561659f884c0c12244010a71520ce560bc434 Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Wed, 28 Aug 2024 16:40:26 -0700 Subject: [PATCH 10/14] respond to reviews, fix telemetry --- .../snowpark/modin/pandas/dataframe.py | 3 ++- src/snowflake/snowpark/modin/pandas/series.py | 25 +------------------ .../modin/plugin/_internal/telemetry.py | 24 +++++++++++++++--- .../compiler/snowflake_query_compiler.py | 8 +++--- .../modin/plugin/extensions/base_overrides.py | 9 +++++++ tests/integ/modin/test_telemetry.py | 15 +++++++++++ tests/unit/modin/modin/test_envvars.py | 2 +- 7 files changed, 52 insertions(+), 34 deletions(-) diff --git a/src/snowflake/snowpark/modin/pandas/dataframe.py b/src/snowflake/snowpark/modin/pandas/dataframe.py index a199d73e999..b42ad5a04c7 100644 --- a/src/snowflake/snowpark/modin/pandas/dataframe.py +++ b/src/snowflake/snowpark/modin/pandas/dataframe.py @@ -78,7 +78,7 @@ DataFrameGroupBy, validate_groupby_args, ) -from snowflake.snowpark.modin.pandas.series import _ATTRS_NO_LOOKUP, Series +from snowflake.snowpark.modin.pandas.series import Series from snowflake.snowpark.modin.pandas.snow_partition_iterator import ( SnowparkPandasRowPartitionIterator, ) @@ -98,6 +98,7 @@ ErrorMessage, dataframe_not_implemented, ) +from snowflake.snowpark.modin.plugin.utils.frontend_constants import _ATTRS_NO_LOOKUP from snowflake.snowpark.modin.plugin.utils.warning_message import ( SET_DATAFRAME_ATTRIBUTE_WARNING, WarningMessage, diff --git a/src/snowflake/snowpark/modin/pandas/series.py b/src/snowflake/snowpark/modin/pandas/series.py index 59f00ef2574..6e1b93437a8 100644 --- a/src/snowflake/snowpark/modin/pandas/series.py +++ b/src/snowflake/snowpark/modin/pandas/series.py @@ -63,6 +63,7 @@ ErrorMessage, series_not_implemented, ) +from snowflake.snowpark.modin.plugin.utils.frontend_constants import _ATTRS_NO_LOOKUP from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage from snowflake.snowpark.modin.utils import ( MODIN_UNNAMED_SERIES_LABEL, @@ -97,30 +98,6 @@ _SERIES_EXTENSIONS_ = {} -# Do not look up certain attributes in columns or index, as they're used for some -# special purposes, like serving remote context -_ATTRS_NO_LOOKUP = { - "____id_pack__", - "__name__", - "_cache", - "_ipython_canary_method_should_not_exist_", - "_ipython_display_", - "_repr_html_", - "_repr_javascript_", - "_repr_jpeg_", - "_repr_json_", - "_repr_latex_", - "_repr_markdown_", - "_repr_mimebundle_", - "_repr_pdf_", - "_repr_png_", - "_repr_svg_", - "__array_struct__", - "__array_interface__", - "_typ", -} - - @_inherit_docstrings( pandas.Series, excluded=[ diff --git a/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py b/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py index 0a022b0d588..fcc61ab66af 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py @@ -9,6 +9,7 @@ from enum import Enum, unique from typing import Any, Callable, Optional, TypeVar, Union, cast +import modin from typing_extensions import ParamSpec import snowflake.snowpark.session @@ -535,17 +536,30 @@ def __new__( snowflake.snowpark.modin.pandas.window.Rolling]: The modified class with decorated methods. """ - for attr_name, attr_value in attrs.items(): + attr_dict = dict(attrs.items()) + # If BasePandasDataset, defined exclusively by upstream modin, is a parent of this class, + # then apply the telemetry decorator to it. + # https://stackoverflow.com/a/71105206 + # TODO figure out solution for dataframe/series when those directly use modin frontend + for base in bases: + if base is modin.pandas.base.BasePandasDataset: + # Newly defined attrs should take precedence over those defined in base, + # so the keys in attr_dict should overwrite those in base_dict + base_dict = dict(vars(base).items()) + base_dict.update(attr_dict) + attr_dict = base_dict + new_attrs = {} + for attr_name, attr_value in attr_dict.items(): if callable(attr_value) and ( not attr_name.startswith("_") or (attr_name in TELEMETRY_PRIVATE_METHODS) ): - attrs[attr_name] = snowpark_pandas_telemetry_method_decorator( + new_attrs[attr_name] = snowpark_pandas_telemetry_method_decorator( attr_value ) elif isinstance(attr_value, property): # wrap on getter and setter - attrs[attr_name] = property( + new_attrs[attr_name] = property( snowpark_pandas_telemetry_method_decorator( cast( # add a cast because mypy doesn't recognize that @@ -575,4 +589,6 @@ def __new__( ), doc=attr_value.__doc__, ) - return type.__new__(cls, name, bases, attrs) + else: + new_attrs[attr_name] = attr_value + return type.__new__(cls, name, bases, new_attrs) diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 8e1abce8d1e..70ab595070b 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -1464,10 +1464,10 @@ def set_columns(self, new_pandas_labels: Axes) -> "SnowflakeQueryCompiler": return SnowflakeQueryCompiler(new_internal_frame) # TODO SNOW-837664: add more tests for df.columns - columns: native_pd.Index = property( - lambda self: self._modin_frame.data_columns_index, - lambda self, labels: self.set_columns(labels), - ) + def _get_columns(self) -> native_pd.Index: + return self._modin_frame.data_columns_index + + columns: native_pd.Index = property(_get_columns, set_columns) def _shift_values( self, periods: int, axis: Union[Literal[0], Literal[1]], fill_value: Hashable diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py index 5756bb2e7dc..3d645074a0d 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py @@ -87,6 +87,15 @@ def register_base_override(method_name: str): + """ + Decorator function to override a method on BasePandasDataset. Since Modin does not provide a mechanism + for directly overriding methods on BasePandasDataset, we mock this by performing the override on + DataFrame and Series, and manually performing a `setattr` on the base class. These steps are necessary + to allow both the docstring extension and method dispatch to work properly. + + Methods annotated here also are automatically instrumented with Snowpark pandas telemetry. + """ + def decorator(base_method: Any): if callable(base_method) and ( not method_name.startswith("_") diff --git a/tests/integ/modin/test_telemetry.py b/tests/integ/modin/test_telemetry.py index 9d3f632dd01..06fbc71eec7 100644 --- a/tests/integ/modin/test_telemetry.py +++ b/tests/integ/modin/test_telemetry.py @@ -547,3 +547,18 @@ def test_telemetry_repr(): {"name": "Series.property.name_set"}, {"name": "Series.Series.__repr__"}, ] + + +@sql_count_checker(query_count=0) +def test_telemetry_copy(): + # copy() is defined in upstream modin's BasePandasDataset class, and not overridden by any + # child class or the extensions module. + s = pd.Series([1, 2, 3, 4]) + copied = s.copy() + assert s._query_compiler.snowpark_pandas_api_calls == [ + {"name": "Series.property.name_set"} + ] + assert copied._query_compiler.snowpark_pandas_api_calls == [ + {"name": "Series.property.name_set"}, + {"name": "Series.BasePandasDataset.copy"}, + ] diff --git a/tests/unit/modin/modin/test_envvars.py b/tests/unit/modin/modin/test_envvars.py index a3f8fca324a..7c5e3a40bb0 100644 --- a/tests/unit/modin/modin/test_envvars.py +++ b/tests/unit/modin/modin/test_envvars.py @@ -112,7 +112,7 @@ def _init_doc_module(): DOC_OVERRIDE_XFAIL_REASON = ( "test docstring overrides currently cannot override real docstring overrides until " - "modin 0.30.1 is available (SNOW-1473605)" + "modin 0.31.0 is available" ) From 0939b5291d67311c7e6c0aebbe7948d45db8579d Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Wed, 28 Aug 2024 17:36:38 -0700 Subject: [PATCH 11/14] add missing file --- .../modin/plugin/utils/frontend_constants.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py diff --git a/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py b/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py new file mode 100644 index 00000000000..26b24886f4a --- /dev/null +++ b/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py @@ -0,0 +1,26 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Do not look up certain attributes in columns or index, as they're used for some +# special purposes, like serving remote context +_ATTRS_NO_LOOKUP = { + "____id_pack__", + "__name__", + "_cache", + "_ipython_canary_method_should_not_exist_", + "_ipython_display_", + "_repr_html_", + "_repr_javascript_", + "_repr_jpeg_", + "_repr_json_", + "_repr_latex_", + "_repr_markdown_", + "_repr_mimebundle_", + "_repr_pdf_", + "_repr_png_", + "_repr_svg_", + "__array_struct__", + "__array_interface__", + "_typ", +} From 01f40d99b429368519d7fb8f15333ade45538b35 Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Wed, 28 Aug 2024 21:08:28 -0700 Subject: [PATCH 12/14] more review fixes --- src/snowflake/snowpark/modin/plugin/__init__.py | 3 ++- .../modin/plugin/compiler/snowflake_query_compiler.py | 4 ++-- .../snowpark/modin/plugin/utils/frontend_constants.py | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/__init__.py b/src/snowflake/snowpark/modin/plugin/__init__.py index 90445892497..c4172f26696 100644 --- a/src/snowflake/snowpark/modin/plugin/__init__.py +++ b/src/snowflake/snowpark/modin/plugin/__init__.py @@ -63,7 +63,8 @@ import modin.utils # type: ignore[import] # isort: skip # noqa: E402 import modin.pandas.series_utils # type: ignore[import] # isort: skip # noqa: E402 -# TODO: https://github.com/modin-project/modin/issues/7113 and https://github.com/modin-project/modin/issues/7134 +# TODO: SNOW-1643979 pull in fixes for +# https://github.com/modin-project/modin/issues/7113 and https://github.com/modin-project/modin/issues/7134 # Upstream Modin has issues with certain docstring generation edge cases, so we should use our version instead _inherit_docstrings = snowflake.snowpark.modin.utils._inherit_docstrings diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index 70ab595070b..50ce5e71310 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -1464,10 +1464,10 @@ def set_columns(self, new_pandas_labels: Axes) -> "SnowflakeQueryCompiler": return SnowflakeQueryCompiler(new_internal_frame) # TODO SNOW-837664: add more tests for df.columns - def _get_columns(self) -> native_pd.Index: + def get_columns(self) -> native_pd.Index: return self._modin_frame.data_columns_index - columns: native_pd.Index = property(_get_columns, set_columns) + columns: native_pd.Index = property(get_columns, set_columns) def _shift_values( self, periods: int, axis: Union[Literal[0], Literal[1]], fill_value: Hashable diff --git a/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py b/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py index 26b24886f4a..f2b28e8bfc1 100644 --- a/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py +++ b/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py @@ -4,6 +4,7 @@ # Do not look up certain attributes in columns or index, as they're used for some # special purposes, like serving remote context +# TODO: SNOW-1643986 examine whether to update upstream modin to follow this _ATTRS_NO_LOOKUP = { "____id_pack__", "__name__", From c42b4d0b576e278642244a3e3a2ff7ebec8f2956 Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Thu, 29 Aug 2024 16:21:34 -0700 Subject: [PATCH 13/14] fix telemetry --- .../snowpark/modin/pandas/__init__.py | 72 +++++++++++- .../modin/plugin/_internal/telemetry.py | 103 ++++++++---------- .../plugin/extensions/base_extensions.py | 46 ++++++++ .../modin/plugin/extensions/base_overrides.py | 44 +------- .../plugin/extensions/dataframe_extensions.py | 33 ------ .../plugin/extensions/series_extensions.py | 33 ------ 6 files changed, 165 insertions(+), 166 deletions(-) create mode 100644 src/snowflake/snowpark/modin/plugin/extensions/base_extensions.py diff --git a/src/snowflake/snowpark/modin/pandas/__init__.py b/src/snowflake/snowpark/modin/pandas/__init__.py index 274d5b3763f..b51a47b64b3 100644 --- a/src/snowflake/snowpark/modin/pandas/__init__.py +++ b/src/snowflake/snowpark/modin/pandas/__init__.py @@ -85,10 +85,16 @@ timedelta_range, ) +import modin.pandas + # TODO: SNOW-851745 make sure add all Snowpark pandas API general functions from modin.pandas import plotting # type: ignore[import] -from snowflake.snowpark.modin.pandas.dataframe import DataFrame +from snowflake.snowpark.modin.pandas.api.extensions import ( + register_dataframe_accessor, + register_series_accessor, +) +from snowflake.snowpark.modin.pandas.dataframe import _DATAFRAME_EXTENSIONS_, DataFrame from snowflake.snowpark.modin.pandas.general import ( concat, crosstab, @@ -140,8 +146,11 @@ read_xml, to_pickle, ) -from snowflake.snowpark.modin.pandas.series import Series +from snowflake.snowpark.modin.pandas.series import _SERIES_EXTENSIONS_, Series from snowflake.snowpark.modin.plugin._internal.session import SnowpandasSessionHolder +from snowflake.snowpark.modin.plugin._internal.telemetry import ( + try_add_telemetry_to_attribute, +) # The extensions assigned to this module _PD_EXTENSIONS_: dict = {} @@ -154,12 +163,71 @@ DatetimeIndex, TimedeltaIndex, ) + +# this must occur before overrides are applied +_attrs_defined_on_modin_base = set(dir(modin.pandas.base.BasePandasDataset)) +_attrs_defined_on_series = set( + dir(Series) +) # TODO: SNOW-1063347 revisit when series.py is removed +_attrs_defined_on_dataframe = set( + dir(DataFrame) +) # TODO: SNOW-1063346 revisit when dataframe.py is removed + +# base overrides occur before subclass overrides in case subclasses override a base method +import snowflake.snowpark.modin.plugin.extensions.base_extensions # isort: skip # noqa: E402,F401 import snowflake.snowpark.modin.plugin.extensions.base_overrides # isort: skip # noqa: E402,F401 import snowflake.snowpark.modin.plugin.extensions.dataframe_extensions # isort: skip # noqa: E402,F401 import snowflake.snowpark.modin.plugin.extensions.dataframe_overrides # isort: skip # noqa: E402,F401 import snowflake.snowpark.modin.plugin.extensions.series_extensions # isort: skip # noqa: E402,F401 import snowflake.snowpark.modin.plugin.extensions.series_overrides # isort: skip # noqa: E402,F401 +# For any method defined on Series/DF, add telemetry to it if it meets all of the following conditions: +# 1. The method was defined directly on upstream BasePandasDataset (_attrs_defined_on_modin_base) +# 2. The method is not overridden by a child class (this will change) +# 3. The method is not overridden by an extensions module +# 4. The method name does not start with an _ +# +# TODO: SNOW-1063347 +# Since we still use the vendored version of Series and the overrides for the top-level +# namespace haven't been performed yet, we need to set properties on the vendored version +_base_telemetry_added_attrs = set() + +_series_ext = _SERIES_EXTENSIONS_.copy() +for attr_name in dir(Series): + if ( + attr_name in _attrs_defined_on_modin_base + and attr_name in _attrs_defined_on_series + and attr_name not in _series_ext + and not attr_name.startswith("_") + ): + register_series_accessor(attr_name)( + try_add_telemetry_to_attribute(attr_name, getattr(Series, attr_name)) + ) + _base_telemetry_added_attrs.add(attr_name) + +# TODO: SNOW-1063346 +# Since we still use the vendored version of DataFrame and the overrides for the top-level +# namespace haven't been performed yet, we need to set properties on the vendored version +_dataframe_ext = _DATAFRAME_EXTENSIONS_.copy() +for attr_name in dir(DataFrame): + if ( + attr_name in _attrs_defined_on_modin_base + and attr_name in _attrs_defined_on_dataframe + and attr_name not in _dataframe_ext + and not attr_name.startswith("_") + ): + # If telemetry was already added via Series, register the override but don't re-wrap + # the method in the telemetry annotation. If we don't do this check, we will end up + # double-reporting telemetry on some methods. + original_attr = getattr(DataFrame, attr_name) + new_attr = ( + original_attr + if attr_name in _base_telemetry_added_attrs + else try_add_telemetry_to_attribute(attr_name, original_attr) + ) + register_dataframe_accessor(attr_name)(new_attr) + _base_telemetry_added_attrs.add(attr_name) + def __getattr__(name: str) -> Any: """ diff --git a/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py b/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py index fcc61ab66af..8057cf93885 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py @@ -9,7 +9,6 @@ from enum import Enum, unique from typing import Any, Callable, Optional, TypeVar, Union, cast -import modin from typing_extensions import ParamSpec import snowflake.snowpark.session @@ -496,6 +495,49 @@ def wrap(*args, **kwargs): # type: ignore } +def try_add_telemetry_to_attribute(attr_name: str, attr_value: Any) -> Any: + """ + Attempts to add telemetry to an attribute. + + If the attribute is callable with name in TELEMETRY_PRIVATE_METHODS, or is a callable that + starts with an underscore, the original attribute will be returned as-is. Otherwise, a version + of the method/property annotated with Snowpark pandas telemetry is returned. + """ + if callable(attr_value) and ( + not attr_name.startswith("_") or (attr_name in TELEMETRY_PRIVATE_METHODS) + ): + return snowpark_pandas_telemetry_method_decorator(attr_value) + elif isinstance(attr_value, property): + # wrap on getter and setter + return property( + snowpark_pandas_telemetry_method_decorator( + cast( + # add a cast because mypy doesn't recognize that + # non-None fget and __get__ are both callable + # arguments to snowpark_pandas_telemetry_method_decorator. + Callable, + attr_value.__get__ # pragma: no cover: we don't encounter this case in pandas or modin because every property has an fget method. + if attr_value.fget is None + else attr_value.fget, + ), + property_name=attr_name, + property_method_type=PropertyMethodType.FGET, + ), + snowpark_pandas_telemetry_method_decorator( + attr_value.__set__ if attr_value.fset is None else attr_value.fset, + property_name=attr_name, + property_method_type=PropertyMethodType.FSET, + ), + snowpark_pandas_telemetry_method_decorator( + attr_value.__delete__ if attr_value.fdel is None else attr_value.fdel, + property_name=attr_name, + property_method_type=PropertyMethodType.FDEL, + ), + doc=attr_value.__doc__, + ) + return attr_value + + class TelemetryMeta(type): def __new__( cls, name: str, bases: tuple, attrs: dict[str, Any] @@ -536,59 +578,6 @@ def __new__( snowflake.snowpark.modin.pandas.window.Rolling]: The modified class with decorated methods. """ - attr_dict = dict(attrs.items()) - # If BasePandasDataset, defined exclusively by upstream modin, is a parent of this class, - # then apply the telemetry decorator to it. - # https://stackoverflow.com/a/71105206 - # TODO figure out solution for dataframe/series when those directly use modin frontend - for base in bases: - if base is modin.pandas.base.BasePandasDataset: - # Newly defined attrs should take precedence over those defined in base, - # so the keys in attr_dict should overwrite those in base_dict - base_dict = dict(vars(base).items()) - base_dict.update(attr_dict) - attr_dict = base_dict - new_attrs = {} - for attr_name, attr_value in attr_dict.items(): - if callable(attr_value) and ( - not attr_name.startswith("_") - or (attr_name in TELEMETRY_PRIVATE_METHODS) - ): - new_attrs[attr_name] = snowpark_pandas_telemetry_method_decorator( - attr_value - ) - elif isinstance(attr_value, property): - # wrap on getter and setter - new_attrs[attr_name] = property( - snowpark_pandas_telemetry_method_decorator( - cast( - # add a cast because mypy doesn't recognize that - # non-None fget and __get__ are both callable - # arguments to snowpark_pandas_telemetry_method_decorator. - Callable, - attr_value.__get__ # pragma: no cover: we don't encounter this case in pandas or modin because every property has an fget method. - if attr_value.fget is None - else attr_value.fget, - ), - property_name=attr_name, - property_method_type=PropertyMethodType.FGET, - ), - snowpark_pandas_telemetry_method_decorator( - attr_value.__set__ - if attr_value.fset is None - else attr_value.fset, - property_name=attr_name, - property_method_type=PropertyMethodType.FSET, - ), - snowpark_pandas_telemetry_method_decorator( - attr_value.__delete__ - if attr_value.fdel is None - else attr_value.fdel, - property_name=attr_name, - property_method_type=PropertyMethodType.FDEL, - ), - doc=attr_value.__doc__, - ) - else: - new_attrs[attr_name] = attr_value - return type.__new__(cls, name, bases, new_attrs) + for attr_name, attr_value in attrs.items(): + attrs[attr_name] = try_add_telemetry_to_attribute(attr_name, attr_value) + return type.__new__(cls, name, bases, attrs) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_extensions.py b/src/snowflake/snowpark/modin/plugin/extensions/base_extensions.py new file mode 100644 index 00000000000..496136d736e --- /dev/null +++ b/src/snowflake/snowpark/modin/plugin/extensions/base_extensions.py @@ -0,0 +1,46 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +""" +File containing BasePandasDataset APIs defined in Snowpark pandas but not the Modin API layer. +""" + +from snowflake.snowpark.modin.plugin._internal.telemetry import ( + snowpark_pandas_telemetry_method_decorator, +) + +from .base_overrides import register_base_override + + +@register_base_override("__array_function__") +@snowpark_pandas_telemetry_method_decorator +def __array_function__(self, func: callable, types: tuple, args: tuple, kwargs: dict): + """ + Apply the `func` to the `BasePandasDataset`. + + Parameters + ---------- + func : np.func + The NumPy func to apply. + types : tuple + The types of the args. + args : tuple + The args to the func. + kwargs : dict + Additional keyword arguments. + + Returns + ------- + BasePandasDataset + The result of the ufunc applied to the `BasePandasDataset`. + """ + from snowflake.snowpark.modin.plugin.utils.numpy_to_pandas import ( + numpy_to_pandas_func_map, + ) + + if func.__name__ in numpy_to_pandas_func_map: + return numpy_to_pandas_func_map[func.__name__](*args, **kwargs) + else: + # per NEP18 we raise NotImplementedError so that numpy can intercept + return NotImplemented # pragma: no cover diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py index 3d645074a0d..5aa937b809b 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py @@ -15,7 +15,7 @@ import pickle as pkl import warnings from collections.abc import Sequence -from typing import Any, Callable, Hashable, Literal, Mapping, cast, get_args +from typing import Any, Callable, Hashable, Literal, Mapping, get_args import modin.pandas as pd import numpy as np @@ -73,9 +73,8 @@ validate_and_try_convert_agg_func_arg_func_to_str, ) from snowflake.snowpark.modin.plugin._internal.telemetry import ( - TELEMETRY_PRIVATE_METHODS, - PropertyMethodType, snowpark_pandas_telemetry_method_decorator, + try_add_telemetry_to_attribute, ) from snowflake.snowpark.modin.plugin._typing import ListLike from snowflake.snowpark.modin.plugin.utils.error_message import ( @@ -97,44 +96,7 @@ def register_base_override(method_name: str): """ def decorator(base_method: Any): - if callable(base_method) and ( - not method_name.startswith("_") - or (method_name in TELEMETRY_PRIVATE_METHODS) - ): - base_method = snowpark_pandas_telemetry_method_decorator(base_method) - elif isinstance(base_method, property): - base_method = property( - snowpark_pandas_telemetry_method_decorator( - cast( - # add a cast because mypy doesn't recognize that - # non-None fget and __get__ are both callable - # arguments to snowpark_pandas_telemetry_method_decorator. - Callable, - base_method.fget, # all properties defined in this file have an fget - ), - property_name=method_name, - property_method_type=PropertyMethodType.FGET, - ), - snowpark_pandas_telemetry_method_decorator( - ( - base_method.__set__ - if base_method.fset is None - else base_method.fset - ), - property_name=method_name, - property_method_type=PropertyMethodType.FSET, - ), - snowpark_pandas_telemetry_method_decorator( - ( - base_method.__delete__ - if base_method.fdel is None - else base_method.fdel - ), - property_name=method_name, - property_method_type=PropertyMethodType.FDEL, - ), - doc=base_method.__doc__, - ) + base_method = try_add_telemetry_to_attribute(method_name, base_method) parent_method = getattr(BasePandasDataset, method_name, None) if isinstance(parent_method, property): parent_method = parent_method.fget diff --git a/src/snowflake/snowpark/modin/plugin/extensions/dataframe_extensions.py b/src/snowflake/snowpark/modin/plugin/extensions/dataframe_extensions.py index b167c924452..a2d4710bf66 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/dataframe_extensions.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/dataframe_extensions.py @@ -254,36 +254,3 @@ def cache_result(self, inplace: bool = False) -> Optional[pd.DataFrame]: self._update_inplace(new_qc) else: return pd.DataFrame(query_compiler=new_qc) - - -@register_dataframe_accessor("__array_function__") -@snowpark_pandas_telemetry_method_decorator -def __array_function__(self, func: callable, types: tuple, args: tuple, kwargs: dict): - """ - Apply the `func` to the `BasePandasDataset`. - - Parameters - ---------- - func : np.func - The NumPy func to apply. - types : tuple - The types of the args. - args : tuple - The args to the func. - kwargs : dict - Additional keyword arguments. - - Returns - ------- - BasePandasDataset - The result of the ufunc applied to the `BasePandasDataset`. - """ - from snowflake.snowpark.modin.plugin.utils.numpy_to_pandas import ( - numpy_to_pandas_func_map, - ) - - if func.__name__ in numpy_to_pandas_func_map: - return numpy_to_pandas_func_map[func.__name__](*args, **kwargs) - else: - # per NEP18 we raise NotImplementedError so that numpy can intercept - return NotImplemented # pragma: no cover diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py b/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py index 729b6c3bb0a..f5e27a44e80 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/series_extensions.py @@ -218,36 +218,3 @@ def cache_result(self, inplace: bool = False) -> Optional[pd.Series]: self._update_inplace(new_qc) else: return pd.Series(query_compiler=new_qc) - - -@register_series_accessor("__array_function__") -@snowpark_pandas_telemetry_method_decorator -def __array_function__(self, func: callable, types: tuple, args: tuple, kwargs: dict): - """ - Apply the `func` to the `BasePandasDataset`. - - Parameters - ---------- - func : np.func - The NumPy func to apply. - types : tuple - The types of the args. - args : tuple - The args to the func. - kwargs : dict - Additional keyword arguments. - - Returns - ------- - BasePandasDataset - The result of the ufunc applied to the `BasePandasDataset`. - """ - from snowflake.snowpark.modin.plugin.utils.numpy_to_pandas import ( - numpy_to_pandas_func_map, - ) - - if func.__name__ in numpy_to_pandas_func_map: - return numpy_to_pandas_func_map[func.__name__](*args, **kwargs) - else: - # per NEP18 we raise NotImplementedError so that numpy can intercept - return NotImplemented # pragma: no cover From 0956d1cd11166e567327859fe8a12ae9185830e1 Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Thu, 29 Aug 2024 17:08:03 -0700 Subject: [PATCH 14/14] add no covers + change qc dispatch --- .../snowpark/modin/plugin/extensions/base_overrides.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py index 5aa937b809b..abbcb9bc762 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py @@ -2072,7 +2072,7 @@ def abs(self): # noqa: RT01, D200 Return a `BasePandasDataset` with absolute numeric value of each element. """ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset - return self.__constructor__(query_compiler=self._query_compiler.unary_op("abs")) + return self.__constructor__(query_compiler=self._query_compiler.abs()) # Modin does dtype validation on unary ops that Snowpark pandas does not. @@ -2101,7 +2101,7 @@ def __neg__(self): BasePandasDataset """ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset - return self.__constructor__(query_compiler=self._query_compiler.unary_op("__neg__")) + return self.__constructor__(query_compiler=self._query_compiler.negative()) # Modin needs to add a check for mapper is not None, which changes query counts in test_concat.py @@ -2199,7 +2199,7 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): # Use pandas version of ufunc if it exists if method != "__call__": # Return sentinel value NotImplemented - return NotImplemented + return NotImplemented # pragma: no cover from snowflake.snowpark.modin.plugin.utils.numpy_to_pandas import ( numpy_to_pandas_universal_func_map, ) @@ -2208,7 +2208,7 @@ def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): ufunc = numpy_to_pandas_universal_func_map[ufunc.__name__] return ufunc(self, inputs[1:], kwargs) # return the sentinel NotImplemented if we do not support this function - return NotImplemented + return NotImplemented # pragma: no cover # Snowpark pandas does extra argument validation.