Skip to content

Commit

Permalink
accessors and relative imports
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-joshi committed Sep 3, 2024
1 parent 91021d8 commit 6e3a8bb
Show file tree
Hide file tree
Showing 2 changed files with 251 additions and 2 deletions.
9 changes: 9 additions & 0 deletions src/snowflake/snowpark/modin/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@
from snowflake.snowpark.modin.plugin._internal.telemetry import (
try_add_telemetry_to_attribute,
)
from snowflake.snowpark.modin.plugin.utils.frontend_constants import _ATTRS_NO_LOOKUP

# The extensions assigned to this module
_PD_EXTENSIONS_: dict = {}
Expand Down Expand Up @@ -179,6 +180,14 @@
import snowflake.snowpark.modin.plugin.extensions.series_extensions # isort: skip # noqa: E402,F401
import snowflake.snowpark.modin.plugin.extensions.series_overrides # isort: skip # noqa: E402,F401


# dt and str accessors raise AttributeErrors that get caught by Modin __getitem__. Whitelist
# them in _ATTRS_NO_LOOKUP here to avoid this.
modin.pandas.base._ATTRS_NO_LOOKUP.add("dt")
modin.pandas.base._ATTRS_NO_LOOKUP.add("str")
modin.pandas.base._ATTRS_NO_LOOKUP.update(_ATTRS_NO_LOOKUP)


