Skip to content

Commit

Permalink
SNOW-1373556: Implement dropna using a new expression (#1748)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jdu authored Jun 10, 2024
1 parent ffdffa6 commit 23db87b
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#### Bug Fixes

- Fixed a bug where python stored procedure with table return type fails when run in a task.
- Fixed a bug where df.dropna fails due to `RecursionError: maximum recursion depth exceeded` when the DataFrame has more than 500 columns.

### Snowpark Local Testing Updates

Expand Down
12 changes: 12 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
case_when_expression,
cast_expression,
collate_expression,
column_sum,
delete_merge_statement,
empty_values_statement,
flatten_expression,
Expand Down Expand Up @@ -55,6 +56,7 @@
Attribute,
CaseWhen,
Collate,
ColumnSum,
Expression,
FunctionExpression,
InExpression,
Expand Down Expand Up @@ -520,6 +522,16 @@ def analyze(
expr.is_distinct,
)

if isinstance(expr, ColumnSum):
return column_sum(
[
self.analyze(
col, df_aliased_col_name_to_real_col_name, parse_local_name
)
for col in expr.exprs
]
)

if isinstance(expr, RankRelatedFunctionExpression):
return rank_related_function_expression(
expr.sql,
Expand Down
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,6 +1443,10 @@ def list_agg(col: str, delimiter: str, is_distinct: bool) -> str:
)


def column_sum(cols: List[str]) -> str:
return LEFT_PARENTHESIS + PLUS.join(cols) + RIGHT_PARENTHESIS


def generator(row_count: int) -> str:
return (
GENERATOR
Expand Down
9 changes: 9 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,3 +414,12 @@ def __init__(self, col: Expression, delimiter: str, is_distinct: bool) -> None:

def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.col)


class ColumnSum(Expression):
def __init__(self, exprs: List[Expression]) -> None:
super().__init__()
self.exprs = exprs

def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(*self.exprs)
10 changes: 5 additions & 5 deletions src/snowflake/snowpark/dataframe_na_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Dict, Optional, Union

import snowflake.snowpark
from snowflake.snowpark._internal.analyzer.expression import ColumnSum
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.telemetry import add_api_call, adjust_api_subcalls
from snowflake.snowpark._internal.type_utils import (
Expand All @@ -18,6 +19,7 @@
python_type_to_snow_type,
)
from snowflake.snowpark._internal.utils import quote_name
from snowflake.snowpark.column import Column
from snowflake.snowpark.functions import iff, lit, when
from snowflake.snowpark.types import (
DataType,
Expand Down Expand Up @@ -192,7 +194,7 @@ def drop(
for field in self._df.schema.fields
}
normalized_col_name_set = {quote_name(col_name) for col_name in subset}
col_counter = None
is_na_columns = []
for normalized_col_name in normalized_col_name_set:
if normalized_col_name not in df_col_type_dict:
raise SnowparkClientExceptionMessages.DF_CANNOT_RESOLVE_COLUMN_NAME(
Expand All @@ -207,10 +209,8 @@ def drop(
else:
# iff(col is null, 0, 1)
is_na = iff(col.is_null(), 0, 1)
if col_counter is not None:
col_counter += is_na
else:
col_counter = is_na
is_na_columns.append(is_na)
col_counter = Column(ColumnSum([c._expression for c in is_na_columns]))
new_df = self._df.where(col_counter >= thresh)
adjust_api_subcalls(new_df, "DataFrameNaFunctions.drop", len_subcalls=1)
return new_df
Expand Down
10 changes: 10 additions & 0 deletions src/snowflake/snowpark/mock/_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
case_when_expression,
cast_expression,
collate_expression,
column_sum,
delete_merge_statement,
flatten_expression,
function_expression,
Expand Down Expand Up @@ -52,6 +53,7 @@
Attribute,
CaseWhen,
Collate,
ColumnSum,
Expression,
FunctionExpression,
InExpression,
Expand Down Expand Up @@ -416,6 +418,14 @@ def analyze(
expr.is_distinct,
)

if isinstance(expr, ColumnSum):
return column_sum(
[
self.analyze(col, expr_to_alias, parse_local_name)
for col in expr.exprs
]
)

if isinstance(expr, RankRelatedFunctionExpression):
return rank_related_function_expression(
expr.sql,
Expand Down
9 changes: 8 additions & 1 deletion src/snowflake/snowpark/mock/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import uuid
from collections.abc import Iterable
from enum import Enum
from functools import cached_property, partial
from functools import cached_property, partial, reduce
from typing import TYPE_CHECKING, Dict, List, NoReturn, Optional, Union
from unittest.mock import MagicMock

Expand Down Expand Up @@ -83,6 +83,7 @@
from snowflake.snowpark._internal.analyzer.expression import (
Attribute,
CaseWhen,
ColumnSum,
Expression,
FunctionExpression,
InExpression,
Expand Down Expand Up @@ -1771,6 +1772,12 @@ def calculate_expression(
raise_error=NotImplementedError,
)
return new_column
elif isinstance(exp, ColumnSum):
cols = [
calculate_expression(e, input_data, analyzer, expr_to_alias)
for e in exp.exprs
]
return reduce(ColumnEmulator.add, cols)
if isinstance(exp, UnaryMinus):
res = calculate_expression(exp.child, input_data, analyzer, expr_to_alias)
return -res
Expand Down
9 changes: 9 additions & 0 deletions tests/integ/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2236,6 +2236,15 @@ def test_dropna(session, local_testing_mode):
assert "subset should be a list or tuple of column names" in str(ex_info)


@pytest.mark.localtest
def test_dropna_large_num_of_columns(session):
n = 1000
data = [str(i) for i in range(n)]
none_data = [None for _ in range(n)]
df = session.create_dataframe([data, none_data], schema=data)
Utils.check_answer(df.dropna(how="all"), [Row(*data)])


@pytest.mark.localtest
def test_fillna(session, local_testing_mode):
if not local_testing_mode: # Enable for local testing after coercion support
Expand Down

0 comments on commit 23db87b

Please sign in to comment.