Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1348700 Fix index issues in handling of CaseExpr, IsNull, IsNotNull and Like expressions #1704

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
- Fixed a bug in convert_timezone that made the setting the source_timezone parameter return an error.
- Fixed a bug where creating DataFrame with empty data of type `DateType` raises `AttributeError`.
- Fixed a bug that table merge fails when update clause exists but no update takes place.
- Fixed a bug in mock implementation of `to_char` that raises `IndexError` when incoming column has inconsecutive row index.
- Fixed a bug in mock implementation of `to_char` that raises `IndexError` when incoming column has nonconsecutive row index.
- Fixed a bug in handling of `CaseExpr` expressions that raises `IndexError` when incoming column has nonconsecutive row index.
- Fixed a bug in implementation of `Column.like` that raises `IndexError` when incoming column has nonconsecutive row index.

#### Improvements

Expand Down
24 changes: 11 additions & 13 deletions src/snowflake/snowpark/mock/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1643,18 +1643,16 @@ def calculate_expression(
child_column = calculate_expression(
exp.child, input_data, analyzer, expr_to_alias
)
return ColumnEmulator(
data=[bool(data is None) for data in child_column],
sf_type=ColumnType(BooleanType(), True),
)
res = child_column.apply(lambda x: bool(x is None))
res.sf_type = ColumnType(BooleanType(), True)
return res
if isinstance(exp, IsNotNull):
child_column = calculate_expression(
exp.child, input_data, analyzer, expr_to_alias
)
return ColumnEmulator(
data=[bool(data is not None) for data in child_column],
sf_type=ColumnType(BooleanType(), True),
)
res = child_column.apply(lambda x: bool(x is not None))
res.sf_type = ColumnType(BooleanType(), True)
return res
if isinstance(exp, IsNaN):
child_column = calculate_expression(
exp.child, input_data, analyzer, expr_to_alias
Expand Down Expand Up @@ -1813,11 +1811,12 @@ def _match_pattern(row) -> bool:
return result
if isinstance(exp, Like):
lhs = calculate_expression(exp.expr, input_data, analyzer, expr_to_alias)

pattern = convert_wildcard_to_regex(
str(
calculate_expression(exp.pattern, input_data, analyzer, expr_to_alias)[
0
]
calculate_expression(
exp.pattern, input_data, analyzer, expr_to_alias
).iloc[0]
)
)
result = lhs.str.match(pattern)
Expand Down Expand Up @@ -1868,8 +1867,7 @@ def _match_pattern(row) -> bool:
return res
if isinstance(exp, CaseWhen):
remaining = input_data
output_data = ColumnEmulator([None] * len(input_data))
output_data.sf_type = None
output_data = ColumnEmulator([None] * len(input_data), index=input_data.index)
for case in exp.branches:
condition = calculate_expression(
case[0], input_data, analyzer, expr_to_alias
Expand Down
20 changes: 20 additions & 0 deletions tests/mock/test_column.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

from snowflake.snowpark.functions import col, when
from snowflake.snowpark.row import Row


def test_casewhen_with_non_zero_row_index(session):
df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
assert df.filter(col("a") > 1).select(
when(col("a").is_null(), 5).when(col("a") == 1, 6).otherwise(7).as_("a")
).collect() == [Row(A=7)]


def test_like_with_non_zero_row_index(session):
df = session.create_dataframe([["1", 2], ["3", 4]], schema=["a", "b"])
assert df.filter(col("b") > 2).select(
col("a").like("1").alias("res")
sfc-gh-stan marked this conversation as resolved.
Show resolved Hide resolved
).collect() == [Row(RES=False)]
Loading