Skip to content

Commit

Permalink
[Local Testing] SNOW-929078 Support Table.update/delete/merge (#1107)
Browse files Browse the repository at this point in the history
* Add changes and enable tests: need refactor

* Add documentation

* Add refactor changes

* Address comments

* Remove ROW_ID
  • Loading branch information
sfc-gh-stan authored Nov 17, 2023
1 parent cb4d1e4 commit bff4206
Show file tree
Hide file tree
Showing 3 changed files with 306 additions and 15 deletions.
12 changes: 3 additions & 9 deletions src/snowflake/snowpark/mock/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,24 +743,18 @@ def do_resolve_with_resolved_children(
)

if isinstance(logical_plan, TableUpdate):
raise NotImplementedError(
"[Local Testing] Table update is not implemented."
)
return MockExecutionPlan(logical_plan, self.session)

if isinstance(logical_plan, TableDelete):
raise NotImplementedError(
"[Local Testing] Table delete is not implemented."
)
return MockExecutionPlan(logical_plan, self.session)

if isinstance(logical_plan, CreateDynamicTableCommand):
raise NotImplementedError(
"[Local Testing] Dynamic tables are currently not supported."
)

if isinstance(logical_plan, TableMerge):
raise NotImplementedError(
"[Local Testing] Table merge is currently not implemented."
)
return MockExecutionPlan(logical_plan, self.session)

if isinstance(logical_plan, MockSelectable):
return MockExecutionPlan(logical_plan, self.session)
Expand Down
280 changes: 279 additions & 1 deletion src/snowflake/snowpark/mock/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
from typing import TYPE_CHECKING, Dict, List, NoReturn, Optional, Union
from unittest.mock import MagicMock

from snowflake.snowpark._internal.analyzer.table_merge_expression import (
DeleteMergeExpression,
InsertMergeExpression,
TableDelete,
TableMerge,
TableUpdate,
UpdateMergeExpression,
)
from snowflake.snowpark._internal.analyzer.window_expression import (
FirstValue,
Lag,
Expand All @@ -25,6 +33,7 @@
UnboundedPreceding,
WindowExpression,
)
from snowflake.snowpark._internal.utils import generate_random_alphanumeric
from snowflake.snowpark.mock.window_utils import (
EntireWindowIndexer,
RowFrameIndexer,
Expand Down Expand Up @@ -86,6 +95,7 @@
from snowflake.snowpark._internal.analyzer.snowflake_plan_node import (
LogicalPlan,
Range,
SaveMode,
SnowflakeCreateTable,
SnowflakeValues,
UnresolvedRelation,
Expand Down Expand Up @@ -821,11 +831,279 @@ def outer_join(base_df):
frac=source_plan.probability_fraction,
random_state=source_plan.seed,
)
elif isinstance(source_plan, CreateViewCommand):
if isinstance(source_plan, CreateViewCommand):
from_df = execute_mock_plan(source_plan.child, expr_to_alias)
view_name = source_plan.name
entity_registry.create_or_replace_view(source_plan.child, view_name)
return from_df

if isinstance(source_plan, TableUpdate):
target = entity_registry.read_table(source_plan.table_name)
ROW_ID = "row_id_" + generate_random_alphanumeric()
target.insert(0, ROW_ID, range(len(target)))

if source_plan.source_data:
# Calculate cartesian product
source = execute_mock_plan(source_plan.source_data, expr_to_alias)
cartesian_product = target.merge(source, on=None, how="cross")
cartesian_product.sf_types.update(target.sf_types)
cartesian_product.sf_types.update(source.sf_types)
intermediate = cartesian_product
else:
intermediate = target

if source_plan.condition:
# Select rows to be updated based on condition
condition = calculate_expression(
source_plan.condition, intermediate, analyzer, expr_to_alias
)

matched = target.apply(tuple, 1).isin(
intermediate[condition][target.columns].apply(tuple, 1)
)
matched.sf_type = ColumnType(BooleanType(), True)
matched_rows = target[matched]
intermediate = intermediate[condition]
else:
matched_rows = target

# Calculate multi_join
matched_count = intermediate[target.columns].value_counts()[
matched_rows.apply(tuple, 1)
]
multi_joins = matched_count.where(lambda x: x > 1).count()

# Select rows that match the condition to be updated
rows_to_update = intermediate.drop_duplicates(
subset=matched_rows.columns, keep="first"
).reset_index( # ERROR_ON_NONDETERMINISTIC_UPDATE is by default False, pick one row to update
drop=True
)
rows_to_update.sf_types = intermediate.sf_types

# Update rows in place
for attr, new_expr in source_plan.assignments.items():
column_name = analyzer.analyze(attr, expr_to_alias)
target_index = target.loc[rows_to_update[ROW_ID]].index
new_val = calculate_expression(
new_expr, rows_to_update, analyzer, expr_to_alias
)
new_val.index = target_index
target.loc[rows_to_update[ROW_ID], column_name] = new_val

# Delete row_id
target = target.drop(ROW_ID, axis=1)

# Write result back to table
entity_registry.write_table(source_plan.table_name, target, SaveMode.OVERWRITE)
return [Row(len(rows_to_update), multi_joins)]
elif isinstance(source_plan, TableDelete):
target = entity_registry.read_table(source_plan.table_name)

if source_plan.source_data:
# Calculate cartesian product
source = execute_mock_plan(source_plan.source_data, expr_to_alias)
cartesian_product = target.merge(source, on=None, how="cross")
cartesian_product.sf_types.update(target.sf_types)
cartesian_product.sf_types.update(source.sf_types)
intermediate = cartesian_product
else:
intermediate = target

# Select rows to keep based on condition
if source_plan.condition:
condition = calculate_expression(
source_plan.condition, intermediate, analyzer, expr_to_alias
)
intermediate = intermediate[condition]
matched = target.apply(tuple, 1).isin(
intermediate[target.columns].apply(tuple, 1)
)
matched.sf_type = ColumnType(BooleanType(), True)
rows_to_keep = target[~matched]
else:
rows_to_keep = target.head(0)

# Write rows to keep to table registry
entity_registry.write_table(
source_plan.table_name, rows_to_keep, SaveMode.OVERWRITE
)
return [Row(len(target) - len(rows_to_keep))]
elif isinstance(source_plan, TableMerge):
target = entity_registry.read_table(source_plan.table_name)
ROW_ID = "row_id_" + generate_random_alphanumeric()
SOURCE_ROW_ID = "source_row_id_" + generate_random_alphanumeric()
# Calculate cartesian product
source = execute_mock_plan(source_plan.source, expr_to_alias)

# Insert row_id and source row_id
target.insert(0, ROW_ID, range(len(target)))
source.insert(0, SOURCE_ROW_ID, range(len(source)))

cartesian_product = target.merge(source, on=None, how="cross")
cartesian_product.sf_types.update(target.sf_types)
cartesian_product.sf_types.update(source.sf_types)
join_condition = calculate_expression(
source_plan.join_expr, cartesian_product, analyzer, expr_to_alias
)
join_result = cartesian_product[join_condition]
join_result.sf_types = cartesian_product.sf_types

# TODO [GA]: # ERROR_ON_NONDETERMINISTIC_MERGE is by default True, raise error if
# (1) A target row is selected to be updated with multiple values OR
# (2) A target row is selected to be both updated and deleted

inserted_rows = []
inserted_row_idx = set() # source_row_id
deleted_row_idx = set()
updated_row_idx = set()
for clause in source_plan.clauses:
if isinstance(clause, UpdateMergeExpression):
# Select rows to update
if clause.condition:
condition = calculate_expression(
clause.condition, join_result, analyzer, expr_to_alias
)
rows_to_update = join_result[condition]
else:
rows_to_update = join_result

rows_to_update = rows_to_update[
~rows_to_update[ROW_ID]
.isin(updated_row_idx.union(deleted_row_idx))
.values
]

# Update rows in place
for attr, new_expr in clause.assignments.items():
column_name = analyzer.analyze(attr, expr_to_alias)
target_index = target.loc[rows_to_update[ROW_ID]].index
new_val = calculate_expression(
new_expr, rows_to_update, analyzer, expr_to_alias
)
new_val.index = target_index
target.loc[rows_to_update[ROW_ID], column_name] = new_val

# Update updated row id set
for _, row in rows_to_update.iterrows():
updated_row_idx.add(row[ROW_ID])

elif isinstance(clause, DeleteMergeExpression):
# Select rows to delete
if clause.condition:
condition = calculate_expression(
clause.condition, join_result, analyzer, expr_to_alias
)
intermediate = join_result[condition]
else:
intermediate = join_result

matched = target.apply(tuple, 1).isin(
intermediate[target.columns].apply(tuple, 1)
)
matched.sf_type = ColumnType(BooleanType(), True)

# Update deleted row id set
for _, row in target[matched].iterrows():
deleted_row_idx.add(row[ROW_ID])

# Delete rows in place
target = target[~matched]

elif isinstance(clause, InsertMergeExpression):
# calculate unmatched rows in the source
matched = source.apply(tuple, 1).isin(
join_result[source.columns].apply(tuple, 1)
)
matched.sf_type = ColumnType(BooleanType(), True)
unmatched_rows_in_source = source[~matched]

# select unmatched rows that qualify the condition
if clause.condition:
condition = calculate_expression(
clause.condition,
unmatched_rows_in_source,
analyzer,
expr_to_alias,
)
unmatched_rows_in_source = unmatched_rows_in_source[condition]

# filter out the unmatched rows that have been inserted in previous clauses
unmatched_rows_in_source = unmatched_rows_in_source[
~unmatched_rows_in_source[SOURCE_ROW_ID]
.isin(inserted_row_idx)
.values
]

# update inserted row idx set
for _, row in unmatched_rows_in_source.iterrows():
inserted_row_idx.add(row[SOURCE_ROW_ID])

# Calculate rows to insert
rows_to_insert = TableEmulator(
[], columns=target.drop(ROW_ID, axis=1).columns
)
rows_to_insert.sf_types = target.sf_types
if clause.keys:
# Keep track of specified columns
inserted_columns = set()
for k, v in zip(clause.keys, clause.values):
column_name = analyzer.analyze(k, expr_to_alias)
if column_name not in rows_to_insert.columns:
raise SnowparkSQLException(
f"Error: invalid identifier '{column_name}'"
)
inserted_columns.add(column_name)
new_val = calculate_expression(
v, unmatched_rows_in_source, analyzer, expr_to_alias
)
rows_to_insert[column_name] = new_val

# For unspecified columns, use None as default value
for unspecified_col in set(rows_to_insert.columns).difference(
inserted_columns
):
rows_to_insert[unspecified_col].replace(
np.nan, None, inplace=True
)

else:
if len(clause.values) != len(rows_to_insert.columns):
raise SnowparkSQLException(
f"Insert value list does not match column list expecting {len(rows_to_insert.columns)} but got {len(clause.values)}"
)
for col, v in zip(rows_to_insert.columns, clause.values):
new_val = calculate_expression(
v, unmatched_rows_in_source, analyzer, expr_to_alias
)
rows_to_insert[col] = new_val

inserted_rows.append(rows_to_insert)

# Remove inserted ROW ID column
target = target.drop(ROW_ID, axis=1)

# Process inserted rows
if inserted_rows:
res = pd.concat([target] + inserted_rows)
res.sf_types = target.sf_types
else:
res = target

# Write the result back to table
entity_registry.write_table(source_plan.table_name, res, SaveMode.OVERWRITE)

# Generate metadata result
res = []
if inserted_rows:
res.append(len(inserted_row_idx))
if updated_row_idx:
res.append(len(updated_row_idx))
if deleted_row_idx:
res.append(len(deleted_row_idx))

return [Row(*res)]

raise NotImplementedError(
f"[Local Testing] Mocking SnowflakePlan {type(source_plan).__name__} is not implemented."
)
Expand Down
Loading

0 comments on commit bff4206

Please sign in to comment.