diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index a42e0682d6e..5e01650b03a 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -184,6 +184,27 @@ def __setitem__(self, col_name: str, col_state: ColumnState) -> None: self.has_new_columns = True +def _deepcopy_selectable_fields( + from_selectable: "Selectable", to_selectable: "Selectable" +) -> None: + """ + Make a deep copy of the fields from the from_selectable to the to_selectable + """ + to_selectable.pre_actions = deepcopy(from_selectable.pre_actions) + to_selectable.post_actions = deepcopy(from_selectable.post_actions) + to_selectable.flatten_disabled = from_selectable.flatten_disabled + to_selectable._column_states = deepcopy(from_selectable._column_states) + to_selectable.expr_to_alias = deepcopy(from_selectable.expr_to_alias) + to_selectable.df_aliased_col_name_to_real_col_name = deepcopy( + from_selectable.df_aliased_col_name_to_real_col_name + ) + # the snowflake plan for selectable typically just point to self, + # to avoid run into recursively copy self problem, we always let it + # rebuild, as far as we have other fields copied correctly, we should + # be able to recover the plan. + to_selectable._snowflake_plan = None + + class Selectable(LogicalPlan, ABC): """The parent abstract class of a DataFrame's logical plan. It can be converted to and from a SnowflakePlan.""" @@ -359,6 +380,12 @@ def __init__( super().__init__(analyzer) self.entity = entity + def __deepcopy__(self, memodict={}) -> "SelectableEntity": # noqa: B006 + copied = SelectableEntity(self.entity_name, analyzer=self.analyzer) + _deepcopy_selectable_fields(from_selectable=self, to_selectable=copied) + + return copied + @property def sql_query(self) -> str: return f"{analyzer_utils.SELECT}{analyzer_utils.STAR}{analyzer_utils.FROM}{self.entity.name}" @@ -419,6 +446,26 @@ def __init__( self._schema_query = sql self._query_param = params + def __deepcopy__(self, memodict={}) -> "SelectSQL": # noqa: B006 + copied = SelectSQL( + sql=self.original_sql, + # when convert_to_select is True, a describe call might be triggered + # to construct the schema query. Since this is a pure copy method, and all + # fields can be done with a pure copy, we set this parameter to False on + # object construct, and correct the fields after. + convert_to_select=False, + analyzer=self.analyzer, + params=deepcopy(self.query_params), + ) + _deepcopy_selectable_fields(from_selectable=self, to_selectable=copied) + # copy over the other fields + copied.convert_to_select = self.convert_to_select + copied._sql_query = self._sql_query + copied._schema_query = self._schema_query + copied._query_param = deepcopy(self._query_param) + + return copied + @property def sql_query(self) -> str: return self._sql_query @@ -485,6 +532,15 @@ def __init__(self, snowflake_plan: LogicalPlan, *, analyzer: "Analyzer") -> None if query.params: self._query_params.extend(query.params) + def __deepcopy__(self, memodict={}) -> "SelectSnowflakePlan": # noqa: B006 + copied = SelectSnowflakePlan( + snowflake_plan=deepcopy(self._snowflake_plan), analyzer=self.analyzer + ) + _deepcopy_selectable_fields(from_selectable=self, to_selectable=copied) + self._query_params = deepcopy(self._query_params) + copied._snowflake_plan = deepcopy(self._snowflake_plan) + return copied + @property def snowflake_plan(self): return self._snowflake_plan @@ -577,6 +633,23 @@ def __copy__(self): return new + def __deepcopy__(self, memodict={}) -> "SelectStatement": # noqa: B006 + copied = SelectStatement( + projection=deepcopy(self.projection), + from_=deepcopy(self.from_), + where=deepcopy(self.where), + order_by=deepcopy(self.order_by), + limit_=deepcopy(self.limit_), + offset=self.offset, + analyzer=self.analyzer, + schema_query=self.schema_query, + ) + + _deepcopy_selectable_fields(from_selectable=self, to_selectable=copied) + copied._projection_in_str = self._projection_in_str + copied._query_params = deepcopy(self._query_params) + return copied + @property def column_states(self) -> ColumnStateDict: if self._column_states is None: @@ -1043,6 +1116,16 @@ def __init__( self.post_actions = self._snowflake_plan.post_actions self._api_calls = self._snowflake_plan.api_calls + def __deepcopy__(self, memodict={}) -> "SelectTableFunction": # noqa: B006 + copied = SelectTableFunction( + func_expr=deepcopy(self.func_expr), analyzer=self.analyzer + ) + _deepcopy_selectable_fields(from_selectable=self, to_selectable=copied) + # need to make a copy of the SnowflakePlan for SelectTableFunction + copied._snowflake_plan = deepcopy(self._snowflake_plan) + + return copied + @property def snowflake_plan(self): return self._snowflake_plan @@ -1093,6 +1176,14 @@ def __init__(self, *set_operands: SetOperand, analyzer: "Analyzer") -> None: self.post_actions.extend(operand.selectable.post_actions) self._nodes.append(operand.selectable) + def __deepcopy__(self, memodict={}) -> "SetStatement": # noqa: B006 + copied = SetStatement(*deepcopy(self.set_operands), analyzer=self.analyzer) + _deepcopy_selectable_fields(from_selectable=self, to_selectable=copied) + copied._placeholder_query = self._placeholder_query + copied._sql_query = self._sql_query + + return copied + @property def sql_query(self) -> str: if not self._sql_query: diff --git a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py index 6cc20f71f4b..fe596f13d19 100644 --- a/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py +++ b/src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py @@ -441,6 +441,32 @@ def __copy__(self) -> "SnowflakePlan": placeholder_query=self.placeholder_query, ) + def __deepcopy__(self, memodict={}) -> "SnowflakePlan": # noqa: B006 + copied_plan = SnowflakePlan( + queries=copy.deepcopy(self.queries) if self.queries else [], + schema_query=self.schema_query, + post_actions=copy.deepcopy(self.post_actions) + if self.post_actions + else None, + expr_to_alias=copy.deepcopy(self.expr_to_alias) + if self.expr_to_alias + else None, + source_plan=copy.deepcopy(self.source_plan) if self.source_plan else None, + is_ddl_on_temp_object=self.is_ddl_on_temp_object, + api_calls=copy.deepcopy(self.api_calls) if self.api_calls else None, + df_aliased_col_name_to_real_col_name=copy.deepcopy( + self.df_aliased_col_name_to_real_col_name + ) + if self.df_aliased_col_name_to_real_col_name + else None, + placeholder_query=self.placeholder_query, + # note that there is no copy of the session object, be careful when using the + # session object after deepcopy + session=self.session, + ) + + return copied_plan + def add_aliases(self, to_add: Dict) -> None: self.expr_to_alias = {**self.expr_to_alias, **to_add} diff --git a/tests/integ/test_deepcopy.py b/tests/integ/test_deepcopy.py new file mode 100644 index 00000000000..d0b6d4cb1d5 --- /dev/null +++ b/tests/integ/test_deepcopy.py @@ -0,0 +1,299 @@ +# +# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. +# + +import copy +from typing import List + +import pytest + +from snowflake.snowpark._internal.analyzer.expression import Attribute +from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( + PlanNodeCategory, +) +from snowflake.snowpark._internal.analyzer.select_statement import ( + ColumnStateDict, + Selectable, + SelectableEntity, + SelectSQL, + SelectTableFunction, + SetStatement, +) +from snowflake.snowpark._internal.analyzer.snowflake_plan import SnowflakePlan +from snowflake.snowpark._internal.analyzer.snowflake_plan_node import ( + LogicalPlan, + SaveMode, + SnowflakeCreateTable, +) +from snowflake.snowpark._internal.analyzer.unary_plan_node import ( + CreateViewCommand, + LocalTempView, +) +from snowflake.snowpark._internal.utils import ( + TempObjectType, + random_name_for_temp_object, +) +from snowflake.snowpark.functions import col, lit, seq1, uniform + +pytestmark = [ + pytest.mark.xfail( + "config.getoption('local_testing_mode', default=False)", + reason="deepcopy is not supported and required by local testing", + run=False, + ) +] + + +def verify_column_state( + copied_state: ColumnStateDict, original_state: ColumnStateDict +) -> None: + assert copied_state.has_changed_columns == original_state.has_changed_columns + assert copied_state.has_new_columns == original_state.has_new_columns + assert copied_state.has_dropped_columns == original_state.has_dropped_columns + assert copied_state.dropped_columns == original_state.dropped_columns + assert copied_state.active_columns == original_state.active_columns + assert ( + copied_state.columns_referencing_all_columns + == original_state.columns_referencing_all_columns + ) + + +def verify_logical_plan_node( + copied_node: LogicalPlan, original_node: LogicalPlan +) -> None: + if copied_node is None and original_node is None: + return + + assert type(copied_node) == type(original_node) + # verify the node complexity + assert ( + copied_node.individual_node_complexity + == original_node.individual_node_complexity + ) + assert ( + copied_node.cumulative_node_complexity + == original_node.cumulative_node_complexity + ) + # verify update accumulative complexity of copied node doesn't impact original node + original_complexity = copied_node.cumulative_node_complexity + copied_node.cumulative_node_complexity = {PlanNodeCategory.OTHERS: 10000} + assert copied_node.cumulative_node_complexity == {PlanNodeCategory.OTHERS: 10000} + assert original_node.cumulative_node_complexity == original_complexity + copied_node.cumulative_node_complexity = original_complexity + + if isinstance(copied_node, Selectable) and isinstance(original_node, Selectable): + verify_column_state(copied_node.column_states, original_node.column_states) + assert copied_node.flatten_disabled == original_node.flatten_disabled + assert ( + copied_node.df_aliased_col_name_to_real_col_name + == original_node.df_aliased_col_name_to_real_col_name + ) + if isinstance(copied_node, SetStatement) and isinstance( + original_node, SetStatement + ): + assert copied_node._sql_query == original_node._sql_query + if isinstance(copied_node, SelectTableFunction) and isinstance( + original_node, SelectTableFunction + ): + # check the source snowflake_plan + assert (copied_node._snowflake_plan is not None) and ( + original_node._snowflake_plan is not None + ) + check_copied_plan(copied_node._snowflake_plan, original_node._snowflake_plan) + + if isinstance(copied_node, Selectable) and isinstance(original_node, Selectable): + copied_child_plan_nodes = copied_node.children_plan_nodes + original_child_plan_nodes = original_node.children_plan_nodes + for (copied_plan_node, original_plan_node) in zip( + copied_child_plan_nodes, original_child_plan_nodes + ): + verify_logical_plan_node(copied_plan_node, original_plan_node) + + +def verify_snowflake_plan_attribute( + copied_plan_attribute: List[Attribute], original_plan_attributes: List[Attribute] +) -> None: + for copied_attribute, original_attribute in zip( + copied_plan_attribute, original_plan_attributes + ): + assert copied_attribute.name == original_attribute.name + assert copied_attribute.datatype == original_attribute.datatype + assert copied_attribute.nullable == original_attribute.nullable + + +def check_copied_plan(copied_plan: SnowflakePlan, original_plan: SnowflakePlan) -> None: + # verify the instance type is the same + assert type(copied_plan) == type(original_plan) + assert copied_plan.queries == original_plan.queries + assert copied_plan.post_actions == original_plan.post_actions + assert ( + copied_plan.df_aliased_col_name_to_real_col_name + == original_plan.df_aliased_col_name_to_real_col_name + ) + assert ( + copied_plan.cumulative_node_complexity + == original_plan.cumulative_node_complexity + ) + assert ( + copied_plan.individual_node_complexity + == original_plan.individual_node_complexity + ) + assert copied_plan.is_ddl_on_temp_object == original_plan.is_ddl_on_temp_object + assert copied_plan.api_calls == original_plan.api_calls + assert copied_plan.expr_to_alias == original_plan.expr_to_alias + assert copied_plan.schema_query == original_plan.schema_query + verify_snowflake_plan_attribute(copied_plan.attributes, original_plan.attributes) + + # verify changes in the copied plan doesn't impact original plan + original_sql = original_plan.queries[-1].sql + copied_plan.queries[-1].sql = "NEW TEST SQL" + assert original_plan.queries[-1].sql == original_sql + # should reset the query back for later comparison + copied_plan.queries[-1].sql = original_sql + + # verify the source plan root node + copied_source_plan = copied_plan.source_plan + original_source_plan = original_plan.source_plan + verify_logical_plan_node(copied_source_plan, original_source_plan) + + +@pytest.mark.parametrize( + "action", + [ + lambda x: x.select("a", "b").select("b"), + lambda x: x.filter(col("a") == 1).select("b"), + lambda x: x.drop("b").sort("a", ascending=False), + lambda x: x.to_df("a1", "b1").alias("L"), + ], +) +def test_selectable_deepcopy(session, action): + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + df_res = action(df) + # make a copy of the plan for df_res + copied_plan = copy.deepcopy(df_res._plan) + # verify copied plan + check_copied_plan(copied_plan, df_res._plan) + + +@pytest.mark.parametrize( + "action", + [ + lambda x, y: x.union_all(y), + lambda x, y: x.except_(y), + lambda x, y: x.select("a").intersect(y.select("a")), + lambda x, y: x.select("a").join(y, how="outer", rsuffix="_y"), + lambda x, y: x.join(y.select("a"), how="left", rsuffix="_y"), + ], +) +def test_setstatement_deepcopy(session, action): + df1 = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + df2 = session.create_dataframe([[3, 4], [2, 1]], schema=["a", "b"]) + df_res = action(df1, df2) + copied_plan = copy.deepcopy(df_res._plan) + check_copied_plan(copied_plan, df_res._plan) + + +def test_selectsql(session): + query = "show tables in schema limit 10" + df = session.sql(query).filter(lit(True)) + + def verify_selectsql(copied_node: SelectSQL, original_node: SelectSQL) -> None: + assert copied_node.original_sql == original_node.original_sql + assert copied_node.convert_to_select == original_node.convert_to_select + assert copied_node.convert_to_select is True + assert copied_node._sql_query == original_node._sql_query + assert copied_node._schema_query == original_node._schema_query + assert copied_node._query_param == original_node._query_param + assert copied_node.pre_actions == original_node.pre_actions + + if session.sql_simplifier_enabled: + assert len(df._plan.children_plan_nodes) == 1 + assert isinstance(df._plan.children_plan_nodes[0], SelectSQL) + select_plan = df._plan.children_plan_nodes[0] + copied_select = copy.deepcopy(select_plan) + verify_logical_plan_node(copied_select, select_plan) + verify_selectsql(copied_select, select_plan) + else: + copied_plan = copy.deepcopy(df._plan) + check_copied_plan(copied_plan, df._plan) + + +def test_selectentity(session): + temp_table_name = random_name_for_temp_object(TempObjectType.TABLE) + session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]).write.save_as_table( + temp_table_name, table_type="temp" + ) + df = session.table(temp_table_name).filter(col("a") == 1) + if session.sql_simplifier_enabled: + assert len(df._plan.children_plan_nodes) == 1 + assert isinstance(df._plan.children_plan_nodes[0], SelectableEntity) + + select_plan = df._plan.children_plan_nodes[0] + copied_select = copy.deepcopy(select_plan) + verify_logical_plan_node(copied_select, select_plan) + assert copied_select.entity_name == select_plan.entity_name + else: + copied_plan = copy.deepcopy(df._plan) + check_copied_plan(copied_plan, df._plan) + + +def test_df_alias_deepcopy(session): + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + df_res = df.to_df("a1", "b1").alias("L") + copied_plan = copy.deepcopy(df_res._plan) + check_copied_plan(copied_plan, df_res._plan) + + +def test_table_function(session): + df = ( + session.generator(seq1(1), uniform(1, 10, 2), rowcount=150) + .order_by(seq1(1)) + .limit(3, offset=20) + ) + df_copied_plan = copy.deepcopy(df._plan) + check_copied_plan(df_copied_plan, df._plan) + df_res = df.union_all(df).select("*") + df_res_copied = copy.deepcopy(df_res._plan) + check_copied_plan(df_res_copied, df_res._plan) + + +@pytest.mark.parametrize( + "mode", [SaveMode.APPEND, SaveMode.TRUNCATE, SaveMode.ERROR_IF_EXISTS] +) +def test_table_creation(session, mode): + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + create_table_logic_plan = SnowflakeCreateTable( + [random_name_for_temp_object(TempObjectType.TABLE)], + column_names=None, + mode=mode, + query=df._plan, + table_type="temp", + clustering_exprs=None, + comment=None, + ) + snowflake_plan = session._analyzer.resolve(create_table_logic_plan) + copied_plan = copy.deepcopy(snowflake_plan) + check_copied_plan(copied_plan, snowflake_plan) + # The snowflake plan resolved for SnowflakeCreateTable doesn't have source plan attached today + # make another copy to check for logical plan copy + copied_logical_plan = copy.deepcopy(create_table_logic_plan) + verify_logical_plan_node(copied_logical_plan, create_table_logic_plan) + + +def test_create_or_replace_view(session): + df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"]) + create_view_logical_plan = CreateViewCommand( + random_name_for_temp_object(TempObjectType.VIEW), + LocalTempView(), + None, + df._plan, + ) + + snowflake_plan = session._analyzer.resolve(create_view_logical_plan) + copied_plan = copy.deepcopy(snowflake_plan) + check_copied_plan(copied_plan, snowflake_plan) + + # The snowflake plan resolved for CreateViewCommand doesn't have source plan attached today + # make another copy to check for logical plan copy + copied_logical_plan = copy.deepcopy(create_view_logical_plan) + verify_logical_plan_node(copied_logical_plan, create_view_logical_plan)