Skip to content

Commit

Permalink
[SNOW-1514712] Support UDFs for Snowpark IR (#2199)
Browse files Browse the repository at this point in the history
Part I of supporting UDFs. This supports the `udf` decorator.

Other:
- Fixes in `udf_utils.py` the case when in `get_opt_arg_defaults` the
returned python `arg_spec = inspect.getfullargspec(target_func)` is
`None`. (Yes, this actually happens).
- Fixes in `conftest.py` for the AST tests the initialization of the
session. Previous init using constructors is incorrect, need to use
builder pattern as session is not properly registered for
`_get_active_session` else.
- Changes `call_udf` to more correctly become calling a builtin
function.
  • Loading branch information
sfc-gh-lspiegelberg authored Sep 3, 2024
1 parent 4b3cb8f commit c7271ab
Show file tree
Hide file tree
Showing 11 changed files with 1,011 additions and 693 deletions.
23 changes: 22 additions & 1 deletion src/snowflake/snowpark/_internal/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import json
import sys
import uuid
from typing import Any, Sequence, Tuple
from dataclasses import dataclass
from typing import Any, Callable, Sequence, Tuple

from google.protobuf.json_format import ParseDict

Expand Down Expand Up @@ -120,12 +121,21 @@ def decode_ast_response_from_snowpark(res: dict, session_parameters: Any) -> Any
)


@dataclass
class TrackedCallable:
var_id: int
func: Callable


class AstBatch:
def __init__(self, session) -> None:
self._session = session
self._id_gen = itertools.count(start=1)
self._init_batch()

# Track callables in this dict (memory id -> TrackedCallable).
self._callables = {}

def assign(self, symbol=None):
stmt = self._request.body.add()
# TODO: extended BindingId spec from the branch snowpark-ir.
Expand Down Expand Up @@ -161,3 +171,14 @@ def _init_batch(self):
self._request.client_language.python_language.version.label = releaselevel

self._request.client_ast_version = CLIENT_AST_VERSION

def register_callable(self, func: Callable) -> int:
"""Tracks client-side an actual callable and returns an ID."""
k = id(func)

if k in self._callables.keys():
return self._callables[k].var_id

next_id = len(self._callables)
self._callables[k] = TrackedCallable(var_id=next_id, func=func)
return next_id
131 changes: 118 additions & 13 deletions src/snowflake/snowpark/_internal/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import sys
from functools import reduce
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
from types import ModuleType
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import snowflake.snowpark
import snowflake.snowpark._internal.proto.ast_pb2 as proto
Expand All @@ -32,6 +33,7 @@
ColumnOrSqlExpr,
)
from snowflake.snowpark._internal.utils import str_to_enum
from snowflake.snowpark.types import DataType

