Skip to content

Commit

Permalink
SNOW 1359041: Lazy Index Constructor, to_pandas() and len (#1729)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-samdesai1 authored Jun 14, 2024
1 parent 133ff6e commit 3fdec44
Show file tree
Hide file tree
Showing 56 changed files with 975 additions and 572 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@

- Added support for named aggregations in `DataFrame.aggregate` and `Series.aggregate` with `axis=0`.
- `pd.read_csv` reads using the native pandas CSV parser, then uploads data to snowflake using parquet. This enables most of the parameters supported by `read_csv` including date parsing and numeric conversions. Uploading via parquet is roughly twice as fast as uploading via CSV.
- Initial work to support an Index directly in Snowpark pandas. Currently, this class is a simple wrapper for a pandas index. Support for Index as a first-class component of Snowpark pandas is coming soon.
- Initial work to support an Index directly in Snowpark pandas. Support for Index as a first-class component of Snowpark pandas is coming soon.
- Added lazy index constructor and support for len, to_pandas() and names. For `df.index`, Snowpark pandas creates a lazy index object.
- For `df.columns`, Snowpark pandas supports a non-lazy version of an Index since the data is already stored locally

## 1.18.0 (2024-05-28)

Expand Down
11 changes: 10 additions & 1 deletion src/snowflake/snowpark/modin/pandas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,15 +796,24 @@ def ensure_index(
>>> ensure_index(['a', 'b'])
Index(['a', 'b'], dtype='object')
# Snowpark pandas converts these tuples to lists
>>> ensure_index([('a', 'a'), ('b', 'c')])
Index([('a', 'a'), ('b', 'c')], dtype='object')
Index([['a', 'a'], ['b', 'c']], dtype='object')
>>> ensure_index([['a', 'a'], ['b', 'c']])
MultiIndex([('a', 'b'),
('a', 'c')],
)
"""
# if we have an index object already, simply copy it if required and return
if isinstance(index_like, (pandas.MultiIndex, pd.Index)):
if copy:
index_like = index_like.copy()
return index_like

if isinstance(index_like, pd.Series):
return pd.Index(index_like.values)

if isinstance(index_like, list):
# if we have a non-empty list that is multi dimensional, convert this to a multi-index and return
if len(index_like) and lib.is_all_arraylike(index_like):
Expand Down
4 changes: 2 additions & 2 deletions src/snowflake/snowpark/modin/plugin/_internal/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,19 +388,19 @@ def data_columns_index(self) -> "pd.Index":
# otherwise, when labels are tuples (e.g., [("A", "a"), ("B", "b")]),
# a MultiIndex will be created incorrectly
tupleize_cols=False,
# setting is_lazy as false because we want to store the columns locally
convert_to_lazy=False,
)

@property
def index_columns_index(self) -> native_pd.Index:
def index_columns_pandas_index(self) -> native_pd.Index:
"""
Get pandas index. The method eagerly pulls the values from Snowflake because index requires the values to be
filled
Returns:
The index (row labels) of the DataFrame.
"""

index_values = snowpark_to_pandas_helper(
self.ordered_dataframe.select(
self.index_column_snowflake_quoted_identifiers
Expand Down
131 changes: 98 additions & 33 deletions src/snowflake/snowpark/modin/plugin/_internal/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from pandas._typing import ArrayLike, DtypeObj, NaPosition, Self
from pandas.core.arrays import ExtensionArray
from pandas.core.dtypes.base import ExtensionDtype
from pandas.core.indexes.frozen import FrozenList

from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native
from snowflake.snowpark.modin.plugin.utils.error_message import (
Expand All @@ -43,7 +42,7 @@
class Index:
def __init__(
self,
# Any should be replaced with SnowflakeQueryCompiler when possible (linter won't allow it now)
# TODO: SNOW-1481037 : Fix typehints for index constructor, set_query_compiler and set_local_index
data: ArrayLike | Any = None,
dtype: str | np.dtype | ExtensionDtype | None = None,
copy: bool = False,
Expand Down Expand Up @@ -86,29 +85,86 @@ def __init__(
>>> pd.Index(list('abc'))
Index(['a', 'b', 'c'], dtype='object')
# Snowpark pandas only supports signed integers so cast to uint won't work
>>> pd.Index([1, 2, 3], dtype="uint8")
Index([1, 2, 3], dtype='uint8')
Index([1, 2, 3], dtype='int64')
"""
self.is_lazy = convert_to_lazy
if self.is_lazy:
self.set_query_compiler(
data=data,
dtype=dtype,
copy=copy,
name=name,
tupleize_cols=tupleize_cols,
)
else:
self.set_local_index(
data=data,
dtype=dtype,
copy=copy,
name=name,
tupleize_cols=tupleize_cols,
)

def set_query_compiler(
self,
# TODO: SNOW-1481037 : Fix typehints for index constructor, set_query_compiler and set_local_index
data: ArrayLike | Any = None,
dtype: str | np.dtype | ExtensionDtype | None = None,
copy: bool = False,
name: object = None,
tupleize_cols: bool = True,
) -> None:
"""
Helper method to find and save query compiler when index should be lazy
"""
from snowflake.snowpark.modin.pandas.dataframe import DataFrame
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
SnowflakeQueryCompiler,
)

# TODO: SNOW-1359041: Switch to lazy index implementation
self.is_lazy = convert_to_lazy
if isinstance(data, native_pd.Index):
self._index = data
elif isinstance(data, Index):
self._index = data.to_pandas()
elif isinstance(data, SnowflakeQueryCompiler):
self._index = data._modin_frame.index_columns_index
if isinstance(data, SnowflakeQueryCompiler):
qc = data
else:
qc = DataFrame(
native_pd.Index(
data=data,
dtype=dtype,
copy=copy,
name=name,
tupleize_cols=tupleize_cols,
).to_frame()
)._query_compiler
self._query_compiler = qc

def set_local_index(
self,
# TODO: SNOW-1481037 : Fix typehints for index constructor, set_query_compiler and set_local_index
data: ArrayLike | Any = None,
dtype: str | np.dtype | ExtensionDtype | None = None,
copy: bool = False,
name: object = None,
tupleize_cols: bool = True,
) -> None:
"""
Helper method to create and save local index when index should not be lazy
"""
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
SnowflakeQueryCompiler,
)