# For any method defined on Series/DF, add telemetry to it if it meets all of the following conditions:
# 1. The method was defined directly on an upstream class (_attrs_defined_on_modin_base, _attrs_defined_on_modin_series)
# 1a. (DataFrame only): The method is not overridden by DataFrame (not applicable to Series, since we use the upstream version)
Expand Down
244 changes: 242 additions & 2 deletions src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@
AnyArrayLike,
Axis,
FillnaOptions,
IgnoreRaise,
IndexKeyFunc,
IndexLabel,
Level,
Renamer,
Scalar,
)
from pandas.api.types import is_datetime64_any_dtype, is_string_dtype
from pandas.core.common import apply_if_callable, is_bool_indexer
from pandas.core.dtypes.common import is_bool_dtype, is_dict_like, is_list_like
from pandas.util._validators import validate_bool_kwarg
Expand All @@ -41,7 +45,7 @@
snowpark_pandas_telemetry_method_decorator,
try_add_telemetry_to_attribute,
)
from snowflake.snowpark.modin.plugin._typing import ListLike
from snowflake.snowpark.modin.plugin._typing import DropKeep, ListLike
from snowflake.snowpark.modin.plugin.utils.error_message import (
ErrorMessage,
series_not_implemented,
Expand Down Expand Up @@ -904,6 +908,39 @@ def argmin(self, axis=None, skipna=True, *args, **kwargs): # noqa: PR01, RT01,
return result


# Snowpark pandas does not respect `ignore_index`, and upstream Modin does not respect `how`.
@register_series_accessor("dropna")
@snowpark_pandas_telemetry_method_decorator
def dropna(
self,
*,
axis: Axis = 0,
inplace: bool = False,
how: str | NoDefault = no_default,
):
"""
Return a new Series with missing values removed.
"""
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
return super(Series, self)._dropna(axis=axis, inplace=inplace, how=how)


# Upstream Modin does not preserve the series name.
# https://github.com/modin-project/modin/issues/7375
@register_series_accessor("duplicated")
@snowpark_pandas_telemetry_method_decorator
def duplicated(self, keep: DropKeep = "first"):
"""
Indicate duplicate Series values.
"""
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
name = self.name
series = self.to_frame().duplicated(keep=keep)
# we are using df.duplicated method for series but its result will lose the series name, so we preserve it here
series.name = name
return series


# Upstream Modin defines sum differently for series/DF, but we use the same implementation for both.
# Even though we already override sum in base_overrides, we need to do another override here because
# Modin has a separate definition in both series.py and base.py. In general, we cannot force base_overrides
Expand Down Expand Up @@ -1038,6 +1075,39 @@ def cat(self) -> CategoryMethods:
return CategoryMethods(self)


# Snowpark pandas performs type validation that Modin does not
@register_series_accessor("dt")
@property
@snowpark_pandas_telemetry_method_decorator
def dt(self): # noqa: RT01, D200
"""
Accessor object for datetimelike properties of the Series values.
"""
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
current_dtype = self.dtype
if not is_datetime64_any_dtype(current_dtype):
raise AttributeError("Can only use .dt accessor with datetimelike values")

from modin.pandas.series_utils import DatetimeProperties

return DatetimeProperties(self)


@property
def str(self): # noqa: RT01, D200
"""
Vectorized string functions for Series and Index.
"""
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
current_dtype = self.dtype
if not is_string_dtype(current_dtype):
raise AttributeError("Can only use .str accessor with string values!")

from modin.pandas.series_utils import StringMethods

return StringMethods(self)


# Snowpark pandas uses an update_in_place call that upstream Modin does not.
def _set_name(self, name):
"""
Expand Down Expand Up @@ -1203,7 +1273,7 @@ def replace(
return self._create_or_update_from_compiler(new_query_compiler, inplace)


# Upstream Modin reset_index produces an extra query.
# Upstream Modin reset_index produces an extra query and performs a relative import of DataFrame.
@register_series_accessor("reset_index")
@snowpark_pandas_telemetry_method_decorator
def reset_index(
Expand Down Expand Up @@ -1263,6 +1333,140 @@ def set_axis(
)


# TODO: SNOW-1063346
# Modin does a relative import (from .dataframe import DataFrame). We should revisit this once
# our vendored copy of DataFrame is removed.
@register_series_accessor("rename")
@snowpark_pandas_telemetry_method_decorator
def rename(
self,
index: Renamer | Hashable | None = None,
*,
axis: Axis | None = None,
copy: bool | None = None,
inplace: bool = False,
level: Level | None = None,
errors: IgnoreRaise = "ignore",
) -> Series | None:
"""
Alter Series index labels or name.
"""
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
if axis is not None:
# make sure we raise if an invalid 'axis' is passed.
# note: axis is unused. It's needed for compatibility with DataFrame.
self._get_axis_number(axis)

if copy is not None:
WarningMessage.ignored_argument(
operation="series.rename",
argument="copy",
message="copy parameter has been ignored with Snowflake execution engine",
)

if callable(index) or is_dict_like(index):
if isinstance(index, dict):
index = Series(index)
new_qc = self._query_compiler.rename(
index_renamer=index, level=level, errors=errors
)
new_series = self._create_or_update_from_compiler(
new_query_compiler=new_qc, inplace=inplace
)
if not inplace and hasattr(self, "name"):
new_series.name = self.name
return new_series
else:
# just change Series.name
if inplace:
self.name = index
else:
self_cp = self.copy()
self_cp.name = index
return self_cp


# TODO: SNOW-1063346
# Modin does a relative import (from .dataframe import DataFrame). We should revisit this once
# our vendored copy of DataFrame is removed.
@register_series_accessor("sort_values")
@snowpark_pandas_telemetry_method_decorator
def sort_values(
self,
axis: Axis = 0,
ascending: bool | int | Sequence[bool] | Sequence[int] = True,
inplace: bool = False,
kind: str = "quicksort",
na_position: str = "last",
ignore_index: bool = False,
key: IndexKeyFunc | None = None,
):
"""
Sort by the values.
"""
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
from snowflake.snowpark.modin.pandas.dataframe import DataFrame

if is_list_like(ascending) and len(ascending) != 1:
raise ValueError(f"Length of ascending ({len(ascending)}) must be 1 for Series")

if axis is not None:
# Validate `axis`
self._get_axis_number(axis)

# When we convert to a DataFrame, the name is automatically converted to 0 if it
# is None, so we do this to avoid a KeyError.
by = self.name if self.name is not None else 0
result = (
DataFrame(self.copy())
.sort_values(
by=by,
ascending=ascending,
inplace=False,
kind=kind,
na_position=na_position,
ignore_index=ignore_index,
key=key,
)
.squeeze(axis=1)
)
result.name = self.name
return self._create_or_update_from_compiler(result._query_compiler, inplace=inplace)


# TODO: SNOW-1063346
# Modin does a relative import (from .dataframe import DataFrame). We should revisit this once
# our vendored copy of DataFrame is removed.
# Modin also defaults to pandas for some arguments for unstack
@register_series_accessor("unstack")
@snowpark_pandas_telemetry_method_decorator
def unstack(
self,
level: int | str | list = -1,
fill_value: int | str | dict = None,
sort: bool = True,
):
"""
Unstack, also known as pivot, Series with MultiIndex to produce DataFrame.
"""
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
from snowflake.snowpark.modin.pandas.dataframe import DataFrame

# We can't unstack a Series object, if we don't have a MultiIndex.
if self._query_compiler.has_multiindex:
result = DataFrame(
query_compiler=self._query_compiler.unstack(
level, fill_value, sort, is_series_input=True
)
)
else:
raise ValueError( # pragma: no cover
f"index must be a MultiIndex to unstack, {type(self.index)} was passed"
)

return result


# Upstream Modin defaults at the frontend layer.
@register_series_accessor("where")
@snowpark_pandas_telemetry_method_decorator
Expand Down Expand Up @@ -1401,6 +1605,42 @@ def to_list(self) -> list:
return self.values.tolist()


# TODO: SNOW-1063346
# Modin does a relative import (from .dataframe import DataFrame), so until we stop using the vendored
# version of DataFrame, we must keep this override.
@register_series_accessor("_create_or_update_from_compiler")
def _create_or_update_from_compiler(self, new_query_compiler, inplace=False):
"""
Return or update a Series with given `new_query_compiler`.
Parameters
----------
new_query_compiler : PandasQueryCompiler
QueryCompiler to use to manage the data.
inplace : bool, default: False
Whether or not to perform update or creation inplace.
Returns
-------
Series, DataFrame or None
None if update was done, Series or DataFrame otherwise.
"""
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
assert (
isinstance(new_query_compiler, type(self._query_compiler))
or type(new_query_compiler) in self._query_compiler.__class__.__bases__
), f"Invalid Query Compiler object: {type(new_query_compiler)}"
if not inplace and new_query_compiler.is_series_like():
return self.__constructor__(query_compiler=new_query_compiler)
elif not inplace:
# This can happen with things like `reset_index` where we can add columns.
from snowflake.snowpark.modin.pandas.dataframe import DataFrame

return DataFrame(query_compiler=new_query_compiler)
else:
self._update_inplace(new_query_compiler=new_query_compiler)


# TODO: SNOW-1063346
# Modin does a relative import (from .dataframe import DataFrame), so until we stop using the vendored
# version of DataFrame, we must keep this override.
Expand Down

0 comments on commit 6e3a8bb

Please sign in to comment.