diff --git a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py index 7fd351b16dc..e22355674ee 100644 --- a/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py +++ b/src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py @@ -856,6 +856,9 @@ def create_table_as_select_statement( max_data_extension_time: Optional[int] = None, change_tracking: Optional[bool] = None, copy_grants: bool = False, + *, + use_scoped_temp_objects: bool = False, + is_generated: bool = False, ) -> str: column_definition_sql = ( f"{LEFT_PARENTHESIS}{column_definition}{RIGHT_PARENTHESIS}" @@ -877,8 +880,9 @@ def create_table_as_select_statement( } ) return ( - f"{CREATE}{OR + REPLACE if replace else EMPTY_STRING} {table_type.upper()} {TABLE}" - f"{IF + NOT + EXISTS if not replace and not error else EMPTY_STRING} " + f"{CREATE}{OR + REPLACE if replace else EMPTY_STRING}" + f" {(get_temp_type_for_object(use_scoped_temp_objects, is_generated) if table_type.lower() in TEMPORARY_STRING_SET else table_type).upper()} " + f"{TABLE}{IF + NOT + EXISTS if not replace and not error else EMPTY_STRING} " f"{table_name}{column_definition_sql}{cluster_by_clause}{options_statement}" f"{COPY_GRANTS if copy_grants else EMPTY_STRING}{comment_sql} {AS}{project_statement([], child)}" ) diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index aad369a8b83..559cbeb3cc5 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -929,6 +929,8 @@ def get_create_table_as_select_plan(child: SnowflakePlan, replace, error): max_data_extension_time=max_data_extension_time, change_tracking=change_tracking, copy_grants=copy_grants, + use_scoped_temp_objects=use_scoped_temp_objects, + is_generated=is_generated, ), child, source_plan, diff --git a/src/snowflake/snowpark/modin/pandas/__init__.py b/src/snowflake/snowpark/modin/pandas/__init__.py index c4eb07d9589..b51a47b64b3 100644 --- a/src/snowflake/snowpark/modin/pandas/__init__.py +++ b/src/snowflake/snowpark/modin/pandas/__init__.py @@ -85,10 +85,16 @@ timedelta_range, ) +import modin.pandas + # TODO: SNOW-851745 make sure add all Snowpark pandas API general functions from modin.pandas import plotting # type: ignore[import] -from snowflake.snowpark.modin.pandas.dataframe import DataFrame +from snowflake.snowpark.modin.pandas.api.extensions import ( + register_dataframe_accessor, + register_series_accessor, +) +from snowflake.snowpark.modin.pandas.dataframe import _DATAFRAME_EXTENSIONS_, DataFrame from snowflake.snowpark.modin.pandas.general import ( concat, crosstab, @@ -140,15 +146,15 @@ read_xml, to_pickle, ) -from snowflake.snowpark.modin.pandas.series import Series +from snowflake.snowpark.modin.pandas.series import _SERIES_EXTENSIONS_, Series from snowflake.snowpark.modin.plugin._internal.session import SnowpandasSessionHolder +from snowflake.snowpark.modin.plugin._internal.telemetry import ( + try_add_telemetry_to_attribute, +) # The extensions assigned to this module _PD_EXTENSIONS_: dict = {} -# base needs to be re-exported in order to properly override docstrings for BasePandasDataset -# moving this import higher prevents sphinx from building documentation (??) -from snowflake.snowpark.modin.pandas import base # isort: skip # noqa: E402,F401 import snowflake.snowpark.modin.plugin.extensions.pd_extensions as pd_extensions # isort: skip # noqa: E402,F401 import snowflake.snowpark.modin.plugin.extensions.pd_overrides # isort: skip # noqa: E402,F401 @@ -157,12 +163,71 @@ DatetimeIndex, TimedeltaIndex, ) + +# this must occur before overrides are applied +_attrs_defined_on_modin_base = set(dir(modin.pandas.base.BasePandasDataset)) +_attrs_defined_on_series = set( + dir(Series) +) # TODO: SNOW-1063347 revisit when series.py is removed +_attrs_defined_on_dataframe = set( + dir(DataFrame) +) # TODO: SNOW-1063346 revisit when dataframe.py is removed + +# base overrides occur before subclass overrides in case subclasses override a base method +import snowflake.snowpark.modin.plugin.extensions.base_extensions # isort: skip # noqa: E402,F401 import snowflake.snowpark.modin.plugin.extensions.base_overrides # isort: skip # noqa: E402,F401 import snowflake.snowpark.modin.plugin.extensions.dataframe_extensions # isort: skip # noqa: E402,F401 import snowflake.snowpark.modin.plugin.extensions.dataframe_overrides # isort: skip # noqa: E402,F401 import snowflake.snowpark.modin.plugin.extensions.series_extensions # isort: skip # noqa: E402,F401 import snowflake.snowpark.modin.plugin.extensions.series_overrides # isort: skip # noqa: E402,F401 +# For any method defined on Series/DF, add telemetry to it if it meets all of the following conditions: +# 1. The method was defined directly on upstream BasePandasDataset (_attrs_defined_on_modin_base) +# 2. The method is not overridden by a child class (this will change) +# 3. The method is not overridden by an extensions module +# 4. The method name does not start with an _ +# +# TODO: SNOW-1063347 +# Since we still use the vendored version of Series and the overrides for the top-level +# namespace haven't been performed yet, we need to set properties on the vendored version +_base_telemetry_added_attrs = set() + +_series_ext = _SERIES_EXTENSIONS_.copy() +for attr_name in dir(Series): + if ( + attr_name in _attrs_defined_on_modin_base + and attr_name in _attrs_defined_on_series + and attr_name not in _series_ext + and not attr_name.startswith("_") + ): + register_series_accessor(attr_name)( + try_add_telemetry_to_attribute(attr_name, getattr(Series, attr_name)) + ) + _base_telemetry_added_attrs.add(attr_name) + +# TODO: SNOW-1063346 +# Since we still use the vendored version of DataFrame and the overrides for the top-level +# namespace haven't been performed yet, we need to set properties on the vendored version +_dataframe_ext = _DATAFRAME_EXTENSIONS_.copy() +for attr_name in dir(DataFrame): + if ( + attr_name in _attrs_defined_on_modin_base + and attr_name in _attrs_defined_on_dataframe + and attr_name not in _dataframe_ext + and not attr_name.startswith("_") + ): + # If telemetry was already added via Series, register the override but don't re-wrap + # the method in the telemetry annotation. If we don't do this check, we will end up + # double-reporting telemetry on some methods. + original_attr = getattr(DataFrame, attr_name) + new_attr = ( + original_attr + if attr_name in _base_telemetry_added_attrs + else try_add_telemetry_to_attribute(attr_name, original_attr) + ) + register_dataframe_accessor(attr_name)(new_attr) + _base_telemetry_added_attrs.add(attr_name) + def __getattr__(name: str) -> Any: """ @@ -220,7 +285,6 @@ def __getattr__(name: str) -> Any: "date_range", "Index", "MultiIndex", - "Series", "bdate_range", "period_range", "DatetimeIndex", @@ -318,8 +382,7 @@ def __getattr__(name: str) -> Any: # Manually re-export the members of the pd_extensions namespace, which are not declared in __all__. _EXTENSION_ATTRS = ["read_snowflake", "to_snowflake", "to_snowpark", "to_pandas"] # We also need to re-export native_pd.offsets, since modin.pandas doesn't re-export it. -# snowflake.snowpark.pandas.base also needs to be re-exported to make docstring overrides for BasePandasDataset work. -_ADDITIONAL_ATTRS = ["offsets", "base"] +_ADDITIONAL_ATTRS = ["offsets"] # This code should eventually be moved into the `snowflake.snowpark.modin.plugin` module instead. # Currently, trying to do so would result in incorrect results because `snowflake.snowpark.modin.pandas` diff --git a/src/snowflake/snowpark/modin/pandas/dataframe.py b/src/snowflake/snowpark/modin/pandas/dataframe.py index a7d53813779..b42ad5a04c7 100644 --- a/src/snowflake/snowpark/modin/pandas/dataframe.py +++ b/src/snowflake/snowpark/modin/pandas/dataframe.py @@ -37,6 +37,7 @@ import numpy as np import pandas from modin.pandas.accessor import CachedAccessor, SparseFrameAccessor +from modin.pandas.base import BasePandasDataset # from . import _update_engine from modin.pandas.iterator import PartitionIterator @@ -73,7 +74,6 @@ from pandas.util._validators import validate_bool_kwarg from snowflake.snowpark.modin import pandas as pd -from snowflake.snowpark.modin.pandas.base import _ATTRS_NO_LOOKUP, BasePandasDataset from snowflake.snowpark.modin.pandas.groupby import ( DataFrameGroupBy, validate_groupby_args, @@ -91,12 +91,14 @@ replace_external_data_keys_with_empty_pandas_series, replace_external_data_keys_with_query_compiler, ) +from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta from snowflake.snowpark.modin.plugin._internal.utils import is_repr_truncated from snowflake.snowpark.modin.plugin._typing import DropKeep, ListLike from snowflake.snowpark.modin.plugin.utils.error_message import ( ErrorMessage, dataframe_not_implemented, ) +from snowflake.snowpark.modin.plugin.utils.frontend_constants import _ATTRS_NO_LOOKUP from snowflake.snowpark.modin.plugin.utils.warning_message import ( SET_DATAFRAME_ATTRIBUTE_WARNING, WarningMessage, @@ -136,7 +138,7 @@ ], apilink="pandas.DataFrame", ) -class DataFrame(BasePandasDataset): +class DataFrame(BasePandasDataset, metaclass=TelemetryMeta): _pandas_class = pandas.DataFrame def __init__( diff --git a/src/snowflake/snowpark/modin/pandas/general.py b/src/snowflake/snowpark/modin/pandas/general.py index f14e14840bb..07f0617d612 100644 --- a/src/snowflake/snowpark/modin/pandas/general.py +++ b/src/snowflake/snowpark/modin/pandas/general.py @@ -30,6 +30,7 @@ import numpy as np import pandas import pandas.core.common as common +from modin.pandas.base import BasePandasDataset from pandas import IntervalIndex, NaT, Timedelta, Timestamp from pandas._libs import NaTType, lib from pandas._libs.tslibs import to_offset @@ -61,7 +62,6 @@ # add this line to make doctests runnable from snowflake.snowpark.modin import pandas as pd # noqa: F401 -from snowflake.snowpark.modin.pandas.base import BasePandasDataset from snowflake.snowpark.modin.pandas.dataframe import DataFrame from snowflake.snowpark.modin.pandas.series import Series from snowflake.snowpark.modin.pandas.utils import ( diff --git a/src/snowflake/snowpark/modin/pandas/indexing.py b/src/snowflake/snowpark/modin/pandas/indexing.py index 0ac62f504ce..c83e3fe41c4 100644 --- a/src/snowflake/snowpark/modin/pandas/indexing.py +++ b/src/snowflake/snowpark/modin/pandas/indexing.py @@ -43,6 +43,7 @@ import numpy as np import pandas +from modin.pandas.base import BasePandasDataset from pandas._libs.tslibs import Resolution, parsing from pandas._typing import AnyArrayLike, Scalar from pandas.api.types import is_bool, is_list_like @@ -58,7 +59,6 @@ import snowflake.snowpark.modin.pandas as pd import snowflake.snowpark.modin.pandas.utils as frontend_utils -from snowflake.snowpark.modin.pandas.base import BasePandasDataset from snowflake.snowpark.modin.pandas.dataframe import DataFrame from snowflake.snowpark.modin.pandas.series import ( SERIES_SETITEM_LIST_LIKE_KEY_AND_RANGE_LIKE_VALUE_ERROR_MESSAGE, diff --git a/src/snowflake/snowpark/modin/pandas/series.py b/src/snowflake/snowpark/modin/pandas/series.py index 1ce3ecfc997..6e1b93437a8 100644 --- a/src/snowflake/snowpark/modin/pandas/series.py +++ b/src/snowflake/snowpark/modin/pandas/series.py @@ -31,6 +31,7 @@ import numpy.typing as npt import pandas from modin.pandas.accessor import CachedAccessor, SparseAccessor +from modin.pandas.base import BasePandasDataset from modin.pandas.iterator import PartitionIterator from pandas._libs.lib import NoDefault, is_integer, no_default from pandas._typing import ( @@ -51,17 +52,18 @@ from pandas.core.series import _coerce_method from pandas.util._validators import validate_bool_kwarg -from snowflake.snowpark.modin.pandas.base import _ATTRS_NO_LOOKUP, BasePandasDataset from snowflake.snowpark.modin.pandas.utils import ( from_pandas, is_scalar, try_convert_index_to_native, ) +from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta from snowflake.snowpark.modin.plugin._typing import DropKeep, ListLike from snowflake.snowpark.modin.plugin.utils.error_message import ( ErrorMessage, series_not_implemented, ) +from snowflake.snowpark.modin.plugin.utils.frontend_constants import _ATTRS_NO_LOOKUP from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage from snowflake.snowpark.modin.utils import ( MODIN_UNNAMED_SERIES_LABEL, @@ -108,7 +110,7 @@ ], apilink="pandas.Series", ) -class Series(BasePandasDataset): +class Series(BasePandasDataset, metaclass=TelemetryMeta): _pandas_class = pandas.Series __array_priority__ = pandas.Series.__array_priority__ diff --git a/src/snowflake/snowpark/modin/pandas/utils.py b/src/snowflake/snowpark/modin/pandas/utils.py index f971e0ff964..32702c8b1a4 100644 --- a/src/snowflake/snowpark/modin/pandas/utils.py +++ b/src/snowflake/snowpark/modin/pandas/utils.py @@ -170,10 +170,9 @@ def is_scalar(obj): bool True if given object is scalar and False otherwise. """ + from modin.pandas.base import BasePandasDataset from pandas.api.types import is_scalar as pandas_is_scalar - from .base import BasePandasDataset - return not isinstance(obj, BasePandasDataset) and pandas_is_scalar(obj) diff --git a/src/snowflake/snowpark/modin/plugin/__init__.py b/src/snowflake/snowpark/modin/plugin/__init__.py index a76b9fe1613..c4172f26696 100644 --- a/src/snowflake/snowpark/modin/plugin/__init__.py +++ b/src/snowflake/snowpark/modin/plugin/__init__.py @@ -63,15 +63,23 @@ import modin.utils # type: ignore[import] # isort: skip # noqa: E402 import modin.pandas.series_utils # type: ignore[import] # isort: skip # noqa: E402 -modin.utils._inherit_docstrings( - docstrings.series_utils.StringMethods, - overwrite_existing=True, -)(modin.pandas.series_utils.StringMethods) - -modin.utils._inherit_docstrings( - docstrings.series_utils.CombinedDatetimelikeProperties, - overwrite_existing=True, -)(modin.pandas.series_utils.DatetimeProperties) +# TODO: SNOW-1643979 pull in fixes for +# https://github.com/modin-project/modin/issues/7113 and https://github.com/modin-project/modin/issues/7134 +# Upstream Modin has issues with certain docstring generation edge cases, so we should use our version instead +_inherit_docstrings = snowflake.snowpark.modin.utils._inherit_docstrings + +inherit_modules = [ + (docstrings.base.BasePandasDataset, modin.pandas.base.BasePandasDataset), + (docstrings.series_utils.StringMethods, modin.pandas.series_utils.StringMethods), + ( + docstrings.series_utils.CombinedDatetimelikeProperties, + modin.pandas.series_utils.DatetimeProperties, + ), +] + +for (doc_module, target_object) in inherit_modules: + _inherit_docstrings(doc_module, overwrite_existing=True)(target_object) + # Don't warn the user about our internal usage of private preview pivot # features. The user should have already been warned that Snowpark pandas diff --git a/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py b/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py index 0a022b0d588..8057cf93885 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/telemetry.py @@ -495,6 +495,49 @@ def wrap(*args, **kwargs): # type: ignore } +def try_add_telemetry_to_attribute(attr_name: str, attr_value: Any) -> Any: + """ + Attempts to add telemetry to an attribute. + + If the attribute is callable with name in TELEMETRY_PRIVATE_METHODS, or is a callable that + starts with an underscore, the original attribute will be returned as-is. Otherwise, a version + of the method/property annotated with Snowpark pandas telemetry is returned. + """ + if callable(attr_value) and ( + not attr_name.startswith("_") or (attr_name in TELEMETRY_PRIVATE_METHODS) + ): + return snowpark_pandas_telemetry_method_decorator(attr_value) + elif isinstance(attr_value, property): + # wrap on getter and setter + return property( + snowpark_pandas_telemetry_method_decorator( + cast( + # add a cast because mypy doesn't recognize that + # non-None fget and __get__ are both callable + # arguments to snowpark_pandas_telemetry_method_decorator. + Callable, + attr_value.__get__ # pragma: no cover: we don't encounter this case in pandas or modin because every property has an fget method. + if attr_value.fget is None + else attr_value.fget, + ), + property_name=attr_name, + property_method_type=PropertyMethodType.FGET, + ), + snowpark_pandas_telemetry_method_decorator( + attr_value.__set__ if attr_value.fset is None else attr_value.fset, + property_name=attr_name, + property_method_type=PropertyMethodType.FSET, + ), + snowpark_pandas_telemetry_method_decorator( + attr_value.__delete__ if attr_value.fdel is None else attr_value.fdel, + property_name=attr_name, + property_method_type=PropertyMethodType.FDEL, + ), + doc=attr_value.__doc__, + ) + return attr_value + + class TelemetryMeta(type): def __new__( cls, name: str, bases: tuple, attrs: dict[str, Any] @@ -536,43 +579,5 @@ def __new__( The modified class with decorated methods. """ for attr_name, attr_value in attrs.items(): - if callable(attr_value) and ( - not attr_name.startswith("_") - or (attr_name in TELEMETRY_PRIVATE_METHODS) - ): - attrs[attr_name] = snowpark_pandas_telemetry_method_decorator( - attr_value - ) - elif isinstance(attr_value, property): - # wrap on getter and setter - attrs[attr_name] = property( - snowpark_pandas_telemetry_method_decorator( - cast( - # add a cast because mypy doesn't recognize that - # non-None fget and __get__ are both callable - # arguments to snowpark_pandas_telemetry_method_decorator. - Callable, - attr_value.__get__ # pragma: no cover: we don't encounter this case in pandas or modin because every property has an fget method. - if attr_value.fget is None - else attr_value.fget, - ), - property_name=attr_name, - property_method_type=PropertyMethodType.FGET, - ), - snowpark_pandas_telemetry_method_decorator( - attr_value.__set__ - if attr_value.fset is None - else attr_value.fset, - property_name=attr_name, - property_method_type=PropertyMethodType.FSET, - ), - snowpark_pandas_telemetry_method_decorator( - attr_value.__delete__ - if attr_value.fdel is None - else attr_value.fdel, - property_name=attr_name, - property_method_type=PropertyMethodType.FDEL, - ), - doc=attr_value.__doc__, - ) + attrs[attr_name] = try_add_telemetry_to_attribute(attr_name, attr_value) return type.__new__(cls, name, bases, attrs) diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index c7ba0180fd1..079f132f372 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -395,6 +395,8 @@ class SnowflakeQueryCompiler(BaseQueryCompiler): this class is best explained by looking at https://github.com/modin-project/modin/blob/a8be482e644519f2823668210cec5cf1564deb7e/modin/experimental/core/storage_formats/hdk/query_compiler.py """ + lazy_execution = True + def __init__(self, frame: InternalFrame) -> None: """this stores internally a local pandas object (refactor this)""" assert frame is not None and isinstance( @@ -774,6 +776,7 @@ def execute(self) -> None: def to_numpy( self, dtype: Optional[npt.DTypeLike] = None, + copy: Optional[bool] = False, na_value: object = lib.no_default, **kwargs: Any, ) -> np.ndarray: @@ -781,6 +784,12 @@ def to_numpy( # i.e., for something like df.values internally to_numpy().flatten() is called # with flatten being another query compiler call into the numpy frontend layer. # here it's overwritten to actually perform numpy conversion, i.e. return an actual numpy object + if copy: + WarningMessage.ignored_argument( + operation="to_numpy", + argument="copy", + message="copy is ignored in Snowflake backend", + ) return self.to_pandas().to_numpy(dtype=dtype, na_value=na_value, **kwargs) def repartition(self, axis: Any = None) -> "SnowflakeQueryCompiler": @@ -1407,17 +1416,6 @@ def cache_result(self) -> "SnowflakeQueryCompiler": """ return SnowflakeQueryCompiler(self._modin_frame.persist_to_temporary_table()) - @property - def columns(self) -> native_pd.Index: - """ - Get pandas column labels. - - Returns: - an index containing all pandas column labels - """ - # TODO SNOW-837664: add more tests for df.columns - return self._modin_frame.data_columns_index - @snowpark_pandas_type_immutable_check def set_columns(self, new_pandas_labels: Axes) -> "SnowflakeQueryCompiler": """ @@ -1472,6 +1470,12 @@ def set_columns(self, new_pandas_labels: Axes) -> "SnowflakeQueryCompiler": ) return SnowflakeQueryCompiler(new_internal_frame) + # TODO SNOW-837664: add more tests for df.columns + def get_columns(self) -> native_pd.Index: + return self._modin_frame.data_columns_index + + columns: native_pd.Index = property(get_columns, set_columns) + def _shift_values( self, periods: int, axis: Union[Literal[0], Literal[1]], fill_value: Hashable ) -> "SnowflakeQueryCompiler": @@ -2814,6 +2818,8 @@ def reset_index( Returns: A new SnowflakeQueryCompiler instance with updated index. """ + if allow_duplicates is no_default: + allow_duplicates = False # These levels will be moved from index columns to data columns levels_to_be_reset = self._modin_frame.parse_levels_to_integer_levels( level, allow_duplicates=False @@ -3014,9 +3020,11 @@ def first_last_valid_index( def sort_index( self, + *, axis: int, level: Optional[list[Union[str, int]]], ascending: Union[bool, list[bool]], + inplace: bool = False, kind: SortKind, na_position: NaPosition, sort_remaining: bool, @@ -3032,6 +3040,8 @@ def sort_index( level: If not None, sort on values in specified index level(s). ascending: A list of bools to represent ascending vs descending sort. Defaults to True. When the index is a MultiIndex the sort direction can be controlled for each level individually. + inplace: Whether or not the sort occurs in-place. This argument is ignored and only provided + for compatibility with Modin. kind: Choice of sorting algorithm. Perform stable sort if 'stable'. Defaults to unstable sort. Snowpark pandas ignores choice of sorting algorithm except 'stable'. na_position: Puts NaNs at the beginning if 'first'; 'last' puts NaNs at the end. Defaults to 'last' @@ -10866,6 +10876,12 @@ def is_multiindex(self, *, axis: int = 0) -> bool: """ return self._modin_frame.is_multiindex(axis=axis) + def abs(self) -> "SnowflakeQueryCompiler": + return self.unary_op("abs") + + def negative(self) -> "SnowflakeQueryCompiler": + return self.unary_op("__neg__") + def unary_op(self, op: str) -> "SnowflakeQueryCompiler": """ Applies a unary operation `op` on each element of the `SnowflakeQueryCompiler`. diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py b/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py index 047b4592068..f0c02aa0e65 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py @@ -13,6 +13,8 @@ _shared_docs, ) +from .base import BasePandasDataset + _doc_binary_op_kwargs = {"returns": "BasePandasDataset", "left": "BasePandasDataset"} @@ -49,7 +51,7 @@ } -class DataFrame: +class DataFrame(BasePandasDataset): """ Snowpark pandas representation of ``pandas.DataFrame`` with a lazily-evaluated relational dataset. diff --git a/src/snowflake/snowpark/modin/plugin/docstrings/series.py b/src/snowflake/snowpark/modin/plugin/docstrings/series.py index 6e48a7e57f3..4878c82635a 100644 --- a/src/snowflake/snowpark/modin/plugin/docstrings/series.py +++ b/src/snowflake/snowpark/modin/plugin/docstrings/series.py @@ -15,6 +15,8 @@ ) from snowflake.snowpark.modin.utils import _create_operator_docstring +from .base import BasePandasDataset + _shared_doc_kwargs = { "axes": "index", "klass": "Series", @@ -35,7 +37,7 @@ } -class Series: +class Series(BasePandasDataset): """ Snowpark pandas representation of `pandas.Series` with a lazily-evaluated relational dataset. diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_extensions.py b/src/snowflake/snowpark/modin/plugin/extensions/base_extensions.py new file mode 100644 index 00000000000..496136d736e --- /dev/null +++ b/src/snowflake/snowpark/modin/plugin/extensions/base_extensions.py @@ -0,0 +1,46 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +""" +File containing BasePandasDataset APIs defined in Snowpark pandas but not the Modin API layer. +""" + +from snowflake.snowpark.modin.plugin._internal.telemetry import ( + snowpark_pandas_telemetry_method_decorator, +) + +from .base_overrides import register_base_override + + +@register_base_override("__array_function__") +@snowpark_pandas_telemetry_method_decorator +def __array_function__(self, func: callable, types: tuple, args: tuple, kwargs: dict): + """ + Apply the `func` to the `BasePandasDataset`. + + Parameters + ---------- + func : np.func + The NumPy func to apply. + types : tuple + The types of the args. + args : tuple + The args to the func. + kwargs : dict + Additional keyword arguments. + + Returns + ------- + BasePandasDataset + The result of the ufunc applied to the `BasePandasDataset`. + """ + from snowflake.snowpark.modin.plugin.utils.numpy_to_pandas import ( + numpy_to_pandas_func_map, + ) + + if func.__name__ in numpy_to_pandas_func_map: + return numpy_to_pandas_func_map[func.__name__](*args, **kwargs) + else: + # per NEP18 we raise NotImplementedError so that numpy can intercept + return NotImplemented # pragma: no cover diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py index 332df757787..abbcb9bc762 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py @@ -6,31 +6,123 @@ Methods defined on BasePandasDataset that are overridden in Snowpark pandas. Adding a method to this file should be done with discretion, and only when relevant changes cannot be made to the query compiler or upstream frontend to accommodate Snowpark pandas. + +If you must override a method in this file, please add a comment describing why it must be overridden, +and if possible, whether this can be reconciled with upstream Modin. """ from __future__ import annotations import pickle as pkl -from typing import Any +import warnings +from collections.abc import Sequence +from typing import Any, Callable, Hashable, Literal, Mapping, get_args +import modin.pandas as pd import numpy as np +import numpy.typing as npt import pandas from modin.pandas.base import BasePandasDataset -from pandas._libs.lib import no_default +from pandas._libs import lib +from pandas._libs.lib import NoDefault, is_bool, no_default from pandas._typing import ( + AggFuncType, + AnyArrayLike, + Axes, Axis, CompressionOptions, + FillnaOptions, + IgnoreRaise, + IndexKeyFunc, + IndexLabel, + Level, + NaPosition, + RandomState, + Scalar, StorageOptions, TimedeltaConvertibleTypes, + TimestampConvertibleTypes, +) +from pandas.core.common import apply_if_callable +from pandas.core.dtypes.common import ( + is_dict_like, + is_dtype_equal, + is_list_like, + is_numeric_dtype, + pandas_dtype, +) +from pandas.core.dtypes.inference import is_integer +from pandas.core.methods.describe import _refine_percentiles +from pandas.errors import SpecificationError +from pandas.util._validators import ( + validate_ascending, + validate_bool_kwarg, + validate_percentile, ) +import snowflake.snowpark.modin.pandas as spd from snowflake.snowpark.modin.pandas.api.extensions import ( register_dataframe_accessor, register_series_accessor, ) +from snowflake.snowpark.modin.pandas.utils import ( + ensure_index, + extract_validate_and_try_convert_named_aggs_from_kwargs, + get_as_shape_compatible_dataframe_or_series, + is_scalar, + raise_if_native_pandas_objects, + validate_and_try_convert_agg_func_arg_func_to_str, +) from snowflake.snowpark.modin.plugin._internal.telemetry import ( snowpark_pandas_telemetry_method_decorator, + try_add_telemetry_to_attribute, +) +from snowflake.snowpark.modin.plugin._typing import ListLike +from snowflake.snowpark.modin.plugin.utils.error_message import ( + ErrorMessage, + base_not_implemented, ) -from snowflake.snowpark.modin.plugin.utils.error_message import base_not_implemented +from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage +from snowflake.snowpark.modin.utils import validate_int_kwarg + + +def register_base_override(method_name: str): + """ + Decorator function to override a method on BasePandasDataset. Since Modin does not provide a mechanism + for directly overriding methods on BasePandasDataset, we mock this by performing the override on + DataFrame and Series, and manually performing a `setattr` on the base class. These steps are necessary + to allow both the docstring extension and method dispatch to work properly. + + Methods annotated here also are automatically instrumented with Snowpark pandas telemetry. + """ + + def decorator(base_method: Any): + base_method = try_add_telemetry_to_attribute(method_name, base_method) + parent_method = getattr(BasePandasDataset, method_name, None) + if isinstance(parent_method, property): + parent_method = parent_method.fget + # If the method was not defined on Series/DataFrame and instead inherited from the superclass + # we need to override it as well because the MRO was already determined or something? + # TODO: SNOW-1063347 + # Since we still use the vendored version of Series and the overrides for the top-level + # namespace haven't been performed yet, we need to set properties on the vendored version + series_method = getattr(spd.series.Series, method_name, None) + if isinstance(series_method, property): + series_method = series_method.fget + if series_method is None or series_method is parent_method: + register_series_accessor(method_name)(base_method) + # TODO: SNOW-1063346 + # Since we still use the vendored version of DataFrame and the overrides for the top-level + # namespace haven't been performed yet, we need to set properties on the vendored version + df_method = getattr(spd.dataframe.DataFrame, method_name, None) + if isinstance(df_method, property): + df_method = df_method.fget + if df_method is None or df_method is parent_method: + register_dataframe_accessor(method_name)(base_method) + # Replace base method + setattr(BasePandasDataset, method_name, base_method) + return base_method + + return decorator def register_base_not_implemented(): @@ -303,3 +395,1901 @@ def truncate( @register_base_not_implemented() def __finalize__(self, other, method=None, **kwargs): pass # pragma: no cover + + +# === OVERRIDDEN METHODS === +# The below methods have their frontend implementations overridden compared to the version present +# in base.py. This is usually for one of the following reasons: +# 1. The underlying QC interface used differs from that of modin. Notably, this applies to aggregate +# and binary operations; further work is needed to refactor either our implementation or upstream +# modin's implementation. +# 2. Modin performs extra validation queries that perform extra SQL queries. Some of these are already +# fixed on main; see https://github.com/modin-project/modin/issues/7340 for details. +# 3. Upstream Modin defaults to pandas for some edge cases. Defaulting to pandas at the query compiler +# layer is acceptable because we can force the method to raise NotImplementedError, but if a method +# defaults at the frontend, Modin raises a warning and performs the operation by coercing the +# dataset to a native pandas object. Removing these is tracked by +# https://github.com/modin-project/modin/issues/7104 +# 4. Snowpark pandas uses different default arguments from modin. This occurs if some parameters are +# only partially supported (like `numeric_only=True` for `skew`), but this behavior should likewise +# be revisited. + +# `aggregate` for axis=1 is performed as a call to `BasePandasDataset.apply` in upstream Modin, +# which is unacceptable for Snowpark pandas. Upstream Modin should be changed to allow the query +# compiler or a different layer to control dispatch. +@register_base_override("aggregate") +def aggregate( + self, func: AggFuncType = None, axis: Axis | None = 0, *args: Any, **kwargs: Any +): + """ + Aggregate using one or more operations over the specified axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + from snowflake.snowpark.modin.pandas import Series + + origin_axis = axis + axis = self._get_axis_number(axis) + + if axis == 1 and isinstance(self, Series): + raise ValueError(f"No axis named {origin_axis} for object type Series") + + if len(self._query_compiler.columns) == 0: + # native pandas raise error with message "no result", here we raise a more readable error. + raise ValueError("No column to aggregate on.") + + # If we are using named kwargs, then we do not clear the kwargs (need them in the QC for processing + # order, as well as formatting error messages.) + uses_named_kwargs = False + # If aggregate is called on a Series, named aggregations can be passed in via a dictionary + # to func. + if func is None or (is_dict_like(func) and not self._is_dataframe): + if axis == 1: + raise ValueError( + "`func` must not be `None` when `axis=1`. Named aggregations are not supported with `axis=1`." + ) + if func is not None: + # If named aggregations are passed in via a dictionary to func, then we + # ignore the kwargs. + if any(is_dict_like(value) for value in func.values()): + # We can only get to this codepath if self is a Series, and func is a dictionary. + # In this case, if any of the values of func are themselves dictionaries, we must raise + # a Specification Error, as that is what pandas does. + raise SpecificationError("nested renamer is not supported") + kwargs = func + func = extract_validate_and_try_convert_named_aggs_from_kwargs( + self, allow_duplication=False, axis=axis, **kwargs + ) + uses_named_kwargs = True + else: + func = validate_and_try_convert_agg_func_arg_func_to_str( + agg_func=func, + obj=self, + allow_duplication=False, + axis=axis, + ) + + # This is to stay consistent with pandas result format, when the func is single + # aggregation function in format of callable or str, reduce the result dimension to + # convert dataframe to series, or convert series to scalar. + # Note: When named aggregations are used, the result is not reduced, even if there + # is only a single function. + # needs_reduce_dimension cannot be True if we are using named aggregations, since + # the values for func in that case are either NamedTuples (AggFuncWithLabels) or + # lists of NamedTuples, both of which are list like. + need_reduce_dimension = ( + (callable(func) or isinstance(func, str)) + # A Series should be returned when a single scalar string/function aggregation function, or a + # dict of scalar string/functions is specified. In all other cases (including if the function + # is a 1-element list), the result is a DataFrame. + # + # The examples below have axis=1, but the same logic is applied for axis=0. + # >>> df = pd.DataFrame({"a": [0, 1], "b": [2, 3]}) + # + # single aggregation: return Series + # >>> df.agg("max", axis=1) + # 0 2 + # 1 3 + # dtype: int64 + # + # list of aggregations: return DF + # >>> df.agg(["max"], axis=1) + # max + # 0 2 + # 1 3 + # + # dict where all aggregations are strings: return Series + # >>> df.agg({1: "max", 0: "min"}, axis=1) + # 1 3 + # 0 0 + # dtype: int64 + # + # dict where one element is a list: return DF + # >>> df.agg({1: "max", 0: ["min"]}, axis=1) + # max min + # 1 3.0 NaN + # 0 NaN 0.0 + or ( + is_dict_like(func) + and all(not is_list_like(value) for value in func.values()) + ) + ) + + # If func is a dict, pandas will not respect kwargs for each aggregation function, and + # we should drop them before passing the to the query compiler. + # + # >>> native_pd.DataFrame({"a": [0, 1], "b": [np.nan, 0]}).agg("max", skipna=False, axis=1) + # 0 NaN + # 1 1.0 + # dtype: float64 + # >>> native_pd.DataFrame({"a": [0, 1], "b": [np.nan, 0]}).agg(["max"], skipna=False, axis=1) + # max + # 0 0.0 + # 1 1.0 + # >>> pd.DataFrame([[np.nan], [0]]).aggregate("count", skipna=True, axis=0) + # 0 1 + # dtype: int8 + # >>> pd.DataFrame([[np.nan], [0]]).count(skipna=True, axis=0) + # TypeError: got an unexpected keyword argument 'skipna' + if is_dict_like(func) and not uses_named_kwargs: + kwargs.clear() + + result = self.__constructor__( + query_compiler=self._query_compiler.agg( + func=func, + axis=axis, + args=args, + kwargs=kwargs, + ) + ) + + if need_reduce_dimension: + if self._is_dataframe: + result = Series(query_compiler=result._query_compiler) + + if isinstance(result, Series): + # When func is just "quantile" with a scalar q, result has quantile value as name + q = kwargs.get("q", 0.5) + if func == "quantile" and is_scalar(q): + result.name = q + else: + result.name = None + + # handle case for single scalar (same as result._reduce_dimension()) + if isinstance(self, Series): + return result.to_pandas().squeeze() + + return result + + +# `agg` is an alias of `aggregate`. +agg = aggregate +register_base_override("agg")(agg) + + +# `_agg_helper` is not defined in modin, and used by Snowpark pandas to do extra validation. +@register_base_override("_agg_helper") +def _agg_helper( + self, + func: str, + skipna: bool = True, + axis: int | None | NoDefault = no_default, + numeric_only: bool = False, + **kwargs: Any, +): + if not self._is_dataframe and numeric_only and not is_numeric_dtype(self.dtype): + # Series aggregations on non-numeric data do not support numeric_only: + # https://github.com/pandas-dev/pandas/blob/cece8c6579854f6b39b143e22c11cac56502c4fd/pandas/core/series.py#L6358 + raise TypeError( + f"Series.{func} does not allow numeric_only=True with non-numeric dtypes." + ) + axis = self._get_axis_number(axis) + numeric_only = validate_bool_kwarg(numeric_only, "numeric_only", none_allowed=True) + skipna = validate_bool_kwarg(skipna, "skipna", none_allowed=False) + agg_kwargs: dict[str, Any] = { + "numeric_only": numeric_only, + "skipna": skipna, + } + agg_kwargs.update(kwargs) + return self.aggregate(func=func, axis=axis, **agg_kwargs) + + +# See _agg_helper +@register_base_override("count") +def count( + self, + axis: Axis | None = 0, + numeric_only: bool = False, +): + """ + Count non-NA cells for `BasePandasDataset`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._agg_helper( + func="count", + axis=axis, + numeric_only=numeric_only, + ) + + +# See _agg_helper +@register_base_override("max") +def max( + self, + axis: Axis | None = 0, + skipna: bool = True, + numeric_only: bool = False, + **kwargs: Any, +): + """ + Return the maximum of the values over the requested axis. + """ + return self._agg_helper( + func="max", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +# See _agg_helper +@register_base_override("min") +def min( + self, + axis: Axis | None | NoDefault = no_default, + skipna: bool = True, + numeric_only: bool = False, + **kwargs, +): + """ + Return the minimum of the values over the requested axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._agg_helper( + func="min", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +# See _agg_helper +@register_base_override("mean") +def mean( + self, + axis: Axis | None | NoDefault = no_default, + skipna: bool = True, + numeric_only: bool = False, + **kwargs: Any, +): + """ + Return the mean of the values over the requested axis. + """ + return self._agg_helper( + func="mean", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +# See _agg_helper +@register_base_override("median") +def median( + self, + axis: Axis | None | NoDefault = no_default, + skipna: bool = True, + numeric_only: bool = False, + **kwargs: Any, +): + """ + Return the mean of the values over the requested axis. + """ + return self._agg_helper( + func="median", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +# See _agg_helper +@register_base_override("std") +def std( + self, + axis: Axis | None = None, + skipna: bool = True, + ddof: int = 1, + numeric_only: bool = False, + **kwargs, +): + """ + Return sample standard deviation over requested axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + kwargs.update({"ddof": ddof}) + return self._agg_helper( + func="std", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +# See _agg_helper +@register_base_override("sum") +def sum( + self, + axis: Axis | None = None, + skipna: bool = True, + numeric_only: bool = False, + min_count: int = 0, + **kwargs: Any, +): + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + min_count = validate_int_kwarg(min_count, "min_count") + kwargs.update({"min_count": min_count}) + return self._agg_helper( + func="sum", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +# See _agg_helper +@register_base_override("var") +def var( + self, + axis: Axis | None = None, + skipna: bool = True, + ddof: int = 1, + numeric_only: bool = False, + **kwargs: Any, +): + """ + Return unbiased variance over requested axis. + """ + kwargs.update({"ddof": ddof}) + return self._agg_helper( + func="var", + axis=axis, + skipna=skipna, + numeric_only=numeric_only, + **kwargs, + ) + + +# Modin does not provide `MultiIndex` support and will default to pandas when `level` is specified, +# and allows binary ops against native pandas objects that Snowpark pandas prohibits. +@register_base_override("_binary_op") +def _binary_op( + self, + op: str, + other: BasePandasDataset, + axis: Axis = None, + level: Level | None = None, + fill_value: float | None = None, + **kwargs: Any, +): + """ + Do binary operation between two datasets. + + Parameters + ---------- + op : str + Name of binary operation. + other : modin.pandas.BasePandasDataset + Second operand of binary operation. + axis: Whether to compare by the index (0 or ‘index’) or columns. (1 or ‘columns’). + level: Broadcast across a level, matching Index values on the passed MultiIndex level. + fill_value: Fill existing missing (NaN) values, and any new element needed for + successful DataFrame alignment, with this value before computation. + If data in both corresponding DataFrame locations is missing the result will be missing. + only arithmetic binary operation has this parameter (e.g., add() has, but eq() doesn't have). + + kwargs can contain the following parameters passed in at the frontend: + func: Only used for `combine` method. Function that takes two series as inputs and + return a Series or a scalar. Used to merge the two dataframes column by columns. + + Returns + ------- + modin.pandas.BasePandasDataset + Result of binary operation. + """ + # In upstream modin, _axis indicates the operator will use the default axis + if kwargs.pop("_axis", None) is None: + if axis is not None: + axis = self._get_axis_number(axis) + else: + axis = 1 + else: + axis = 0 + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + raise_if_native_pandas_objects(other) + axis = self._get_axis_number(axis) + squeeze_self = isinstance(self, pd.Series) + + # pandas itself will ignore the axis argument when using Series.. + # Per default, it is set to axis=0. However, for the case of a Series interacting with + # a DataFrame the behavior is axis=1. Manually check here for this case and adjust the axis. + + is_lhs_series_and_rhs_dataframe = ( + True + if isinstance(self, pd.Series) and isinstance(other, pd.DataFrame) + else False + ) + + new_query_compiler = self._query_compiler.binary_op( + op=op, + other=other, + axis=1 if is_lhs_series_and_rhs_dataframe else axis, + level=level, + fill_value=fill_value, + squeeze_self=squeeze_self, + **kwargs, + ) + + from snowflake.snowpark.modin.pandas.dataframe import DataFrame + + # Modin Bug: https://github.com/modin-project/modin/issues/7236 + # For a Series interacting with a DataFrame, always return a DataFrame + return ( + DataFrame(query_compiler=new_query_compiler) + if is_lhs_series_and_rhs_dataframe + else self._create_or_update_from_compiler(new_query_compiler) + ) + + +# Current Modin does not use _dropna and instead defines `dropna` directly, but Snowpark pandas +# Series/DF still do. Snowpark pandas still needs to add support for the `ignore_index` parameter +# (added in pandas 2.0), and should be able to refactor to remove this override. +@register_base_override("_dropna") +def _dropna( + self, + axis: Axis = 0, + how: str | NoDefault = no_default, + thresh: int | NoDefault = no_default, + subset: IndexLabel = None, + inplace: bool = False, +): + inplace = validate_bool_kwarg(inplace, "inplace") + + if is_list_like(axis): + raise TypeError("supplying multiple axes to axis is no longer supported.") + + axis = self._get_axis_number(axis) + + if (how is not no_default) and (thresh is not no_default): + raise TypeError( + "You cannot set both the how and thresh arguments at the same time." + ) + + if how is no_default: + how = "any" + if how not in ["any", "all"]: + raise ValueError("invalid how option: %s" % how) + if subset is not None: + if axis == 1: + indices = self.index.get_indexer_for(subset) + check = indices == -1 + if check.any(): + raise KeyError(list(np.compress(check, subset))) + else: + indices = self.columns.get_indexer_for(subset) + check = indices == -1 + if check.any(): + raise KeyError(list(np.compress(check, subset))) + + new_query_compiler = self._query_compiler.dropna( + axis=axis, + how=how, + thresh=thresh, + subset=subset, + ) + return self._create_or_update_from_compiler(new_query_compiler, inplace) + + +# Snowpark pandas uses `self_is_series` instead of `squeeze_self` and `squeeze_value` to determine +# the shape of `self` and `value`. Further work is needed to reconcile these two approaches. +@register_base_override("fillna") +def fillna( + self, + self_is_series, + value: Hashable | Mapping | pd.Series | pd.DataFrame = None, + method: FillnaOptions | None = None, + axis: Axis | None = None, + inplace: bool = False, + limit: int | None = None, + downcast: dict | None = None, +): + """ + Fill NA/NaN values using the specified method. + + Parameters + ---------- + self_is_series : bool + If True then self contains a Series object, if False then self contains + a DataFrame object. + value : scalar, dict, Series, or DataFrame, default: None + Value to use to fill holes (e.g. 0), alternately a + dict/Series/DataFrame of values specifying which value to use for + each index (for a Series) or column (for a DataFrame). Values not + in the dict/Series/DataFrame will not be filled. This value cannot + be a list. + method : {'backfill', 'bfill', 'pad', 'ffill', None}, default: None + Method to use for filling holes in reindexed Series + pad / ffill: propagate last valid observation forward to next valid + backfill / bfill: use next valid observation to fill gap. + axis : {None, 0, 1}, default: None + Axis along which to fill missing values. + inplace : bool, default: False + If True, fill in-place. Note: this will modify any + other views on this object (e.g., a no-copy slice for a column in a + DataFrame). + limit : int, default: None + If method is specified, this is the maximum number of consecutive + NaN values to forward/backward fill. In other words, if there is + a gap with more than this number of consecutive NaNs, it will only + be partially filled. If method is not specified, this is the + maximum number of entries along the entire axis where NaNs will be + filled. Must be greater than 0 if not None. + downcast : dict, default: None + A dict of item->dtype of what to downcast if possible, + or the string 'infer' which will try to downcast to an appropriate + equal type (e.g. float64 to int64 if possible). + + Returns + ------- + Series, DataFrame or None + Object with missing values filled or None if ``inplace=True``. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + raise_if_native_pandas_objects(value) + inplace = validate_bool_kwarg(inplace, "inplace") + axis = self._get_axis_number(axis) + if isinstance(value, (list, tuple)): + raise TypeError( + '"value" parameter must be a scalar or dict, but ' + + f'you passed a "{type(value).__name__}"' + ) + if value is None and method is None: + # same as pandas + raise ValueError("Must specify a fill 'value' or 'method'.") + if value is not None and method is not None: + raise ValueError("Cannot specify both 'value' and 'method'.") + if method is not None and method not in ["backfill", "bfill", "pad", "ffill"]: + expecting = "pad (ffill) or backfill (bfill)" + msg = "Invalid fill method. Expecting {expecting}. Got {method}".format( + expecting=expecting, method=method + ) + raise ValueError(msg) + if limit is not None: + if not isinstance(limit, int): + raise ValueError("Limit must be an integer") + elif limit <= 0: + raise ValueError("Limit must be greater than 0") + + new_query_compiler = self._query_compiler.fillna( + self_is_series=self_is_series, + value=value, + method=method, + axis=axis, + limit=limit, + downcast=downcast, + ) + return self._create_or_update_from_compiler(new_query_compiler, inplace) + + +# Snowpark pandas passes the query compiler object from a BasePandasDataset, which Modin does not do. +@register_base_override("isin") +def isin( + self, values: BasePandasDataset | ListLike | dict[Hashable, ListLike] +) -> BasePandasDataset: # noqa: PR01, RT01, D200 + """ + Whether elements in `BasePandasDataset` are contained in `values`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + + # Pass as query compiler if values is BasePandasDataset. + if isinstance(values, BasePandasDataset): + values = values._query_compiler + + # Convert non-dict values to List if values is neither List[Any] nor np.ndarray. SnowflakeQueryCompiler + # expects for the non-lazy case, where values is not a BasePandasDataset, the data to be materialized + # as list or numpy array. Because numpy may perform implicit type conversions, use here list to be more general. + elif not isinstance(values, dict) and ( + not isinstance(values, list) or not isinstance(values, np.ndarray) + ): + values = list(values) + + return self.__constructor__(query_compiler=self._query_compiler.isin(values=values)) + + +# Snowpark pandas uses the single `quantiles_along_axis0` query compiler method, while upstream +# Modin splits this into `quantile_for_single_value` and `quantile_for_list_of_values` calls. +# It should be possible to merge those two functions upstream and reconcile the implementations. +@register_base_override("quantile") +def quantile( + self, + q: Scalar | ListLike = 0.5, + axis: Axis = 0, + numeric_only: bool = False, + interpolation: Literal[ + "linear", "lower", "higher", "midpoint", "nearest" + ] = "linear", + method: Literal["single", "table"] = "single", +) -> float | BasePandasDataset: + """ + Return values at the given quantile over requested axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + axis = self._get_axis_number(axis) + + # TODO + # - SNOW-1008361: support axis=1 + # - SNOW-1008367: support when q is Snowpandas DF/Series (need to require QC interface to accept QC q values) + # - SNOW-1003587: support datetime/timedelta columns + + if axis == 1 or interpolation not in ["linear", "nearest"] or method != "single": + ErrorMessage.not_implemented( + f"quantile function with parameters axis={axis}, interpolation={interpolation}, method={method} not supported" + ) + + if not numeric_only: + # If not numeric_only and columns, then check all columns are either + # numeric, timestamp, or timedelta + # Check if dtype is numeric, timedelta ("m"), or datetime ("M") + if not axis and not all( + is_numeric_dtype(t) or lib.is_np_dtype(t, "mM") for t in self._get_dtypes() + ): + raise TypeError("can't multiply sequence by non-int of type 'float'") + # If over rows, then make sure that all dtypes are equal for not + # numeric_only + elif axis: + for i in range(1, len(self._get_dtypes())): + pre_dtype = self._get_dtypes()[i - 1] + curr_dtype = self._get_dtypes()[i] + if not is_dtype_equal(pre_dtype, curr_dtype): + raise TypeError( + "Cannot compare type '{}' with type '{}'".format( + pre_dtype, curr_dtype + ) + ) + else: + # Normally pandas returns this near the end of the quantile, but we + # can't afford the overhead of running the entire operation before + # we error. + if not any(is_numeric_dtype(t) for t in self._get_dtypes()): + raise ValueError("need at least one array to concatenate") + + # check that all qs are between 0 and 1 + validate_percentile(q) + axis = self._get_axis_number(axis) + query_compiler = self._query_compiler.quantiles_along_axis0( + q=q if is_list_like(q) else [q], + numeric_only=numeric_only, + interpolation=interpolation, + method=method, + ) + if is_list_like(q): + return self.__constructor__(query_compiler=query_compiler) + else: + # result is either a scalar or Series + result = self._reduce_dimension(query_compiler.transpose_single_row()) + if isinstance(result, BasePandasDataset): + result.name = q + return result + + +# Current Modin does not define this method. Snowpark pandas currently only uses it in +# `DataFrame.set_index`. Modin does not support MultiIndex, or have its own lazy index class, +# so we may need to keep this method for the foreseeable future. +@register_base_override("_to_series_list") +def _to_series_list(self, index: pd.Index) -> list[pd.Series]: + """ + Convert index to a list of series + Args: + index: can be single or multi index + + Returns: + the list of series + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + if isinstance(index, pd.MultiIndex): + return [ + pd.Series(index.get_level_values(level)) for level in range(index.nlevels) + ] + elif isinstance(index, pd.Index): + return [pd.Series(index)] + else: + raise Exception("invalid index: " + str(index)) + + +# Upstream modin defaults to pandas when `suffix` is provided. +@register_base_override("shift") +def shift( + self, + periods: int | Sequence[int] = 1, + freq=None, + axis: Axis = 0, + fill_value: Hashable = no_default, + suffix: str | None = None, +) -> BasePandasDataset: + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + if periods == 0 and freq is None: + # Check obvious case first, freq manipulates the index even for periods == 0 so check for it in addition. + return self.copy() + + # pandas compatible ValueError for freq='infer' + # TODO: Test as part of SNOW-1023324. + if freq == "infer": # pragma: no cover + if not hasattr(self, "freq") and not hasattr( # pragma: no cover + self, "inferred_freq" # pragma: no cover + ): # pragma: no cover + raise ValueError() # pragma: no cover + + axis = self._get_axis_number(axis) + + if fill_value == no_default: + fill_value = None + + new_query_compiler = self._query_compiler.shift( + periods, freq, axis, fill_value, suffix + ) + return self._create_or_update_from_compiler(new_query_compiler, False) + + +# Snowpark pandas supports only `numeric_only=True`, which is not the default value of the argument, +# so we have this overridden. We should revisit this behavior. +@register_base_override("skew") +def skew( + self, + axis: Axis | None | NoDefault = no_default, + skipna: bool = True, + numeric_only=True, + **kwargs, +): # noqa: PR01, RT01, D200 + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + """ + Return unbiased skew over requested axis. + """ + return self._stat_operation("skew", axis, skipna, numeric_only, **kwargs) + + +@register_base_override("resample") +def resample( + self, + rule, + axis: Axis = lib.no_default, + closed: str | None = None, + label: str | None = None, + convention: str = "start", + kind: str | None = None, + on: Level = None, + level: Level = None, + origin: str | TimestampConvertibleTypes = "start_day", + offset: TimedeltaConvertibleTypes | None = None, + group_keys=no_default, +): # noqa: PR01, RT01, D200 + """ + Resample time-series data. + """ + from snowflake.snowpark.modin.pandas.resample import Resampler + + if axis is not lib.no_default: # pragma: no cover + axis = self._get_axis_number(axis) + if axis == 1: + warnings.warn( + "DataFrame.resample with axis=1 is deprecated. Do " + + "`frame.T.resample(...)` without axis instead.", + FutureWarning, + stacklevel=1, + ) + else: + warnings.warn( + f"The 'axis' keyword in {type(self).__name__}.resample is " + + "deprecated and will be removed in a future version.", + FutureWarning, + stacklevel=1, + ) + else: + axis = 0 + + return Resampler( + dataframe=self, + rule=rule, + axis=axis, + closed=closed, + label=label, + convention=convention, + kind=kind, + on=on, + level=level, + origin=origin, + offset=offset, + group_keys=group_keys, + ) + + +# Snowpark pandas needs to return a custom Expanding window object. We cannot use the +# extensions module for this at the moment because modin performs a relative import of +# `from .window import Expanding`. +@register_base_override("expanding") +def expanding(self, min_periods=1, axis=0, method="single"): # noqa: PR01, RT01, D200 + """ + Provide expanding window calculations. + """ + from snowflake.snowpark.modin.pandas.window import Expanding + + if axis is not lib.no_default: + axis = self._get_axis_number(axis) + name = "expanding" + if axis == 1: + warnings.warn( + f"Support for axis=1 in {type(self).__name__}.{name} is " + + "deprecated and will be removed in a future version. " + + f"Use obj.T.{name}(...) instead", + FutureWarning, + stacklevel=1, + ) + else: + warnings.warn( + f"The 'axis' keyword in {type(self).__name__}.{name} is " + + "deprecated and will be removed in a future version. " + + "Call the method without the axis keyword instead.", + FutureWarning, + stacklevel=1, + ) + else: + axis = 0 + + return Expanding( + self, + min_periods=min_periods, + axis=axis, + method=method, + ) + + +# Same as Expanding: Snowpark pandas needs to return a custmo Window object. +@register_base_override("rolling") +def rolling( + self, + window, + min_periods: int | None = None, + center: bool = False, + win_type: str | None = None, + on: str | None = None, + axis: Axis = lib.no_default, + closed: str | None = None, + step: int | None = None, + method: str = "single", +): # noqa: PR01, RT01, D200 + """ + Provide rolling window calculations. + """ + if axis is not lib.no_default: + axis = self._get_axis_number(axis) + name = "rolling" + if axis == 1: + warnings.warn( + f"Support for axis=1 in {type(self).__name__}.{name} is " + + "deprecated and will be removed in a future version. " + + f"Use obj.T.{name}(...) instead", + FutureWarning, + stacklevel=1, + ) + else: # pragma: no cover + warnings.warn( + f"The 'axis' keyword in {type(self).__name__}.{name} is " + + "deprecated and will be removed in a future version. " + + "Call the method without the axis keyword instead.", + FutureWarning, + stacklevel=1, + ) + else: + axis = 0 + + if win_type is not None: + from snowflake.snowpark.modin.pandas.window import Window + + return Window( + self, + window=window, + min_periods=min_periods, + center=center, + win_type=win_type, + on=on, + axis=axis, + closed=closed, + step=step, + method=method, + ) + from snowflake.snowpark.modin.pandas.window import Rolling + + return Rolling( + self, + window=window, + min_periods=min_periods, + center=center, + win_type=win_type, + on=on, + axis=axis, + closed=closed, + step=step, + method=method, + ) + + +# Snowpark pandas uses a custom indexer object for all indexing methods. +@register_base_override("iloc") +@property +def iloc(self): + """ + Purely integer-location based indexing for selection by position. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # TODO: SNOW-930028 enable all skipped doctests + from snowflake.snowpark.modin.pandas.indexing import _iLocIndexer + + return _iLocIndexer(self) + + +# Snowpark pandas uses a custom indexer object for all indexing methods. +@register_base_override("loc") +@property +def loc(self): + """ + Get a group of rows and columns by label(s) or a boolean array. + """ + # TODO: SNOW-935444 fix doctest where index key has name + # TODO: SNOW-933782 fix multiindex transpose bug, e.g., Name: (cobra, mark ii) => Name: ('cobra', 'mark ii') + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + from snowflake.snowpark.modin.pandas.indexing import _LocIndexer + + return _LocIndexer(self) + + +# Snowpark pandas uses a custom indexer object for all indexing methods. +@register_base_override("iat") +@property +def iat(self, axis=None): # noqa: PR01, RT01, D200 + """ + Get a single value for a row/column pair by integer position. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + from snowflake.snowpark.modin.pandas.indexing import _iAtIndexer + + return _iAtIndexer(self) + + +# Snowpark pandas uses a custom indexer object for all indexing methods. +@register_base_override("at") +@property +def at(self, axis=None): # noqa: PR01, RT01, D200 + """ + Get a single value for a row/column label pair. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + from snowflake.snowpark.modin.pandas.indexing import _AtIndexer + + return _AtIndexer(self) + + +# Snowpark pandas performs different dispatch logic; some changes may need to be upstreamed +# to fix edge case indexing behaviors. +@register_base_override("__getitem__") +def __getitem__(self, key): + """ + Retrieve dataset according to `key`. + + Parameters + ---------- + key : callable, scalar, slice, str or tuple + The global row index to retrieve data from. + + Returns + ------- + BasePandasDataset + Located dataset. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + key = apply_if_callable(key, self) + # If a slice is passed in, use .iloc[key]. + if isinstance(key, slice): + if (is_integer(key.start) or key.start is None) and ( + is_integer(key.stop) or key.stop is None + ): + return self.iloc[key] + else: + return self.loc[key] + + # If the object calling getitem is a Series, only use .loc[key] to filter index. + if isinstance(self, pd.Series): + return self.loc[key] + + # Sometimes the result of a callable is a DataFrame (e.g. df[df > 0]) - use where. + elif isinstance(key, pd.DataFrame): + return self.where(cond=key) + + # If the object is a boolean list-like object, use .loc[key] to filter index. + # The if statement is structured this way to avoid calling dtype and reduce query count. + if isinstance(key, pd.Series): + if pandas.api.types.is_bool_dtype(key.dtype): + return self.loc[key] + elif is_list_like(key): + if hasattr(key, "dtype"): + if pandas.api.types.is_bool_dtype(key.dtype): + return self.loc[key] + if (all(is_bool(k) for k in key)) and len(key) > 0: + return self.loc[key] + + # In all other cases, use .loc[:, key] to filter columns. + return self.loc[:, key] + + +# Snowpark pandas does extra argument validation, which may need to be upstreamed. +@register_base_override("sort_values") +def sort_values( + self, + by, + axis=0, + ascending=True, + inplace: bool = False, + kind="quicksort", + na_position="last", + ignore_index: bool = False, + key: IndexKeyFunc | None = None, +): # noqa: PR01, RT01, D200 + """ + Sort by the values along either axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + axis = self._get_axis_number(axis) + inplace = validate_bool_kwarg(inplace, "inplace") + ascending = validate_ascending(ascending) + if axis == 0: + # If any column is None raise KeyError (same a native pandas). + if by is None or (isinstance(by, list) and None in by): + # Same error message as native pandas. + raise KeyError(None) + if not isinstance(by, list): + by = [by] + + # Convert 'ascending' to sequence if needed. + if not isinstance(ascending, Sequence): + ascending = [ascending] * len(by) + if len(by) != len(ascending): + # Same error message as native pandas. + raise ValueError( + f"Length of ascending ({len(ascending)})" + f" != length of by ({len(by)})" + ) + + columns = self._query_compiler.columns.values.tolist() + index_names = self._query_compiler.get_index_names() + for by_col in by: + col_count = columns.count(by_col) + index_count = index_names.count(by_col) + if col_count == 0 and index_count == 0: + # Same error message as native pandas. + raise KeyError(by_col) + if col_count and index_count: + # Same error message as native pandas. + raise ValueError( + f"'{by_col}' is both an index level and a column label, which is ambiguous." + ) + if col_count > 1: + # Same error message as native pandas. + raise ValueError(f"The column label '{by_col}' is not unique.") + + if na_position not in get_args(NaPosition): + # Same error message as native pandas for invalid 'na_position' value. + raise ValueError(f"invalid na_position: {na_position}") + result = self._query_compiler.sort_rows_by_column_values( + by, + ascending=ascending, + kind=kind, + na_position=na_position, + ignore_index=ignore_index, + key=key, + ) + else: + result = self._query_compiler.sort_columns_by_row_values( + by, + ascending=ascending, + kind=kind, + na_position=na_position, + ignore_index=ignore_index, + key=key, + ) + return self._create_or_update_from_compiler(result, inplace) + + +# Modin does not define `where` on BasePandasDataset, and defaults to pandas at the frontend +# layer for Series. +@register_base_override("where") +def where( + self, + cond: BasePandasDataset | Callable | AnyArrayLike, + other: BasePandasDataset | Callable | Scalar | None = np.nan, + inplace: bool = False, + axis: Axis | None = None, + level: Level | None = None, +): + """ + Replace values where the condition is False. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # TODO: SNOW-985670: Refactor `where` and `mask` + # will move pre-processing to QC layer. + inplace = validate_bool_kwarg(inplace, "inplace") + if cond is None: + raise ValueError("Array conditional must be same shape as self") + + cond = apply_if_callable(cond, self) + + if isinstance(cond, Callable): + raise NotImplementedError("Do not support callable for 'cond' parameter.") + + from snowflake.snowpark.modin.pandas import Series + + if isinstance(cond, Series): + cond._query_compiler._shape_hint = "column" + if isinstance(self, Series): + self._query_compiler._shape_hint = "column" + if isinstance(other, Series): + other._query_compiler._shape_hint = "column" + + if not isinstance(cond, BasePandasDataset): + cond = get_as_shape_compatible_dataframe_or_series(cond, self) + cond._query_compiler._shape_hint = "array" + + if other is not None: + other = apply_if_callable(other, self) + + if isinstance(other, np.ndarray): + other = get_as_shape_compatible_dataframe_or_series( + other, + self, + shape_mismatch_message="other must be the same shape as self when an ndarray", + ) + other._query_compiler._shape_hint = "array" + + if isinstance(other, BasePandasDataset): + other = other._query_compiler + + query_compiler = self._query_compiler.where( + cond._query_compiler, + other, + axis, + level, + ) + + return self._create_or_update_from_compiler(query_compiler, inplace) + + +# Snowpark pandas performs extra argument validation, some of which should be pushed down +# to the QC layer. +@register_base_override("mask") +def mask( + self, + cond: BasePandasDataset | Callable | AnyArrayLike, + other: BasePandasDataset | Callable | Scalar | None = np.nan, + inplace: bool = False, + axis: Axis | None = None, + level: Level | None = None, +): + """ + Replace values where the condition is True. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # TODO: https://snowflakecomputing.atlassian.net/browse/SNOW-985670 + # will move pre-processing to QC layer. + inplace = validate_bool_kwarg(inplace, "inplace") + if cond is None: + raise ValueError("Array conditional must be same shape as self") + + cond = apply_if_callable(cond, self) + + if isinstance(cond, Callable): + raise NotImplementedError("Do not support callable for 'cond' parameter.") + + from snowflake.snowpark.modin.pandas import Series + + if isinstance(cond, Series): + cond._query_compiler._shape_hint = "column" + if isinstance(self, Series): + self._query_compiler._shape_hint = "column" + if isinstance(other, Series): + other._query_compiler._shape_hint = "column" + + if not isinstance(cond, BasePandasDataset): + cond = get_as_shape_compatible_dataframe_or_series(cond, self) + cond._query_compiler._shape_hint = "array" + + if other is not None: + other = apply_if_callable(other, self) + + if isinstance(other, np.ndarray): + other = get_as_shape_compatible_dataframe_or_series( + other, + self, + shape_mismatch_message="other must be the same shape as self when an ndarray", + ) + other._query_compiler._shape_hint = "array" + + if isinstance(other, BasePandasDataset): + other = other._query_compiler + + query_compiler = self._query_compiler.mask( + cond._query_compiler, + other, + axis, + level, + ) + + return self._create_or_update_from_compiler(query_compiler, inplace) + + +# Snowpark pandas uses a custom I/O dispatcher class. +@register_base_override("to_csv") +def to_csv( + self, + path_or_buf=None, + sep=",", + na_rep="", + float_format=None, + columns=None, + header=True, + index=True, + index_label=None, + mode="w", + encoding=None, + compression="infer", + quoting=None, + quotechar='"', + lineterminator=None, + chunksize=None, + date_format=None, + doublequote=True, + escapechar=None, + decimal=".", + errors: str = "strict", + storage_options: StorageOptions = None, +): # pragma: no cover + from snowflake.snowpark.modin.core.execution.dispatching.factories.dispatcher import ( + FactoryDispatcher, + ) + + return FactoryDispatcher.to_csv( + self._query_compiler, + path_or_buf=path_or_buf, + sep=sep, + na_rep=na_rep, + float_format=float_format, + columns=columns, + header=header, + index=index, + index_label=index_label, + mode=mode, + encoding=encoding, + compression=compression, + quoting=quoting, + quotechar=quotechar, + lineterminator=lineterminator, + chunksize=chunksize, + date_format=date_format, + doublequote=doublequote, + escapechar=escapechar, + decimal=decimal, + errors=errors, + storage_options=storage_options, + ) + + +# Modin performs extra argument validation and defaults to pandas for some edge cases. +@register_base_override("sample") +def sample( + self, + n: int | None = None, + frac: float | None = None, + replace: bool = False, + weights: str | np.ndarray | None = None, + random_state: RandomState | None = None, + axis: Axis | None = None, + ignore_index: bool = False, +): + """ + Return a random sample of items from an axis of object. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + if self._get_axis_number(axis): + if weights is not None and isinstance(weights, str): + raise ValueError( + "Strings can only be passed to weights when sampling from rows on a DataFrame" + ) + else: + if n is None and frac is None: + n = 1 + elif n is not None and frac is not None: + raise ValueError("Please enter a value for `frac` OR `n`, not both") + else: + if n is not None: + if n < 0: + raise ValueError( + "A negative number of rows requested. Please provide `n` >= 0." + ) + if n % 1 != 0: + raise ValueError("Only integers accepted as `n` values") + else: + if frac < 0: + raise ValueError( + "A negative number of rows requested. Please provide `frac` >= 0." + ) + + query_compiler = self._query_compiler.sample( + n, frac, replace, weights, random_state, axis, ignore_index + ) + return self.__constructor__(query_compiler=query_compiler) + + +# Modin performs an extra query calling self.isna() to raise a warning when fill_method is unspecified. +@register_base_override("pct_change") +def pct_change( + self, periods=1, fill_method=no_default, limit=no_default, freq=None, **kwargs +): # noqa: PR01, RT01, D200 + """ + Percentage change between the current and a prior element. + """ + if fill_method not in (lib.no_default, None) or limit is not lib.no_default: + warnings.warn( + "The 'fill_method' keyword being not None and the 'limit' keyword in " + + f"{type(self).__name__}.pct_change are deprecated and will be removed " + + "in a future version. Either fill in any non-leading NA values prior " + + "to calling pct_change or specify 'fill_method=None' to not fill NA " + + "values.", + FutureWarning, + stacklevel=1, + ) + if fill_method is lib.no_default: + warnings.warn( + f"The default fill_method='pad' in {type(self).__name__}.pct_change is " + + "deprecated and will be removed in a future version. Either fill in any " + + "non-leading NA values prior to calling pct_change or specify 'fill_method=None' " + + "to not fill NA values.", + FutureWarning, + stacklevel=1, + ) + fill_method = "pad" + + if limit is lib.no_default: + limit = None + + if "axis" in kwargs: + kwargs["axis"] = self._get_axis_number(kwargs["axis"]) + + # Attempting to match pandas error behavior here + if not isinstance(periods, int): + raise TypeError(f"periods must be an int. got {type(periods)} instead") + + # Attempting to match pandas error behavior here + for dtype in self._get_dtypes(): + if not is_numeric_dtype(dtype): + raise TypeError( + f"cannot perform pct_change on non-numeric column with dtype {dtype}" + ) + + return self.__constructor__( + query_compiler=self._query_compiler.pct_change( + periods=periods, + fill_method=fill_method, + limit=limit, + freq=freq, + **kwargs, + ) + ) + + +# Snowpark pandas has different `copy` behavior, and some different behavior with native series arguments. +@register_base_override("astype") +def astype( + self, + dtype: str | type | pd.Series | dict[str, type], + copy: bool = True, + errors: Literal["raise", "ignore"] = "raise", +) -> pd.DataFrame | pd.Series: + """ + Cast a Modin object to a specified dtype `dtype`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # dtype can be a series, a dict, or a scalar. If it's series or scalar, + # convert it to a dict before passing it to the query compiler. + raise_if_native_pandas_objects(dtype) + from snowflake.snowpark.modin.pandas import Series + + if isinstance(dtype, Series): + dtype = dtype.to_pandas() + if not dtype.index.is_unique: + raise ValueError( + "The new Series of types must have a unique index, i.e. " + + "it must be one-to-one mapping from column names to " + + " their new dtypes." + ) + dtype = dtype.to_dict() + # If we got a series or dict originally, dtype is a dict now. Its keys + # must be column names. + if isinstance(dtype, dict): + # Avoid materializing columns. The query compiler will handle errors where + # dtype dict includes keys that are not in columns. + col_dtypes = dtype + for col_name in col_dtypes: + if col_name not in self._query_compiler.columns: + raise KeyError( + "Only a column name can be used for the key in a dtype mappings argument. " + f"'{col_name}' not found in columns." + ) + else: + # Assume that the dtype is a scalar. + col_dtypes = {column: dtype for column in self._query_compiler.columns} + + # ensure values are pandas dtypes + col_dtypes = {k: pandas_dtype(v) for k, v in col_dtypes.items()} + new_query_compiler = self._query_compiler.astype(col_dtypes, errors=errors) + return self._create_or_update_from_compiler(new_query_compiler, not copy) + + +# Modin defaults to pandsa when `level` is specified, and has some extra axis validation that +# is guarded in newer versions. +@register_base_override("drop") +def drop( + self, + labels: IndexLabel = None, + axis: Axis = 0, + index: IndexLabel = None, + columns: IndexLabel = None, + level: Level = None, + inplace: bool = False, + errors: IgnoreRaise = "raise", +) -> BasePandasDataset | None: + """ + Drop specified labels from `BasePandasDataset`. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + inplace = validate_bool_kwarg(inplace, "inplace") + if labels is not None: + if index is not None or columns is not None: + raise ValueError("Cannot specify both 'labels' and 'index'/'columns'") + axes = {self._get_axis_number(axis): labels} + elif index is not None or columns is not None: + axes = {0: index, 1: columns} + else: + raise ValueError( + "Need to specify at least one of 'labels', 'index' or 'columns'" + ) + + for axis, labels in axes.items(): + if labels is not None: + if level is not None and not self._query_compiler.has_multiindex(axis=axis): + # Same error as native pandas. + raise AssertionError("axis must be a MultiIndex") + # According to pandas documentation, a tuple will be used as a single + # label and not treated as a list-like. + if not is_list_like(labels) or isinstance(labels, tuple): + axes[axis] = [labels] + + new_query_compiler = self._query_compiler.drop( + index=axes.get(0), columns=axes.get(1), level=level, errors=errors + ) + return self._create_or_update_from_compiler(new_query_compiler, inplace) + + +# Modin calls len(self.index) instead of a direct query compiler method. +@register_base_override("__len__") +def __len__(self) -> int: + """ + Return length of info axis. + + Returns + ------- + int + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self._query_compiler.get_axis_len(axis=0) + + +# Snowpark pandas ignores `copy`. +@register_base_override("set_axis") +def set_axis( + self, + labels: IndexLabel, + *, + axis: Axis = 0, + copy: bool | NoDefault = no_default, +): + """ + Assign desired index to given axis. + """ + # Behavior based on copy: + # ----------------------------------- + # - In native pandas, copy determines whether to create a copy of the data (not DataFrame). + # - We cannot emulate the native pandas' copy behavior in Snowpark since a copy of only data + # cannot be created -- you can only copy the whole object (DataFrame/Series). + # + # Snowpark behavior: + # ------------------ + # - copy is kept for compatibility with native pandas but is ignored. The user is warned that copy is unused. + # Warn user that copy does not do anything. + if copy is not no_default: + WarningMessage.single_warning( + message=f"{type(self).__name__}.set_axis 'copy' keyword is unused and is ignored." + ) + if labels is None: + raise TypeError("None is not a valid value for the parameter 'labels'.") + + # Determine whether to update self or a copy and perform update. + obj = self.copy() + setattr(obj, axis, labels) + return obj + + +# Modin has different behavior for empty dataframes and some slightly different length validation. +@register_base_override("describe") +def describe( + self, + percentiles: ListLike | None = None, + include: ListLike | Literal["all"] | None = None, + exclude: ListLike | None = None, +) -> BasePandasDataset: + """ + Generate descriptive statistics. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + percentiles = _refine_percentiles(percentiles) + data = self + if self._is_dataframe: + # Upstream modin lacks this check because it defaults to pandas for describing empty dataframes + if len(self.columns) == 0: + raise ValueError("Cannot describe a DataFrame without columns") + + # include/exclude are ignored for Series + if (include is None) and (exclude is None): + # when some numerics are found, keep only numerics + default_include: list[npt.DTypeLike] = [np.number] + default_include.append("datetime") + data = self.select_dtypes(include=default_include) + if len(data.columns) == 0: + data = self + elif include == "all": + if exclude is not None: + raise ValueError("exclude must be None when include is 'all'") + data = self + else: + data = self.select_dtypes( + include=include, + exclude=exclude, + ) + # Upstream modin uses data.empty, but that incurs an extra row count query + if self._is_dataframe and len(data.columns) == 0: + # Match pandas error from concatenating empty list of series descriptions. + raise ValueError("No objects to concatenate") + + return self.__constructor__( + query_compiler=data._query_compiler.describe(percentiles=percentiles) + ) + + +# Modin does type validation on self that Snowpark pandas defers to SQL. +@register_base_override("diff") +def diff(self, periods: int = 1, axis: Axis = 0): + """ + First discrete difference of element. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + # We must only accept integer (or float values that are whole numbers) + # for periods. + int_periods = validate_int_kwarg(periods, "periods", float_allowed=True) + axis = self._get_axis_number(axis) + return self.__constructor__( + query_compiler=self._query_compiler.diff(axis=axis, periods=int_periods) + ) + + +# Modin does an unnecessary len call when n == 0. +@register_base_override("tail") +def tail(self, n: int = 5): + if n == 0: + return self.iloc[0:0] + return self.iloc[-n:] + + +# Snowpark pandas does extra argument validation (which should probably be deferred to SQL instead). +@register_base_override("idxmax") +def idxmax(self, axis=0, skipna=True, numeric_only=False): # noqa: PR01, RT01, D200 + """ + Return index of first occurrence of maximum over requested axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + dtypes = self._get_dtypes() + if ( + axis == 1 + and not numeric_only + and any(not is_numeric_dtype(d) for d in dtypes) + and len(set(dtypes)) > 1 + ): + # For numeric_only=False, if we have any non-numeric dtype, e.g. + # a string type, we need every other column to be of the same type. + # We can't compare two objects of different non-numeric types, e.g. + # a string and a timestamp. + # If we have only numeric data, we can compare columns even if they + # different types, e.g. we can compare an int column to a float + # column. + raise TypeError("'>' not supported for these dtypes") + axis = self._get_axis_number(axis) + return self._reduce_dimension( + self._query_compiler.idxmax(axis=axis, skipna=skipna, numeric_only=numeric_only) + ) + + +# Snowpark pandas does extra argument validation (which should probably be deferred to SQL instead). +@register_base_override("idxmin") +def idxmin(self, axis=0, skipna=True, numeric_only=False): # noqa: PR01, RT01, D200 + """ + Return index of first occurrence of minimum over requested axis. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + dtypes = self._get_dtypes() + if ( + axis == 1 + and not numeric_only + and any(not is_numeric_dtype(d) for d in dtypes) + and len(set(dtypes)) > 1 + ): + # For numeric_only=False, if we have any non-numeric dtype, e.g. + # a string type, we need every other column to be of the same type. + # We can't compare two objects of different non-numeric types, e.g. + # a string and a timestamp. + # If we have only numeric data, we can compare columns even if they + # different types, e.g. we can compare an int column to a float + # column. + raise TypeError("'<' not supported for these dtypes") + axis = self._get_axis_number(axis) + return self._reduce_dimension( + self._query_compiler.idxmin(axis=axis, skipna=skipna, numeric_only=numeric_only) + ) + + +# Modin does dtype validation on unary ops that Snowpark pandas does not. +@register_base_override("__abs__") +def abs(self): # noqa: RT01, D200 + """ + Return a `BasePandasDataset` with absolute numeric value of each element. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.__constructor__(query_compiler=self._query_compiler.abs()) + + +# Modin does dtype validation on unary ops that Snowpark pandas does not. +@register_base_override("__invert__") +def __invert__(self): + """ + Apply bitwise inverse to each element of the `BasePandasDataset`. + + Returns + ------- + BasePandasDataset + New BasePandasDataset containing bitwise inverse to each value. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.__constructor__(query_compiler=self._query_compiler.invert()) + + +# Modin does dtype validation on unary ops that Snowpark pandas does not. +@register_base_override("__neg__") +def __neg__(self): + """ + Change the sign for every value of self. + + Returns + ------- + BasePandasDataset + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + return self.__constructor__(query_compiler=self._query_compiler.negative()) + + +# Modin needs to add a check for mapper is not None, which changes query counts in test_concat.py +# if not present. +@register_base_override("rename_axis") +def rename_axis( + self, + mapper=lib.no_default, + *, + index=lib.no_default, + columns=lib.no_default, + axis=0, + copy=None, + inplace=False, +): # noqa: PR01, RT01, D200 + """ + Set the name of the axis for the index or columns. + """ + axes = {"index": index, "columns": columns} + + if copy is None: + copy = True + + if axis is not None: + axis = self._get_axis_number(axis) + + inplace = validate_bool_kwarg(inplace, "inplace") + + if mapper is not lib.no_default and mapper is not None: + # Use v0.23 behavior if a scalar or list + non_mapper = is_scalar(mapper) or ( + is_list_like(mapper) and not is_dict_like(mapper) + ) + if non_mapper: + return self._set_axis_name(mapper, axis=axis, inplace=inplace) + else: + raise ValueError("Use `.rename` to alter labels with a mapper.") + else: + # Use new behavior. Means that index and/or columns is specified + result = self if inplace else self.copy(deep=copy) + + for axis in range(self.ndim): + v = axes.get(pandas.DataFrame._get_axis_name(axis)) + if v is lib.no_default: + continue + non_mapper = is_scalar(v) or (is_list_like(v) and not is_dict_like(v)) + if non_mapper: + newnames = v + else: + + def _get_rename_function(mapper): + if isinstance(mapper, (dict, BasePandasDataset)): + + def f(x): + if x in mapper: + return mapper[x] + else: + return x + + else: + f = mapper + + return f + + f = _get_rename_function(v) + curnames = self.index.names if axis == 0 else self.columns.names + newnames = [f(name) for name in curnames] + result._set_axis_name(newnames, axis=axis, inplace=True) + if not inplace: + return result + + +# Snowpark pandas has custom dispatch logic for ufuncs, while modin defaults to pandas. +@register_base_override("__array_ufunc__") +def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs): + """ + Apply the `ufunc` to the `BasePandasDataset`. + + Parameters + ---------- + ufunc : np.ufunc + The NumPy ufunc to apply. + method : str + The method to apply. + *inputs : tuple + The inputs to the ufunc. + **kwargs : dict + Additional keyword arguments. + + Returns + ------- + BasePandasDataset + The result of the ufunc applied to the `BasePandasDataset`. + """ + # Use pandas version of ufunc if it exists + if method != "__call__": + # Return sentinel value NotImplemented + return NotImplemented # pragma: no cover + from snowflake.snowpark.modin.plugin.utils.numpy_to_pandas import ( + numpy_to_pandas_universal_func_map, + ) + + if ufunc.__name__ in numpy_to_pandas_universal_func_map: + ufunc = numpy_to_pandas_universal_func_map[ufunc.__name__] + return ufunc(self, inputs[1:], kwargs) + # return the sentinel NotImplemented if we do not support this function + return NotImplemented # pragma: no cover + + +# Snowpark pandas does extra argument validation. +@register_base_override("reindex") +def reindex( + self, + index=None, + columns=None, + copy=True, + **kwargs, +): # noqa: PR01, RT01, D200 + """ + Conform `BasePandasDataset` to new index with optional filling logic. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + if kwargs.get("limit", None) is not None and kwargs.get("method", None) is None: + raise ValueError( + "limit argument only valid if doing pad, backfill or nearest reindexing" + ) + new_query_compiler = None + if index is not None: + if not isinstance(index, pandas.Index) or not index.equals(self.index): + new_query_compiler = self._query_compiler.reindex( + axis=0, labels=index, **kwargs + ) + if new_query_compiler is None: + new_query_compiler = self._query_compiler + final_query_compiler = None + if columns is not None: + if not isinstance(index, pandas.Index) or not columns.equals(self.columns): + final_query_compiler = new_query_compiler.reindex( + axis=1, labels=columns, **kwargs + ) + if final_query_compiler is None: + final_query_compiler = new_query_compiler + return self._create_or_update_from_compiler( + final_query_compiler, inplace=False if copy is None else not copy + ) + + +# No direct override annotation; used as part of `property`. +# Snowpark pandas may return a custom lazy index object. +def _get_index(self): + """ + Get the index for this DataFrame. + + Returns + ------- + pandas.Index + The union of all indexes across the partitions. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + from snowflake.snowpark.modin.plugin.extensions.index import Index + + if self._query_compiler.is_multiindex(): + # Lazy multiindex is not supported + return self._query_compiler.index + + idx = Index(query_compiler=self._query_compiler) + idx._set_parent(self) + return idx + + +# No direct override annotation; used as part of `property`. +# Snowpark pandas may return a custom lazy index object. +def _set_index(self, new_index: Axes) -> None: + """ + Set the index for this DataFrame. + + Parameters + ---------- + new_index : pandas.Index + The new index to set this. + """ + # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset + self._update_inplace( + new_query_compiler=self._query_compiler.set_index( + [s._query_compiler for s in self._to_series_list(ensure_index(new_index))] + ) + ) + + +# Snowpark pandas may return a custom lazy index object. +register_base_override("index")(property(_get_index, _set_index)) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/index.py b/src/snowflake/snowpark/modin/plugin/extensions/index.py index 95fcf684924..808489b8917 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/index.py @@ -29,6 +29,7 @@ import modin import numpy as np import pandas as native_pd +from modin.pandas.base import BasePandasDataset from pandas import get_option from pandas._libs import lib from pandas._libs.lib import is_list_like, is_scalar @@ -48,7 +49,6 @@ from pandas.core.dtypes.inference import is_hashable from snowflake.snowpark.modin.pandas import DataFrame, Series -from snowflake.snowpark.modin.pandas.base import BasePandasDataset from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta from snowflake.snowpark.modin.plugin._internal.timestamp_utils import DateTimeOrigin diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py index 0afea30e29a..a33d7702203 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py @@ -161,6 +161,7 @@ def plot( @register_series_accessor("transform") +@snowpark_pandas_telemetry_method_decorator @series_not_implemented() def transform(self, func, axis=0, *args, **kwargs): # noqa: PR01, RT01, D200 pass # pragma: no cover diff --git a/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py b/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py new file mode 100644 index 00000000000..f2b28e8bfc1 --- /dev/null +++ b/src/snowflake/snowpark/modin/plugin/utils/frontend_constants.py @@ -0,0 +1,27 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +# Do not look up certain attributes in columns or index, as they're used for some +# special purposes, like serving remote context +# TODO: SNOW-1643986 examine whether to update upstream modin to follow this +_ATTRS_NO_LOOKUP = { + "____id_pack__", + "__name__", + "_cache", + "_ipython_canary_method_should_not_exist_", + "_ipython_display_", + "_repr_html_", + "_repr_javascript_", + "_repr_jpeg_", + "_repr_json_", + "_repr_latex_", + "_repr_markdown_", + "_repr_mimebundle_", + "_repr_pdf_", + "_repr_png_", + "_repr_svg_", + "__array_struct__", + "__array_interface__", + "_typ", +} diff --git a/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py b/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py index 3da545c64b6..f673bf157bf 100644 --- a/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py +++ b/src/snowflake/snowpark/modin/plugin/utils/numpy_to_pandas.py @@ -3,8 +3,9 @@ # from typing import Any, Optional, Union +from modin.pandas.base import BasePandasDataset + import snowflake.snowpark.modin.pandas as pd -from snowflake.snowpark.modin.pandas.base import BasePandasDataset from snowflake.snowpark.modin.pandas.utils import is_scalar from snowflake.snowpark.modin.plugin.utils.warning_message import WarningMessage diff --git a/tests/integ/compiler/test_query_generator.py b/tests/integ/compiler/test_query_generator.py index 507b338d6e7..5ce4c005ad3 100644 --- a/tests/integ/compiler/test_query_generator.py +++ b/tests/integ/compiler/test_query_generator.py @@ -197,7 +197,7 @@ def test_table_create_from_large_query_breakdown(session, plan_source_generator) assert ( queries[PlanQueryType.QUERIES][0].sql - == f" CREATE TEMP TABLE {table_name} AS SELECT * FROM (select 1 as a, 2 as b)" + == f" CREATE SCOPED TEMPORARY TABLE {table_name} AS SELECT * FROM (select 1 as a, 2 as b)" ) diff --git a/tests/integ/modin/test_telemetry.py b/tests/integ/modin/test_telemetry.py index 9c24c6b6853..06fbc71eec7 100644 --- a/tests/integ/modin/test_telemetry.py +++ b/tests/integ/modin/test_telemetry.py @@ -398,7 +398,7 @@ def test_telemetry_getitem_setitem(): s = df["a"] assert len(df._query_compiler.snowpark_pandas_api_calls) == 0 assert s._query_compiler.snowpark_pandas_api_calls == [ - {"name": "DataFrame.BasePandasDataset.__getitem__"} + {"name": "DataFrame.__getitem__"} ] df["a"] = 0 df["b"] = 0 @@ -412,12 +412,12 @@ def test_telemetry_getitem_setitem(): # the telemetry log from the connector to validate _ = s[0] data = _extract_snowpark_pandas_telemetry_log_data( - expected_func_name="Series.BasePandasDataset.__getitem__", + expected_func_name="Series.__getitem__", session=s._query_compiler._modin_frame.ordered_dataframe.session, ) assert data["api_calls"] == [ - {"name": "DataFrame.BasePandasDataset.__getitem__"}, - {"name": "Series.BasePandasDataset.__getitem__"}, + {"name": "DataFrame.__getitem__"}, + {"name": "Series.__getitem__"}, ] @@ -547,3 +547,18 @@ def test_telemetry_repr(): {"name": "Series.property.name_set"}, {"name": "Series.Series.__repr__"}, ] + + +@sql_count_checker(query_count=0) +def test_telemetry_copy(): + # copy() is defined in upstream modin's BasePandasDataset class, and not overridden by any + # child class or the extensions module. + s = pd.Series([1, 2, 3, 4]) + copied = s.copy() + assert s._query_compiler.snowpark_pandas_api_calls == [ + {"name": "Series.property.name_set"} + ] + assert copied._query_compiler.snowpark_pandas_api_calls == [ + {"name": "Series.property.name_set"}, + {"name": "Series.BasePandasDataset.copy"}, + ] diff --git a/tests/integ/scala/test_snowflake_plan_suite.py b/tests/integ/scala/test_snowflake_plan_suite.py index 25ee097d27b..e5971e2d2f5 100644 --- a/tests/integ/scala/test_snowflake_plan_suite.py +++ b/tests/integ/scala/test_snowflake_plan_suite.py @@ -317,7 +317,7 @@ def test_create_scoped_temp_table(session): ) .queries[0] .sql - == f" CREATE TEMP TABLE {temp_table_name} AS SELECT * FROM ( SELECT * FROM ({table_name}))" + == f" CREATE TEMPORARY TABLE {temp_table_name} AS SELECT * FROM ( SELECT * FROM ({table_name}))" ) expected_sql = f' CREATE TEMPORARY TABLE {temp_table_name}("NUM" BIGINT, "STR" STRING(8))' assert expected_sql in ( @@ -342,7 +342,9 @@ def test_create_scoped_temp_table(session): .queries[0] .sql ) - expected_sql = f" CREATE TEMPORARY TABLE {temp_table_name} AS SELECT" + expected_sql = ( + f" CREATE SCOPED TEMPORARY TABLE {temp_table_name} AS SELECT" + ) assert expected_sql in ( session._plan_builder.save_as_table( table_name=[temp_table_name], diff --git a/tests/integ/test_large_query_breakdown.py b/tests/integ/test_large_query_breakdown.py index 1368bf460f2..5381019cb6c 100644 --- a/tests/integ/test_large_query_breakdown.py +++ b/tests/integ/test_large_query_breakdown.py @@ -99,7 +99,7 @@ def test_large_query_breakdown_with_cte_optimization(session): check_result_with_and_without_breakdown(session, df4) assert len(df4.queries["queries"]) == 2 - assert df4.queries["queries"][0].startswith("CREATE TEMP TABLE") + assert df4.queries["queries"][0].startswith("CREATE SCOPED TEMPORARY TABLE") assert df4.queries["queries"][1].startswith("WITH SNOWPARK_TEMP_CTE_") assert len(df4.queries["post_actions"]) == 1 @@ -115,7 +115,7 @@ def test_save_as_table(session, large_query_df): assert len(history.queries) == 4 assert history.queries[0].sql_text == "SELECT CURRENT_TRANSACTION()" - assert history.queries[1].sql_text.startswith("CREATE TEMP TABLE") + assert history.queries[1].sql_text.startswith("CREATE SCOPED TEMPORARY TABLE") assert history.queries[2].sql_text.startswith( f"CREATE OR REPLACE TABLE {table_name}" ) @@ -135,7 +135,7 @@ def test_update_delete_merge(session, large_query_df): t.update({"B": 0}, t.a == large_query_df.a, large_query_df) assert len(history.queries) == 4 assert history.queries[0].sql_text == "SELECT CURRENT_TRANSACTION()" - assert history.queries[1].sql_text.startswith("CREATE TEMP TABLE") + assert history.queries[1].sql_text.startswith("CREATE SCOPED TEMPORARY TABLE") assert history.queries[2].sql_text.startswith(f"UPDATE {table_name}") assert history.queries[3].sql_text.startswith("DROP TABLE If EXISTS") @@ -144,7 +144,7 @@ def test_update_delete_merge(session, large_query_df): t.delete(t.a == large_query_df.a, large_query_df) assert len(history.queries) == 4 assert history.queries[0].sql_text == "SELECT CURRENT_TRANSACTION()" - assert history.queries[1].sql_text.startswith("CREATE TEMP TABLE") + assert history.queries[1].sql_text.startswith("CREATE SCOPED TEMPORARY TABLE") assert history.queries[2].sql_text.startswith(f"DELETE FROM {table_name} USING") assert history.queries[3].sql_text.startswith("DROP TABLE If EXISTS") @@ -157,7 +157,7 @@ def test_update_delete_merge(session, large_query_df): ) assert len(history.queries) == 4 assert history.queries[0].sql_text == "SELECT CURRENT_TRANSACTION()" - assert history.queries[1].sql_text.startswith("CREATE TEMP TABLE") + assert history.queries[1].sql_text.startswith("CREATE SCOPED TEMPORARY TABLE") assert history.queries[2].sql_text.startswith(f"MERGE INTO {table_name} USING") assert history.queries[3].sql_text.startswith("DROP TABLE If EXISTS") @@ -176,7 +176,7 @@ def test_copy_into_location(session, large_query_df): ) assert len(history.queries) == 4, history.queries assert history.queries[0].sql_text == "SELECT CURRENT_TRANSACTION()" - assert history.queries[1].sql_text.startswith("CREATE TEMP TABLE") + assert history.queries[1].sql_text.startswith("CREATE SCOPED TEMPORARY TABLE") assert history.queries[2].sql_text.startswith(f"COPY INTO '{remote_file_path}'") assert history.queries[3].sql_text.startswith("DROP TABLE If EXISTS") @@ -215,7 +215,7 @@ def test_pivot_unpivot(session): plan_queries = final_df.queries assert len(plan_queries["queries"]) == 2 - assert plan_queries["queries"][0].startswith("CREATE TEMP TABLE") + assert plan_queries["queries"][0].startswith("CREATE SCOPED TEMPORARY TABLE") assert len(plan_queries["post_actions"]) == 1 assert plan_queries["post_actions"][0].startswith("DROP TABLE If EXISTS") @@ -239,7 +239,7 @@ def test_sort(session): plan_queries = final_df.queries assert len(plan_queries["queries"]) == 2 - assert plan_queries["queries"][0].startswith("CREATE TEMP TABLE") + assert plan_queries["queries"][0].startswith("CREATE SCOPED TEMPORARY TABLE") assert len(plan_queries["post_actions"]) == 1 assert plan_queries["post_actions"][0].startswith("DROP TABLE If EXISTS") @@ -283,7 +283,7 @@ def test_multiple_query_plan(session, large_query_df): "CREATE OR REPLACE SCOPED TEMPORARY TABLE" ) assert plan_queries["queries"][1].startswith("INSERT INTO") - assert plan_queries["queries"][2].startswith("CREATE TEMP TABLE") + assert plan_queries["queries"][2].startswith("CREATE SCOPED TEMPORARY TABLE") assert len(plan_queries["post_actions"]) == 2 for query in plan_queries["post_actions"]: @@ -349,7 +349,9 @@ def test_async_job_with_large_query_breakdown(session, large_query_df): result = job.result() assert result == large_query_df.collect() assert len(large_query_df.queries["queries"]) == 2 - assert large_query_df.queries["queries"][0].startswith("CREATE TEMP TABLE") + assert large_query_df.queries["queries"][0].startswith( + "CREATE SCOPED TEMPORARY TABLE" + ) assert len(large_query_df.queries["post_actions"]) == 1 assert large_query_df.queries["post_actions"][0].startswith( @@ -365,7 +367,9 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df): session._large_query_breakdown_enabled = True assert len(large_query_df.queries["queries"]) == 2 assert len(large_query_df.queries["post_actions"]) == 1 - assert large_query_df.queries["queries"][0].startswith("CREATE TEMP TABLE") + assert large_query_df.queries["queries"][0].startswith( + "CREATE SCOPED TEMPORARY TABLE" + ) assert large_query_df.queries["post_actions"][0].startswith( "DROP TABLE If EXISTS" ) @@ -374,8 +378,12 @@ def test_complexity_bounds_affect_num_partitions(session, large_query_df): session._large_query_breakdown_enabled = True assert len(large_query_df.queries["queries"]) == 3 assert len(large_query_df.queries["post_actions"]) == 2 - assert large_query_df.queries["queries"][0].startswith("CREATE TEMP TABLE") - assert large_query_df.queries["queries"][1].startswith("CREATE TEMP TABLE") + assert large_query_df.queries["queries"][0].startswith( + "CREATE SCOPED TEMPORARY TABLE" + ) + assert large_query_df.queries["queries"][1].startswith( + "CREATE SCOPED TEMPORARY TABLE" + ) assert large_query_df.queries["post_actions"][0].startswith( "DROP TABLE If EXISTS" ) diff --git a/tests/unit/modin/modin/test_envvars.py b/tests/unit/modin/modin/test_envvars.py index 4f3540a63bf..7c5e3a40bb0 100644 --- a/tests/unit/modin/modin/test_envvars.py +++ b/tests/unit/modin/modin/test_envvars.py @@ -90,6 +90,32 @@ def test_custom_help(make_custom_envvar): assert "custom var" in make_custom_envvar.get_help() +def _init_doc_module(): + # Put the docs_module on the path + sys.path.append(f"{os.path.dirname(__file__)}") + # We use base.py from upstream modin, so we need to initialize its doc module + # However, since using the environment variable causes an importlib.reload call, + # we need to manually call _inherit_docstrings (https://github.com/modin-project/modin/issues/7138) + from .docs_module import classes + + # As a workaround for upstream modin bugs, we use our own _inherit_docstrings instead of the upstream + # function. We accordingly need to clear the docstring dictionary in testing because + # we manually called the annotation on initializing snowflake.snowpark.modin.pandas. + # snowflake.snowpark.modin.utils._attributes_with_docstrings_replaced.clear() + # TODO: once modin 0.31.0 is available, use the actual modin DocModule class + snowflake.snowpark.modin.utils._inherit_docstrings( + classes.BasePandasDataset, + overwrite_existing=True, + )(pd.base.BasePandasDataset) + DocModule.put("docs_module") + + +DOC_OVERRIDE_XFAIL_REASON = ( + "test docstring overrides currently cannot override real docstring overrides until " + "modin 0.31.0 is available" +) + + class TestDocModule: """ Test using a module to replace default docstrings. @@ -99,11 +125,9 @@ class TestDocModule: which we need to fix in upstream modin. """ + @pytest.mark.xfail(strict=True, reason=DOC_OVERRIDE_XFAIL_REASON) def test_overrides(self): - # Put the docs_module on the path - sys.path.append(f"{os.path.dirname(__file__)}") - DocModule.put("docs_module") - + _init_doc_module() # Test for override # TODO(https://github.com/modin-project/modin/issues/7134): Upstream # the BasePandasDataset tests to modin. @@ -144,11 +168,7 @@ def test_overrides(self): def test_not_redefining_classes_modin_issue_7138(self): original_dataframe_class = pd.DataFrame - - # Put the docs_module on the path - sys.path.append(f"{os.path.dirname(__file__)}") - DocModule.put("docs_module") - + _init_doc_module() # Test for override assert ( pd.DataFrame.apply.__doc__ @@ -157,22 +177,20 @@ def test_not_redefining_classes_modin_issue_7138(self): assert pd.DataFrame is original_dataframe_class + @pytest.mark.xfail(strict=True, reason=DOC_OVERRIDE_XFAIL_REASON) def test_base_docstring_override_with_no_dataframe_or_series_class_modin_issue_7113( self, ): # TODO(https://github.com/modin-project/modin/issues/7113): Upstream # this test case to Modin. This test case tests scenario 1 from issue 7113. - sys.path.append(f"{os.path.dirname(__file__)}") - DocModule.put("docs_module_with_just_base") + _init_doc_module() assert pd.base.BasePandasDataset.astype.__doc__ == ( "This is a test of the documentation module for BasePandasDataSet.astype." ) + @pytest.mark.xfail(strict=True, reason=DOC_OVERRIDE_XFAIL_REASON) def test_base_property_not_overridden_in_either_subclass_modin_issue_7113(self): - # Put the docs_module on the path - sys.path.append(f"{os.path.dirname(__file__)}") - DocModule.put("docs_module") - + _init_doc_module() assert ( pd.base.BasePandasDataset.loc.__doc__ == "This is a test of the documentation module for BasePandasDataset.loc."