if isinstance(data, SnowflakeQueryCompiler):
index = data._modin_frame.index_columns_pandas_index
else:
self._index = native_pd.Index(
index = native_pd.Index(
data=data,
dtype=dtype,
copy=copy,
name=name,
tupleize_cols=tupleize_cols,
)
self._index = index

def is_lazy_check(func: Any) -> Any:
"""
Expand All @@ -134,6 +190,9 @@ def check_lazy(*args: Any, **kwargs: Any) -> Any:

# Remove the first argument in args, because it is `self` and we don't need it
args = args[1:]
args = tuple(try_convert_index_to_native(a) for a in args)
for k, v in kwargs.items():
kwargs[k] = try_convert_index_to_native(v)
returned_value = native_func(*args, **kwargs)

# If we return a native Index, we need to convert this to a modin index but keep it locally.
Expand Down Expand Up @@ -191,6 +250,8 @@ def to_pandas(self) -> native_pd.Index:
pandas Index
A native pandas Index representation of self
"""
if self.is_lazy:
return self._query_compiler._modin_frame.index_columns_pandas_index
return self._index

@property
Expand Down Expand Up @@ -281,12 +342,12 @@ def is_unique(self) -> bool:
True
>>> idx = pd.Index(["Watermelon", "Orange", "Apple",
... "Watermelon"]).astype("category")
... "Watermelon"])
>>> idx.is_unique
False
>>> idx = pd.Index(["Orange", "Apple",
... "Watermelon"]).astype("category")
... "Watermelon"])
>>> idx.is_unique
True
"""
Expand Down Expand Up @@ -320,12 +381,12 @@ def has_duplicates(self) -> bool:
False
>>> idx = pd.Index(["Watermelon", "Orange", "Apple",
... "Watermelon"]).astype("category")
... "Watermelon"])
>>> idx.has_duplicates
True
>>> idx = pd.Index(["Orange", "Apple",
... "Watermelon"]).astype("category")
... "Watermelon"])
>>> idx.has_duplicates
False
"""
Expand Down Expand Up @@ -431,6 +492,7 @@ def astype(self, dtype: Any, copy: bool = True) -> Index:
WarningMessage.index_to_pandas_warning("astype")
return Index(
self.to_pandas().astype(dtype=dtype, copy=copy),
dtype=dtype,
convert_to_lazy=self.is_lazy,
)

