Skip to content

Commit

Permalink
finally got steel-thread working
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-vbudati committed Jun 11, 2024
1 parent feffb46 commit 7445b0d
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 33 deletions.
23 changes: 13 additions & 10 deletions src/snowflake/snowpark/_internal/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,17 @@ def expr_to_dataframe_expr(expr):
return dfe


# Map from python type to its corresponding IR field. IR fields below all have the 'v' attribute.
# Map from python type to its corresponding IR entity. The entities below all have the 'v' attribute.
TYPE_TO_IR_TYPE_NAME = {
bytes: "binary_val",
bool: "bool_val",
datetime64: "date_val",
Decimal: "big_decimal_val",
float64: "float_64_val",
int32: "int_32_val",
int64: "int_64_val",
float64: "float_64_val",
Decimal: "big_decimal_val",
str: "string_val",
slice: "slice_val",
Timestamp: "timestamp_val",
datetime64: "date_val",
}


Expand All @@ -49,7 +48,7 @@ def ast_expr_from_python_val(expr, val):
Parameters
----------
expr : IR expression object
expr : IR entity protobuf builder
val : Python value that needs to be converted to IR expression.
"""
if val is None:
Expand All @@ -62,12 +61,16 @@ def ast_expr_from_python_val(expr, val):
if isinstance(val, Callable):
expr.fn_val.params = signature(val).parameters
expr.fn_val.body = val
if isinstance(val, slice):
expr.slice_val.start = val.start
expr.slice_val.stop = val.stop
expr.slice_val.step = val.step
elif not isinstance(val, Series) and is_list_like(val):
# Checking that val is not a Series since Series objects are considered list-like.
expr.list_val.vs = val
elif isinstance(val, Series):
# Checking Series before the list-like type since Series are considered to be list-like.
expr.series_val.ref = val
elif is_list_like(val):
expr.list_val.vs = val
if isinstance(val, DataFrame):
elif isinstance(val, DataFrame):
expr.dataframe_val.ref = val
else:
ir_type_name = TYPE_TO_IR_TYPE_NAME[val_type]
Expand Down
12 changes: 8 additions & 4 deletions src/snowflake/snowpark/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,15 +666,19 @@ def __repr__(self):

def as_(self, alias: str) -> "Column":
"""Returns a new renamed Column. Alias of :func:`name`."""
return self.name(alias)
ast = proto.SpColumnExpr()
ast.sp_column_alias.variant_is_as = True
return self.name(alias, ast)

def alias(self, alias: str) -> "Column":
"""Returns a new renamed Column. Alias of :func:`name`."""
return self.name(alias)
ast = proto.SpColumnExpr()
ast.sp_column_alias.variant_is_as = False
return self.name(alias, ast)

def name(self, alias: str) -> "Column":
def name(self, alias: str, ast=None) -> "Column":
"""Returns a new renamed Column."""
return Column(Alias(self._expression, quote_name(alias)))
return Column(Alias(self._expression, quote_name(alias)), ast=ast)

def over(self, window: Optional[WindowSpec] = None) -> "Column":
"""
Expand Down
6 changes: 3 additions & 3 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,11 @@ def col(df_alias: str, col_name: str) -> Column:
... # pragma: no cover


def col(name1: str, name2: Optional[str] = None) -> Column:
def col(name1: str, name2: Optional[str] = None, ast=None) -> Column:
if name2 is None:
return Column(name1)
return Column(name1, ast=ast)
else:
return Column(name1, name2)
return Column(name1, name2, ast=ast)


