Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1491306 Phase 0 AST for DataFrame first, sample, and random_split #1813

Merged
1,179 changes: 595 additions & 584 deletions src/snowflake/snowpark/_internal/proto/ast_pb2.py

Large diffs are not rendered by default.

57 changes: 49 additions & 8 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2898,7 +2898,10 @@ def _join_dataframes_internal(

@df_api_usage
def with_column(
self, col_name: str, col: Union[Column, TableFunctionCall]
self,
col_name: str,
col: Union[Column, TableFunctionCall],
ast_stmt: proto.Expr = None,
) -> "DataFrame":
"""
Returns a DataFrame with an additional column with the specified name
Expand Down Expand Up @@ -2940,11 +2943,14 @@ def with_column(
col_name: The name of the column to add or replace.
col: The :class:`Column` or :class:`table_function.TableFunctionCall` with single column output to add or replace.
"""
return self.with_columns([col_name], [col])
return self.with_columns([col_name], [col], ast_stmt=ast_stmt)

@df_api_usage
def with_columns(
self, col_names: List[str], values: List[Union[Column, TableFunctionCall]]
self,
col_names: List[str],
values: List[Union[Column, TableFunctionCall]],
ast_stmt: proto.Expr = None,
) -> "DataFrame":
"""Returns a DataFrame with additional columns with the specified names
``col_names``. The columns are computed by using the specified expressions
Expand Down Expand Up @@ -3037,7 +3043,7 @@ def with_columns(
]

# Put it all together
return self.select([*old_cols, *new_cols])
return self.select([*old_cols, *new_cols], _ast_stmt=ast_stmt)

@overload
def count(
Expand Down Expand Up @@ -3712,7 +3718,16 @@ def first(
results. ``n`` is ``None``, it returns the first :class:`Row` of
results, or ``None`` if it does not exist.
"""
# AST.
stmt = self._session._ast_batch.assign()
ast = stmt.expr.sp_dataframe_first
if statement_params is not None:
ast.statement_params.append((k, v) for k, v in statement_params)
self.set_ast_ref(ast.df)
set_src_position(ast.src)
ast.block = block
if n is None:
ast.num = 1
df = self.limit(1)
add_api_call(df, "DataFrame.first")
result = df._internal_collect_with_tag(
Expand All @@ -3724,10 +3739,12 @@ def first(
elif not isinstance(n, int):
raise ValueError(f"Invalid type of argument passed to first(): {type(n)}")
elif n < 0:
ast.num = n
return self._internal_collect_with_tag(
statement_params=statement_params, block=block
)
else:
ast.num = n
df = self.limit(n)
add_api_call(df, "DataFrame.first")
return df._internal_collect_with_tag(
Expand All @@ -3750,6 +3767,18 @@ def sample(
a :class:`DataFrame` containing the sample of rows.
"""
DataFrame._validate_sample_input(frac, n)

# AST.
stmt = self._session._ast_batch.assign()
if frac is not None:
ast = stmt.expr.sp_dataframe_sample__double
ast.probability_fraction = frac
else:
ast = stmt.expr.sp_dataframe_sample__long
ast.num = n
self.set_ast_ref(ast.df)
set_src_position(ast.src)

sample_plan = Sample(self._plan, probability_fraction=frac, row_count=n)
if self._select_statement:
return self._with_plan(
Expand All @@ -3758,9 +3787,10 @@ def sample(
sample_plan, analyzer=self._session._analyzer
),
analyzer=self._session._analyzer,
)
),
ast_stmt=stmt,
)
return self._with_plan(sample_plan)
return self._with_plan(sample_plan, ast_stmt=stmt)

@staticmethod
def _validate_sample_input(frac: Optional[float] = None, n: Optional[int] = None):
Expand Down Expand Up @@ -4141,11 +4171,22 @@ def random_split(
2. When a weight or a normailized weight is less than ``1e-6``, the
corresponding split dataframe will be empty.
"""
# AST.
stmt = self._session._ast_batch.assign()
ast = stmt.expr.sp_dataframe_random_split
self.set_ast_ref(ast.df)
set_src_position(ast.src)
if not weights:
raise ValueError(
"weights can't be None or empty and must be positive numbers"
)
elif len(weights) == 1:
for w in weights:
ast.weights.append(w)
if seed:
ast.seed = seed
if statement_params:
ast.statement_params = statement_params
if len(weights) == 1:
return [self]
else:
for w in weights:
Expand All @@ -4154,7 +4195,7 @@ def random_split(

temp_column_name = random_name_for_temp_object(TempObjectType.COLUMN)
cached_df = self.with_column(
temp_column_name, abs_(random(seed)) % _ONE_MILLION
temp_column_name, abs_(random(seed)) % _ONE_MILLION, ast_stmt=stmt
).cache_result(statement_params=statement_params)
sum_weights = sum(weights)
normalized_cum_weights = [0] + [
Expand Down
19 changes: 19 additions & 0 deletions tests/ast/data/df_first.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
## TEST CASE

df = session.table("test_table")

df1 = df.first(-5)

df2 = df.first(2)

df3 = df.first()

## EXPECTED OUTPUT

res1 = session.table("test_table")

res2 = res1.first(-5, True)

res3 = res1.first(2, True)

res4 = res1.first(1, True)
17 changes: 17 additions & 0 deletions tests/ast/data/df_random_split.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
## TEST CASE
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test will not pass right now since the Local History code has not implemented any "random" functionality yet. A NotImplementedError is raised.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still true? One way or another, we should make sure that all (not disabled) checked in tests pass.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have run into this in a number of recent APIs as well. I've been able to get around this by checking session._conn._suppress_not_implemented_error, which is a new connection property I've added. In many places, I just return None if suppression is enabled.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sfc-gh-azwiegincew I'm still unable to get this test to pass. Now, the NotImplementedError is not raised but instead a KeyError is raised:

def handle_function_expression(    
        exp: FunctionExpression,
        input_data: Union[TableEmulator, ColumnEmulator],
        analyzer: "MockAnalyzer",
        expr_to_alias: Dict[str, str],
        current_row=None,
    ):
            . . .
try:
>           result = _MOCK_FUNCTION_IMPLEMENTATION_MAP[func_name](*to_pass_args)
E           KeyError: 'random'

../../src/snowflake/snowpark/mock/_plan.py:454: KeyError

What should I do to get around this? Should I try to catch the error and check if it's a KeyError with this message?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In random_split(), right after creating the AST, say something along the lines of if self._conn._suppress_not_implemented_error: return None


df = session.table("test_table")

weights = [0.1, 0.2, 0.3]

df2 = df.random_split(weights)

df3 = df.random_split(weights, seed=24)

## EXPECTED OUTPUT

res1 = session.table("test_table")

res2 = res1.random_split([0.1, 0.2, 0.3])

res4 = res1.random_split([0.1, 0.2, 0.3], 24)
15 changes: 15 additions & 0 deletions tests/ast/data/df_sample.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
## TEST CASE

df = session.table("test_table")

df = df.sample(n=3)

df = df.sample(frac=0.5)

## EXPECTED OUTPUT

res1 = session.table("test_table")

res2 = res1.sample(3)
sfc-gh-vbudati marked this conversation as resolved.
Show resolved Hide resolved

res3 = res2.sample(0.5)
Loading