Skip to content

Commit

Permalink
SNOW-1526571: fix implementation of lead and lag handling ignore null…
Browse files Browse the repository at this point in the history
… option (#1959)
  • Loading branch information
sfc-gh-aling authored Jul 29, 2024
1 parent 8166d7a commit 61ba224
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 17 deletions.
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
- Fixed a bug in `DataFrame.to_pandas_batches` where the iterator could throw an error if certain transformation is made to the pandas dataframe due to wrong isolation level.

### Snowpark Local Testing Updates
#### New Features

#### New Features
- Added support for the following APIs:
- snowflake.snowpark.functions
- rank
Expand All @@ -25,8 +25,8 @@
- datediff

#### Bug Fixes

Fixed a bug where values were not populated into the result DataFrame during the insertion of table merge operation.
- Fixed a bug that Window Functions LEAD and LAG do not handle option `ignore_nulls` properly.
- Fixed a bug where values were not populated into the result DataFrame during the insertion of table merge operation.

### Snowpark pandas API Updates
#### New Features
Expand Down
51 changes: 37 additions & 14 deletions src/snowflake/snowpark/mock/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2103,29 +2103,52 @@ def _match_pattern(row) -> bool:
delta = 1 if offset > 0 else -1
cur_idx = row_idx + delta
cur_count = 0
# default calc_expr is None for the case of cur_idx < 0 or cur_idx >= len(w)
# if cur_idx is within the value, it will be overwritten by the following valid value
calc_expr = ColumnEmulator(
[None], sf_type=ColumnType(NullType(), True), dtype=object
)
target_value = calc_expr.iloc[0]
while 0 <= cur_idx < len(w):
target_expr = calculate_expression(
calc_expr = calculate_expression(
window_function.expr,
w.iloc[[cur_idx]],
analyzer,
expr_to_alias,
).iloc[0]
if target_expr is not None:
)
target_value = calc_expr.iloc[0]
if target_value is not None:
cur_count += 1
if cur_count == abs(offset):
break
cur_idx += delta
if cur_idx < 0 or cur_idx >= len(w):
res_cols.append(
calculate_expression(
window_function.default,
w,
analyzer,
expr_to_alias,
).iloc[0]
)
else:
res_cols.append(target_expr)
if not calculated_sf_type:
calculated_sf_type = calc_expr.sf_type
elif calculated_sf_type.datatype != calc_expr.sf_type.datatype:
if isinstance(calculated_sf_type.datatype, NullType):
calculated_sf_type = calc_expr.sf_type
# the result calculated upon a windows can be None, this is still valid and we can keep
# the calculation
elif not isinstance( # pragma: no cover
calc_expr.sf_type.datatype, NullType
):
analyzer.session._conn.log_not_supported_error( # pragma: no cover
external_feature_name=f"Coercion of detected type"
f" {type(calculated_sf_type.datatype).__name__}"
f" and type {type(calc_expr.sf_type.datatype).__name__}",
internal_feature_name=type(exp).__name__,
parameters_info={
"window_function": type(window_function).__name__,
"calc_expr.sf_type.datatype": str(
type(calc_expr.sf_type.datatype).__name__
),
"calculated_sf_type.datatype": str(
type(calculated_sf_type.datatype).__name__
),
},
raise_error=SnowparkLocalTestingException,
)
res_cols.append(target_value)
res_col = ColumnEmulator(
data=res_cols, dtype=object
) # dtype=object prevents implicit converting None to Nan
Expand Down
37 changes: 37 additions & 0 deletions tests/integ/scala/test_function_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -4648,6 +4648,43 @@ def test_lead(session, col_z, local_testing_mode):
)


def test_lead_lag_nulls(session):
data = [
(1, 1, 1, None),
(1, 1, 2, None),
(1, 1, 3, 1),
(1, 1, 4, None),
(1, 1, 5, None),
(1, 1, 6, 2),
(1, 1, 7, None),
(1, 1, 8, None),
]
schema = ["COL1", "COL2", "COL3", "COLUMN_TO_FILL"]
df = session.create_dataframe(data=data, schema=schema)

window = Window.partition_by(["COL1", "COL2"]).order_by("COL3")

lead_col = lead(df.col("COLUMN_TO_FILL"), ignore_nulls=True).over(window)
lag_col = lag(df.col("COLUMN_TO_FILL"), ignore_nulls=True).over(window)
max_lead_lag = iff(lead_col > lag_col, lead_col, lag_col)

final_df = df.with_column("MAX_LEAD_LAG", max_lead_lag)
Utils.check_answer(
final_df,
[
Row(1, 1, 1, None, None),
Row(1, 1, 2, None, None),
Row(1, 1, 3, 1, None),
Row(1, 1, 4, None, 2),
Row(1, 1, 5, None, 2),
Row(1, 1, 6, 2, 1),
Row(1, 1, 7, None, 2),
Row(1, 1, 8, None, 2),
],
sort=False,
)


@pytest.mark.parametrize("col_z", ["Z", col("Z")])
def test_last_value(session, col_z):
Utils.check_answer(
Expand Down

0 comments on commit 61ba224

Please sign in to comment.