diff --git a/CHANGELOG.md b/CHANGELOG.md index b887aafce07..bf440b675ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,7 +27,9 @@ - Fixed a bug in convert_timezone that made the setting the source_timezone parameter return an error. - Fixed a bug where creating DataFrame with empty data of type `DateType` raises `AttributeError`. - Fixed a bug that table merge fails when update clause exists but no update takes place. -- Fixed a bug in mock implementation of `to_char` that raises `IndexError` when incoming column has inconsecutive row index. +- Fixed a bug in mock implementation of `to_char` that raises `IndexError` when incoming column has nonconsecutive row index. +- Fixed a bug in handling of `CaseExpr` expressions that raises `IndexError` when incoming column has nonconsecutive row index. +- Fixed a bug in implementation of `Column.like` that raises `IndexError` when incoming column has nonconsecutive row index. #### Improvements diff --git a/src/snowflake/snowpark/mock/_plan.py b/src/snowflake/snowpark/mock/_plan.py index e88edca5fa3..d2503b80e21 100644 --- a/src/snowflake/snowpark/mock/_plan.py +++ b/src/snowflake/snowpark/mock/_plan.py @@ -1643,18 +1643,16 @@ def calculate_expression( child_column = calculate_expression( exp.child, input_data, analyzer, expr_to_alias ) - return ColumnEmulator( - data=[bool(data is None) for data in child_column], - sf_type=ColumnType(BooleanType(), True), - ) + res = child_column.apply(lambda x: bool(x is None)) + res.sf_type = ColumnType(BooleanType(), True) + return res if isinstance(exp, IsNotNull): child_column = calculate_expression( exp.child, input_data, analyzer, expr_to_alias ) - return ColumnEmulator( - data=[bool(data is not None) for data in child_column], - sf_type=ColumnType(BooleanType(), True), - ) + res = child_column.apply(lambda x: bool(x is not None)) + res.sf_type = ColumnType(BooleanType(), True) + return res if isinstance(exp, IsNaN): child_column = calculate_expression( exp.child, input_data, analyzer, expr_to_alias @@ -1813,11 +1811,12 @@ def _match_pattern(row) -> bool: return result if isinstance(exp, Like): lhs = calculate_expression(exp.expr, input_data, analyzer, expr_to_alias) + pattern = convert_wildcard_to_regex( str( - calculate_expression(exp.pattern, input_data, analyzer, expr_to_alias)[ - 0 - ] + calculate_expression( + exp.pattern, input_data, analyzer, expr_to_alias + ).iloc[0] ) ) result = lhs.str.match(pattern) @@ -1868,8 +1867,7 @@ def _match_pattern(row) -> bool: return res if isinstance(exp, CaseWhen): remaining = input_data - output_data = ColumnEmulator([None] * len(input_data)) - output_data.sf_type = None + output_data = ColumnEmulator([None] * len(input_data), index=input_data.index) for case in exp.branches: condition = calculate_expression( case[0], input_data, analyzer, expr_to_alias diff --git a/tests/mock/test_column.py b/tests/mock/test_column.py new file mode 100644 index 00000000000..9a2e27f37ef --- /dev/null +++ b/tests/mock/test_column.py @@ -0,0 +1,20 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +from snowflake.snowpark.functions import col, when +from snowflake.snowpark.row import Row + + +def test_casewhen_with_non_zero_row_index(session): + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + assert df.filter(col("a") > 1).select( + when(col("a").is_null(), 5).when(col("a") == 1, 6).otherwise(7).as_("a") + ).collect() == [Row(A=7)] + + +def test_like_with_non_zero_row_index(session): + df = session.create_dataframe([["1", 2], ["3", 4]], schema=["a", "b"]) + assert df.filter(col("b") > 2).select( + col("a").like("1").alias("res") + ).collect() == [Row(RES=False)]