Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-841405 Fix df copy and enable basic diamond shaped joins for simplifier #1003

Merged
merged 5 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

- Fixed a bug where automatic package upload would raise `ValueError` even when compatible package version were added in `session.add_packages`.
- Fixed a bug where table stored procedures were not registered correctly when using `register_from_file`.
- Fixed a bug where dataframe joins failed with `invalid_identifier` error.
- Fixed a bug where `DataFrame.copy` disables SQL simplfier for the returned copy.

## 1.7.0 (2023-08-28)

Expand Down
63 changes: 49 additions & 14 deletions src/snowflake/snowpark/_internal/analyzer/select_statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
#

from abc import ABC, abstractmethod
from collections import UserDict
from copy import copy
from collections import UserDict, defaultdict
from copy import copy, deepcopy
from enum import Enum
from typing import (
TYPE_CHECKING,
Expand All @@ -24,6 +24,7 @@
TableFunctionJoin,
TableFunctionRelation,
)
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages

if TYPE_CHECKING:
from snowflake.snowpark._internal.analyzer.analyzer import (
Expand Down Expand Up @@ -179,7 +180,9 @@ def __init__(
self._column_states: Optional[ColumnStateDict] = None
self._snowflake_plan: Optional[SnowflakePlan] = None
self.expr_to_alias = {}
self.df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]] = {}
self.df_aliased_col_name_to_real_col_name: DefaultDict[
str, Dict[str, str]
] = defaultdict(dict)
Comment on lines +183 to +185
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious what triggered this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is bug actually, I intended it to be what the type hints suggests: a defaultdict of dictionaries. This PR fixes dataframe.copy to work for simplifier path, so this is caught.

self._api_calls = api_calls.copy() if api_calls is not None else None

@property
Expand Down Expand Up @@ -254,6 +257,13 @@ def column_states(self) -> ColumnStateDict:
)
return self._column_states

@column_states.setter
def column_states(self, value: ColumnStateDict):
"""A dictionary that contains the column states of a query.
Refer to class ColumnStateDict.
"""
self._column_states = deepcopy(value)


