From 91466e7d39877d9abd3e1cdef62df5bb259e3fd5 Mon Sep 17 00:00:00 2001 From: Jonathan Shi Date: Fri, 30 Aug 2024 17:59:33 -0700 Subject: [PATCH] fillna, reset_index, set_axis --- .../compiler/snowflake_query_compiler.py | 12 ++- .../plugin/extensions/series_overrides.py | 102 +++++++++++++++++- 2 files changed, 111 insertions(+), 3 deletions(-) diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index cf200226012..b46b7f48aa6 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -755,7 +755,17 @@ def to_list(self) -> list: Only called if the frontend object was a Series. """ - return native_pd.Series(self.to_pandas()).to_list() + return self.to_pandas().squeeze().to_list() + + def series_to_dict(self, into=dict) -> dict: # type: ignore + """ + Convert the Series to a dictionary. + + Returns + ------- + dict or `into` instance + """ + return self.to_pandas().squeeze().to_dict(into=into) def to_pandas( self, diff --git a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py index 8a548489e9b..1ec455894de 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py @@ -9,7 +9,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any, Callable, Hashable, Mapping import modin.pandas as pd import numpy as np @@ -17,7 +17,15 @@ from modin.pandas import Series from modin.pandas.base import BasePandasDataset from pandas._libs.lib import NoDefault, is_integer, no_default -from pandas._typing import AggFuncType, AnyArrayLike, Axis, IndexLabel, Level, Scalar +from pandas._typing import ( + AggFuncType, + AnyArrayLike, + Axis, + FillnaOptions, + IndexLabel, + Level, + Scalar, +) from pandas.core.common import apply_if_callable, is_bool_indexer from pandas.core.dtypes.common import is_bool_dtype, is_list_like @@ -959,6 +967,36 @@ def empty(self) -> bool: return _old_empty_fget(self) +# Upstream modin uses squeeze_self instead of self_is_series. +@register_series_accessor("fillna") +@snowpark_pandas_telemetry_method_decorator +def fillna( + self, + value: Hashable | Mapping | Series = None, + *, + method: FillnaOptions | None = None, + axis: Axis | None = None, + inplace: bool = False, + limit: int | None = None, + downcast: dict | None = None, +) -> Series | None: + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + if isinstance(value, BasePandasDataset) and not isinstance(value, Series): + raise TypeError( + '"value" parameter must be a scalar, dict or Series, but ' + + f'you passed a "{type(value).__name__}"' + ) + return super(Series, self).fillna( + self_is_series=True, + value=value, + method=method, + axis=axis, + inplace=inplace, + limit=limit, + downcast=downcast, + ) + + # Snowpark pandas defines a custom GroupBy object @register_series_accessor("groupby") @property @@ -1003,6 +1041,66 @@ def groupby( ) +# Upstream Modin reset_index produces an extra query. +@register_series_accessor("reset_index") +@snowpark_pandas_telemetry_method_decorator +def reset_index( + self, + level=None, + drop=False, + name=no_default, + inplace=False, + allow_duplicates=False, +): + """ + Generate a new Series with the index reset. + """ + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + if drop: + name = self.name + elif name is no_default: + # For backwards compatibility, keep columns as [0] instead of + # [None] when self.name is None + name = 0 if self.name is None else self.name + + if not drop and inplace: + raise TypeError("Cannot reset_index inplace on a Series to create a DataFrame") + else: + obj = self.copy() + obj.name = name + new_query_compiler = obj._query_compiler.reset_index( + drop=drop, + level=level, + col_level=0, + col_fill="", + allow_duplicates=allow_duplicates, + names=None, + ) + return self._create_or_update_from_compiler(new_query_compiler, inplace) + + +# Snowpark pandas performs additional type validation. +@register_series_accessor("set_axis") +@snowpark_pandas_telemetry_method_decorator +def set_axis( + self, + labels: IndexLabel, + *, + axis: Axis = 0, + copy: bool | NoDefault = no_default, # ignored +): + # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions + if not is_scalar(axis): + raise TypeError(f"{type(axis).__name__} is not a valid type for axis.") + return super(Series, self).set_axis( + labels=labels, + # 'rows', 'index, and 0 are valid axis values for Series. + # 'columns' and 1 are valid axis values only for DataFrame. + axis=native_pd.Series._get_axis_name(axis), + copy=copy, + ) + + # Upstream Modin defaults at the frontend layer. @register_series_accessor("where") @snowpark_pandas_telemetry_method_decorator