Skip to content

Commit

Permalink
push unimplemented back to overrides
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-joshi committed Sep 5, 2024
1 parent 3c2392d commit 0fca10f
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -759,24 +759,6 @@ def copy(self) -> "SnowflakeQueryCompiler":
qc.snowpark_pandas_api_calls = self.snowpark_pandas_api_calls.copy()
return qc

def to_list(self) -> list:
"""
Return a native Python list of the values.

Only called if the frontend object was a Series.
"""
return self.to_pandas().squeeze().to_list()

def series_to_dict(self, into=dict) -> dict: # type: ignore
"""
Convert the Series to a dictionary.

Returns
-------
dict or `into` instance
"""
return self.to_pandas().squeeze().to_dict(into=into)

def to_pandas(
self,
*,
Expand Down Expand Up @@ -1849,11 +1831,6 @@ def get_index_names(self, axis: int = 0) -> list[Hashable]:
else self._modin_frame.data_column_pandas_index_names
)

def rdivmod(self, other: "SnowflakeQueryCompiler", **kwargs: Any) -> None:
ErrorMessage.method_not_implemented_error(
name="rdivmod", class_="Series"
) # pragma: no cover

def _binary_op_scalar_rhs(
self, op: str, other: Scalar, fill_value: Scalar
) -> "SnowflakeQueryCompiler":
Expand Down Expand Up @@ -6685,11 +6662,6 @@ def series_to_datetime(
).frame
)

def series_view(self, dtype: npt.DTypeLike) -> None:
ErrorMessage.method_not_implemented_error(
name="view", class_="Series"
) # pragma: no cover

def concat(
self,
axis: Axis,
Expand Down Expand Up @@ -12393,18 +12365,6 @@ def _quantiles_single_col(

return SnowflakeQueryCompiler(internal_frame)

def mode(
self, axis: Axis = 0, numeric_only: bool = False, dropna: bool = True
) -> None:
ErrorMessage.method_not_implemented_error(
name="mode", class_="Series"
) # pragma: no cover

def repeat(self, repeats: Union[int, ListLike], axis: Axis = None) -> None:
ErrorMessage.method_not_implemented_error(
name="repeat", class_="Series"
) # pragma: no cover

def skew(
self,
axis: int,
Expand Down
25 changes: 25 additions & 0 deletions src/snowflake/snowpark/modin/plugin/extensions/series_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ def items(self): # noqa: RT01, D200
pass # pragma: no cover


@register_series_not_implemented()
def mode(self, dropna=True): # noqa: PR01, RT01, D200
pass


@register_series_not_implemented()
def prod(
self,
Expand All @@ -258,6 +263,11 @@ def ravel(self, order="C"): # noqa: PR01, RT01, D200
pass # pragma: no cover


@register_series_not_implemented()
def rdivmod(self, other, level=None, fill_value=None, axis=0): # noqa: PR01, RT01, D200
pass # pragma: no cover


@register_series_not_implemented()
def reindex_like(
self,
Expand All @@ -275,6 +285,11 @@ def reorder_levels(self, order): # noqa: PR01, RT01, D200
pass # pragma: no cover


@register_series_not_implemented()
def repeat(self, repeats, axis=None): # noqa: PR01, RT01, D200
pass # pragma: no cover


@register_series_not_implemented()
def searchsorted(self, value, side="left", sorter=None): # noqa: PR01, RT01, D200
pass # pragma: no cover
Expand All @@ -285,6 +300,11 @@ def swaplevel(self, i=-2, j=-1, copy=True): # noqa: PR01, RT01, D200
pass # pragma: no cover


@register_series_not_implemented()
def to_dict(self, into=dict): # noqa: PR01, RT01, D200
pass # pragma: no cover


@register_series_not_implemented()
def to_period(self, freq=None, copy=True): # noqa: PR01, RT01, D200
pass # pragma: no cover
Expand Down Expand Up @@ -312,6 +332,11 @@ def to_timestamp(self, freq=None, how="start", copy=True): # noqa: PR01, RT01,
pass # pragma: no cover


@register_series_not_implemented()
def view(self, dtype=None): # noqa: PR01, RT01, D200
pass # pragma: no cover


@register_series_not_implemented()
def array(self): # noqa: PR01, RT01, D200
pass # pragma: no cover
Expand Down
13 changes: 3 additions & 10 deletions tests/integ/modin/test_unimplemented.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,11 @@ def helper(df):
# This set triggers SeriesDefault.register
UNSUPPORTED_SERIES_METHODS = [
(lambda df: df.transform(lambda x: x + 1), "transform"),
(lambda df: df.repeat(2), "repeat"),
(lambda df: df.view(), "view"),
]

# unsupported binary operations that can be applied on both dataframe and series
# this set triggers default_to_pandas test with Snowpark pandas objects in arguments
UNSUPPORTED_DATAFRAME_SERIES_BINARY_METHODS = [
UNSUPPORTED_BINARY_METHODS = [
# TODO SNOW-862664, support together with combine
# (lambda dfs: dfs[0].combine(dfs[1], np.minimum, fill_value=1), "combine"),
(lambda dfs: dfs[0].align(dfs[1]), "align"),
Expand All @@ -98,11 +96,6 @@ def helper(df):
(lambda dfs: dfs[0].update(dfs[1]), "update"),
]

# unsupported binary operations that are only series
UNSUPPORTED_SERIES_BINARY_METHODS = [
(lambda dfs: dfs[0].rdivmod(dfs[1]), "rdivmod"),
]


# When any unsupported method gets supported, we should run the test to verify (expect failure)
# and remove the corresponding method in the above list.
Expand Down Expand Up @@ -131,7 +124,7 @@ def test_unsupported_series_methods(func, func_name, caplog) -> None:

@pytest.mark.parametrize(
"func, func_name",
UNSUPPORTED_DATAFRAME_SERIES_BINARY_METHODS,
UNSUPPORTED_BINARY_METHODS,
)
@sql_count_checker(query_count=0)
def test_unsupported_dataframe_binary_methods(func, func_name, caplog) -> None:
Expand All @@ -150,7 +143,7 @@ def test_unsupported_dataframe_binary_methods(func, func_name, caplog) -> None:

@pytest.mark.parametrize(
"func, func_name",
UNSUPPORTED_DATAFRAME_SERIES_BINARY_METHODS + UNSUPPORTED_SERIES_BINARY_METHODS,
UNSUPPORTED_BINARY_METHODS,
)
@sql_count_checker(query_count=0)
def test_unsupported_series_binary_methods(func, func_name, caplog) -> None:
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/modin/test_unsupported.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def test_unsupported_df(df_method, kwargs):
["prod", {}],
["ravel", {}],
["reorder_levels", {"order": ""}],
["repeat", {"repeats": ""}],
["rdivmod", {"other": ""}],
["searchsorted", {"value": ""}],
["set_flags", {}],
["swapaxes", {"axis1": "", "axis2": ""}],
Expand All @@ -161,6 +163,7 @@ def test_unsupported_df(df_method, kwargs):
["to_timestamp", {}],
["to_xarray", {}],
["truncate", {}],
["view", {}],
["xs", {"key": ""}],
],
)
Expand Down

0 comments on commit 0fca10f

Please sign in to comment.