class SelectableEntity(Selectable):
"""Query from a table, view, or any other Snowflake objects.
Expand Down Expand Up @@ -337,7 +347,7 @@ def to_subqueryable(self) -> "SelectSQL":
analyzer=self.analyzer,
params=self.query_params,
)
new._column_states = self.column_states
new.column_states = self.column_states
new._api_calls = self._api_calls
return new

Expand Down Expand Up @@ -446,11 +456,20 @@ def __copy__(self):
def column_states(self) -> ColumnStateDict:
if self._column_states is None:
if not self.projection and not self.has_clause:
self._column_states = self.from_.column_states
self.column_states = self.from_.column_states
else:
super().column_states # will assign value to self._column_states
return self._column_states

@column_states.setter
def column_states(self, value: ColumnStateDict):
"""A dictionary that contains the column states of a query.
Refer to class ColumnStateDict.
"""
self._column_states = copy(value)
if value is not None:
self._column_states.projection = [copy(attr) for attr in value.projection]
sfc-gh-stan marked this conversation as resolved.
Show resolved Hide resolved

@property
def has_clause_using_columns(self) -> bool:
return any(
Expand Down Expand Up @@ -533,7 +552,7 @@ def to_subqueryable(self) -> "Selectable":
new.pre_actions = from_subqueryable.pre_actions
new.post_actions = from_subqueryable.post_actions
new.from_ = from_subqueryable
new._column_states = self._column_states
new.column_states = self.column_states
return new
return self

Expand All @@ -546,13 +565,14 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
and isinstance(cols[0], UnresolvedAlias)
and isinstance(cols[0].child, Star)
and not cols[0].child.expressions
and not cols[0].child.df_alias
# df.select("*") doesn't have the child.expressions
# df.select(df["*"]) has the child.expressions
):
new = copy(self) # it copies the api_calls
new._projection_in_str = self._projection_in_str
new._schema_query = self._schema_query
new._column_states = self._column_states
new.column_states = self.column_states
new._snowflake_plan = self._snowflake_plan
new.flatten_disabled = self.flatten_disabled
return new
Expand Down Expand Up @@ -644,7 +664,7 @@ def filter(self, col: Expression) -> "SelectStatement":
new.from_ = self.from_.to_subqueryable()
new.pre_actions = new.from_.pre_actions
new.post_actions = new.from_.post_actions
new._column_states = self._column_states
new.column_states = self.column_states
new.where = And(self.where, col) if self.where is not None else col
else:
new = SelectStatement(
Expand All @@ -665,7 +685,7 @@ def sort(self, cols: List[Expression]) -> "SelectStatement":
new.pre_actions = new.from_.pre_actions
new.post_actions = new.from_.post_actions
new.order_by = cols + (self.order_by or [])
new._column_states = self._column_states
new.column_states = self.column_states
else:
new = SelectStatement(
from_=self.to_subqueryable(),
Expand Down Expand Up @@ -731,15 +751,15 @@ def set_operator(
api_calls.extend(s.api_calls)
set_statement.api_calls = api_calls
new = SelectStatement(analyzer=self.analyzer, from_=set_statement)
new._column_states = set_statement.column_states
new.column_states = set_statement.column_states
return new

def limit(self, n: int, *, offset: int = 0) -> "SelectStatement":
new = copy(self)
new.from_ = self.from_.to_subqueryable()
new.limit_ = min(self.limit_, n) if self.limit_ else n
new.offset = (self.offset + offset) if self.offset else offset
new._column_states = self._column_states
new.column_states = self.column_states
return new


Expand Down Expand Up @@ -987,7 +1007,9 @@ def initiate_column_states(
referenced_by_same_level_columns=COLUMN_DEPENDENCY_EMPTY,
state_dict=column_states,
)
column_states.projection = column_attrs
column_states.projection = [
copy(attr) for attr in column_attrs
] # copy to re-generate expr_id
return column_states


Expand Down Expand Up @@ -1026,6 +1048,19 @@ def derive_column_states_from_subquery(
if c.child.expressions:
# df.select(df["*"]) will have child expressions. df.select("*") doesn't.
columns_from_star = [copy(e) for e in c.child.expressions]
elif c.child.df_alias:
if c.child.df_alias not in from_.df_aliased_col_name_to_real_col_name:
raise SnowparkClientExceptionMessages.DF_ALIAS_NOT_RECOGNIZED(
c.child.df_alias
)
aliased_cols = from_.df_aliased_col_name_to_real_col_name[
c.child.df_alias
].values()
columns_from_star = [
copy(e)
for e in from_.column_states.projection
if e.name in aliased_cols
]
else:
columns_from_star = [copy(e) for e in from_.column_states.projection]
column_states.update(
Expand All @@ -1035,7 +1070,7 @@ def derive_column_states_from_subquery(
from_.df_aliased_col_name_to_real_col_name,
)
)
column_states.projection.extend(columns_from_star)
column_states.projection.extend([copy(c) for c in columns_from_star])
continue
c_name = parse_column_name(
c, analyzer, from_.df_aliased_col_name_to_real_col_name
Expand All @@ -1046,7 +1081,7 @@ def derive_column_states_from_subquery(
# if c is not an Attribute object, we will only care about the column name,
# so we can build a dummy Attribute with the column name
column_states.projection.append(
c if isinstance(c, Attribute) else Attribute(quoted_c_name)
copy(c) if isinstance(c, Attribute) else Attribute(quoted_c_name)
)
from_c_state = from_.column_states.get(quoted_c_name)
if from_c_state and from_c_state.change_state != ColumnChangeState.DROPPED:
Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def to_local_iterator(
)

def __copy__(self) -> "DataFrame":
return DataFrame(self._session, copy.copy(self._plan))
return DataFrame(self._session, copy.copy(self._select_statement or self._plan))

if installed_pandas:
import pandas # pragma: no cover
Expand Down
22 changes: 19 additions & 3 deletions tests/integ/scala/test_dataframe_join_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,23 +1118,23 @@ def test_select_columns_on_join_result_with_conflict_name(session):
assert df4.collect() == [Row(3, 4, 1)]


def test_join_diamond_shape_error(session):
def test_nested_join_diamond_shape_error(session):
"""This is supposed to work but currently we don't handle it correctly. We should fix this with a good design."""
df1 = session.create_dataframe([[1]], schema=["a"])
df2 = session.create_dataframe([[1]], schema=["a"])
df3 = df1.join(df2, df1["a"] == df2["a"])
df4 = df3.select(df1["a"].as_("a"))
# df1["a"] and df4["a"] has the same expr_id in map expr_to_alias. When they join, only one will be in df5's alias
# map. It leaves the other one resolved to "a" instead of the alias.
df5 = df1.join(df4, df1["a"] == df4["a"])
df5 = df1.join(df4, df1["a"] == df4["a"]) # (df1) JOIN ((df1 JOIN df2)->df4)
with pytest.raises(
SnowparkSQLAmbiguousJoinException,
match="The reference to the column 'A' is ambiguous.",
):
df5.collect()


def test_join_diamond_shape_workaround(session):
def test_nested_join_diamond_shape_workaround(session):
df1 = session.create_dataframe([[1]], schema=["a"])
df2 = session.create_dataframe([[1]], schema=["a"])
df3 = df1.join(df2, df1["a"] == df2["a"])
Expand All @@ -1143,3 +1143,19 @@ def test_join_diamond_shape_workaround(session):
df1_converted = df1.select(df1["a"])
df5 = df1_converted.join(df4, df1_converted["a"] == df4["a"])
Utils.check_answer(df5, [Row(1, 1)])


def test_dataframe_basic_diamond_shaped_join(session):
df1 = session.create_dataframe([[1, 2], [3, 4], [5, 6]], schema=["a", "b"])
df2 = df1.filter(col("a") > 1).with_column("c", lit(7))
assert df1.a._expression.expr_id != df2.a._expression.expr_id

Comment on lines +1151 to +1152
Copy link
Contributor

@sfc-gh-aalam sfc-gh-aalam Aug 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we testing expr_ids are not the same. Does it affect dataframes in weird ways if they are same? Can we test that those subsequent side effects of expr_ids being the same are tested instead?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If they are the same we won't be able to perform the join.

for v in resolved_children.values():
in the analyzer counts duplicate expression id's in the resolved children and delete the <expr_id, alias name> entries in the analyzer's map to prevent ambiguous references, but then analyzer will translate df1.a to the "A", where the actual name of df1.a is suffixed. So the generated query will fail.

# (df1) JOIN (df1->df2)
Utils.check_answer(
df1.join(df2, df1.a == df2.a).select(df1.a, df2.c), [Row(3, 7), Row(5, 7)]
)

# (df1->df3) JOIN (df1-> df2)
df3 = df1.filter(col("b") < 6).with_column("d", lit(8))
assert df2.b._expression.expr_id != df3.b._expression.expr_id
Utils.check_answer(df3.join(df2, df2.b == df3.b).select(df2.a, df3.d), [Row(3, 8)])
Loading