@overload
Expand Down
16 changes: 10 additions & 6 deletions src/snowflake/snowpark/modin/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(
dtype=None,
copy=None,
query_compiler=None,
ast_stmt=None,
) -> None:
# TODO: SNOW-1063346: Modin upgrade - modin.pandas.DataFrame functions
# Siblings are other dataframes that share the same query compiler. We
Expand All @@ -158,6 +159,11 @@ def __init__(

self._siblings = []

if not ast_stmt:
ast_stmt = pd.session._ast_batch.assign()
self._ast_id = ast_stmt.var_id.bitfield1
self._ast_stmt = ast_stmt

# Engine.subscribe(_update_engine)
if isinstance(data, (DataFrame, Series)):
self._query_compiler = data._query_compiler.copy()
Expand Down Expand Up @@ -255,11 +261,6 @@ def __init__(
else:
self._query_compiler = query_compiler

def _get_ast_id(self):
return (
self._query_compiler._modin_frame.ordered_dataframe._dataframe_ref.snowpark_dataframe._ast_id
)

def __repr__(self):
"""
Return a string representation for a particular ``DataFrame``.
Expand Down Expand Up @@ -2869,7 +2870,10 @@ def __setattr__(self, key, value):
# - `_siblings`, which Modin initializes before it appears in __dict__
# - `_cache`, which pandas.cache_readonly uses to cache properties
# before it appears in __dict__.
if key in ("_query_compiler", "_siblings", "_cache") or key in self.__dict__:
if (
key in ("_query_compiler", "_siblings", "_cache", "_ast_id", "_ast_stmt")
or key in self.__dict__
):
pass
elif key in self and key not in dir(self):
self.__setitem__(key, value)
Expand Down
6 changes: 3 additions & 3 deletions src/snowflake/snowpark/modin/pandas/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,7 +1053,7 @@ def __getitem__(
# IR changes! Add iloc get nodes to AST.
stmt = pd.session._ast_batch.assign()
ast = stmt.expr
ast.pd_dataframe_i_loc.df.var_id.bitfield1 = self.df._get_ast_id()
ast.pd_dataframe_i_loc.df.var_id.bitfield1 = self.df._ast_id
# Map python built-ins (functions, scalars, lists, slices, etc.) to AST expr and emit Ref nodes for dataframes,
# series, and indexes.
ast_expr_from_python_val(ast.pd_dataframe_i_loc.rows, row_loc)
Expand All @@ -1067,7 +1067,7 @@ def __getitem__(

# Convert all scalar, list-like, and indexer row_loc to a Series object to get a query compiler object.
if is_scalar(row_loc):
row_loc = pd.Series([row_loc])
row_loc = pd.Series([row_loc], ast_stmt=stmt)
elif is_list_like(row_loc):
if hasattr(row_loc, "dtype"):
dtype = row_loc.dtype
Expand All @@ -1076,7 +1076,7 @@ def __getitem__(
dtype = float
else:
dtype = None
row_loc = pd.Series(row_loc, dtype=dtype)
row_loc = pd.Series(row_loc, dtype=dtype, ast_stmt=stmt)

# Check whether the row and column input is of numeric dtype.
self._validate_numeric_get_key_values(row_loc, original_row_loc)
Expand Down
10 changes: 9 additions & 1 deletion src/snowflake/snowpark/modin/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(
copy=False,
fastpath=False,
query_compiler=None,
ast_stmt=None,
) -> None:
# TODO: SNOW-1063347: Modin upgrade - modin.pandas.Series functions
# Siblings are other dataframes that share the same query compiler. We
Expand All @@ -130,6 +131,11 @@ def __init__(
# modified:
# Engine.subscribe(_update_engine)

if not ast_stmt:
ast_stmt = pd.session._ast_batch.assign()
self._ast_id = ast_stmt.var_id.bitfield1
self._ast_stmt = ast_stmt

if isinstance(data, type(self)):
query_compiler = data._query_compiler.copy()
if index is not None:
Expand Down Expand Up @@ -196,7 +202,9 @@ def _set_name(self, name):
else:
columns = [name]
self._update_inplace(
new_query_compiler=self._query_compiler.set_columns(columns)
new_query_compiler=self._query_compiler.set_columns(
columns, ast_stmt=self._ast_stmt
)
)

name = property(_get_name, _set_name)
Expand Down
8 changes: 6 additions & 2 deletions src/snowflake/snowpark/modin/plugin/_internal/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pandas._typing import IndexLabel
from pandas.core.dtypes.common import is_object_dtype

import snowflake.snowpark._internal.proto.ast_pb2 as proto
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
quote_name_without_upper_casing,
)
Expand Down Expand Up @@ -802,7 +803,9 @@ def project_columns(
)

def rename_snowflake_identifiers(
self, old_to_new_identifiers: dict[str, str]
self,
old_to_new_identifiers: dict[str, str],
ast_stmt: Optional[proto.Expr] = None,
) -> "InternalFrame":
"""
Rename columns for underlying ordered dataframe.
Expand Down Expand Up @@ -847,7 +850,8 @@ def rename_snowflake_identifiers(
# retain the original column
select_list.append(old_id)
else:
select_list.append(col(old_id).as_(new_id))
ast = ast_stmt.expr if ast_stmt is not None else None
select_list.append(col(old_id, ast=ast).as_(new_id))
# if the old column is part of the ordering or row position columns, retains the column
# as part of the projected columns.
if old_id in ordering_and_row_position_columns:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from pandas.io.formats.format import format_percentiles
from pandas.io.formats.printing import PrettyDict

import snowflake.snowpark._internal.proto.ast_pb2 as proto
import snowflake.snowpark.modin.pandas as pd
from snowflake.snowpark._internal.analyzer.analyzer_utils import (
quote_name_without_upper_casing,
Expand Down Expand Up @@ -1250,7 +1251,11 @@ def columns(self) -> "pd.Index":
# TODO SNOW-837664: add more tests for df.columns
return self._modin_frame.data_columns_index

def set_columns(self, new_pandas_labels: Axes) -> "SnowflakeQueryCompiler":
def set_columns(
self,
new_pandas_labels: Axes,
ast_stmt: Optional[proto.Expr] = None,
) -> "SnowflakeQueryCompiler":
"""
Set pandas column labels with the new column labels

Expand Down Expand Up @@ -1288,7 +1293,8 @@ def set_columns(self, new_pandas_labels: Axes) -> "SnowflakeQueryCompiler":
)

renamed_frame = self._modin_frame.rename_snowflake_identifiers(
renamed_quoted_identifier_mapping
renamed_quoted_identifier_mapping,
ast_stmt=ast_stmt,
)

new_internal_frame = InternalFrame.create(
Expand Down
4 changes: 2 additions & 2 deletions tests/thin-client/modin-steel-thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
[[1, 2, 3, 4, 5, 6, 7], [8, 9, 10, 11, 12, 13, 14], [15, 16, 17, 18, 19, 20, 21]],
columns=["A", "B", "C", "D", "E", "F", "G"],
)
df = df.iloc[3, 4]
df.show()
result = df.iloc[2, 2]
print(result)

0 comments on commit 7445b0d

Please sign in to comment.