Skip to content

Commit

Permalink
fillna, reset_index, set_axis
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-joshi committed Sep 3, 2024
1 parent bdb8008 commit 91466e7
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -755,7 +755,17 @@ def to_list(self) -> list:

Only called if the frontend object was a Series.
"""
return native_pd.Series(self.to_pandas()).to_list()
return self.to_pandas().squeeze().to_list()

def series_to_dict(self, into=dict) -> dict: # type: ignore
"""
Convert the Series to a dictionary.

Returns
-------
dict or `into` instance
"""
return self.to_pandas().squeeze().to_dict(into=into)

def to_pandas(
self,
Expand Down
102 changes: 100 additions & 2 deletions src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,23 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Hashable, Mapping

import modin.pandas as pd
import numpy as np
import pandas as native_pd
from modin.pandas import Series
from modin.pandas.base import BasePandasDataset
from pandas._libs.lib import NoDefault, is_integer, no_default
from pandas._typing import AggFuncType, AnyArrayLike, Axis, IndexLabel, Level, Scalar
from pandas._typing import (
AggFuncType,
AnyArrayLike,
Axis,
FillnaOptions,
IndexLabel,
Level,
Scalar,
)
from pandas.core.common import apply_if_callable, is_bool_indexer
from pandas.core.dtypes.common import is_bool_dtype, is_list_like

Expand Down Expand Up @@ -959,6 +967,36 @@ def empty(self) -> bool:
return _old_empty_fget(self)


# Upstream modin uses squeeze_self instead of self_is_series.
@register_series_accessor("fillna")
@snowpark_pandas_telemetry_method_decorator
def fillna(
self,
value: Hashable | Mapping | Series = None,
*,
method: FillnaOptions | None = None,
axis: Axis | None = None,
inplace: bool = False,
limit: int | None = None,
downcast: dict | None = None,
) -> Series | None:
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
if isinstance(value, BasePandasDataset) and not isinstance(value, Series):
raise TypeError(
'"value" parameter must be a scalar, dict or Series, but '
+ f'you passed a "{type(value).__name__}"'
)
return super(Series, self).fillna(
self_is_series=True,
value=value,
method=method,
axis=axis,
inplace=inplace,
limit=limit,
downcast=downcast,
)


# Snowpark pandas defines a custom GroupBy object
@register_series_accessor("groupby")
@property
Expand Down Expand Up @@ -1003,6 +1041,66 @@ def groupby(
)


# Upstream Modin reset_index produces an extra query.
@register_series_accessor("reset_index")
@snowpark_pandas_telemetry_method_decorator
def reset_index(
self,
level=None,
drop=False,
name=no_default,
inplace=False,
allow_duplicates=False,
):
"""
Generate a new Series with the index reset.
"""
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
if drop:
name = self.name
elif name is no_default:
# For backwards compatibility, keep columns as [0] instead of
# [None] when self.name is None
name = 0 if self.name is None else self.name

if not drop and inplace:
raise TypeError("Cannot reset_index inplace on a Series to create a DataFrame")
else:
obj = self.copy()
obj.name = name
new_query_compiler = obj._query_compiler.reset_index(
drop=drop,
level=level,
col_level=0,
col_fill="",
allow_duplicates=allow_duplicates,
names=None,
)
return self._create_or_update_from_compiler(new_query_compiler, inplace)


# Snowpark pandas performs additional type validation.
@register_series_accessor("set_axis")
@snowpark_pandas_telemetry_method_decorator
def set_axis(
self,
labels: IndexLabel,
*,
axis: Axis = 0,
copy: bool | NoDefault = no_default, # ignored
):
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
if not is_scalar(axis):
raise TypeError(f"{type(axis).__name__} is not a valid type for axis.")
return super(Series, self).set_axis(
labels=labels,
# 'rows', 'index, and 0 are valid axis values for Series.
# 'columns' and 1 are valid axis values only for DataFrame.
axis=native_pd.Series._get_axis_name(axis),
copy=copy,
)


# Upstream Modin defaults at the frontend layer.
@register_series_accessor("where")
@snowpark_pandas_telemetry_method_decorator
Expand Down

0 comments on commit 91466e7

Please sign in to comment.