Skip to content

Commit

Permalink
SNOW-1672579 Encode DataFrame.to_snowpark_pandas (#2711)
Browse files Browse the repository at this point in the history
1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   Fixes SNOW-1672579

2. Fill out the following pre-review checklist:

- [x] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.
- [x] I acknowledge that I have ensured my changes to be thread-safe.
Follow the link for more information: [Thread-safe Developer
Guidelines](https://github.com/snowflakedb/snowpark-python/blob/main/CONTRIBUTING.md#thread-safe-development)

3. Please describe how your code solves the related issue.

- Added AST encoding for `DataFrame.to_snowpark_pandas`. 
- I had to add some Local testing functionality to help my expectation
test to pass. Also needed to add some janky logic to create a temp
read-only table in lieu of the table that is created by
`to_snowpark_pandas`.
- I updated the script generating the relevant proto files to create
them in the correct directory.
  • Loading branch information
sfc-gh-vbudati authored Dec 12, 2024
1 parent 3a66c84 commit 782de21
Show file tree
Hide file tree
Showing 8 changed files with 307 additions and 61 deletions.
106 changes: 58 additions & 48 deletions src/snowflake/snowpark/_internal/proto/ast.proto
Original file line number Diff line number Diff line change
Expand Up @@ -793,21 +793,22 @@ message Expr {
SpTableMerge sp_table_merge = 183;
SpTableSample sp_table_sample = 184;
SpTableUpdate sp_table_update = 185;
SpWriteCopyIntoLocation sp_write_copy_into_location = 186;
SpWriteCsv sp_write_csv = 187;
SpWriteJson sp_write_json = 188;
SpWritePandas sp_write_pandas = 189;
SpWriteParquet sp_write_parquet = 190;
SpWriteTable sp_write_table = 191;
StoredProcedure stored_procedure = 192;
StringVal string_val = 193;
Sub sub = 194;
TimeVal time_val = 195;
TimestampVal timestamp_val = 196;
TupleVal tuple_val = 197;
Udaf udaf = 198;
Udf udf = 199;
Udtf udtf = 200;
SpToSnowparkPandas sp_to_snowpark_pandas = 186;
SpWriteCopyIntoLocation sp_write_copy_into_location = 187;
SpWriteCsv sp_write_csv = 188;
SpWriteJson sp_write_json = 189;
SpWritePandas sp_write_pandas = 190;
SpWriteParquet sp_write_parquet = 191;
SpWriteTable sp_write_table = 192;
StoredProcedure stored_procedure = 193;
StringVal string_val = 194;
Sub sub = 195;
TimeVal time_val = 196;
TimestampVal timestamp_val = 197;
TupleVal tuple_val = 198;
Udaf udaf = 199;
Udf udf = 200;
Udtf udtf = 201;
}
}

Expand Down Expand Up @@ -1075,26 +1076,27 @@ message HasSrcPosition {
SpTableMerge sp_table_merge = 192;
SpTableSample sp_table_sample = 193;
SpTableUpdate sp_table_update = 194;
SpWindowSpecEmpty sp_window_spec_empty = 195;
SpWindowSpecOrderBy sp_window_spec_order_by = 196;
SpWindowSpecPartitionBy sp_window_spec_partition_by = 197;
SpWindowSpecRangeBetween sp_window_spec_range_between = 198;
SpWindowSpecRowsBetween sp_window_spec_rows_between = 199;
SpWriteCopyIntoLocation sp_write_copy_into_location = 200;
SpWriteCsv sp_write_csv = 201;
SpWriteJson sp_write_json = 202;
SpWritePandas sp_write_pandas = 203;
SpWriteParquet sp_write_parquet = 204;
SpWriteTable sp_write_table = 205;
StoredProcedure stored_procedure = 206;
StringVal string_val = 207;
Sub sub = 208;
TimeVal time_val = 209;
TimestampVal timestamp_val = 210;
TupleVal tuple_val = 211;
Udaf udaf = 212;
Udf udf = 213;
Udtf udtf = 214;
SpToSnowparkPandas sp_to_snowpark_pandas = 195;
SpWindowSpecEmpty sp_window_spec_empty = 196;
SpWindowSpecOrderBy sp_window_spec_order_by = 197;
SpWindowSpecPartitionBy sp_window_spec_partition_by = 198;
SpWindowSpecRangeBetween sp_window_spec_range_between = 199;
SpWindowSpecRowsBetween sp_window_spec_rows_between = 200;
SpWriteCopyIntoLocation sp_write_copy_into_location = 201;
SpWriteCsv sp_write_csv = 202;
SpWriteJson sp_write_json = 203;
SpWritePandas sp_write_pandas = 204;
SpWriteParquet sp_write_parquet = 205;
SpWriteTable sp_write_table = 206;
StoredProcedure stored_procedure = 207;
StringVal string_val = 208;
Sub sub = 209;
TimeVal time_val = 210;
TimestampVal timestamp_val = 211;
TupleVal tuple_val = 212;
Udaf udaf = 213;
Udf udf = 214;
Udtf udtf = 215;
}
}

