Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SNOW-1632895] Add derive_dependent_columns_with_duplication capability #2272

Merged
merged 9 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

from typing import AbstractSet, Optional
from typing import AbstractSet, List, Optional

from snowflake.snowpark._internal.analyzer.expression import (
Expression,
derive_dependent_columns,
derive_dependent_columns_with_duplication,
)
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanNodeCategory,
Expand All @@ -29,6 +30,9 @@ def __str__(self):
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.left, self.right)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.left, self.right)

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.LOW_IMPACT
Expand Down
96 changes: 94 additions & 2 deletions src/snowflake/snowpark/_internal/analyzer/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@
def derive_dependent_columns(
*expressions: "Optional[Expression]",
) -> Optional[AbstractSet[str]]:
"""
Given set of expressions, derive the set of columns that the expressions dependents on.

Note, the returned dependent columns is a set without duplication. For example, given expression
concat(col1, upper(co1), upper(col2)), the result will be {col1, col2} even if col1 has
occurred in the given expression twice.
"""
result = set()
for exp in expressions:
if exp is not None:
Expand All @@ -48,6 +55,23 @@ def derive_dependent_columns(
return result


def derive_dependent_columns_with_duplication(
*expressions: "Optional[Expression]",
) -> List[str]:
"""
Given set of expressions, derive the list of columns that the expression dependents on.

Note, the returned columns will have duplication if the column occurred more than once in
the given expression. For example, concat(col1, upper(co1), upper(col2)) will have result
[col1, col1, col2], where col1 occurred twice in the result.
"""
result = []
for exp in expressions:
if exp is not None:
result.extend(exp.dependent_column_names_with_duplication())
return result


class Expression:
"""Consider removing attributes, and adding properties and methods.
A subclass of Expression may have no child, one child, or multiple children.
Expand All @@ -68,6 +92,9 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]:
# TODO: consider adding it to __init__ or use cached_property.
return COLUMN_DEPENDENCY_EMPTY

def dependent_column_names_with_duplication(self) -> List[str]:
return []

@property
def pretty_name(self) -> str:
"""Returns a user-facing string representation of this expression's name.
Expand Down Expand Up @@ -143,6 +170,9 @@ def __init__(self, plan: "SnowflakePlan") -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return COLUMN_DEPENDENCY_DOLLAR

def dependent_column_names_with_duplication(self) -> List[str]:
return list(COLUMN_DEPENDENCY_DOLLAR)

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
return self.plan.cumulative_node_complexity
Expand All @@ -156,6 +186,9 @@ def __init__(self, expressions: List[Expression]) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(*self.expressions)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(*self.expressions)

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
return sum_node_complexities(
Expand All @@ -172,6 +205,9 @@ def __init__(self, columns: Expression, values: List[Expression]) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.columns, *self.values)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.columns, *self.values)

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.IN
Expand Down Expand Up @@ -212,6 +248,9 @@ def __str__(self):
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return {self.name}

def dependent_column_names_with_duplication(self) -> List[str]:
return [self.name]

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.COLUMN
Expand All @@ -235,6 +274,13 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]:
else COLUMN_DEPENDENCY_ALL
)

def dependent_column_names_with_duplication(self) -> List[str]:
return (
derive_dependent_columns_with_duplication(*self.expressions)
if self.expressions
else [] # we currently do not handle * dependency
)

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
complexity = {} if self.expressions else {PlanNodeCategory.COLUMN: 1}
Expand Down Expand Up @@ -278,6 +324,14 @@ def __hash__(self):
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return self._dependent_column_names

def dependent_column_names_with_duplication(self) -> List[str]:
return (
[]
if (self._dependent_column_names == COLUMN_DEPENDENCY_ALL)
or (self._dependent_column_names is None)
else list(self._dependent_column_names)
)

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.COLUMN
Expand Down Expand Up @@ -371,6 +425,9 @@ def __init__(self, expr: Expression, pattern: Expression) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr, self.pattern)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.expr, self.pattern)

@property
def plan_node_category(self) -> PlanNodeCategory:
# expr LIKE pattern
Expand Down Expand Up @@ -400,6 +457,9 @@ def __init__(
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr, self.pattern)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.expr, self.pattern)

@property
def plan_node_category(self) -> PlanNodeCategory:
# expr REG_EXP pattern
Expand All @@ -423,6 +483,9 @@ def __init__(self, expr: Expression, collation_spec: str) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.expr)

@property
def plan_node_category(self) -> PlanNodeCategory:
# expr COLLATE collate_spec
Expand All @@ -444,6 +507,9 @@ def __init__(self, expr: Expression, field: str) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.expr)

@property
def plan_node_category(self) -> PlanNodeCategory:
# the literal corresponds to the contribution from self.field
Expand All @@ -466,6 +532,9 @@ def __init__(self, expr: Expression, field: int) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.expr)

