Skip to content

Commit

Permalink
[SNOW-1527902]: Add support for limit parameter in fillna. (#1891)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-rdurrani authored Jul 10, 2024
1 parent bae2812 commit be50b7b
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#### New Features
- Added partial support for `Series.str.translate` where the values in the `table` are single-codepoint strings.
- Added support for `DataFrame.corr`.
- Added support for `limit` parameter when `method` parameter is used in `fillna`.

#### Bug Fixes
- Fixed an issue when using np.where and df.where when the scalar 'other' is the literal 0.
Expand Down
4 changes: 3 additions & 1 deletion docs/source/modin/supported/dataframe_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ Methods
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``explode`` | N | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``ffill`` | P | | ``N`` if param ``limit`` is set |
| ``ffill`` | P | | ``N`` if parameter ``downcast`` is set. ``limit`` |
| | | | parameter only supported if ``method`` parameter |
| | | | is used. |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``fillna`` | P | | See ``ffill`` |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
Expand Down
4 changes: 3 additions & 1 deletion docs/source/modin/supported/series_supported.rst
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ Methods
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``factorize`` | N | | |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``ffill`` | P | | ``N`` if parameter ``limit`` is set |
| ``ffill`` | P | | ``N`` if parameter ``downcast`` is set. ``limit`` |
| | | | parameter only supported if ``method`` parameter |
| | | | is used. |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
| ``fillna`` | P | | See ``ffill`` |
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8783,7 +8783,10 @@ def where(
return SnowflakeQueryCompiler(new_frame)

def _make_fill_expression_for_column_wise_fillna(
self, snowflake_quoted_identifier: str, method: FillNAMethod
self,
snowflake_quoted_identifier: str,
method: FillNAMethod,
limit: Optional[int] = None,
) -> SnowparkColumn:
"""
Helper function to get the Snowpark Column expression corresponding to snowflake_quoted_id when doing a column wise fillna.
Expand All @@ -8794,6 +8797,8 @@ def _make_fill_expression_for_column_wise_fillna(
The snowflake quoted identifier of the column that we are generating the expression for.
method : FillNAMethod
Enum representing if this method is a ffill method or a bfill method.
limit : optional, int
Maximum number of consecutive NA values to fill.

Returns
-------
Expand All @@ -8815,17 +8820,23 @@ def _make_fill_expression_for_column_wise_fillna(
):
return col(snowflake_quoted_identifier)
if method_is_ffill:
start_pos = 0
if limit is not None:
start_pos = max(col_pos - limit, start_pos)
return coalesce(
snowflake_quoted_identifier,
*self._modin_frame.data_column_snowflake_quoted_identifiers[:col_pos][
::-1
],
*self._modin_frame.data_column_snowflake_quoted_identifiers[
start_pos:col_pos
][::-1],
)
else:
start_pos = len_ids
if limit is not None:
start_pos = min(col_pos + limit, len_ids)
return coalesce(
snowflake_quoted_identifier,
*self._modin_frame.data_column_snowflake_quoted_identifiers[
len_ids:col_pos:-1
start_pos:col_pos:-1
][::-1],
)

Expand Down Expand Up @@ -8857,10 +8868,9 @@ def fillna(
BaseQueryCompiler
New QueryCompiler with all null values filled.
"""
# TODO: SNOW-891788 support limit
if limit:
if value is not None and limit is not None:
ErrorMessage.not_implemented(
"Snowpark pandas fillna API doesn't yet support 'limit' parameter"
"Snowpark pandas fillna API doesn't yet support 'limit' parameter with 'value' parameter"
)
if downcast:
ErrorMessage.not_implemented(
Expand All @@ -8883,12 +8893,18 @@ def fillna(
self._modin_frame = self._modin_frame.ensure_row_position_column()
if method_is_ffill:
func = last_value
window_start = Window.UNBOUNDED_PRECEDING
if limit is None:
window_start = Window.UNBOUNDED_PRECEDING
else:
window_start = -1 * limit
window_end = Window.CURRENT_ROW
else:
func = first_value
window_start = Window.CURRENT_ROW
window_end = Window.UNBOUNDED_FOLLOWING
if limit is None:
window_end = Window.UNBOUNDED_FOLLOWING
else:
window_end = limit
fillna_column_map = {
snowflake_quoted_id: coalesce(
snowflake_quoted_id,
Expand All @@ -8903,7 +8919,7 @@ def fillna(
else:
fillna_column_map = {
snowflake_quoted_id: self._make_fill_expression_for_column_wise_fillna(
snowflake_quoted_id, method
snowflake_quoted_id, method, limit=limit
)
for snowflake_quoted_id in self._modin_frame.data_column_snowflake_quoted_identifiers
}
Expand Down
13 changes: 13 additions & 0 deletions src/snowflake/snowpark/modin/plugin/docstrings/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,15 @@ def fillna():
2 0.0 1.0 2.0 3.0
3 0.0 3.0 2.0 4.0
Only replace the first NaN element.
>>> df.fillna(method="ffill", limit=1)
A B C D
0 NaN 2.0 NaN 0.0
1 3.0 4.0 NaN 1.0
2 3.0 4.0 NaN 1.0
3 NaN 3.0 NaN 4.0
When filling using a DataFrame, replacement happens along
the same column names and same indices
Expand All @@ -1389,6 +1398,10 @@ def fillna():
3 0.0 3.0 0.0 4.0
Note that column D is not affected since it is not present in df2.
Notes
-----
`limit` parameter is only supported when using `method` parameter.
"""

def floordiv():
Expand Down
13 changes: 13 additions & 0 deletions src/snowflake/snowpark/modin/plugin/docstrings/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,6 +1247,15 @@ def fillna():
2 0.0 1.0 2.0 3.0
3 0.0 3.0 2.0 4.0
Only replace the first NaN element.
>>> df.fillna(method="ffill", limit=1)
A B C D
0 NaN 2.0 NaN 0.0
1 3.0 4.0 NaN 1.0
2 3.0 4.0 NaN 1.0
3 NaN 3.0 NaN 4.0
When filling using a DataFrame, replacement happens along
the same column names and same indices
Expand All @@ -1259,6 +1268,10 @@ def fillna():
3 0.0 3.0 0.0 4.0
Note that column D is not affected since it is not present in df2.
Notes
-----
`limit` parameter is only supported when using `method` parameter.
"""

@_create_operator_docstring(
Expand Down
31 changes: 25 additions & 6 deletions tests/integ/modin/frame/test_fillna.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@ def test_fillna_df():
)


@pytest.fixture(scope="function")
def test_fillna_df_limit():
return native_pd.DataFrame(
[
[1, 2, np.nan, 4],
[np.nan, np.nan, 7, np.nan],
[np.nan, 10, np.nan, 12],
[np.nan, np.nan, 15, 16],
],
columns=list("ABCD"),
)


@pytest.fixture(scope="function")
def test_fillna_df_none_index():
# test case to make sure fillna only fill missing values in data columns not index columns
Expand Down Expand Up @@ -260,12 +273,18 @@ def test_value_scalar_inplace(test_fillna_df):
)


@sql_count_checker(query_count=0)
def test_value_scalar_limit_not_implemented(test_fillna_df):
df = pd.DataFrame(test_fillna_df)
msg = "Snowpark pandas fillna API doesn't yet support 'limit' parameter"
with pytest.raises(NotImplementedError, match=msg):
df.fillna(1, limit=1)
@sql_count_checker(query_count=1)
@pytest.mark.parametrize("axis", [0, 1])
@pytest.mark.parametrize("limit", [1, 2, 3, 100])
@pytest.mark.parametrize("method", ["ffill", "bfill"])
def test_fillna_limit(test_fillna_df_limit, method, limit, axis):
native_df = test_fillna_df_limit
if axis == 1:
native_df = native_df.T
snow_df = pd.DataFrame(native_df)
eval_snowpark_pandas_result(
snow_df, native_df, lambda df: df.fillna(method=method, limit=limit, axis=axis)
)


@sql_count_checker(query_count=0)
Expand Down
18 changes: 18 additions & 0 deletions tests/integ/modin/series/test_fillna.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ def test_fillna_series_2():
return native_pd.Series([np.nan, 2, np.nan, 0], list("abcd"))


@pytest.fixture(scope="function")
def test_fillna_series_limit():
return native_pd.Series(
[np.nan, 1, 2, np.nan, np.nan, np.nan, np.nan, 7, np.nan, 9]
)


@pytest.fixture(scope="function")
def test_fillna_df():
return native_pd.DataFrame(
Expand Down Expand Up @@ -125,6 +132,17 @@ def test_value_scalar(test_fillna_series):
)


@sql_count_checker(query_count=1)
@pytest.mark.parametrize("limit", [1, 2, 3, 100])
@pytest.mark.parametrize("method", ["ffill", "bfill"])
def test_fillna_limit(test_fillna_series_limit, limit, method):
eval_snowpark_pandas_result(
pd.Series(test_fillna_series_limit),
test_fillna_series_limit,
lambda s: s.fillna(method=method, limit=limit),
)


@sql_count_checker(query_count=1, join_count=1)
def test_value_series(test_fillna_series, test_fillna_series_2):
eval_snowpark_pandas_result(
Expand Down

0 comments on commit be50b7b

Please sign in to comment.