Skip to content

Commit

Permalink
SNOW-1573193: Remove local index from pd.Index (#2031)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-nkumar authored Aug 5, 2024
1 parent 6e7b721 commit 5e360be
Show file tree
Hide file tree
Showing 16 changed files with 62 additions and 305 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
- Fixed a bug in `Index.to_frame` where the result frame's column name may be wrong where name is unspecified.
- Fixed a bug where some Index docstrings are ignored.

### Behavior change
- `Dataframe.columns` now returns native pandas Index object instead of Snowpark Index object.

## 1.20.0 (2024-07-17)

Expand Down
2 changes: 1 addition & 1 deletion src/snowflake/snowpark/modin/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def _repr_html_(self): # pragma: no cover
else:
return result

def _get_columns(self) -> pd.Index:
def _get_columns(self) -> pandas.Index:
"""
Get the columns for this Snowpark pandas ``DataFrame``.
Expand Down
7 changes: 2 additions & 5 deletions src/snowflake/snowpark/modin/plugin/_internal/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
last_value,
max as max_,
)
from snowflake.snowpark.modin import pandas as pd
from snowflake.snowpark.modin.plugin._internal.ordered_dataframe import (
OrderedDataFrame,
OrderingColumn,
Expand Down Expand Up @@ -378,7 +377,7 @@ def is_unnamed_series(self) -> bool:
)

@property
def data_columns_index(self) -> "pd.Index":
def data_columns_index(self) -> native_pd.Index:
"""
Returns Snowpark pandas Index object for column index (df.columns).
Note this object will still hold an internal pandas index (i.e., not lazy) to avoid unnecessary pulling data from Snowflake.
Expand All @@ -389,15 +388,13 @@ def data_columns_index(self) -> "pd.Index":
names=self.data_column_pandas_index_names,
)
else:
return pd.Index(
return native_pd.Index(
self.data_column_pandas_labels,
name=self.data_column_pandas_index_names[0],
# setting tupleize_cols=False to avoid creating a MultiIndex
# otherwise, when labels are tuples (e.g., [("A", "a"), ("B", "b")]),
# a MultiIndex will be created incorrectly
tupleize_cols=False,
# setting is_lazy as false because we want to store the columns locally
convert_to_lazy=False,
)

def index_columns_pandas_index(self, **kwargs: Any) -> native_pd.Index:
Expand Down
15 changes: 7 additions & 8 deletions src/snowflake/snowpark/modin/plugin/_internal/indexing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,8 +776,8 @@ def _extract_loc_set_col_info(
label for label in columns if label not in frame_data_columns
]
columns = [label for label in columns if label in frame_data_columns]
before = frame_data_columns.to_pandas().value_counts()
after = union_data_columns.to_pandas().value_counts()
before = frame_data_columns.value_counts()
after = union_data_columns.value_counts()
frame_data_col_labels = frame_data_columns.tolist()
for label in after.index:
if label in frame_data_columns:
Expand Down Expand Up @@ -872,7 +872,9 @@ def get_valid_col_positions_from_col_labels(
)
)
)
col_loc = pd.Index(col_loc, convert_to_lazy=False)
col_loc = col_loc.index
if isinstance(col_loc, pd.Index):
col_loc = col_loc.to_pandas()
# get the position of the selected labels
return [pos for pos, label in enumerate(columns) if label in col_loc]
else:
Expand Down Expand Up @@ -939,11 +941,8 @@ def get_valid_col_positions_from_col_labels(
# Convert col_loc to Index with object dtype since _get_indexer_strict() converts None values in lists to
# np.nan. This does not filter columns with label None and errors. Not using np.array(col_loc) as the key since
# np.array(["A", 12]) turns into array(['A', '12'].
col_loc = pd.Index(
[label for label in col_loc if label in columns],
dtype=object,
# we do not convert to lazy because we are using this index as columns
convert_to_lazy=False,
col_loc = native_pd.Index(
[label for label in col_loc if label in columns], dtype=object
)

# `Index._get_indexer_strict` returns position index from label index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1322,7 +1322,7 @@ def cache_result(self) -> "SnowflakeQueryCompiler":
return SnowflakeQueryCompiler(self._modin_frame.persist_to_temporary_table())

@property
def columns(self) -> "pd.Index":
def columns(self) -> native_pd.Index:
"""
Get pandas column labels.

Expand Down Expand Up @@ -2399,17 +2399,7 @@ def _reindex_axis_1(
limit = kwargs.get("limit", None)
tolerance = kwargs.get("tolerance", None)
fill_value = kwargs.get("fill_value", np.nan) # type: ignore[arg-type]
# Currently, our error checking relies on the column axis being eager (i.e. stored
# locally as a pandas Index, rather than pushed down to the database). This allows
# us to have parity with native pandas for things like monotonicity checks. If
# our columns are no longer eagerly stored, we would no longer be able to rely
# on pandas for these error checks, and the behaviour of reindex would change.
# This change is user-facing, so we should catch this in CI first, which we can
# by having this assert here, as a sentinel.
assert (
not self.columns.is_lazy
), "`reindex` with axis=1 failed on error checking."
self.columns.to_pandas().reindex(labels, method, level, limit, tolerance)
self.columns.reindex(labels, method, level, limit, tolerance)
data_column_pandas_labels = []
data_column_snowflake_quoted_identifiers = []
modin_frame = self._modin_frame
Expand Down
Loading

0 comments on commit 5e360be

Please sign in to comment.