@property
def plan_node_category(self) -> PlanNodeCategory:
# the literal corresponds to the contribution from self.field
Expand Down Expand Up @@ -510,6 +579,9 @@ def sql(self) -> str:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(*self.children)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(*self.children)

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.FUNCTION
Expand All @@ -525,6 +597,9 @@ def __init__(self, expr: Expression, order_by_cols: List[Expression]) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr, *self.order_by_cols)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.expr, *self.order_by_cols)

@property
def plan_node_category(self) -> PlanNodeCategory:
# expr WITHIN GROUP (ORDER BY cols)
Expand All @@ -549,13 +624,21 @@ def __init__(
self.branches = branches
self.else_value = else_value

def dependent_column_names(self) -> Optional[AbstractSet[str]]:
@property
def _child_expressions(self) -> List[Expression]:
exps = []
for exp_tuple in self.branches:
exps.extend(exp_tuple)
if self.else_value is not None:
exps.append(self.else_value)
return derive_dependent_columns(*exps)

return exps

def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(*self._child_expressions)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(*self._child_expressions)

@property
def plan_node_category(self) -> PlanNodeCategory:
Expand Down Expand Up @@ -602,6 +685,9 @@ def __init__(
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(*self.children)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(*self.children)

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.FUNCTION
Expand All @@ -617,6 +703,9 @@ def __init__(self, col: Expression, delimiter: str, is_distinct: bool) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.col)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.col)

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.FUNCTION
Expand All @@ -636,6 +725,9 @@ def __init__(self, exprs: List[Expression]) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(*self.exprs)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(*self.exprs)

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
return sum_node_complexities(
Expand Down
8 changes: 8 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/grouping_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from snowflake.snowpark._internal.analyzer.expression import (
Expression,
derive_dependent_columns,
derive_dependent_columns_with_duplication,
)
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanNodeCategory,
Expand All @@ -23,6 +24,9 @@ def __init__(self, group_by_exprs: List[Expression]) -> None:
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(*self.group_by_exprs)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(*self.group_by_exprs)

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.LOW_IMPACT
Expand All @@ -45,6 +49,10 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]:
flattened_args = [exp for sublist in self.args for exp in sublist]
return derive_dependent_columns(*flattened_args)

def dependent_column_names_with_duplication(self) -> List[str]:
flattened_args = [exp for sublist in self.args for exp in sublist]
return derive_dependent_columns_with_duplication(*flattened_args)

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
return sum_node_complexities(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

from typing import AbstractSet, Optional, Type
from typing import AbstractSet, List, Optional, Type

from snowflake.snowpark._internal.analyzer.expression import (
Expression,
derive_dependent_columns,
derive_dependent_columns_with_duplication,
)


Expand Down Expand Up @@ -55,3 +56,6 @@ def __init__(

def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.child)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.child)
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
# Copyright (c) 2012-2024 Snowflake Computing Inc. All rights reserved.
#

from typing import AbstractSet, Dict, Optional
from typing import AbstractSet, Dict, List, Optional

from snowflake.snowpark._internal.analyzer.expression import (
Expression,
NamedExpression,
derive_dependent_columns,
derive_dependent_columns_with_duplication,
)
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanNodeCategory,
Expand Down Expand Up @@ -36,6 +37,9 @@ def __str__(self):
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.child)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.child)

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.LOW_IMPACT
Expand Down
17 changes: 17 additions & 0 deletions src/snowflake/snowpark/_internal/analyzer/window_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from snowflake.snowpark._internal.analyzer.expression import (
Expression,
derive_dependent_columns,
derive_dependent_columns_with_duplication,
)
from snowflake.snowpark._internal.analyzer.query_plan_analysis_utils import (
PlanNodeCategory,
Expand Down Expand Up @@ -71,6 +72,9 @@ def __init__(
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.lower, self.upper)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.lower, self.upper)

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.LOW_IMPACT
Expand Down Expand Up @@ -102,6 +106,11 @@ def dependent_column_names(self) -> Optional[AbstractSet[str]]:
*self.partition_spec, *self.order_spec, self.frame_spec
)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(
*self.partition_spec, *self.order_spec, self.frame_spec
)

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
# partition_spec order_by_spec frame_spec
Expand Down Expand Up @@ -138,6 +147,11 @@ def __init__(
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.window_function, self.window_spec)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(
self.window_function, self.window_spec
)

@property
def plan_node_category(self) -> PlanNodeCategory:
return PlanNodeCategory.WINDOW
Expand Down Expand Up @@ -171,6 +185,9 @@ def __init__(
def dependent_column_names(self) -> Optional[AbstractSet[str]]:
return derive_dependent_columns(self.expr, self.default)

def dependent_column_names_with_duplication(self) -> List[str]:
return derive_dependent_columns_with_duplication(self.expr, self.default)

@property
def individual_node_complexity(self) -> Dict[PlanNodeCategory, int]:
# for func_name
Expand Down
Loading
Loading