Expand All @@ -452,24 +514,26 @@ def name(self) -> Hashable:
>>> idx.name
'x'
"""
# TODO: SNOW-1458122 implement name
WarningMessage.index_to_pandas_warning("name")
return self.to_pandas().name
return self.names[0] if self.names else None

@name.setter
def name(self, value: Hashable) -> None:
"""
Set Index name.
"""
WarningMessage.index_to_pandas_warning("name")
self.to_pandas().name = value
if self.is_lazy:
self._query_compiler = self._query_compiler.set_index_names([value])
else:
self._index.name = value

def _get_names(self) -> FrozenList:
def _get_names(self) -> list[Hashable]:
"""
Get names of index
"""
WarningMessage.index_to_pandas_warning("_get_names")
return self.to_pandas()._get_names()
if self.is_lazy:
return self._query_compiler.get_index_names()
else:
return self.to_pandas().names

def _set_names(self, values: list) -> None:
"""
Expand All @@ -484,8 +548,10 @@ def _set_names(self, values: list) -> None:
------
TypeError if each name is not hashable.
"""
WarningMessage.index_to_pandas_warning("_set_names")
self.to_pandas()._set_names(values)
if self.is_lazy:
self._query_compiler = self._query_compiler.set_index_names(values)
else:
self._index.names = values

names = property(fset=_set_names, fget=_get_names)

Expand Down Expand Up @@ -937,9 +1003,11 @@ def equals(self, other: Any) -> bool:
>>> int64_idx = pd.Index([1, 2, 3], dtype='int64')
>>> int64_idx
Index([1, 2, 3], dtype='int64')
# Snowpark pandas only supports signed integers so cast to uint won't work
>>> uint64_idx = pd.Index([1, 2, 3], dtype='uint64')
>>> uint64_idx
Index([1, 2, 3], dtype='uint64')
Index([1, 2, 3], dtype='int64')
>>> int64_idx.equals(uint64_idx)
True
"""
Expand Down Expand Up @@ -1835,15 +1903,13 @@ def get_indexer_for(self, target: Any) -> Any:
Examples
--------
# Snowpark pandas converts np.nan, pd.NA, pd.NaT to None
>>> idx = pd.Index([np.nan, 'var1', np.nan])
>>> idx.get_indexer_for([np.nan])
array([0, 2])
"""
WarningMessage.index_to_pandas_warning("get_indexer_for")
ret = self.to_pandas().get_indexer_for(target=target)
# if isinstance(ret, native_pd.Index):
# return Index(ret, convert_to_lazy=self.is_lazy)
return ret
return self.to_pandas().get_indexer_for(target=target)

@is_lazy_check
def _get_indexer_strict(self, key: Any, axis_name: str) -> tuple[Index, np.ndarray]:
Expand Down Expand Up @@ -2084,8 +2150,7 @@ def __len__(self) -> int:
"""
Return the length of the Index as an int.
"""
WarningMessage.index_to_pandas_warning("__len__")
return self.to_pandas().__len__()
return self._query_compiler.get_axis_len(0)

@is_lazy_check
def __getitem__(self, key: Any) -> np.ndarray | None | Index:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ def get_valid_col_positions_from_col_labels(
)
)
)
col_loc = col_loc.index
col_loc = pd.Index(col_loc, convert_to_lazy=False)
# get the position of the selected labels
return [pos for pos, label in enumerate(columns) if label in col_loc]
else:
Expand Down Expand Up @@ -939,7 +939,10 @@ def get_valid_col_positions_from_col_labels(
# np.nan. This does not filter columns with label None and errors. Not using np.array(col_loc) as the key since
# np.array(["A", 12]) turns into array(['A', '12'].
col_loc = pd.Index(
[label for label in col_loc if label in columns], dtype=object
[label for label in col_loc if label in columns],
dtype=object,
# we do not convert to lazy because we are using this index as columns
convert_to_lazy=False,
)

# `Index._get_indexer_strict` returns position index from label index
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def transpose_empty_df(

return SnowflakeQueryCompiler.from_pandas(
native_pd.DataFrame(
columns=original_frame.index_columns_index,
columns=original_frame.index_columns_pandas_index,
index=try_convert_index_to_native(original_frame.data_columns_index),
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
)
from pandas.core.dtypes.base import ExtensionDtype
from pandas.core.dtypes.common import is_dict_like, is_list_like, pandas_dtype
from pandas.core.indexes.base import ensure_index
from pandas.io.formats.format import format_percentiles
from pandas.io.formats.printing import PrettyDict

Expand Down Expand Up @@ -1266,9 +1267,9 @@ def set_columns(self, new_pandas_labels: Axes) -> "SnowflakeQueryCompiler":
a new `SnowflakeQueryCompiler` with updated column labels
"""
# new_pandas_names should be able to convert into an index which is consistent to pandas df.columns behavior
from snowflake.snowpark.modin.pandas.utils import ensure_index
from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native

new_pandas_labels = ensure_index(new_pandas_labels)
new_pandas_labels = ensure_index(try_convert_index_to_native(new_pandas_labels))
if len(new_pandas_labels) != len(self._modin_frame.data_column_pandas_labels):
raise ValueError(
"Length mismatch: Expected axis has {} elements, new values have {} elements".format(
Expand Down Expand Up @@ -1500,14 +1501,14 @@ def shift(
@property
def index(self) -> Union["pd.Index", native_pd.MultiIndex]:
"""
Get pandas index. The method eagerly pulls the values from Snowflake because index requires the values to be
filled
Get index. If MultiIndex, the method eagerly pulls the values from Snowflake because index requires the values to be
filled and returns a pandas MultiIndex. If not MultiIndex, create a modin index and pass it self

Returns:
The index (row labels) of the DataFrame.
"""
if self.is_multiindex():
return self._modin_frame.index_columns_index
return self._modin_frame.index_columns_pandas_index
else:
return pd.Index(self)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def contains():
>>> ind = pd.Index(['Mouse', 'dog', 'house and parrot', '23.0', np.NaN])
>>> ind.str.contains('23', regex=False)
Index([False, False, False, True, nan], dtype='object')
Index([False, False, False, True, None], dtype='object')
Specifying case sensitivity using case.
Expand Down
Loading

0 comments on commit 3fdec44

Please sign in to comment.