Expand Down Expand Up @@ -1539,7 +1541,7 @@ message SpDataframeAlias {
SrcPosition src = 3;
}

// sp-df-expr.ir:464
// sp-df-expr.ir:470
message SpDataframeAnalyticsComputeLag {
repeated Expr cols = 1;
SpDataframeExpr df = 2;
Expand All @@ -1550,7 +1552,7 @@ message SpDataframeAnalyticsComputeLag {
SrcPosition src = 7;
}

// sp-df-expr.ir:473
// sp-df-expr.ir:479
message SpDataframeAnalyticsComputeLead {
repeated Expr cols = 1;
SpDataframeExpr df = 2;
Expand All @@ -1561,7 +1563,7 @@ message SpDataframeAnalyticsComputeLead {
SrcPosition src = 7;
}

// sp-df-expr.ir:455
// sp-df-expr.ir:461
message SpDataframeAnalyticsCumulativeAgg {
repeated Tuple_String_List_String aggs = 1;
SpDataframeExpr df = 2;
Expand All @@ -1572,7 +1574,7 @@ message SpDataframeAnalyticsCumulativeAgg {
SrcPosition src = 7;
}

// sp-df-expr.ir:446
// sp-df-expr.ir:452
message SpDataframeAnalyticsMovingAgg {
repeated Tuple_String_List_String aggs = 1;
SpDataframeExpr df = 2;
Expand All @@ -1583,7 +1585,7 @@ message SpDataframeAnalyticsMovingAgg {
repeated int64 window_sizes = 7;
}

// sp-df-expr.ir:482
// sp-df-expr.ir:488
message SpDataframeAnalyticsTimeSeriesAgg {
repeated Tuple_String_List_String aggs = 1;
SpDataframeExpr df = 2;
Expand Down Expand Up @@ -2330,21 +2332,21 @@ message SpMatchedClause {
}
}

// sp-df-expr.ir:499
// sp-df-expr.ir:505
message SpMergeDeleteWhenMatchedClause {
Expr condition = 1;
SrcPosition src = 2;
}

// sp-df-expr.ir:503
// sp-df-expr.ir:509
message SpMergeInsertWhenNotMatchedClause {
Expr condition = 1;
List_Expr insert_keys = 2;
List_Expr insert_values = 3;
SrcPosition src = 4;
}

// sp-df-expr.ir:494
// sp-df-expr.ir:500
message SpMergeUpdateWhenMatchedClause {
Expr condition = 1;
SrcPosition src = 2;
Expand Down Expand Up @@ -2490,7 +2492,7 @@ message SpTable {
SpTableVariant variant = 4;
}

// sp-df-expr.ir:509
// sp-df-expr.ir:515
message SpTableDelete {
bool block = 1;
Expr condition = 2;
Expand All @@ -2500,7 +2502,7 @@ message SpTableDelete {
repeated Tuple_String_String statement_params = 6;
}

// sp-df-expr.ir:517
// sp-df-expr.ir:523
message SpTableDropTable {
VarId id = 1;
SrcPosition src = 2;
Expand All @@ -2521,7 +2523,7 @@ message SpTableFnCallOver {
SrcPosition src = 4;
}

// sp-df-expr.ir:521
// sp-df-expr.ir:527
message SpTableMerge {
bool block = 1;
repeated SpMatchedClause clauses = 2;
Expand All @@ -2532,7 +2534,7 @@ message SpTableMerge {
repeated Tuple_String_String statement_params = 7;
}

// sp-df-expr.ir:530
// sp-df-expr.ir:536
message SpTableSample {
SpDataframeExpr df = 1;
google.protobuf.Int64Value num = 2;
Expand All @@ -2542,7 +2544,7 @@ message SpTableSample {
SrcPosition src = 6;
}

// sp-df-expr.ir:538
// sp-df-expr.ir:544
message SpTableUpdate {
repeated Tuple_String_Expr assignments = 1;
bool block = 2;
Expand All @@ -2553,6 +2555,14 @@ message SpTableUpdate {
repeated Tuple_String_String statement_params = 7;
}

// sp-df-expr.ir:438
message SpToSnowparkPandas {
List_String columns = 1;
SpDataframeExpr df = 2;
List_String index_col = 3;
SrcPosition src = 4;
}

message SpType {
oneof variant {
SpColExprType sp_col_expr_type = 1;
Expand Down
1 change: 0 additions & 1 deletion src/snowflake/snowpark/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,7 +1249,6 @@ def name(
expr = self._expression # Snowpark expression
if isinstance(expr, Alias):
expr = expr.child

ast_expr = None # Snowpark IR expression
if _emit_ast and self._ast is not None:
ast_expr = proto.Expr()
Expand Down
37 changes: 31 additions & 6 deletions src/snowflake/snowpark/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,22 +1241,47 @@ def to_snowpark_pandas(
# If snowflake.snowpark.modin.plugin was successfully imported, then modin.pandas is available
import modin.pandas as pd # isort: skip
# fmt: on

# AST.
stmt = None
if _emit_ast:
raise NotImplementedError(
"TODO SNOW-1672579: Support Snowpark pandas API handover."
)
stmt = self._session._ast_batch.assign()
ast = with_src_position(stmt.expr.sp_to_snowpark_pandas, stmt)
self._set_ast_ref(ast.df)
debug_check_missing_ast(self._ast_id, self)
if index_col is not None:
ast.index_col.list.extend(
index_col if isinstance(index_col, list) else [index_col]
)
if columns is not None:
ast.columns.list.extend(
columns if isinstance(columns, list) else [columns]
)

# create a temporary table out of the current snowpark dataframe
temporary_table_name = random_name_for_temp_object(
TempObjectType.TABLE
) # pragma: no cover
ast_id = self._ast_id
self._ast_id = None # set the AST ID to None to prevent AST emission.
self.write.save_as_table(
temporary_table_name, mode="errorifexists", table_type="temporary"
temporary_table_name,
mode="errorifexists",
table_type="temporary",
_emit_ast=False,
) # pragma: no cover
self._ast_id = ast_id # reset the AST ID.

snowpandas_df = pd.read_snowflake(
name_or_query=temporary_table_name, index_col=index_col, columns=columns
) # pragma: no cover

if _emit_ast:
# Set the Snowpark DataFrame AST ID to the AST ID of this pandas query.
snowpandas_df._query_compiler._modin_frame.ordered_dataframe._dataframe_ref.snowpark_dataframe._ast_id = (
stmt.var_id.bitfield1
)

return snowpandas_df

def __getitem__(self, item: Union[str, Column, List, Tuple, int]):
Expand Down Expand Up @@ -3904,7 +3929,7 @@ def count(
return result[0][0] if block else result

@property
def write(self, _emit_ast: bool = True) -> DataFrameWriter:
def write(self) -> DataFrameWriter:
"""Returns a new :class:`DataFrameWriter` object that you can use to write the data in the :class:`DataFrame` to
a Snowflake database or a stage location
Expand All @@ -3925,7 +3950,7 @@ def write(self, _emit_ast: bool = True) -> DataFrameWriter:
"""

# AST.
if _emit_ast and self._ast_id is not None:
if self._ast_id is not None:
stmt = self._session._ast_batch.assign()
expr = with_src_position(stmt.expr.sp_dataframe_write, stmt)
self._set_ast_ref(expr.df)
Expand Down
7 changes: 5 additions & 2 deletions src/snowflake/snowpark/mock/_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,10 +828,10 @@ def do_resolve_with_resolved_children(
)

if isinstance(logical_plan, Project):
return logical_plan
return MockExecutionPlan(logical_plan, self.session)

if isinstance(logical_plan, Filter):
return logical_plan
return MockExecutionPlan(logical_plan, self.session)

# Add a sample stop to the plan being built
if isinstance(logical_plan, Sample):
Expand Down Expand Up @@ -895,6 +895,9 @@ def do_resolve_with_resolved_children(
if isinstance(logical_plan, SnowflakeCreateTable):
return MockExecutionPlan(logical_plan, self.session)

if isinstance(logical_plan, SnowflakePlan):
return MockExecutionPlan(logical_plan, self.session)

if isinstance(logical_plan, Limit):
on_top_of_order_by = isinstance(
logical_plan.child, SnowflakePlan
Expand Down
16 changes: 16 additions & 0 deletions src/snowflake/snowpark/mock/_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@
CreateViewCommand,
Pivot,
Sample,
Project,
)
from snowflake.snowpark._internal.type_utils import infer_type
from snowflake.snowpark._internal.utils import (
Expand Down Expand Up @@ -1289,6 +1290,8 @@ def aggregate_by_groups(cur_group: TableEmulator):
dtype=object,
)
return result_df
if isinstance(source_plan, Project):
return TableEmulator(ColumnEmulator(col) for col in source_plan.project_list)
if isinstance(source_plan, Join):
L_expr_to_alias = {}
R_expr_to_alias = {}
Expand Down Expand Up @@ -1450,6 +1453,19 @@ def outer_join(base_df):

obj_name_tuple = parse_table_name(entity_name)
obj_name = obj_name_tuple[-1]

# Logic to create a read-only temp table for AST testing purposes.
# Functions like to_snowpark_pandas create a clone of an existing table as a read-only table that is referenced
# during testing.
if "SNOWPARK_TEMP_TABLE" in obj_name and "READONLY" in obj_name:
# Create the read-only temp table.
entity_registry.write_table(
obj_name,
TableEmulator({"A": [1], "B": [1], "C": [1]}),
SaveMode.IGNORE,
)
return entity_registry.read_table_if_exists(obj_name)

obj_schema = (
obj_name_tuple[-2]
if len(obj_name_tuple) > 1
Expand Down
9 changes: 5 additions & 4 deletions src/snowflake/snowpark/modin/plugin/_internal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,9 @@ def _create_read_only_table(
)
# TODO (SNOW-1669224): pushing read only table creation down to snowpark for general usage
session.sql(
f"CREATE OR REPLACE {get_temp_type_for_object(use_scoped_temp_objects=use_scoped_temp_table, is_generated=True)} READ ONLY TABLE {readonly_table_name} CLONE {table_name}"
).collect(statement_params=statement_params)
f"CREATE OR REPLACE {get_temp_type_for_object(use_scoped_temp_objects=use_scoped_temp_table, is_generated=True)} READ ONLY TABLE {readonly_table_name} CLONE {table_name}",
_emit_ast=False,
).collect(statement_params=statement_params, _emit_ast=False)

return readonly_table_name

Expand Down Expand Up @@ -389,7 +390,7 @@ def create_ordered_dataframe_with_readonly_temp_table(
error_code=SnowparkPandasErrorCode.GENERAL_SQL_EXCEPTION.value,
) from ex
initial_ordered_dataframe = OrderedDataFrame(
DataFrameReference(session.table(readonly_table_name))
DataFrameReference(session.table(readonly_table_name, _emit_ast=False))
)
# generate a snowflake quoted identifier for row position column that can be used for aliasing
snowflake_quoted_identifiers = (
Expand All @@ -415,7 +416,7 @@ def create_ordered_dataframe_with_readonly_temp_table(
# with the created snowpark dataframe. In order to get the metadata column access in the created
# dataframe, we create dataframe through sql which access the corresponding metadata column.
dataframe_sql = f"SELECT {columns_to_select} FROM {readonly_table_name}"
snowpark_df = session.sql(dataframe_sql)
snowpark_df = session.sql(dataframe_sql, _emit_ast=False)

result_columns_quoted_identifiers = [
row_position_snowflake_quoted_identifier
Expand Down
Loading

0 comments on commit 782de21

Please sign in to comment.