From 96ec174257e726124472b1a6cf44447ebfd1048e Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Mon, 9 Sep 2024 14:17:24 -0700 Subject: [PATCH 1/9] fix error --- .../snowpark/_internal/analyzer/expression.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index a2d21db4eb2..a167020367f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -4,6 +4,7 @@ import copy import uuid +from functools import cached_property from typing import TYPE_CHECKING, AbstractSet, Any, Dict, List, Optional, Tuple import snowflake.snowpark._internal.utils @@ -35,6 +36,13 @@ def derive_dependent_columns( *expressions: "Optional[Expression]", ) -> Optional[AbstractSet[str]]: + """ + Given set of expressions, derive the set of columns that the expressions dependents on. + + Note, the returned dependent columns is a set without duplication. For example, given expression + concat(col1, upper(co1), upper(col2)), the result will be {col1, col2} even if col1 has + occurred in the given expression twice. + """ result = set() for exp in expressions: if exp is not None: @@ -48,6 +56,20 @@ def derive_dependent_columns( return result +def derive_dependent_columns_with_duplication(*expressions: "Optional[Expression]") -> list[str]: + """ + Given set of expressions, derive the list of columns that the expression dependents on. + + Note, the returned columns will have duplication if the column occurred more than once in + the given expression. For example, concat(col1, upper(co1), upper(col2)) will have result + [col1, col1, col2], where col1 occurred twice in the result. + """ + result: List[str] = [] + for exp in expressions: + result.extend(exp.dependent_column_names_with_duplication) + return result + + class Expression: """Consider removing attributes, and adding properties and methods. A subclass of Expression may have no child, one child, or multiple children. @@ -68,6 +90,10 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: # TODO: consider adding it to __init__ or use cached_property. return COLUMN_DEPENDENCY_EMPTY + @cached_property + def dependent_column_names_with_duplication(self) -> list[str]: + return [] + @property def pretty_name(self) -> str: """Returns a user-facing string representation of this expression's name. @@ -143,6 +169,10 @@ def __init__(self, plan: "SnowflakePlan") -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return COLUMN_DEPENDENCY_DOLLAR + @cached_property + def dependent_column_names_with_duplication(self) -> list[str]: + return list(COLUMN_DEPENDENCY_DOLLAR) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return self.plan.cumulative_node_complexity @@ -156,6 +186,10 @@ def __init__(self, expressions: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) + @cached_property + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication(*self.expressions) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( @@ -172,6 +206,10 @@ def __init__(self, columns: Expression, values: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.columns, *self.values) + @cached_property + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication(self.columns, *self.values) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.IN @@ -212,6 +250,10 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return {self.name} + @cached_property + def dependent_column_names_with_duplication(self) -> list[str]: + return [self.name] + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.COLUMN @@ -235,6 +277,14 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: else COLUMN_DEPENDENCY_ALL ) + @cached_property + def dependent_column_names_with_duplication(self) -> list[str]: + return ( + derive_dependent_columns_with_duplication(*self.expressions) + if self.expressions + else [] # we currently do not handle * dependency + ) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: complexity = {} if self.expressions else {PlanNodeCategory.COLUMN: 1} From b2461032e1e652a0a5c130ac2636c7b4c4c209da Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Tue, 10 Sep 2024 13:17:06 -0700 Subject: [PATCH 2/9] fix error --- .../snowpark/_internal/analyzer/expression.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index a167020367f..6201278f216 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -328,6 +328,10 @@ def __hash__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return self._dependent_column_names + @cached_property + def dependent_column_names_with_duplication(self) -> list[str]: + return [] if self._dependent_column_names == COLUMN_DEPENDENCY_ALL else list(self._dependent_column_names) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.COLUMN @@ -421,6 +425,10 @@ def __init__(self, expr: Expression, pattern: Expression) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) + @cached_property + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication(self.expr, self.pattern) + @property def plan_node_category(self) -> PlanNodeCategory: # expr LIKE pattern @@ -450,6 +458,10 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) + @cached_property + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication(self.expr, self.pattern) + @property def plan_node_category(self) -> PlanNodeCategory: # expr REG_EXP pattern @@ -473,6 +485,10 @@ def __init__(self, expr: Expression, collation_spec: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + @cached_property + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication(self.expr) + @property def plan_node_category(self) -> PlanNodeCategory: # expr COLLATE collate_spec From 94090f0c1bd662d3cc4324773feed7982e5071cc Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Wed, 11 Sep 2024 13:38:58 -0700 Subject: [PATCH 3/9] fix error --- .../_internal/analyzer/binary_expression.py | 6 +++ .../snowpark/_internal/analyzer/expression.py | 37 ++++++++++++++++--- .../_internal/analyzer/unary_expression.py | 6 +++ 3 files changed, 44 insertions(+), 5 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py index 3ed969caada..612f07e7ce1 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py @@ -2,11 +2,13 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from functools import cached_property from typing import AbstractSet, Optional from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -29,6 +31,10 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.left, self.right) + @cached_property + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication(self.left, self.right) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 6201278f216..85f4b887699 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -56,7 +56,9 @@ def derive_dependent_columns( return result -def derive_dependent_columns_with_duplication(*expressions: "Optional[Expression]") -> list[str]: +def derive_dependent_columns_with_duplication( + *expressions: "Optional[Expression]", +) -> list[str]: """ Given set of expressions, derive the list of columns that the expression dependents on. @@ -282,7 +284,7 @@ def dependent_column_names_with_duplication(self) -> list[str]: return ( derive_dependent_columns_with_duplication(*self.expressions) if self.expressions - else [] # we currently do not handle * dependency + else [] # we currently do not handle * dependency ) @property @@ -330,7 +332,11 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: @cached_property def dependent_column_names_with_duplication(self) -> list[str]: - return [] if self._dependent_column_names == COLUMN_DEPENDENCY_ALL else list(self._dependent_column_names) + return ( + [] + if self._dependent_column_names == COLUMN_DEPENDENCY_ALL + else list(self._dependent_column_names) + ) @property def plan_node_category(self) -> PlanNodeCategory: @@ -615,13 +621,22 @@ def __init__( self.branches = branches self.else_value = else_value - def dependent_column_names(self) -> Optional[AbstractSet[str]]: + @property + def _child_expressions(self) -> list[Expression]: exps = [] for exp_tuple in self.branches: exps.extend(exp_tuple) if self.else_value is not None: exps.append(self.else_value) - return derive_dependent_columns(*exps) + + return exps + + def dependent_column_names(self) -> Optional[AbstractSet[str]]: + return derive_dependent_columns(*self._child_expressions) + + @cached_property + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication(*self._child_expressions) @property def plan_node_category(self) -> PlanNodeCategory: @@ -668,6 +683,10 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) + @cached_property + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication(*self.children) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION @@ -683,6 +702,10 @@ def __init__(self, col: Expression, delimiter: str, is_distinct: bool) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.col) + @cached_property + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication(self.col) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION @@ -702,6 +725,10 @@ def __init__(self, exprs: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.exprs) + @cached_property + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication(*self.exprs) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index e5886e11069..0800a33ba93 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -2,12 +2,14 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # +from functools import cached_property from typing import AbstractSet, Dict, Optional from snowflake.snowpark._internal.analyzer.expression import ( Expression, NamedExpression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -36,6 +38,10 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) + @cached_property + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication(self.child) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT From 1849f2adbce820295de01467240532fb30ef9f06 Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Wed, 11 Sep 2024 14:22:45 -0700 Subject: [PATCH 4/9] add test and other chagns --- .../_internal/analyzer/binary_expression.py | 2 -- .../snowpark/_internal/analyzer/expression.py | 29 +++++++-------- .../_internal/analyzer/grouping_set.py | 8 +++++ .../_internal/analyzer/sort_expression.py | 4 +++ .../_internal/analyzer/unary_expression.py | 2 -- .../_internal/analyzer/window_expression.py | 15 ++++++++ .../unit/test_expression_dependent_columns.py | 36 +++++++++++++++++++ 7 files changed, 76 insertions(+), 20 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py index 612f07e7ce1..bcc6cfb1e30 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py @@ -2,7 +2,6 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from functools import cached_property from typing import AbstractSet, Optional from snowflake.snowpark._internal.analyzer.expression import ( @@ -31,7 +30,6 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.left, self.right) - @cached_property def dependent_column_names_with_duplication(self) -> list[str]: return derive_dependent_columns_with_duplication(self.left, self.right) diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 85f4b887699..6a36c9f1c0f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -4,7 +4,6 @@ import copy import uuid -from functools import cached_property from typing import TYPE_CHECKING, AbstractSet, Any, Dict, List, Optional, Tuple import snowflake.snowpark._internal.utils @@ -68,7 +67,7 @@ def derive_dependent_columns_with_duplication( """ result: List[str] = [] for exp in expressions: - result.extend(exp.dependent_column_names_with_duplication) + result.extend(exp.dependent_column_names_with_duplication()) return result @@ -92,7 +91,6 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: # TODO: consider adding it to __init__ or use cached_property. return COLUMN_DEPENDENCY_EMPTY - @cached_property def dependent_column_names_with_duplication(self) -> list[str]: return [] @@ -171,7 +169,6 @@ def __init__(self, plan: "SnowflakePlan") -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return COLUMN_DEPENDENCY_DOLLAR - @cached_property def dependent_column_names_with_duplication(self) -> list[str]: return list(COLUMN_DEPENDENCY_DOLLAR) @@ -188,7 +185,6 @@ def __init__(self, expressions: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) - @cached_property def dependent_column_names_with_duplication(self) -> list[str]: return derive_dependent_columns_with_duplication(*self.expressions) @@ -208,7 +204,6 @@ def __init__(self, columns: Expression, values: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.columns, *self.values) - @cached_property def dependent_column_names_with_duplication(self) -> list[str]: return derive_dependent_columns_with_duplication(self.columns, *self.values) @@ -252,7 +247,6 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return {self.name} - @cached_property def dependent_column_names_with_duplication(self) -> list[str]: return [self.name] @@ -279,7 +273,6 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: else COLUMN_DEPENDENCY_ALL ) - @cached_property def dependent_column_names_with_duplication(self) -> list[str]: return ( derive_dependent_columns_with_duplication(*self.expressions) @@ -330,7 +323,6 @@ def __hash__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return self._dependent_column_names - @cached_property def dependent_column_names_with_duplication(self) -> list[str]: return ( [] @@ -431,7 +423,6 @@ def __init__(self, expr: Expression, pattern: Expression) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) - @cached_property def dependent_column_names_with_duplication(self) -> list[str]: return derive_dependent_columns_with_duplication(self.expr, self.pattern) @@ -464,7 +455,6 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) - @cached_property def dependent_column_names_with_duplication(self) -> list[str]: return derive_dependent_columns_with_duplication(self.expr, self.pattern) @@ -491,7 +481,6 @@ def __init__(self, expr: Expression, collation_spec: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) - @cached_property def dependent_column_names_with_duplication(self) -> list[str]: return derive_dependent_columns_with_duplication(self.expr) @@ -516,6 +505,9 @@ def __init__(self, expr: Expression, field: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication(self.expr) + @property def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field @@ -538,6 +530,9 @@ def __init__(self, expr: Expression, field: int) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication(self.expr) + @property def plan_node_category(self) -> PlanNodeCategory: # the literal corresponds to the contribution from self.field @@ -582,6 +577,9 @@ def sql(self) -> str: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication(*self.children) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.FUNCTION @@ -597,6 +595,9 @@ def __init__(self, expr: Expression, order_by_cols: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, *self.order_by_cols) + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication(self.expr, *self.order_by_cols) + @property def plan_node_category(self) -> PlanNodeCategory: # expr WITHIN GROUP (ORDER BY cols) @@ -634,7 +635,6 @@ def _child_expressions(self) -> list[Expression]: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self._child_expressions) - @cached_property def dependent_column_names_with_duplication(self) -> list[str]: return derive_dependent_columns_with_duplication(*self._child_expressions) @@ -683,7 +683,6 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) - @cached_property def dependent_column_names_with_duplication(self) -> list[str]: return derive_dependent_columns_with_duplication(*self.children) @@ -702,7 +701,6 @@ def __init__(self, col: Expression, delimiter: str, is_distinct: bool) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.col) - @cached_property def dependent_column_names_with_duplication(self) -> list[str]: return derive_dependent_columns_with_duplication(self.col) @@ -725,7 +723,6 @@ def __init__(self, exprs: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.exprs) - @cached_property def dependent_column_names_with_duplication(self) -> list[str]: return derive_dependent_columns_with_duplication(*self.exprs) diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 84cd63fd87d..0bba259f68f 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -7,6 +7,7 @@ from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -23,6 +24,9 @@ def __init__(self, group_by_exprs: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.group_by_exprs) + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication(*self.group_by_exprs) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT @@ -45,6 +49,10 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: flattened_args = [exp for sublist in self.args for exp in sublist] return derive_dependent_columns(*flattened_args) + def dependent_column_names_with_duplication(self) -> list[str]: + flattened_args = [exp for sublist in self.args for exp in sublist] + return derive_dependent_columns_with_duplication(*flattened_args) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: return sum_node_complexities( diff --git a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py index 1d06f7290a0..e5b673a402c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py @@ -7,6 +7,7 @@ from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) @@ -55,3 +56,6 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) + + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication(self.child) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index 0800a33ba93..d47e01417f6 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -2,7 +2,6 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from functools import cached_property from typing import AbstractSet, Dict, Optional from snowflake.snowpark._internal.analyzer.expression import ( @@ -38,7 +37,6 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) - @cached_property def dependent_column_names_with_duplication(self) -> list[str]: return derive_dependent_columns_with_duplication(self.child) diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index 69db3f265ce..b864ab2ed0d 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -7,6 +7,7 @@ from snowflake.snowpark._internal.analyzer.expression import ( Expression, derive_dependent_columns, + derive_dependent_columns_with_duplication, ) from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import ( PlanNodeCategory, @@ -71,6 +72,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.lower, self.upper) + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication(self.lower, self.upper) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.LOW_IMPACT @@ -102,6 +106,11 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: *self.partition_spec, *self.order_spec, self.frame_spec ) + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication( + *self.partition_spec, *self.order_spec, self.frame_spec + ) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # partition_spec order_by_spec frame_spec @@ -138,6 +147,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.window_function, self.window_spec) + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication(self.window_function, self.window_spec) + @property def plan_node_category(self) -> PlanNodeCategory: return PlanNodeCategory.WINDOW @@ -171,6 +183,9 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.default) + def dependent_column_names_with_duplication(self) -> list[str]: + return derive_dependent_columns_with_duplication(self.expr, self.default) + @property def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]: # for func_name diff --git a/tests/unit/test_expression_dependent_columns.py b/tests/unit/test_expression_dependent_columns.py index c31e5cc6290..fdc8981a734 100644 --- a/tests/unit/test_expression_dependent_columns.py +++ b/tests/unit/test_expression_dependent_columns.py @@ -87,30 +87,37 @@ def test_expression(): a = Expression() assert a.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY + assert a.dependent_column_names_with_duplication() == [] b = Expression(child=UnresolvedAttribute("a")) assert b.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY + assert b.dependent_column_names_with_duplication() == [] # root class Expression always returns empty dependency def test_literal(): a = Literal(5) assert a.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY + assert a.dependent_column_names_with_duplication() == [] def test_attribute(): a = Attribute("A", IntegerType()) assert a.dependent_column_names() == {"A"} + assert a.dependent_column_names_with_duplication() == ["A"] def test_unresolved_attribute(): a = UnresolvedAttribute("A") assert a.dependent_column_names() == {"A"} + assert a.dependent_column_names_with_duplication() == ["A"] b = UnresolvedAttribute("a > 1", is_sql_text=True) assert b.dependent_column_names() == COLUMN_DEPENDENCY_ALL + assert b.dependent_column_names_with_duplication() == [] c = UnresolvedAttribute("$1 > 1", is_sql_text=True) assert c.dependent_column_names() == COLUMN_DEPENDENCY_DOLLAR + assert c.dependent_column_names_with_duplication() == ['$'] def test_case_when(): @@ -118,46 +125,56 @@ def test_case_when(): b = Column("b") z = when(a > b, col("c")).when(a < b, col("d")).else_(col("e")) assert z._expression.dependent_column_names() == {'"A"', '"B"', '"C"', '"D"', '"E"'} + # verify column '"A"', '"B"' occurred twice in the dependency columns + assert z._expression.dependent_column_names_with_duplication() == ['"A"', '"B"', '"C"', '"A"', '"B"', '"D"', '"E"'] def test_collate(): a = Collate(UnresolvedAttribute("a"), "spec") assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_function_expression(): a = FunctionExpression("test_func", [UnresolvedAttribute(x) for x in "abcd"], False) assert a.dependent_column_names() == set("abcd") + assert a.dependent_column_names_with_duplication() == list("abcd") def test_in_expression(): a = InExpression(UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcd"]) assert a.dependent_column_names() == set("abcde") + assert a.dependent_column_names_with_duplication() == list("eabcd") def test_like(): a = Like(UnresolvedAttribute("a"), UnresolvedAttribute("b")) assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] def test_list_agg(): a = ListAgg(UnresolvedAttribute("a"), ",", True) assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_multiple_expression(): a = MultipleExpression([UnresolvedAttribute(x) for x in "abcd"]) assert a.dependent_column_names() == set("abcd") + assert a.dependent_column_names_with_duplication() == list("abcd") def test_reg_exp(): a = RegExp(UnresolvedAttribute("a"), UnresolvedAttribute("b")) assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] def test_scalar_subquery(): a = ScalarSubquery(None) assert a.dependent_column_names() == COLUMN_DEPENDENCY_DOLLAR + assert a.dependent_column_names_with_duplication() == list(COLUMN_DEPENDENCY_DOLLAR) def test_snowflake_udf(): @@ -165,21 +182,25 @@ def test_snowflake_udf(): "udf_name", [UnresolvedAttribute(x) for x in "abcd"], IntegerType() ) assert a.dependent_column_names() == set("abcd") + assert a.dependent_column_names_with_duplication() == list("abcd") def test_star(): a = Star([Attribute(x, IntegerType()) for x in "abcd"]) assert a.dependent_column_names() == set("abcd") + assert a.dependent_column_names_with_duplication() == list("abcd") def test_subfield_string(): a = SubfieldString(UnresolvedAttribute("a"), "field") assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_within_group(): a = WithinGroup(UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcd"]) assert a.dependent_column_names() == set("abcde") + assert a.dependent_column_names_with_duplication() == list("eabcd") @pytest.mark.parametrize( @@ -189,16 +210,19 @@ def test_within_group(): def test_unary_expression(expression_class): a = expression_class(child=UnresolvedAttribute("a")) assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_alias(): a = Alias(child=Add(UnresolvedAttribute("a"), UnresolvedAttribute("b")), name="c") assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] def test_cast(): a = Cast(UnresolvedAttribute("a"), IntegerType()) assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] @pytest.mark.parametrize( @@ -234,6 +258,10 @@ def test_binary_expression(expression_class): assert b.dependent_column_names() == {"B"} assert binary_expression.dependent_column_names() == {"A", "B"} + assert a.dependent_column_names_with_duplication() == ["A"] + assert b.dependent_column_names_with_duplication() == ["B"] + assert binary_expression.dependent_column_names_with_duplication() == ["A", "B"] + @pytest.mark.parametrize( "expression_class", @@ -253,6 +281,7 @@ def test_grouping_set(expression_class): ] ) assert a.dependent_column_names() == {"a", "b", "c", "d"} + assert a.dependent_column_names_with_duplication() == ["a", "b", "c", "d"] def test_grouping_sets_expression(): @@ -263,11 +292,13 @@ def test_grouping_sets_expression(): ] ) assert a.dependent_column_names() == {"a", "b", "c", "d"} + assert a.dependent_column_names_with_duplication() == ["a", "b", "c", "d"] def test_sort_order(): a = SortOrder(UnresolvedAttribute("a"), Ascending()) assert a.dependent_column_names() == {"a"} + assert a.dependent_column_names_with_duplication() == ["a"] def test_specified_window_frame(): @@ -275,12 +306,14 @@ def test_specified_window_frame(): RowFrame(), UnresolvedAttribute("a"), UnresolvedAttribute("b") ) assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] @pytest.mark.parametrize("expression_class", [RankRelatedFunctionExpression, Lag, Lead]) def test_rank_related_function_expression(expression_class): a = expression_class(UnresolvedAttribute("a"), 1, UnresolvedAttribute("b"), False) assert a.dependent_column_names() == {"a", "b"} + assert a.dependent_column_names_with_duplication() == ["a", "b"] def test_window_spec_definition(): @@ -295,6 +328,7 @@ def test_window_spec_definition(): ), ) assert a.dependent_column_names() == set("abcdef") + assert a.dependent_column_names_with_duplication() == list("abcdef") def test_window_expression(): @@ -310,6 +344,7 @@ def test_window_expression(): ) a = WindowExpression(UnresolvedAttribute("x"), window_spec_definition) assert a.dependent_column_names() == set("abcdefx") + assert a.dependent_column_names_with_duplication() == list("xabcdef") @pytest.mark.parametrize( @@ -325,3 +360,4 @@ def test_window_expression(): def test_other_window_expressions(expression_class): a = expression_class() assert a.dependent_column_names() == COLUMN_DEPENDENCY_EMPTY + assert a.dependent_column_names_with_duplication() == [] From 0103502fad523cd46c0cb01eeb0e1eb958a21d46 Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Wed, 11 Sep 2024 15:30:52 -0700 Subject: [PATCH 5/9] add test --- .../_internal/analyzer/window_expression.py | 4 +- .../unit/test_expression_dependent_columns.py | 71 ++++++++++++++++++- 2 files changed, 72 insertions(+), 3 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index b864ab2ed0d..2540d98e938 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -148,7 +148,9 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.window_function, self.window_spec) def dependent_column_names_with_duplication(self) -> list[str]: - return derive_dependent_columns_with_duplication(self.window_function, self.window_spec) + return derive_dependent_columns_with_duplication( + self.window_function, self.window_spec + ) @property def plan_node_category(self) -> PlanNodeCategory: diff --git a/tests/unit/test_expression_dependent_columns.py b/tests/unit/test_expression_dependent_columns.py index fdc8981a734..c1f4e163f9c 100644 --- a/tests/unit/test_expression_dependent_columns.py +++ b/tests/unit/test_expression_dependent_columns.py @@ -117,7 +117,7 @@ def test_unresolved_attribute(): c = UnresolvedAttribute("$1 > 1", is_sql_text=True) assert c.dependent_column_names() == COLUMN_DEPENDENCY_DOLLAR - assert c.dependent_column_names_with_duplication() == ['$'] + assert c.dependent_column_names_with_duplication() == ["$"] def test_case_when(): @@ -126,7 +126,15 @@ def test_case_when(): z = when(a > b, col("c")).when(a < b, col("d")).else_(col("e")) assert z._expression.dependent_column_names() == {'"A"', '"B"', '"C"', '"D"', '"E"'} # verify column '"A"', '"B"' occurred twice in the dependency columns - assert z._expression.dependent_column_names_with_duplication() == ['"A"', '"B"', '"C"', '"A"', '"B"', '"D"', '"E"'] + assert z._expression.dependent_column_names_with_duplication() == [ + '"A"', + '"B"', + '"C"', + '"A"', + '"B"', + '"D"', + '"E"', + ] def test_collate(): @@ -140,6 +148,13 @@ def test_function_expression(): assert a.dependent_column_names() == set("abcd") assert a.dependent_column_names_with_duplication() == list("abcd") + # expressions with duplicated dependent column + b = FunctionExpression( + "test_func", [UnresolvedAttribute(x) for x in "abcdad"], False + ) + assert b.dependent_column_names() == set("abcd") + assert b.dependent_column_names_with_duplication() == list("abcdad") + def test_in_expression(): a = InExpression(UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcd"]) @@ -152,6 +167,11 @@ def test_like(): assert a.dependent_column_names() == {"a", "b"} assert a.dependent_column_names_with_duplication() == ["a", "b"] + # with duplication + b = Like(UnresolvedAttribute("a"), UnresolvedAttribute("a")) + assert b.dependent_column_names() == {"a"} + assert b.dependent_column_names_with_duplication() == ["a", "a"] + def test_list_agg(): a = ListAgg(UnresolvedAttribute("a"), ",", True) @@ -164,12 +184,21 @@ def test_multiple_expression(): assert a.dependent_column_names() == set("abcd") assert a.dependent_column_names_with_duplication() == list("abcd") + # with duplication + a = MultipleExpression([UnresolvedAttribute(x) for x in "abcdbea"]) + assert a.dependent_column_names() == set("abcde") + assert a.dependent_column_names_with_duplication() == list("abcdbea") + def test_reg_exp(): a = RegExp(UnresolvedAttribute("a"), UnresolvedAttribute("b")) assert a.dependent_column_names() == {"a", "b"} assert a.dependent_column_names_with_duplication() == ["a", "b"] + b = RegExp(UnresolvedAttribute("a"), UnresolvedAttribute("a")) + assert b.dependent_column_names() == {"a"} + assert b.dependent_column_names_with_duplication() == ["a", "a"] + def test_scalar_subquery(): a = ScalarSubquery(None) @@ -184,12 +213,23 @@ def test_snowflake_udf(): assert a.dependent_column_names() == set("abcd") assert a.dependent_column_names_with_duplication() == list("abcd") + # with duplication + b = SnowflakeUDF( + "udf_name", [UnresolvedAttribute(x) for x in "abcdfc"], IntegerType() + ) + assert b.dependent_column_names() == set("abcdf") + assert b.dependent_column_names_with_duplication() == list("abcdfc") + def test_star(): a = Star([Attribute(x, IntegerType()) for x in "abcd"]) assert a.dependent_column_names() == set("abcd") assert a.dependent_column_names_with_duplication() == list("abcd") + b = Star([]) + assert b.dependent_column_names() == COLUMN_DEPENDENCY_ALL + assert b.dependent_column_names_with_duplication() == [] + def test_subfield_string(): a = SubfieldString(UnresolvedAttribute("a"), "field") @@ -202,6 +242,10 @@ def test_within_group(): assert a.dependent_column_names() == set("abcde") assert a.dependent_column_names_with_duplication() == list("eabcd") + b = WithinGroup(UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcdea"]) + assert b.dependent_column_names() == set("abcde") + assert b.dependent_column_names_with_duplication() == list("eabcdea") + @pytest.mark.parametrize( "expression_class", @@ -262,6 +306,11 @@ def test_binary_expression(expression_class): assert b.dependent_column_names_with_duplication() == ["B"] assert binary_expression.dependent_column_names_with_duplication() == ["A", "B"] + # hierarchical expressions with duplication + hierarchical_binary_expression = expression_class(expression_class(a, b), b) + assert hierarchical_binary_expression.dependent_column_names() == {"A", "B"} + assert hierarchical_binary_expression.dependent_column_names_with_duplication() == ["A", "B", "B"] + @pytest.mark.parametrize( "expression_class", @@ -283,6 +332,17 @@ def test_grouping_set(expression_class): assert a.dependent_column_names() == {"a", "b", "c", "d"} assert a.dependent_column_names_with_duplication() == ["a", "b", "c", "d"] + # with duplication + b = expression_class( + [ + UnresolvedAttribute("a"), + UnresolvedAttribute("a"), + UnresolvedAttribute("c"), + ] + ) + assert b.dependent_column_names() == {"a", "c"} + assert b.dependent_column_names_with_duplication() == ["a", "a", "c"] + def test_grouping_sets_expression(): a = GroupingSetsExpression( @@ -308,6 +368,13 @@ def test_specified_window_frame(): assert a.dependent_column_names() == {"a", "b"} assert a.dependent_column_names_with_duplication() == ["a", "b"] + # with duplication + b = SpecifiedWindowFrame( + RowFrame(), UnresolvedAttribute("a"), UnresolvedAttribute("a") + ) + assert b.dependent_column_names() == {"a"} + assert b.dependent_column_names_with_duplication() == ["a", "a"] + @pytest.mark.parametrize("expression_class", [RankRelatedFunctionExpression, Lag, Lead]) def test_rank_related_function_expression(expression_class): From d2539f1b6ba1c8df9618f59a213786f0b1613056 Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Wed, 11 Sep 2024 16:12:32 -0700 Subject: [PATCH 6/9] fix error --- .../unit/test_expression_dependent_columns.py | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_expression_dependent_columns.py b/tests/unit/test_expression_dependent_columns.py index c1f4e163f9c..c9b8a1ce38d 100644 --- a/tests/unit/test_expression_dependent_columns.py +++ b/tests/unit/test_expression_dependent_columns.py @@ -242,7 +242,9 @@ def test_within_group(): assert a.dependent_column_names() == set("abcde") assert a.dependent_column_names_with_duplication() == list("eabcd") - b = WithinGroup(UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcdea"]) + b = WithinGroup( + UnresolvedAttribute("e"), [UnresolvedAttribute(x) for x in "abcdea"] + ) assert b.dependent_column_names() == set("abcde") assert b.dependent_column_names_with_duplication() == list("eabcdea") @@ -309,7 +311,11 @@ def test_binary_expression(expression_class): # hierarchical expressions with duplication hierarchical_binary_expression = expression_class(expression_class(a, b), b) assert hierarchical_binary_expression.dependent_column_names() == {"A", "B"} - assert hierarchical_binary_expression.dependent_column_names_with_duplication() == ["A", "B", "B"] + assert hierarchical_binary_expression.dependent_column_names_with_duplication() == [ + "A", + "B", + "B", + ] @pytest.mark.parametrize( @@ -414,6 +420,22 @@ def test_window_expression(): assert a.dependent_column_names_with_duplication() == list("xabcdef") +def test_window_expression_with_duplication_columns(): + window_spec_definition = WindowSpecDefinition( + [UnresolvedAttribute("a"), UnresolvedAttribute("b")], + [ + SortOrder(UnresolvedAttribute("c"), Ascending()), + SortOrder(UnresolvedAttribute("a"), Ascending()), + ], + SpecifiedWindowFrame( + RowFrame(), UnresolvedAttribute("e"), UnresolvedAttribute("f") + ), + ) + a = WindowExpression(UnresolvedAttribute("e"), window_spec_definition) + assert a.dependent_column_names() == set("abcef") + assert a.dependent_column_names_with_duplication() == list("eabcaef") + + @pytest.mark.parametrize( "expression_class", [ From ebcb213f7e836cbb1ef599cf6a6368f29e5346b4 Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Wed, 11 Sep 2024 16:30:33 -0700 Subject: [PATCH 7/9] test --- .../_internal/analyzer/binary_expression.py | 4 +- .../snowpark/_internal/analyzer/expression.py | 40 +++++++++---------- .../_internal/analyzer/grouping_set.py | 2 +- .../_internal/analyzer/sort_expression.py | 4 +- .../_internal/analyzer/unary_expression.py | 4 +- .../_internal/analyzer/window_expression.py | 8 ++-- 6 files changed, 31 insertions(+), 31 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py index bcc6cfb1e30..22591f55e47 100644 --- a/src/snowflake/snowpark/_internal/analyzer/binary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/binary_expression.py @@ -2,7 +2,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Optional +from typing import AbstractSet, List, Optional from snowflake.snowpark._internal.analyzer.expression import ( Expression, @@ -30,7 +30,7 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.left, self.right) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication(self.left, self.right) @property diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 6a36c9f1c0f..2d50bc2f260 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -57,7 +57,7 @@ def derive_dependent_columns( def derive_dependent_columns_with_duplication( *expressions: "Optional[Expression]", -) -> list[str]: +) -> List[str]: """ Given set of expressions, derive the list of columns that the expression dependents on. @@ -65,7 +65,7 @@ def derive_dependent_columns_with_duplication( the given expression. For example, concat(col1, upper(co1), upper(col2)) will have result [col1, col1, col2], where col1 occurred twice in the result. """ - result: List[str] = [] + result = [] for exp in expressions: result.extend(exp.dependent_column_names_with_duplication()) return result @@ -91,7 +91,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: # TODO: consider adding it to __init__ or use cached_property. return COLUMN_DEPENDENCY_EMPTY - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return [] @property @@ -185,7 +185,7 @@ def __init__(self, expressions: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.expressions) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication(*self.expressions) @property @@ -204,7 +204,7 @@ def __init__(self, columns: Expression, values: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.columns, *self.values) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication(self.columns, *self.values) @property @@ -247,7 +247,7 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return {self.name} - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return [self.name] @property @@ -273,7 +273,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: else COLUMN_DEPENDENCY_ALL ) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return ( derive_dependent_columns_with_duplication(*self.expressions) if self.expressions @@ -323,7 +323,7 @@ def __hash__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return self._dependent_column_names - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return ( [] if self._dependent_column_names == COLUMN_DEPENDENCY_ALL @@ -423,7 +423,7 @@ def __init__(self, expr: Expression, pattern: Expression) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication(self.expr, self.pattern) @property @@ -455,7 +455,7 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.pattern) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication(self.expr, self.pattern) @property @@ -481,7 +481,7 @@ def __init__(self, expr: Expression, collation_spec: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication(self.expr) @property @@ -505,7 +505,7 @@ def __init__(self, expr: Expression, field: str) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication(self.expr) @property @@ -530,7 +530,7 @@ def __init__(self, expr: Expression, field: int) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication(self.expr) @property @@ -577,7 +577,7 @@ def sql(self) -> str: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication(*self.children) @property @@ -595,7 +595,7 @@ def __init__(self, expr: Expression, order_by_cols: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, *self.order_by_cols) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication(self.expr, *self.order_by_cols) @property @@ -623,7 +623,7 @@ def __init__( self.else_value = else_value @property - def _child_expressions(self) -> list[Expression]: + def _child_expressions(self) -> List[Expression]: exps = [] for exp_tuple in self.branches: exps.extend(exp_tuple) @@ -635,7 +635,7 @@ def _child_expressions(self) -> list[Expression]: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self._child_expressions) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication(*self._child_expressions) @property @@ -683,7 +683,7 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.children) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication(*self.children) @property @@ -701,7 +701,7 @@ def __init__(self, col: Expression, delimiter: str, is_distinct: bool) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.col) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication(self.col) @property @@ -723,7 +723,7 @@ def __init__(self, exprs: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.exprs) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication(*self.exprs) @property diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 0bba259f68f..906b15ffd9c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -24,7 +24,7 @@ def __init__(self, group_by_exprs: List[Expression]) -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(*self.group_by_exprs) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication(*self.group_by_exprs) @property diff --git a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py index e5b673a402c..82451245e4c 100644 --- a/src/snowflake/snowpark/_internal/analyzer/sort_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/sort_expression.py @@ -2,7 +2,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Optional, Type +from typing import AbstractSet, List, Optional, Type from snowflake.snowpark._internal.analyzer.expression import ( Expression, @@ -57,5 +57,5 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication(self.child) diff --git a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py index d47e01417f6..1ae08e8fde2 100644 --- a/src/snowflake/snowpark/_internal/analyzer/unary_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/unary_expression.py @@ -2,7 +2,7 @@ # Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved. # -from typing import AbstractSet, Dict, Optional +from typing import AbstractSet, Dict, List, Optional from snowflake.snowpark._internal.analyzer.expression import ( Expression, @@ -37,7 +37,7 @@ def __str__(self): def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.child) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication(self.child) @property diff --git a/src/snowflake/snowpark/_internal/analyzer/window_expression.py b/src/snowflake/snowpark/_internal/analyzer/window_expression.py index 2540d98e938..4381c4a2e22 100644 --- a/src/snowflake/snowpark/_internal/analyzer/window_expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/window_expression.py @@ -72,7 +72,7 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.lower, self.upper) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication(self.lower, self.upper) @property @@ -106,7 +106,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: *self.partition_spec, *self.order_spec, self.frame_spec ) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication( *self.partition_spec, *self.order_spec, self.frame_spec ) @@ -147,7 +147,7 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.window_function, self.window_spec) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication( self.window_function, self.window_spec ) @@ -185,7 +185,7 @@ def __init__( def dependent_column_names(self) -> Optional[AbstractSet[str]]: return derive_dependent_columns(self.expr, self.default) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return derive_dependent_columns_with_duplication(self.expr, self.default) @property From 53579516e9c590e68fad4238f57077ef37ea262b Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Wed, 11 Sep 2024 16:44:55 -0700 Subject: [PATCH 8/9] fix error --- src/snowflake/snowpark/_internal/analyzer/expression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 2d50bc2f260..2b8ae96c840 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -169,7 +169,7 @@ def __init__(self, plan: "SnowflakePlan") -> None: def dependent_column_names(self) -> Optional[AbstractSet[str]]: return COLUMN_DEPENDENCY_DOLLAR - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: return list(COLUMN_DEPENDENCY_DOLLAR) @property From a1bed4bee11570419696453b2591303561a1b3f9 Mon Sep 17 00:00:00 2001 From: Yun Zou Date: Wed, 11 Sep 2024 17:36:27 -0700 Subject: [PATCH 9/9] fix failure --- src/snowflake/snowpark/_internal/analyzer/expression.py | 6 ++++-- src/snowflake/snowpark/_internal/analyzer/grouping_set.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/snowflake/snowpark/_internal/analyzer/expression.py b/src/snowflake/snowpark/_internal/analyzer/expression.py index 2b8ae96c840..a7cb5fd97a9 100644 --- a/src/snowflake/snowpark/_internal/analyzer/expression.py +++ b/src/snowflake/snowpark/_internal/analyzer/expression.py @@ -67,7 +67,8 @@ def derive_dependent_columns_with_duplication( """ result = [] for exp in expressions: - result.extend(exp.dependent_column_names_with_duplication()) + if exp is not None: + result.extend(exp.dependent_column_names_with_duplication()) return result @@ -326,7 +327,8 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: def dependent_column_names_with_duplication(self) -> List[str]: return ( [] - if self._dependent_column_names == COLUMN_DEPENDENCY_ALL + if (self._dependent_column_names == COLUMN_DEPENDENCY_ALL) + or (self._dependent_column_names is None) else list(self._dependent_column_names) ) diff --git a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py index 906b15ffd9c..012940471d0 100644 --- a/src/snowflake/snowpark/_internal/analyzer/grouping_set.py +++ b/src/snowflake/snowpark/_internal/analyzer/grouping_set.py @@ -49,7 +49,7 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]: flattened_args = [exp for sublist in self.args for exp in sublist] return derive_dependent_columns(*flattened_args) - def dependent_column_names_with_duplication(self) -> list[str]: + def dependent_column_names_with_duplication(self) -> List[str]: flattened_args = [exp for sublist in self.args for exp in sublist] return derive_dependent_columns_with_duplication(*flattened_args)