From 4327cb75861a034fe235c238a90e408c60e93a38 Mon Sep 17 00:00:00 2001 From: Sophie Tan Date: Sun, 13 Aug 2023 22:40:05 -0400 Subject: [PATCH 1/4] Add changes --- .../_internal/analyzer/select_statement.py | 42 ++++++++++++++----- tests/integ/test_dataframe.py | 9 ++++ 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index a9ec90de9b2..af022aeebca 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -254,6 +254,15 @@ def column_states(self) -> ColumnStateDict: ) return self._column_states + @column_states.setter + def column_states(self, value: ColumnStateDict): + """A dictionary that contains the column states of a query. + Refer to class ColumnStateDict. + """ + self._column_states = copy(value) + if value is not None: + self._column_states.projection = [copy(attr) for attr in value.projection] + class SelectableEntity(Selectable): """Query from a table, view, or any other Snowflake objects. @@ -337,7 +346,7 @@ def to_subqueryable(self) -> "SelectSQL": analyzer=self.analyzer, params=self.query_params, ) - new._column_states = self.column_states + new.column_states = self.column_states new._api_calls = self._api_calls return new @@ -446,11 +455,20 @@ def __copy__(self): def column_states(self) -> ColumnStateDict: if self._column_states is None: if not self.projection and not self.has_clause: - self._column_states = self.from_.column_states + self.column_states = self.from_.column_states else: super().column_states # will assign value to self._column_states return self._column_states + @column_states.setter + def column_states(self, value: ColumnStateDict): + """A dictionary that contains the column states of a query. + Refer to class ColumnStateDict. + """ + self._column_states = copy(value) + if value is not None: + self._column_states.projection = [copy(attr) for attr in value.projection] + @property def has_clause_using_columns(self) -> bool: return any( @@ -533,7 +551,7 @@ def to_subqueryable(self) -> "Selectable": new.pre_actions = from_subqueryable.pre_actions new.post_actions = from_subqueryable.post_actions new.from_ = from_subqueryable - new._column_states = self._column_states + new.column_states = self.column_states return new return self @@ -552,7 +570,7 @@ def select(self, cols: List[Expression]) -> "SelectStatement": new = copy(self) # it copies the api_calls new._projection_in_str = self._projection_in_str new._schema_query = self._schema_query - new._column_states = self._column_states + new.column_states = self.column_states new._snowflake_plan = self._snowflake_plan new.flatten_disabled = self.flatten_disabled return new @@ -644,7 +662,7 @@ def filter(self, col: Expression) -> "SelectStatement": new.from_ = self.from_.to_subqueryable() new.pre_actions = new.from_.pre_actions new.post_actions = new.from_.post_actions - new._column_states = self._column_states + new.column_states = self.column_states new.where = And(self.where, col) if self.where is not None else col else: new = SelectStatement( @@ -665,7 +683,7 @@ def sort(self, cols: List[Expression]) -> "SelectStatement": new.pre_actions = new.from_.pre_actions new.post_actions = new.from_.post_actions new.order_by = cols + (self.order_by or []) - new._column_states = self._column_states + new.column_states = self.column_states else: new = SelectStatement( from_=self.to_subqueryable(), @@ -731,7 +749,7 @@ def set_operator( api_calls.extend(s.api_calls) set_statement.api_calls = api_calls new = SelectStatement(analyzer=self.analyzer, from_=set_statement) - new._column_states = set_statement.column_states + new.column_states = set_statement.column_states return new def limit(self, n: int, *, offset: int = 0) -> "SelectStatement": @@ -739,7 +757,7 @@ def limit(self, n: int, *, offset: int = 0) -> "SelectStatement": new.from_ = self.from_.to_subqueryable() new.limit_ = min(self.limit_, n) if self.limit_ else n new.offset = (self.offset + offset) if self.offset else offset - new._column_states = self._column_states + new.column_states = self.column_states return new @@ -987,7 +1005,9 @@ def initiate_column_states( referenced_by_same_level_columns=COLUMN_DEPENDENCY_EMPTY, state_dict=column_states, ) - column_states.projection = column_attrs + column_states.projection = [ + copy(attr) for attr in column_attrs + ] # copy to re-generate expr_id return column_states @@ -1035,7 +1055,7 @@ def derive_column_states_from_subquery( from_.df_aliased_col_name_to_real_col_name, ) ) - column_states.projection.extend(columns_from_star) + column_states.projection.extend([copy(c) for c in columns_from_star]) continue c_name = parse_column_name( c, analyzer, from_.df_aliased_col_name_to_real_col_name @@ -1046,7 +1066,7 @@ def derive_column_states_from_subquery( # if c is not an Attribute object, we will only care about the column name, # so we can build a dummy Attribute with the column name column_states.projection.append( - c if isinstance(c, Attribute) else Attribute(quoted_c_name) + copy(c) if isinstance(c, Attribute) else Attribute(quoted_c_name) ) from_c_state = from_.column_states.get(quoted_c_name) if from_c_state and from_c_state.change_state != ColumnChangeState.DROPPED: diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index da549b72ae5..0fdbed48077 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -3218,3 +3218,12 @@ def test_dataframe_result_cache_changing_schema(session): old_cached_df = df.cache_result() session.use_schema("public") # schema change old_cached_df.show() + + +def test_dataframe_diamond_join(session): + session.sql_simplifier_enabled = True # use False behavior + df1 = session.create_dataframe([[1, 2], [3, 4], [5, 6]], schema=["a", "b"]) + df2 = df1.filter(col("a") > 3) + assert df1.a._expression.expr_id != df2.a._expression.expr_id + df3 = df1.join(df2, df1.a == df2.a) + df3.select(df1.a, df2.a).show() From 35a54dc6453ba06a6141db8f758eeaccb4ef228f Mon Sep 17 00:00:00 2001 From: Sophie Tan Date: Mon, 14 Aug 2023 10:59:45 -0400 Subject: [PATCH 2/4] Fix df.alias tests --- .../_internal/analyzer/select_statement.py | 21 +++++++++++++++++-- src/snowflake/snowpark/dataframe.py | 2 +- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index af022aeebca..a10d1aa1ea8 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -3,7 +3,7 @@ # from abc import ABC, abstractmethod -from collections import UserDict +from collections import UserDict, defaultdict from copy import copy from enum import Enum from typing import ( @@ -24,6 +24,7 @@ TableFunctionJoin, TableFunctionRelation, ) +from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages if TYPE_CHECKING: from snowflake.snowpark._internal.analyzer.analyzer import ( @@ -179,7 +180,9 @@ def __init__( self._column_states: Optional[ColumnStateDict] = None self._snowflake_plan: Optional[SnowflakePlan] = None self.expr_to_alias = {} - self.df_aliased_col_name_to_real_col_name: DefaultDict[str, Dict[str, str]] = {} + self.df_aliased_col_name_to_real_col_name: DefaultDict[ + str, Dict[str, str] + ] = defaultdict(dict) self._api_calls = api_calls.copy() if api_calls is not None else None @property @@ -564,6 +567,7 @@ def select(self, cols: List[Expression]) -> "SelectStatement": and isinstance(cols[0], UnresolvedAlias) and isinstance(cols[0].child, Star) and not cols[0].child.expressions + and not cols[0].child.df_alias # df.select("*") doesn't have the child.expressions # df.select(df["*"]) has the child.expressions ): @@ -1046,6 +1050,19 @@ def derive_column_states_from_subquery( if c.child.expressions: # df.select(df["*"]) will have child expressions. df.select("*") doesn't. columns_from_star = [copy(e) for e in c.child.expressions] + elif c.child.df_alias: + if c.child.df_alias not in from_.df_aliased_col_name_to_real_col_name: + raise SnowparkClientExceptionMessages.DF_ALIAS_NOT_RECOGNIZED( + c.child.df_alias + ) + aliased_cols = from_.df_aliased_col_name_to_real_col_name[ + c.child.df_alias + ].values() + columns_from_star = [ + copy(e) + for e in from_.column_states.projection + if e.name in aliased_cols + ] else: columns_from_star = [copy(e) for e in from_.column_states.projection] column_states.update( diff --git a/src/snowflake/snowpark/dataframe.py b/src/snowflake/snowpark/dataframe.py index 8f0219e05b0..ece5ffe7a02 100644 --- a/src/snowflake/snowpark/dataframe.py +++ b/src/snowflake/snowpark/dataframe.py @@ -710,7 +710,7 @@ def to_local_iterator( ) def __copy__(self) -> "DataFrame": - return DataFrame(self._session, copy.copy(self._plan)) + return DataFrame(self._session, copy.copy(self._select_statement or self._plan)) if installed_pandas: import pandas # pragma: no cover From 9a99bd76b2b552640e41184668ce835d29a44c56 Mon Sep 17 00:00:00 2001 From: Sophie Tan Date: Mon, 14 Aug 2023 11:42:14 -0400 Subject: [PATCH 3/4] Refine --- .../integ/scala/test_dataframe_join_suite.py | 22 ++++++++++++++++--- tests/integ/test_dataframe.py | 9 -------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/tests/integ/scala/test_dataframe_join_suite.py b/tests/integ/scala/test_dataframe_join_suite.py index b3e3be45fb6..fd47079d4a3 100644 --- a/tests/integ/scala/test_dataframe_join_suite.py +++ b/tests/integ/scala/test_dataframe_join_suite.py @@ -1118,7 +1118,7 @@ def test_select_columns_on_join_result_with_conflict_name(session): assert df4.collect() == [Row(3, 4, 1)] -def test_join_diamond_shape_error(session): +def test_nested_join_diamond_shape_error(session): """This is supposed to work but currently we don't handle it correctly. We should fix this with a good design.""" df1 = session.create_dataframe([[1]], schema=["a"]) df2 = session.create_dataframe([[1]], schema=["a"]) @@ -1126,7 +1126,7 @@ def test_join_diamond_shape_error(session): df4 = df3.select(df1["a"].as_("a")) # df1["a"] and df4["a"] has the same expr_id in map expr_to_alias. When they join, only one will be in df5's alias # map. It leaves the other one resolved to "a" instead of the alias. - df5 = df1.join(df4, df1["a"] == df4["a"]) + df5 = df1.join(df4, df1["a"] == df4["a"]) # (df1) JOIN ((df1 JOIN df2)->df4) with pytest.raises( SnowparkSQLAmbiguousJoinException, match="The reference to the column 'A' is ambiguous.", @@ -1134,7 +1134,7 @@ def test_join_diamond_shape_error(session): df5.collect() -def test_join_diamond_shape_workaround(session): +def test_nested_join_diamond_shape_workaround(session): df1 = session.create_dataframe([[1]], schema=["a"]) df2 = session.create_dataframe([[1]], schema=["a"]) df3 = df1.join(df2, df1["a"] == df2["a"]) @@ -1143,3 +1143,19 @@ def test_join_diamond_shape_workaround(session): df1_converted = df1.select(df1["a"]) df5 = df1_converted.join(df4, df1_converted["a"] == df4["a"]) Utils.check_answer(df5, [Row(1, 1)]) + + +def test_dataframe_basic_diamond_shaped_join(session): + df1 = session.create_dataframe([[1, 2], [3, 4], [5, 6]], schema=["a", "b"]) + df2 = df1.filter(col("a") > 1).with_column("c", lit(7)) + assert df1.a._expression.expr_id != df2.a._expression.expr_id + + # (df1) JOIN (df1->df2) + Utils.check_answer( + df1.join(df2, df1.a == df2.a).select(df1.a, df2.c), [Row(3, 7), Row(5, 7)] + ) + + # (df1->df3) JOIN (df1-> df2) + df3 = df1.filter(col("b") < 6).with_column("d", lit(8)) + assert df2.b._expression.expr_id != df3.b._expression.expr_id + Utils.check_answer(df3.join(df2, df2.b == df3.b).select(df2.a, df3.d), [Row(3, 8)]) diff --git a/tests/integ/test_dataframe.py b/tests/integ/test_dataframe.py index 0fdbed48077..da549b72ae5 100644 --- a/tests/integ/test_dataframe.py +++ b/tests/integ/test_dataframe.py @@ -3218,12 +3218,3 @@ def test_dataframe_result_cache_changing_schema(session): old_cached_df = df.cache_result() session.use_schema("public") # schema change old_cached_df.show() - - -def test_dataframe_diamond_join(session): - session.sql_simplifier_enabled = True # use False behavior - df1 = session.create_dataframe([[1, 2], [3, 4], [5, 6]], schema=["a", "b"]) - df2 = df1.filter(col("a") > 3) - assert df1.a._expression.expr_id != df2.a._expression.expr_id - df3 = df1.join(df2, df1.a == df2.a) - df3.select(df1.a, df2.a).show() From 40cf9e47f8f20253883e7a8780894875570883f3 Mon Sep 17 00:00:00 2001 From: Sophie Tan Date: Fri, 8 Sep 2023 19:34:48 -0400 Subject: [PATCH 4/4] Address comments and add changelog --- CHANGELOG.md | 2 ++ .../snowpark/_internal/analyzer/select_statement.py | 6 ++---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fb03cda5d8c..4a9d9136fff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ - Fixed a bug where automatic package upload would raise `ValueError` even when compatible package version were added in `session.add_packages`. - Fixed a bug where table stored procedures were not registered correctly when using `register_from_file`. +- Fixed a bug where dataframe joins failed with `invalid_identifier` error. +- Fixed a bug where `DataFrame.copy` disables SQL simplfier for the returned copy. ## 1.7.0 (2023-08-28) diff --git a/src/snowflake/snowpark/_internal/analyzer/select_statement.py b/src/snowflake/snowpark/_internal/analyzer/select_statement.py index a10d1aa1ea8..5be840ddb92 100644 --- a/src/snowflake/snowpark/_internal/analyzer/select_statement.py +++ b/src/snowflake/snowpark/_internal/analyzer/select_statement.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from collections import UserDict, defaultdict -from copy import copy +from copy import copy, deepcopy from enum import Enum from typing import ( TYPE_CHECKING, @@ -262,9 +262,7 @@ def column_states(self, value: ColumnStateDict): """A dictionary that contains the column states of a query. Refer to class ColumnStateDict. """ - self._column_states = copy(value) - if value is not None: - self._column_states.projection = [copy(attr) for attr in value.projection] + self._column_states = deepcopy(value) class SelectableEntity(Selectable):