diff --git a/src/snowflake/snowpark/_internal/ast.py b/src/snowflake/snowpark/_internal/ast.py index 437fceaf8a9..ce40e0174ed 100644 --- a/src/snowflake/snowpark/_internal/ast.py +++ b/src/snowflake/snowpark/_internal/ast.py @@ -27,18 +27,17 @@ def expr_to_dataframe_expr(expr): return dfe -# Map from python type to its corresponding IR field. IR fields below all have the 'v' attribute. +# Map from python type to its corresponding IR entity. The entities below all have the 'v' attribute. TYPE_TO_IR_TYPE_NAME = { bytes: "binary_val", bool: "bool_val", + datetime64: "date_val", + Decimal: "big_decimal_val", + float64: "float_64_val", int32: "int_32_val", int64: "int_64_val", - float64: "float_64_val", - Decimal: "big_decimal_val", str: "string_val", - slice: "slice_val", Timestamp: "timestamp_val", - datetime64: "date_val", } @@ -49,7 +48,7 @@ def ast_expr_from_python_val(expr, val): Parameters ---------- - expr : IR expression object + expr : IR entity protobuf builder val : Python value that needs to be converted to IR expression. """ if val is None: @@ -62,12 +61,16 @@ def ast_expr_from_python_val(expr, val): if isinstance(val, Callable): expr.fn_val.params = signature(val).parameters expr.fn_val.body = val + if isinstance(val, slice): + expr.slice_val.start = val.start + expr.slice_val.stop = val.stop + expr.slice_val.step = val.step + elif not isinstance(val, Series) and is_list_like(val): + # Checking that val is not a Series since Series objects are considered list-like. + expr.list_val.vs = val elif isinstance(val, Series): - # Checking Series before the list-like type since Series are considered to be list-like. expr.series_val.ref = val - elif is_list_like(val): - expr.list_val.vs = val - if isinstance(val, DataFrame): + elif isinstance(val, DataFrame): expr.dataframe_val.ref = val else: ir_type_name = TYPE_TO_IR_TYPE_NAME[val_type] diff --git a/src/snowflake/snowpark/column.py b/src/snowflake/snowpark/column.py index 864da395229..2d024aec29a 100644 --- a/src/snowflake/snowpark/column.py +++ b/src/snowflake/snowpark/column.py @@ -666,15 +666,19 @@ def __repr__(self): def as_(self, alias: str) -> "Column": """Returns a new renamed Column. Alias of :func:`name`.""" - return self.name(alias) + ast = proto.SpColumnExpr() + ast.sp_column_alias.variant_is_as = True + return self.name(alias, ast) def alias(self, alias: str) -> "Column": """Returns a new renamed Column. Alias of :func:`name`.""" - return self.name(alias) + ast = proto.SpColumnExpr() + ast.sp_column_alias.variant_is_as = False + return self.name(alias, ast) - def name(self, alias: str) -> "Column": + def name(self, alias: str, ast=None) -> "Column": """Returns a new renamed Column.""" - return Column(Alias(self._expression, quote_name(alias))) + return Column(Alias(self._expression, quote_name(alias)), ast=ast) def over(self, window: Optional[WindowSpec] = None) -> "Column": """ diff --git a/src/snowflake/snowpark/functions.py b/src/snowflake/snowpark/functions.py index 2bf0c679cc2..3cf41078ea4 100644 --- a/src/snowflake/snowpark/functions.py +++ b/src/snowflake/snowpark/functions.py @@ -250,11 +250,11 @@ def col(df_alias: str, col_name: str) -> Column: ... # pragma: no cover -def col(name1: str, name2: Optional[str] = None) -> Column: +def col(name1: str, name2: Optional[str] = None, ast=None) -> Column: if name2 is None: - return Column(name1) + return Column(name1, ast=ast) else: - return Column(name1, name2) + return Column(name1, name2, ast=ast) @overload diff --git a/src/snowflake/snowpark/modin/pandas/dataframe.py b/src/snowflake/snowpark/modin/pandas/dataframe.py index 5774b25f723..85ef6654165 100644 --- a/src/snowflake/snowpark/modin/pandas/dataframe.py +++ b/src/snowflake/snowpark/modin/pandas/dataframe.py @@ -150,6 +150,7 @@ def __init__( dtype=None, copy=None, query_compiler=None, + ast_stmt=None, ) -> None: # TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions # Siblings are other dataframes that share the same query compiler. We @@ -158,6 +159,11 @@ def __init__( self._siblings = [] + if not ast_stmt: + ast_stmt = pd.session._ast_batch.assign() + self._ast_id = ast_stmt.var_id.bitfield1 + self._ast_stmt = ast_stmt + # Engine.subscribe(_update_engine) if isinstance(data, (DataFrame, Series)): self._query_compiler = data._query_compiler.copy() @@ -255,11 +261,6 @@ def __init__( else: self._query_compiler = query_compiler - def _get_ast_id(self): - return ( - self._query_compiler._modin_frame.ordered_dataframe._dataframe_ref.snowpark_dataframe._ast_id - ) - def __repr__(self): """ Return a string representation for a particular ``DataFrame``. @@ -2869,7 +2870,10 @@ def __setattr__(self, key, value): # - `_siblings`, which Modin initializes before it appears in __dict__ # - `_cache`, which pandas.cache_readonly uses to cache properties # before it appears in __dict__. - if key in ("_query_compiler", "_siblings", "_cache") or key in self.__dict__: + if ( + key in ("_query_compiler", "_siblings", "_cache", "_ast_id", "_ast_stmt") + or key in self.__dict__ + ): pass elif key in self and key not in dir(self): self.__setitem__(key, value) diff --git a/src/snowflake/snowpark/modin/pandas/indexing.py b/src/snowflake/snowpark/modin/pandas/indexing.py index 50c292cdd2b..5aaae34c876 100644 --- a/src/snowflake/snowpark/modin/pandas/indexing.py +++ b/src/snowflake/snowpark/modin/pandas/indexing.py @@ -1053,7 +1053,7 @@ def __getitem__( # IR changes! Add iloc get nodes to AST. stmt = pd.session._ast_batch.assign() ast = stmt.expr - ast.pd_dataframe_i_loc.df.var_id.bitfield1 = self.df._get_ast_id() + ast.pd_dataframe_i_loc.df.var_id.bitfield1 = self.df._ast_id # Map python built-ins (functions, scalars, lists, slices, etc.) to AST expr and emit Ref nodes for dataframes, # series, and indexes. ast_expr_from_python_val(ast.pd_dataframe_i_loc.rows, row_loc) @@ -1067,7 +1067,7 @@ def __getitem__( # Convert all scalar, list-like, and indexer row_loc to a Series object to get a query compiler object. if is_scalar(row_loc): - row_loc = pd.Series([row_loc]) + row_loc = pd.Series([row_loc], ast_stmt=stmt) elif is_list_like(row_loc): if hasattr(row_loc, "dtype"): dtype = row_loc.dtype @@ -1076,7 +1076,7 @@ def __getitem__( dtype = float else: dtype = None - row_loc = pd.Series(row_loc, dtype=dtype) + row_loc = pd.Series(row_loc, dtype=dtype, ast_stmt=stmt) # Check whether the row and column input is of numeric dtype. self._validate_numeric_get_key_values(row_loc, original_row_loc) diff --git a/src/snowflake/snowpark/modin/pandas/series.py b/src/snowflake/snowpark/modin/pandas/series.py index 8b022d5a1f9..66698d14a9b 100644 --- a/src/snowflake/snowpark/modin/pandas/series.py +++ b/src/snowflake/snowpark/modin/pandas/series.py @@ -121,6 +121,7 @@ def __init__( copy=False, fastpath=False, query_compiler=None, + ast_stmt=None, ) -> None: # TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions # Siblings are other dataframes that share the same query compiler. We @@ -130,6 +131,11 @@ def __init__( # modified: # Engine.subscribe(_update_engine) + if not ast_stmt: + ast_stmt = pd.session._ast_batch.assign() + self._ast_id = ast_stmt.var_id.bitfield1 + self._ast_stmt = ast_stmt + if isinstance(data, type(self)): query_compiler = data._query_compiler.copy() if index is not None: @@ -196,7 +202,9 @@ def _set_name(self, name): else: columns = [name] self._update_inplace( - new_query_compiler=self._query_compiler.set_columns(columns) + new_query_compiler=self._query_compiler.set_columns( + columns, ast_stmt=self._ast_stmt + ) ) name = property(_get_name, _set_name) diff --git a/src/snowflake/snowpark/modin/plugin/_internal/frame.py b/src/snowflake/snowpark/modin/plugin/_internal/frame.py index 64f76afb4e5..c40917ba2d9 100644 --- a/src/snowflake/snowpark/modin/plugin/_internal/frame.py +++ b/src/snowflake/snowpark/modin/plugin/_internal/frame.py @@ -11,6 +11,7 @@ from pandas._typing import IndexLabel from pandas.core.dtypes.common import is_object_dtype +import snowflake.snowpark._internal.proto.ast_pb2 as proto from snowflake.snowpark._internal.analyzer.analyzer_utils import ( quote_name_without_upper_casing, ) @@ -802,7 +803,9 @@ def project_columns( ) def rename_snowflake_identifiers( - self, old_to_new_identifiers: dict[str, str] + self, + old_to_new_identifiers: dict[str, str], + ast_stmt: Optional[proto.Expr] = None, ) -> "InternalFrame": """ Rename columns for underlying ordered dataframe. @@ -847,7 +850,8 @@ def rename_snowflake_identifiers( # retain the original column select_list.append(old_id) else: - select_list.append(col(old_id).as_(new_id)) + ast = ast_stmt.expr if ast_stmt is not None else None + select_list.append(col(old_id, ast=ast).as_(new_id)) # if the old column is part of the ordering or row position columns, retains the column # as part of the projected columns. if old_id in ordering_and_row_position_columns: diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index a02bf4e108c..743d7186dc9 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -59,6 +59,7 @@ from pandas.io.formats.format import format_percentiles from pandas.io.formats.printing import PrettyDict +import snowflake.snowpark._internal.proto.ast_pb2 as proto import snowflake.snowpark.modin.pandas as pd from snowflake.snowpark._internal.analyzer.analyzer_utils import ( quote_name_without_upper_casing, @@ -1250,7 +1251,11 @@ def columns(self) -> "pd.Index": # TODO SNOW-837664: add more tests for df.columns return self._modin_frame.data_columns_index - def set_columns(self, new_pandas_labels: Axes) -> "SnowflakeQueryCompiler": + def set_columns( + self, + new_pandas_labels: Axes, + ast_stmt: Optional[proto.Expr] = None, + ) -> "SnowflakeQueryCompiler": """ Set pandas column labels with the new column labels @@ -1288,7 +1293,8 @@ def set_columns(self, new_pandas_labels: Axes) -> "SnowflakeQueryCompiler": ) renamed_frame = self._modin_frame.rename_snowflake_identifiers( - renamed_quoted_identifier_mapping + renamed_quoted_identifier_mapping, + ast_stmt=ast_stmt, ) new_internal_frame = InternalFrame.create( diff --git a/tests/thin-client/modin-steel-thread.py b/tests/thin-client/modin-steel-thread.py index ad758b6c8a9..44687d64c5c 100644 --- a/tests/thin-client/modin-steel-thread.py +++ b/tests/thin-client/modin-steel-thread.py @@ -12,5 +12,5 @@ [[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14], [15, 16, 17, 18, 19, 20, 21]], columns=["A", "B", "C", "D", "E", "F", "G"], ) -df = df.iloc[3, 4] -df.show() +result = df.iloc[2, 2] +print(result)