Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1491306 Phase 0 AST for DataFrame first, sample, and random_split #1813

Merged
3 changes: 3 additions & 0 deletions src/snowflake/snowpark/_internal/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from google.protobuf.json_format import ParseDict

import snowflake.snowpark._internal.proto.ast_pb2 as proto
from snowflake.connector.arrow_context import ArrowConverterContext
from snowflake.connector.cursor import ResultMetadataV2
from snowflake.connector.result_batch import ArrowResultBatch
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark.exceptions import SnowparkSQLException

Expand Down
16 changes: 9 additions & 7 deletions src/snowflake/snowpark/_internal/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
MultipleExpression,
UnresolvedAttribute,
)
from snowflake.snowpark._internal.analyzer.unary_expression import (
Alias,
)
from snowflake.snowpark._internal.analyzer.unary_expression import Alias
from snowflake.snowpark._internal.type_utils import (
VALID_PYTHON_TYPES_FOR_LITERAL_VALUE,
ColumnOrLiteral,
Expand Down Expand Up @@ -238,7 +236,7 @@ def set_src_position(src: proto.SrcPosition) -> None:
if pos.lineno is not None:
src.start_line = pos.lineno
if pos.end_lineno is not None:
src.end_lineno = pos.end_lineno
src.end_line = pos.end_lineno
if pos.col_offset is not None:
src.start_column = pos.col_offset
if pos.end_col_offset is not None:
Expand All @@ -248,13 +246,17 @@ def set_src_position(src: proto.SrcPosition) -> None:
assignment_re = re.compile(r"^\s*([a-zA-Z_]\w*)\s*=.*$", re.DOTALL)


def with_src_position(expr_ast: proto.Expr, assign: Optional[proto.Assign] = None) -> proto.Expr:
def with_src_position(
expr_ast: proto.Expr, assign: Optional[proto.Assign] = None
) -> proto.Expr:
"""
Sets the src_position on the supplied Expr AST node and returns it.
N.B. This function assumes it's always invoked from a public API, meaning that the caller's caller
is always the code of interest.
"""
frame = get_first_non_snowpark_stack_frame() # TODO: implement the assumption above to minimize overhead.
frame = (
get_first_non_snowpark_stack_frame()
) # TODO: implement the assumption above to minimize overhead.
source_line = frame.code_context[0].strip() if frame.code_context else ""

src = expr_ast.src
Expand All @@ -265,7 +267,7 @@ def with_src_position(expr_ast: proto.Expr, assign: Optional[proto.Assign] = Non
if pos.lineno is not None:
src.start_line = pos.lineno
if pos.end_lineno is not None:
src.end_lineno = pos.end_lineno
src.end_line = pos.end_lineno
if pos.col_offset is not None:
src.start_column = pos.col_offset
if pos.end_col_offset is not None:
Expand Down
1,177 changes: 593 additions & 584 deletions src/snowflake/snowpark/_internal/proto/ast_pb2.py

Large diffs are not rendered by default.

77 changes: 64 additions & 13 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,7 +1375,9 @@ def drop(
return self.select(list(keep_col_names))

@df_api_usage
def filter(self, expr: ColumnOrSqlExpr, _ast_stmt: proto.Assign = None) -> "DataFrame":
def filter(
self, expr: ColumnOrSqlExpr, _ast_stmt: proto.Assign = None
) -> "DataFrame":
"""Filters rows based on the specified conditional expression (similar to WHERE
in SQL).

Expand Down Expand Up @@ -1926,7 +1928,9 @@ def unpivot(
return self._with_plan(unpivot_plan)

@df_api_usage
def limit(self, n: int, offset: int = 0, _ast_stmt: proto.Assign = None) -> "DataFrame":
def limit(
self, n: int, offset: int = 0, _ast_stmt: proto.Assign = None
) -> "DataFrame":
"""Returns a new DataFrame that contains at most ``n`` rows from the current
DataFrame, skipping ``offset`` rows from the beginning (similar to LIMIT and OFFSET in SQL).

Expand Down Expand Up @@ -2900,7 +2904,10 @@ def _join_dataframes_internal(

@df_api_usage
def with_column(
self, col_name: str, col: Union[Column, TableFunctionCall]
self,
col_name: str,
col: Union[Column, TableFunctionCall],
ast_stmt: proto.Expr = None,
) -> "DataFrame":
"""
Returns a DataFrame with an additional column with the specified name
Expand Down Expand Up @@ -2942,11 +2949,14 @@ def with_column(
col_name: The name of the column to add or replace.
col: The :class:`Column` or :class:`table_function.TableFunctionCall` with single column output to add or replace.
"""
return self.with_columns([col_name], [col])
return self.with_columns([col_name], [col], ast_stmt=ast_stmt)

@df_api_usage
def with_columns(
self, col_names: List[str], values: List[Union[Column, TableFunctionCall]]
self,
col_names: List[str],
values: List[Union[Column, TableFunctionCall]],
ast_stmt: proto.Expr = None,
) -> "DataFrame":
"""Returns a DataFrame with additional columns with the specified names
``col_names``. The columns are computed by using the specified expressions
Expand Down Expand Up @@ -3039,7 +3049,7 @@ def with_columns(
]

# Put it all together
return self.select([*old_cols, *new_cols])
return self.select([*old_cols, *new_cols], _ast_stmt=ast_stmt)

@overload
def count(
Expand Down Expand Up @@ -3374,8 +3384,11 @@ def flatten(
_ast_stmt=stmt,
)

def _lateral(self, table_function: TableFunctionExpression, _ast_stmt: proto.Assign = None) -> "DataFrame":
def _lateral(
self, table_function: TableFunctionExpression, _ast_stmt: proto.Assign = None
) -> "DataFrame":
from snowflake.snowpark.mock._connection import MockServerConnection

if isinstance(self._session._conn, MockServerConnection):
return DataFrame(self._session, ast_stmt=_ast_stmt)

Expand All @@ -3387,7 +3400,9 @@ def _lateral(self, table_function: TableFunctionExpression, _ast_stmt: proto.Ass
]
common_col_names = [k for k, v in Counter(result_columns).items() if v > 1]
if len(common_col_names) == 0:
return DataFrame(self._session, Lateral(self._plan, table_function), ast_stmt=_ast_stmt)
return DataFrame(
self._session, Lateral(self._plan, table_function), ast_stmt=_ast_stmt
)
prefix = _generate_prefix("a")
child = self.select(
[
Expand All @@ -3402,7 +3417,9 @@ def _lateral(self, table_function: TableFunctionExpression, _ast_stmt: proto.Ass
],
ast_stmt=False, # Suppress AST generation for this SELECT.
)
return DataFrame(self._session, Lateral(child._plan, table_function), ast_stmt=_ast_stmt)
return DataFrame(
self._session, Lateral(child._plan, table_function), ast_stmt=_ast_stmt
)

def _show_string(self, n: int = 10, max_width: int = 50, **kwargs) -> str:
query = self._plan.queries[-1].sql.strip().lower()
Expand Down Expand Up @@ -3737,7 +3754,16 @@ def first(
results. ``n`` is ``None``, it returns the first :class:`Row` of
results, or ``None`` if it does not exist.
"""
# AST.
stmt = self._session._ast_batch.assign()
ast = stmt.expr.sp_dataframe_first
if statement_params is not None:
ast.statement_params.append((k, v) for k, v in statement_params)
self.set_ast_ref(ast.df)
set_src_position(ast.src)
ast.block = block
if n is None:
ast.num = 1
df = self.limit(1)
add_api_call(df, "DataFrame.first")
result = df._internal_collect_with_tag(
Expand All @@ -3749,10 +3775,12 @@ def first(
elif not isinstance(n, int):
raise ValueError(f"Invalid type of argument passed to first(): {type(n)}")
elif n < 0:
ast.num = n
return self._internal_collect_with_tag(
statement_params=statement_params, block=block
)
else:
ast.num = n
df = self.limit(n)
add_api_call(df, "DataFrame.first")
return df._internal_collect_with_tag(
Expand All @@ -3775,6 +3803,17 @@ def sample(
a :class:`DataFrame` containing the sample of rows.
"""
DataFrame._validate_sample_input(frac, n)

# AST.
stmt = self._session._ast_batch.assign()
ast = stmt.expr.sp_dataframe_sample
if frac:
ast.probability_fraction.value = frac
if n:
ast.num.value = n
self.set_ast_ref(ast.df)
set_src_position(ast.src)

sample_plan = Sample(self._plan, probability_fraction=frac, row_count=n)
if self._select_statement:
return self._with_plan(
Expand All @@ -3783,9 +3822,10 @@ def sample(
sample_plan, analyzer=self._session._analyzer
),
analyzer=self._session._analyzer,
)
),
ast_stmt=stmt,
)
return self._with_plan(sample_plan)
return self._with_plan(sample_plan, ast_stmt=stmt)

@staticmethod
def _validate_sample_input(frac: Optional[float] = None, n: Optional[int] = None):
Expand Down Expand Up @@ -4166,11 +4206,22 @@ def random_split(
2. When a weight or a normailized weight is less than ``1e-6``, the
corresponding split dataframe will be empty.
"""
# AST.
stmt = self._session._ast_batch.assign()
ast = stmt.expr.sp_dataframe_random_split
self.set_ast_ref(ast.df)
set_src_position(ast.src)
if not weights:
raise ValueError(
"weights can't be None or empty and must be positive numbers"
)
elif len(weights) == 1:
for w in weights:
ast.weights.append(w)
if seed:
ast.seed = seed
if statement_params:
ast.statement_params = statement_params
if len(weights) == 1:
return [self]
else:
for w in weights:
Expand All @@ -4179,7 +4230,7 @@ def random_split(

temp_column_name = random_name_for_temp_object(TempObjectType.COLUMN)
cached_df = self.with_column(
temp_column_name, abs_(random(seed)) % _ONE_MILLION
temp_column_name, abs_(random(seed)) % _ONE_MILLION, ast_stmt=stmt
).cache_result(statement_params=statement_params)
sum_weights = sum(weights)
normalized_cum_weights = [0] + [
Expand Down
7 changes: 5 additions & 2 deletions src/snowflake/snowpark/dataframe_na_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from typing import Dict, Optional, Union

import snowflake.snowpark
from snowflake.snowpark._internal.ast_utils import build_const_from_python_val, with_src_position
from snowflake.snowpark._internal.ast_utils import (
build_const_from_python_val,
with_src_position,
)
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.telemetry import add_api_call, adjust_api_subcalls
from snowflake.snowpark._internal.type_utils import (
Expand Down Expand Up @@ -180,7 +183,7 @@ def drop(
subset = [subset]
elif not isinstance(subset, (list, tuple)):
raise TypeError("subset should be a list or tuple of column names")

# if thresh is not provided,
# drop a row if it contains any nulls when how == 'any',
# otherwise drop a row only if all its values are null.
Expand Down
9 changes: 5 additions & 4 deletions src/snowflake/snowpark/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
import cloudpickle
import pkg_resources

import snowflake.snowpark._internal.proto.ast_pb2 as proto
from snowflake.connector import ProgrammingError, SnowflakeConnection
from snowflake.connector.options import installed_pandas, pandas
from snowflake.connector.pandas_tools import write_pandas
import snowflake.snowpark._internal.proto.ast_pb2 as proto
from snowflake.snowpark._internal.analyzer import analyzer_utils
from snowflake.snowpark._internal.analyzer.analyzer import Analyzer
from snowflake.snowpark._internal.analyzer.analyzer_utils import result_scan_statement
Expand Down Expand Up @@ -1930,8 +1930,8 @@ def table_function(
expr = with_src_position(stmt.expr.apply_expr)
if isinstance(func_name, TableFunctionCall):
expr.fn.udtf.name = func_name.name
func_arguments = func.arguments
func_named_arguments = func.named_arguments
func_arguments = func_name.arguments
func_named_arguments = func_name.named_arguments
# TODO: func.{_over, _partition_by, _order_by, _aliases, _api_call_source}
elif isinstance(func_name, str):
expr.fn.udtf.name = func_name
Expand Down Expand Up @@ -3291,7 +3291,8 @@ def flatten(
return None
else:
self._conn.log_not_supported_error(
external_feature_name="Session.flatten", raise_error=NotImplementedError
external_feature_name="Session.flatten",
raise_error=NotImplementedError,
)
if isinstance(input, str):
input = col(input)
Expand Down
19 changes: 19 additions & 0 deletions tests/ast/data/df_first.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
## TEST CASE

df = session.table("test_table")

df1 = df.first(-5)

df2 = df.first(2)

df3 = df.first()

## EXPECTED OUTPUT

res1 = session.table("test_table")

res2 = res1.first(-5, True)

res3 = res1.first(2, True)

res4 = res1.first(1, True)
17 changes: 17 additions & 0 deletions tests/ast/data/df_random_split.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
## TEST CASE
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test will not pass right now since the Local History code has not implemented any "random" functionality yet. A NotImplementedError is raised.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still true? One way or another, we should make sure that all (not disabled) checked in tests pass.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have run into this in a number of recent APIs as well. I've been able to get around this by checking session._conn._suppress_not_implemented_error, which is a new connection property I've added. In many places, I just return None if suppression is enabled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfc-gh-azwiegincew I'm still unable to get this test to pass. Now, the NotImplementedError is not raised but instead a KeyError is raised:

def handle_function_expression(    
        exp: FunctionExpression,
        input_data: Union[TableEmulator, ColumnEmulator],
        analyzer: "MockAnalyzer",
        expr_to_alias: Dict[str, str],
        current_row=None,
    ):
            . . .
try:
>           result = _MOCK_FUNCTION_IMPLEMENTATION_MAP[func_name](*to_pass_args)
E           KeyError: 'random'

../../src/snowflake/snowpark/mock/_plan.py:454: KeyError

What should I do to get around this? Should I try to catch the error and check if it's a KeyError with this message?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In random_split(), right after creating the AST, say something along the lines of if self._conn._suppress_not_implemented_error: return None


df = session.table("test_table")

weights = [0.1, 0.2, 0.3]

df2 = df.random_split(weights)

df3 = df.random_split(weights, seed=24)

## EXPECTED OUTPUT

res1 = session.table("test_table")

res2 = res1.random_split([0.1, 0.2, 0.3])

res4 = res1.random_split([0.1, 0.2, 0.3], 24)
15 changes: 15 additions & 0 deletions tests/ast/data/df_sample.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
## TEST CASE

df = session.table("test_table")

df = df.sample(n=3)

df = df.sample(frac=0.5)

## EXPECTED OUTPUT

res1 = session.table("test_table")

res2 = res1.sample(None, 3)

res3 = res2.sample(0.5, None)
Loading