# This flag causes an explicit error to be raised if any Snowpark object instance is missing an AST or field, when this
# AST or field is required to populate the AST field of a different Snowpark object instance.
Expand Down Expand Up @@ -252,12 +254,11 @@ def build_builtin_fn_apply(

def build_udf_apply(
ast: proto.Expr,
udf_name: str,
udf_id: int,
*args: Tuple[Union[proto.Expr, Any]],
) -> None:
expr = with_src_position(ast.apply_expr)
_set_fn_name(udf_name, expr.fn.udf)
with_src_position(expr.fn.udf)
expr.fn.sp_fn_ref.id.bitfield1 = udf_id
build_fn_apply_args(ast, *args)


Expand Down Expand Up @@ -730,15 +731,6 @@ def fill_sp_write_file(
build_expr_from_python_val(t._2, v)


def build_proto_from_callable(
expr_builder: proto.SpCallable, func: Callable, ast_batch: Optional[AstBatch] = None
):
"""Registers a python callable (i.e., a function or lambda) to the AstBatch and encodes it as SpCallable protobuf."""

# TODO SNOW-1514712: This will be filled in as part of UDF ticket.
expr_builder.name = func.__name__


def build_proto_from_pivot_values(
expr_builder: proto.SpPivotValue,
values: Optional[Union[Iterable["LiteralType"], "DataFrame"]], # noqa: F821
Expand All @@ -751,3 +743,116 @@ def build_proto_from_pivot_values(
expr_builder.sp_pivot_value__dataframe.v.id.bitfield1 = values._ast_id
else:
build_expr_from_python_val(expr_builder.sp_pivot_value__expr.v, values)


def build_proto_from_callable(
expr_builder: proto.SpCallable, func: Callable, ast_batch: Optional[AstBatch] = None
):
"""Registers a python callable (i.e., a function or lambda) to the AstBatch and encodes it as SpCallable protobuf."""

udf_id = None
if ast_batch is not None:
udf_id = ast_batch.register_callable(func)
expr_builder.id = udf_id

if callable(func) and func.__name__ == "<lambda>":
# Won't be able to extract name, unless there is <sym> = <lambda>
# use string rep.
expr_builder.name = "<lambda>"

# If it is not the first tracked lambda, use a unique ref name.
if udf_id is not None and udf_id != 0:
expr_builder.name = f"<lambda [{udf_id}]>"

else:
# Use the actual function name. Note: We do not support different scopes yet, need to be careful with this then.
expr_builder.name = func.__name__


def build_udf(
ast: proto.Udf,
func: Union[Callable, Tuple[str, str]],
return_type: Optional[DataType],
input_types: Optional[List[DataType]],
name: Optional[str],
stage_location: Optional[str] = None,
imports: Optional[List[Union[str, Tuple[str, str]]]] = None,
packages: Optional[List[Union[str, ModuleType]]] = None,
replace: bool = False,
if_not_exists: bool = False,
parallel: int = 4,
max_batch_size: Optional[int] = None,
strict: bool = False,
secure: bool = False,
external_access_integrations: Optional[List[str]] = None,
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
comment: Optional[str] = None,
*,
statement_params: Optional[Dict[str, str]] = None,
source_code_display: bool = True,
is_permanent: bool = False,
session=None,
**kwargs,
):
"""Helper function to encode UDF parameters (used in both regular and mock UDFRegistration)."""
# This is the name the UDF is registered to. Not the name to display when unaparsing, that name is captured in callable.

if name is not None:
_set_fn_name(name, ast)

# TODO: to unparse/reference callables client-side - track them in ast_batch.
build_proto_from_callable(
ast.func, func, session._ast_batch if session is not None else None
)

if return_type is not None:
return_type._fill_ast(ast.return_type)
if input_types is not None and len(input_types) != 0:
for input_type in input_types:
input_type._fill_ast(ast.input_types.list.add())
ast.is_permanent = is_permanent
if stage_location is not None:
ast.stage_location = stage_location
if imports is not None and len(imports) != 0:
for import_ in imports:
import_expr = proto.SpTableName()
build_sp_table_name(import_expr, import_)
ast.imports.append(import_expr)
if packages is not None and len(packages) != 0:
for package in packages:
if isinstance(package, ModuleType):
raise NotImplementedError
ast.packages.append(package)
ast.replace = replace
ast.if_not_exists = if_not_exists
ast.parallel = parallel
if max_batch_size is not None:
ast.max_batch_size.value = max_batch_size

if statement_params is not None and len(statement_params) != 0:
for k, v in statement_params.items():
t = ast.statement_params.add()
t._1 = k
t._2 = v

ast.source_code_display = source_code_display
ast.strict = strict
ast.secure = secure
if (
external_access_integrations is not None
and len(external_access_integrations) != 0
):
ast.external_access_integrations.extend(external_access_integrations)
if secrets is not None and len(secrets) != 0:
for k, v in secrets.items():
t = ast.secrets.add()
t._1 = k
t._2 = v
ast.immutable = immutable
if comment is not None:
ast.comment.value = comment
for k, v in kwargs.items():
t = ast.kwargs.add()
t._1 = k
build_expr_from_python_val(t._2, v)
1,314 changes: 659 additions & 655 deletions src/snowflake/snowpark/_internal/proto/ast_pb2.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions src/snowflake/snowpark/_internal/proto/update-from-devvm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ set -euxo pipefail

SCRIPT_DIR=$(dirname "$0")

# Note: If changes are not reflected, run `bazel clean --expunge` first.

# Step 1: Build the python proto file from scratch via bazel
ssh $HOST "bash -c 'cd Snowflake/trunk;bazel build //Snowpark:ast && bazel build //Snowpark:py_proto'"

Expand Down
13 changes: 8 additions & 5 deletions src/snowflake/snowpark/_internal/udf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def build_default_values_result(
input_types: List[DataType],
convert_python_str_to_object: bool,
) -> List[Optional[str]]:
num_optional_args = len(default_values)
num_optional_args = len(default_values) if default_values is not None else 0
num_positional_args = len(input_types) - num_optional_args
input_types_for_default_args = input_types[-num_optional_args:]
if convert_python_str_to_object:
Expand All @@ -406,10 +406,13 @@ def build_default_values_result(
for value, tp in zip(default_values, input_types_for_default_args)
]

default_values_to_sql_str = [
to_sql(value, datatype)
for value, datatype in zip(default_values, input_types_for_default_args)
]
if num_optional_args != 0:
default_values_to_sql_str = [
to_sql(value, datatype)
for value, datatype in zip(default_values, input_types_for_default_args)
]
else:
default_values_to_sql_str = []
return [None] * num_positional_args + default_values_to_sql_str

def get_opt_arg_defaults_from_callable():
Expand Down
4 changes: 4 additions & 0 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -5149,6 +5149,10 @@ def _with_plan(self, plan, ast_stmt=None) -> "DataFrame":
"""
df = DataFrame(self._session, plan, ast_stmt=ast_stmt)
df._statement_params = self._statement_params

if ast_stmt is not None:
df._ast_id = ast_stmt.var_id.bitfield1

return df

def _get_column_names_from_column_or_name_list(
Expand Down
47 changes: 36 additions & 11 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@
build_expr_from_snowpark_column_or_python_val,
build_expr_from_snowpark_column_or_sql_str,
build_table_fn_apply,
build_udf_apply,
create_ast_for_column,
set_builtin_fn_alias,
snowpark_expression_to_ast,
Expand Down Expand Up @@ -7333,6 +7332,7 @@ def udf(
secrets: Optional[Dict[str, str]] = None,
immutable: bool = False,
comment: Optional[str] = None,
_emit_ast: bool = True,
**kwargs,
) -> Union[UserDefinedFunction, functools.partial]:
"""Registers a Python function as a Snowflake Python UDF and returns the UDF.
Expand Down Expand Up @@ -7494,6 +7494,7 @@ def udf(
session = snowflake.snowpark.session._get_sandbox_conditional_active_session(
session
)

if session is None:
udf_registration_method = UDFRegistration(session=session).register
else:
Expand Down Expand Up @@ -7521,6 +7522,7 @@ def udf(
secrets=secrets,
immutable=immutable,
comment=comment,
_emit_ast=_emit_ast,
**kwargs,
)
else:
Expand All @@ -7545,6 +7547,7 @@ def udf(
secrets=secrets,
immutable=immutable,
comment=comment,
_emit_ast=_emit_ast,
**kwargs,
)

Expand Down Expand Up @@ -8321,10 +8324,7 @@ def pandas_udtf(
)


def call_udf(
udf_name: str,
*args: ColumnOrLiteral,
) -> Column:
def call_udf(udf_name: str, *args: ColumnOrLiteral, _emit_ast: bool = True) -> Column:
"""Calls a user-defined function (UDF) by name.
Args:
Expand All @@ -8346,13 +8346,36 @@ def call_udf(
-------------------------------
<BLANKLINE>
"""
# AST
ast = proto.Expr()
build_udf_apply(ast, udf_name, *args)

validate_object_name(udf_name)

ast = None
# AST.
if _emit_ast:
args_list = parse_positional_args_to_list(*args)
ast = proto.Expr()
# Note: The type hint says ColumnOrLiteral, but in Snowpark sometimes arbitrary
# Python objects are passed.
build_builtin_fn_apply(
ast,
"call_udf",
*(
(udf_name,)
+ tuple(
snowpark_expression_to_ast(arg)
if isinstance(arg, Expression)
else arg
for arg in args_list
)
),
)

return _call_function(
udf_name, False, *args, api_call_source="functions.call_udf", _ast=ast
udf_name,
False,
*args,
api_call_source="functions.call_udf",
_ast=ast,
_emit_ast=_emit_ast,
)


Expand Down Expand Up @@ -8479,11 +8502,12 @@ def _call_function(
api_call_source: Optional[str] = None,
is_data_generator: bool = False,
_ast: proto.Expr = None,
_emit_ast: bool = True,
) -> Column:

args_list = parse_positional_args_to_list(*args)
ast = _ast
if ast is None:
if ast is None and _emit_ast:
ast = proto.Expr()
# Note: The type hint says ColumnOrLiteral, but in Snowpark sometimes arbitrary
# Python objects are passed.
Expand All @@ -8506,6 +8530,7 @@ def _call_function(
is_data_generator=is_data_generator,
),
ast=ast,
_emit_ast=_emit_ast,
)


Expand Down
Loading

0 comments on commit c7271ab

Please sign in to comment.