Skip to content

Commit

Permalink
[SNOW-1541087] Allow deep copy of SnowflakePlan and Selectable (#1937)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-yzou authored Jul 23, 2024
1 parent 50275a7 commit 6f83b75
Show file tree
Hide file tree
Showing 3 changed files with 416 additions and 0 deletions.
91 changes: 91 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,27 @@ def __setitem__(self, col_name: str, col_state: ColumnState) -> None:
self.has_new_columns = True


def _deepcopy_selectable_fields(
from_selectable: "Selectable", to_selectable: "Selectable"
) -> None:
"""
Make a deep copy of the fields from the from_selectable to the to_selectable
"""
to_selectable.pre_actions = deepcopy(from_selectable.pre_actions)
to_selectable.post_actions = deepcopy(from_selectable.post_actions)
to_selectable.flatten_disabled = from_selectable.flatten_disabled
to_selectable._column_states = deepcopy(from_selectable._column_states)
to_selectable.expr_to_alias = deepcopy(from_selectable.expr_to_alias)
to_selectable.df_aliased_col_name_to_real_col_name = deepcopy(
from_selectable.df_aliased_col_name_to_real_col_name
)
# the snowflake plan for selectable typically just point to self,
# to avoid run into recursively copy self problem, we always let it
# rebuild, as far as we have other fields copied correctly, we should
# be able to recover the plan.
to_selectable._snowflake_plan = None


class Selectable(LogicalPlan, ABC):
"""The parent abstract class of a DataFrame's logical plan. It can be converted to and from a SnowflakePlan."""

Expand Down Expand Up @@ -359,6 +380,12 @@ def __init__(
super().__init__(analyzer)
self.entity = entity

def __deepcopy__(self, memodict={}) -> "SelectableEntity": # noqa: B006
copied = SelectableEntity(self.entity_name, analyzer=self.analyzer)
_deepcopy_selectable_fields(from_selectable=self, to_selectable=copied)

return copied

@property
def sql_query(self) -> str:
return f"{analyzer_utils.SELECT}{analyzer_utils.STAR}{analyzer_utils.FROM}{self.entity.name}"
Expand Down Expand Up @@ -419,6 +446,26 @@ def __init__(
self._schema_query = sql
self._query_param = params

def __deepcopy__(self, memodict={}) -> "SelectSQL": # noqa: B006
copied = SelectSQL(
sql=self.original_sql,
# when convert_to_select is True, a describe call might be triggered
# to construct the schema query. Since this is a pure copy method, and all
# fields can be done with a pure copy, we set this parameter to False on
# object construct, and correct the fields after.
convert_to_select=False,
analyzer=self.analyzer,
params=deepcopy(self.query_params),
)
_deepcopy_selectable_fields(from_selectable=self, to_selectable=copied)
# copy over the other fields
copied.convert_to_select = self.convert_to_select
copied._sql_query = self._sql_query
copied._schema_query = self._schema_query
copied._query_param = deepcopy(self._query_param)

return copied

@property
def sql_query(self) -> str:
return self._sql_query
Expand Down Expand Up @@ -485,6 +532,15 @@ def __init__(self, snowflake_plan: LogicalPlan, *, analyzer: "Analyzer") -> None
if query.params:
self._query_params.extend(query.params)

def __deepcopy__(self, memodict={}) -> "SelectSnowflakePlan": # noqa: B006
copied = SelectSnowflakePlan(
snowflake_plan=deepcopy(self._snowflake_plan), analyzer=self.analyzer
)
_deepcopy_selectable_fields(from_selectable=self, to_selectable=copied)
self._query_params = deepcopy(self._query_params)
copied._snowflake_plan = deepcopy(self._snowflake_plan)
return copied

@property
def snowflake_plan(self):
return self._snowflake_plan
Expand Down Expand Up @@ -577,6 +633,23 @@ def __copy__(self):

return new

def __deepcopy__(self, memodict={}) -> "SelectStatement": # noqa: B006
copied = SelectStatement(
projection=deepcopy(self.projection),
from_=deepcopy(self.from_),
where=deepcopy(self.where),
order_by=deepcopy(self.order_by),
limit_=deepcopy(self.limit_),
offset=self.offset,
analyzer=self.analyzer,
schema_query=self.schema_query,
)

_deepcopy_selectable_fields(from_selectable=self, to_selectable=copied)
copied._projection_in_str = self._projection_in_str
copied._query_params = deepcopy(self._query_params)
return copied

@property
def column_states(self) -> ColumnStateDict:
if self._column_states is None:
Expand Down Expand Up @@ -1043,6 +1116,16 @@ def __init__(
self.post_actions = self._snowflake_plan.post_actions
self._api_calls = self._snowflake_plan.api_calls

def __deepcopy__(self, memodict={}) -> "SelectTableFunction": # noqa: B006
copied = SelectTableFunction(
func_expr=deepcopy(self.func_expr), analyzer=self.analyzer
)
_deepcopy_selectable_fields(from_selectable=self, to_selectable=copied)
# need to make a copy of the SnowflakePlan for SelectTableFunction
copied._snowflake_plan = deepcopy(self._snowflake_plan)

return copied

@property
def snowflake_plan(self):
return self._snowflake_plan
Expand Down Expand Up @@ -1093,6 +1176,14 @@ def __init__(self, *set_operands: SetOperand, analyzer: "Analyzer") -> None:
self.post_actions.extend(operand.selectable.post_actions)
self._nodes.append(operand.selectable)

def __deepcopy__(self, memodict={}) -> "SetStatement": # noqa: B006
copied = SetStatement(*deepcopy(self.set_operands), analyzer=self.analyzer)
_deepcopy_selectable_fields(from_selectable=self, to_selectable=copied)
copied._placeholder_query = self._placeholder_query
copied._sql_query = self._sql_query

return copied

@property
def sql_query(self) -> str:
if not self._sql_query:
Expand Down
26 changes: 26 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/snowflake_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,32 @@ def __copy__(self) -> "SnowflakePlan":
placeholder_query=self.placeholder_query,
)

def __deepcopy__(self, memodict={}) -> "SnowflakePlan": # noqa: B006
copied_plan = SnowflakePlan(
queries=copy.deepcopy(self.queries) if self.queries else [],
schema_query=self.schema_query,
post_actions=copy.deepcopy(self.post_actions)
if self.post_actions
else None,
expr_to_alias=copy.deepcopy(self.expr_to_alias)
if self.expr_to_alias
else None,
source_plan=copy.deepcopy(self.source_plan) if self.source_plan else None,
is_ddl_on_temp_object=self.is_ddl_on_temp_object,
api_calls=copy.deepcopy(self.api_calls) if self.api_calls else None,
df_aliased_col_name_to_real_col_name=copy.deepcopy(
self.df_aliased_col_name_to_real_col_name
)
if self.df_aliased_col_name_to_real_col_name
else None,
placeholder_query=self.placeholder_query,
# note that there is no copy of the session object, be careful when using the
# session object after deepcopy
session=self.session,
)

return copied_plan

def add_aliases(self, to_add: Dict) -> None:
self.expr_to_alias = {**self.expr_to_alias, **to_add}

Expand Down
Loading

0 comments on commit 6f83b75

Please sign in to comment.