Skip to content

Commit

Permalink
merge in recent server-side main
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-lspiegelberg committed Sep 3, 2024
2 parents 2324f85 + 4b3cb8f commit e905df8
Show file tree
Hide file tree
Showing 24 changed files with 1,559 additions and 779 deletions.
16 changes: 16 additions & 0 deletions src/snowflake/snowpark/_internal/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def build_expr_from_python_val(expr_builder: proto.Expr, obj: Any) -> None:
elif isinstance(obj, snowflake.snowpark._internal.type_utils.DataType):
ast = with_src_position(expr_builder.sp_datatype_val)
obj._fill_ast(ast.datatype)
elif isinstance(obj, snowflake.snowpark._internal.analyzer.expression.Literal):
build_expr_from_python_val(expr_builder, obj.value)
else:
raise NotImplementedError("not supported type: %s" % type(obj))

Expand Down Expand Up @@ -729,6 +731,20 @@ def fill_sp_write_file(
build_expr_from_python_val(t._2, v)


def build_proto_from_pivot_values(
expr_builder: proto.SpPivotValue,
values: Optional[Union[Iterable["LiteralType"], "DataFrame"]], # noqa: F821
):
"""Helper function to encode Snowpark pivot values that are used in various pivot operations to AST."""
if not values:
return

if isinstance(values, snowflake.snowpark.dataframe.DataFrame):
expr_builder.sp_pivot_value__dataframe.v.id.bitfield1 = values._ast_id
else:
build_expr_from_python_val(expr_builder.sp_pivot_value__expr.v, values)


def build_proto_from_callable(
expr_builder: proto.SpCallable, func: Callable, ast_batch: Optional[AstBatch] = None
):
Expand Down
1,310 changes: 651 additions & 659 deletions src/snowflake/snowpark/_internal/proto/ast_pb2.py

Large diffs are not rendered by default.

52 changes: 42 additions & 10 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,7 @@ def to_df(
ast = with_src_position(stmt.expr.sp_dataframe_to_df, stmt)
ast.col_names.extend(col_names)
ast.variadic = is_variadic
self.set_ast_ref(ast.df)

new_cols = []
for attr, name in zip(self._output, col_names):
Expand Down Expand Up @@ -1325,6 +1326,7 @@ def select(
_, new_cols, alias_cols = _get_cols_after_join_table(
func_expr, self._plan, temp_join_plan
)

# when generating join table expression, we inculcate aliased column into the initial
# query like so,
#
Expand Down Expand Up @@ -1631,6 +1633,7 @@ def sort(
for c in _cols:
build_expr_from_snowpark_column_or_col_name(ast.cols.add(), c)
ast.cols_variadic = is_variadic
self.set_ast_ref(ast.df)

orders = []
# `ascending` is represented by Expr in the AST.
Expand Down Expand Up @@ -1686,11 +1689,16 @@ def sort(
SortOrder(exprs[idx], orders[idx] if orders else Ascending())
)

if self._select_statement:
return self._with_plan(
self._select_statement.sort(sort_exprs), ast_stmt=stmt
)
return self._with_plan(Sort(sort_exprs, self._plan), ast_stmt=stmt)
df = (
self._with_plan(self._select_statement.sort(sort_exprs))
if self._select_statement
else self._with_plan(Sort(sort_exprs, self._plan))
)

if _emit_ast:
df._ast_id = stmt.var_id.bitfield1

return df

@experimental(version="1.5.0")
def alias(self, name: str, _emit_ast: bool = True):
Expand Down Expand Up @@ -1731,6 +1739,7 @@ def alias(self, name: str, _emit_ast: bool = True):
stmt = self._session._ast_batch.assign()
ast = with_src_position(stmt.expr.sp_dataframe_alias, stmt)
ast.name = name
self.set_ast_ref(ast.df)

# TODO: Support alias in MockServerConnection.
from snowflake.snowpark.mock._connection import MockServerConnection
Expand Down Expand Up @@ -1830,6 +1839,7 @@ def agg(
for e in exprs:
build_expr_from_python_val(expr.exprs.args.add(), e)
expr.exprs.variadic = is_variadic
self.set_ast_ref(expr.df)

df = self.group_by(_emit_ast=False).agg(*exprs, _emit_ast=False)

Expand Down Expand Up @@ -1916,6 +1926,8 @@ def group_by(
)
for c in col_list:
build_expr_from_snowpark_column_or_col_name(expr.cols.args.add(), c)

expr.df.sp_dataframe_ref.id.bitfield1 = self._ast_id
else:
stmt = _ast_stmt

Expand Down Expand Up @@ -2319,6 +2331,7 @@ def union(self, other: "DataFrame", _emit_ast: bool = True) -> "DataFrame":
stmt = self._session._ast_batch.assign()
ast = with_src_position(stmt.expr.sp_dataframe_union, stmt)
other.set_ast_ref(ast.other)
self.set_ast_ref(ast.df)

df = (
self._with_plan(
Expand Down Expand Up @@ -2369,6 +2382,7 @@ def union_all(self, other: "DataFrame", _emit_ast: bool = True) -> "DataFrame":
stmt = self._session._ast_batch.assign()
ast = with_src_position(stmt.expr.sp_dataframe_union_all, stmt)
other.set_ast_ref(ast.other)
self.set_ast_ref(ast.df)

df = (
self._with_plan(
Expand Down Expand Up @@ -2531,6 +2545,7 @@ def intersect(self, other: "DataFrame", _emit_ast: bool = True) -> "DataFrame":
stmt = self._session._ast_batch.assign()
ast = with_src_position(stmt.expr.sp_dataframe_intersect, stmt)
other.set_ast_ref(ast.other)
self.set_ast_ref(ast.df)

df = (
self._with_plan(
Expand Down Expand Up @@ -2579,6 +2594,7 @@ def except_(self, other: "DataFrame", _emit_ast: bool = True) -> "DataFrame":
stmt = self._session._ast_batch.assign()
ast = with_src_position(stmt.expr.sp_dataframe_except, stmt)
other.set_ast_ref(ast.other)
self.set_ast_ref(ast.df)

if self._select_statement:
df = self._with_plan(
Expand Down Expand Up @@ -3023,9 +3039,7 @@ def join(
elif isinstance(join_type, LeftAnti):
ast.join_type.sp_join_type__left_anti = True
elif isinstance(join_type, AsOf):
raise NotImplementedError(
"TODO SNOW-1638064: Add support for asof join to IR."
)
ast.join_type.sp_join_type__asof = True
else:
raise ValueError(f"Unsupported join type {join_type}")

Expand Down Expand Up @@ -3444,6 +3458,7 @@ def with_column(
expr = with_src_position(ast_stmt.expr.sp_dataframe_with_column, ast_stmt)
expr.col_name = col_name
build_expr_from_snowpark_column_or_table_fn(expr.col, col)
self.set_ast_ref(expr.df)

df = self.with_columns([col_name], [col], ast_stmt=ast_stmt, _emit_ast=False)

Expand Down Expand Up @@ -3558,6 +3573,7 @@ def with_columns(
expr.col_names.append(col_name)
for value in values:
build_expr_from_snowpark_column_or_table_fn(expr.values.add(), value)
self.set_ast_ref(expr.df)

# Put it all together
df = self.select([*old_cols, *new_cols], _ast_stmt=ast_stmt, _emit_ast=False)
Expand Down Expand Up @@ -3738,6 +3754,10 @@ def copy_into_table(
statement_params: Dictionary of statement level parameters to be set while executing this action.
copy_options: The kwargs that is used to specify the ``copyOptions`` of the ``COPY INTO <table>`` command.
"""

# TODO: This should be an eval operation, not an assign only as implemented here. Rather, the AST should be
# issued as query similar to collect().

# AST.
stmt = None
if _emit_ast:
Expand Down Expand Up @@ -3774,6 +3794,7 @@ def copy_into_table(
entry = expr.copy_options.add()
entry._1 = k
build_expr_from_python_val(entry._2, copy_options[k])
self.set_ast_ref(expr.df)

# TODO: Support copy_into_table in MockServerConnection.
from snowflake.snowpark.mock._connection import MockServerConnection
Expand Down Expand Up @@ -4069,7 +4090,7 @@ def _show_string(self, n: int = 10, max_width: int = 50, **kwargs) -> str:
# Phase 0 code where string gets formatted.
if is_sql_select_statement(query):
result, meta = self._session._conn.get_result_and_metadata(
self.limit(n)._plan, **kwargs
self.limit(n, _emit_ast=False)._plan, **kwargs
)
else:
res, meta = self._session._conn.get_result_and_metadata(
Expand Down Expand Up @@ -5080,7 +5101,18 @@ def _explain_string(self) -> str:

def _resolve(self, col_name: str) -> Union[Expression, NamedExpression]:
normalized_col_name = quote_name(col_name)
cols = list(filter(lambda attr: attr.name == normalized_col_name, self._output))
cols = list(
filter(
lambda attr: quote_name(attr.name) == normalized_col_name, self._output
)
)

# Remove UnresolvedAttributes. This is an artifact of the analyzer for regular Snowpark and local test mode
# being largely incompatible and not adhering to the same protocols when defining input and output schemas.
cols = list(
filter(lambda attr: not isinstance(attr, UnresolvedAttribute), cols)
)

if len(cols) == 1:
return cols[0].with_name(normalized_col_name)
else:
Expand Down
29 changes: 22 additions & 7 deletions src/snowflake/snowpark/dataframe_analytics_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
unix_timestamp,
year,
)
from snowflake.snowpark.types import IntegerType, StructField, StructType
from snowflake.snowpark.window import Window

# "s" (seconds), "m" (minutes), "h" (hours), "d" (days), "w" (weeks), "mm" (months), "y" (years)
Expand Down Expand Up @@ -544,6 +545,7 @@ def compute_lag(
ast.lags.extend(lags)
ast.group_by.extend(group_by)
ast.order_by.extend(order_by)
self._df.set_ast_ref(ast.df)

for c in cols:
for _lag in lags:
Expand Down Expand Up @@ -734,6 +736,7 @@ def time_series_agg(
ast.windows.extend(windows)
ast.group_by.extend(group_by)
ast.sliding_interval = sliding_interval
self._df.set_ast_ref(ast.df)

for window in windows:
for column, funcs in aggs.items():
Expand All @@ -752,7 +755,15 @@ def time_series_agg(
isinstance(self._df._session._conn, MockServerConnection)
and self._df._session._conn._suppress_not_implemented_error
):
return self._df._session.createDataFrame([])
# TODO: Snowpark does not allow empty dataframes (no schema, no data). Have a dummy schema here.
ans = self._df._session.createDataFrame(
[],
schema=StructType([StructField("row", IntegerType())]),
_emit_ast=False,
)
if _emit_ast:
ans._ast_id = stmt.var_id.bitfield1
return ans

slide_duration, slide_unit = self._validate_and_extract_time_unit(
sliding_interval, "sliding_interval", allow_negative=False
Expand All @@ -779,14 +790,18 @@ def time_series_agg(
window, "window"
)
# Perform self-join on DataFrame for aggregation within each group and time window.
left_df = sliding_windows_df.alias("A")
right_df = sliding_windows_df.alias("B")
left_df = sliding_windows_df.alias("A", _emit_ast=False)
right_df = sliding_windows_df.alias("B", _emit_ast=False)

for column in right_df.columns:
if column not in group_by:
right_df = right_df.with_column_renamed(column, f"{column}B")
right_df = right_df.with_column_renamed(
column, f"{column}B", _emit_ast=False
)

self_joined_df = left_df.join(right_df, on=group_by, how="leftouter")
self_joined_df = left_df.join(
right_df, on=group_by, how="leftouter", _emit_ast=False
)

window_frame = dateadd(
window_unit, lit(window_duration), f"{sliding_point_col}"
Expand All @@ -801,8 +816,8 @@ def time_series_agg(

# Filter rows to include only those within the specified time window for aggregation.
self_joined_df = self_joined_df.filter(
col(f"{sliding_point_col}B") >= window_start
).filter(col(f"{sliding_point_col}B") <= window_end)
col(f"{sliding_point_col}B") >= window_start, _emit_ast=False
).filter(col(f"{sliding_point_col}B") <= window_end, _emit_ast=False)

# Peform final aggregations.
group_by_cols = group_by + [sliding_point_col]
Expand Down
1 change: 1 addition & 0 deletions src/snowflake/snowpark/dataframe_na_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def drop(
ast.subset.list.append(subset)
elif isinstance(subset, Iterable):
ast.subset.list.extend(subset)
self._df.set_ast_ref(ast.df)

# if subset is not provided, drop will be applied to all columns
if subset is None:
Expand Down
Loading

0 comments on commit e905df8

Please sign in to comment.