Skip to content

Commit

Permalink
SNOW-1618623 lazy index: refactor qc as input for index constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-azhan committed Aug 9, 2024
1 parent 70de08a commit 3c915f2
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 32 deletions.
6 changes: 5 additions & 1 deletion src/snowflake/snowpark/modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,8 +670,12 @@ def _get_index(self):
from snowflake.snowpark.modin.plugin.extensions.index import Index

if self._query_compiler.is_multiindex():
# Lazy multiindex is not supported
return self._query_compiler.index
return Index(data=self)

idx = Index(query_compiler=self._query_compiler)
idx._parent = self
return idx

index = property(_get_index, _set_index)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1585,6 +1585,7 @@ def index(self) -> Union["pd.Index", native_pd.MultiIndex]:
The index (row labels) of the DataFrame.
"""
if self.is_multiindex():
# Lazy multiindex is not supported
return self._modin_frame.index_columns_pandas_index()
else:
return pd.Index(query_compiler=self)
Expand Down
14 changes: 9 additions & 5 deletions src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from __future__ import annotations

import modin
import numpy as np
import pandas as native_pd
from pandas._libs import lib
Expand Down Expand Up @@ -69,7 +70,7 @@ def __new__(cls, *args, **kwargs):

def __init__(
self,
data: ArrayLike | SnowflakeQueryCompiler | None = None,
data: ArrayLike | native_pd.Index | modin.pandas.Sereis | None = None,
freq: Frequency | lib.NoDefault = _CONSTRUCTOR_DEFAULTS["freq"],
tz=_CONSTRUCTOR_DEFAULTS["tz"],
normalize: bool | lib.NoDefault = _CONSTRUCTOR_DEFAULTS["normalize"],
Expand All @@ -80,13 +81,14 @@ def __init__(
dtype: Dtype | None = _CONSTRUCTOR_DEFAULTS["dtype"],
copy: bool = _CONSTRUCTOR_DEFAULTS["copy"],
name: Hashable | None = _CONSTRUCTOR_DEFAULTS["name"],
query_compiler: SnowflakeQueryCompiler = None,
) -> None:
"""
Immutable ndarray-like of datetime64 data.
Parameters
----------
data : array-like (1-dimensional) or snowflake query compiler
data : array-like (1-dimensional), pandas.Index, modin.pandas.Series, optional
Datetime-like data to construct index with.
freq : str or pandas offset object, optional
One of pandas date offset strings or corresponding objects. The string
Expand Down Expand Up @@ -123,16 +125,18 @@ def __init__(
Make a copy of input ndarray.
name : label, default None
Name to be stored in the index.
query_compiler : SnowflakeQueryCompiler, optional
A query compiler object to create the ``Index`` from.
Examples
--------
>>> idx = pd.DatetimeIndex(["1/1/2020 10:00:00+00:00", "2/1/2020 11:00:00+00:00"], tz="America/Los_Angeles")
>>> idx
DatetimeIndex(['2020-01-01 02:00:00-08:00', '2020-02-01 03:00:00-08:00'], dtype='datetime64[ns, America/Los_Angeles]', freq=None)
"""
if isinstance(data, SnowflakeQueryCompiler):
if query_compiler:
# Raise error if underlying type is not a TimestampType.
current_dtype = data.index_dtypes[0]
current_dtype = query_compiler.index_dtypes[0]
if not current_dtype == np.dtype("datetime64[ns]"):
raise ValueError(
"DatetimeIndex can only be created from a query compiler with TimestampType."
Expand All @@ -149,4 +153,4 @@ def __init__(
"copy": copy,
"name": name,
}
self._init_index(data, _CONSTRUCTOR_DEFAULTS, **kwargs)
self._init_index(data, _CONSTRUCTOR_DEFAULTS, query_compiler, **kwargs)
61 changes: 38 additions & 23 deletions src/snowflake/snowpark/modin/plugin/extensions/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,20 @@ class Index(metaclass=TelemetryMeta):

def __new__(
cls,
data: ArrayLike | SnowflakeQueryCompiler | None = None,
data: ArrayLike | native_pd.Index | Series | None = None,
dtype: str | np.dtype | ExtensionDtype | None = _CONSTRUCTOR_DEFAULTS["dtype"],
copy: bool = _CONSTRUCTOR_DEFAULTS["copy"],
name: object = _CONSTRUCTOR_DEFAULTS["name"],
tupleize_cols: bool = _CONSTRUCTOR_DEFAULTS["tupleize_cols"],
query_compiler: SnowflakeQueryCompiler = None,
) -> Index:
"""
Override __new__ method to control new instance creation of Index.
Depending on data type, it will create a Index or DatetimeIndex instance.
Parameters
----------
data : array-like (1-dimensional)
data : array-like (1-dimensional), pandas.Index, modin.pandas.Series, optional
dtype : str, numpy.dtype, or ExtensionDtype, optional
Data type for the output Index. If not specified, this will be
inferred from `data`.
Expand All @@ -86,7 +87,8 @@ def __new__(
Name to be stored in the index.
tupleize_cols : bool (default: True)
When True, attempt to create a MultiIndex if possible.
query_compiler : SnowflakeQueryCompiler, optional
A query compiler object to create the ``Index`` from.
Returns
-------
New instance of Index or DatetimeIndex.
Expand All @@ -96,23 +98,25 @@ def __new__(
DatetimeIndex,
)

orig_data = data
data = data._query_compiler if isinstance(data, BasePandasDataset) else data

if isinstance(data, SnowflakeQueryCompiler):
dtype = data.index_dtypes[0]
if query_compiler:
dtype = query_compiler.index_dtypes[0]
if dtype == np.dtype("datetime64[ns]"):
return DatetimeIndex(orig_data)
return object.__new__(cls)
return DatetimeIndex(query_compiler=query_compiler)
elif isinstance(data, BasePandasDataset):
if data.ndim != 1:
raise ValueError("Index data must be 1 - dimensional")
dtype = data.dtype
if dtype == np.dtype("datetime64[ns]"):
return DatetimeIndex(data, dtype, copy, name, tupleize_cols)
else:
index = native_pd.Index(data, dtype, copy, name, tupleize_cols)
if isinstance(index, native_pd.DatetimeIndex):
return DatetimeIndex(orig_data)
return object.__new__(cls)
return DatetimeIndex(data)
return object.__new__(cls)

def __init__(
self,
data: ArrayLike | modin.pandas.DataFrame | Series | None = None,
data: ArrayLike | native_pd.Index | Series | None = None,
dtype: str | np.dtype | ExtensionDtype | None = _CONSTRUCTOR_DEFAULTS["dtype"],
copy: bool = _CONSTRUCTOR_DEFAULTS["copy"],
name: object = _CONSTRUCTOR_DEFAULTS["name"],
Expand All @@ -126,7 +130,7 @@ def __init__(
Parameters
----------
data : array-like (1-dimensional), modin.pandas.Series, modin.pandas.DataFrame, optional
data : array-like (1-dimensional), pandas.Index, modin.pandas.Series, optional
dtype : str, numpy.dtype, or ExtensionDtype, optional
Data type for the output Index. If not specified, this will be
inferred from `data`.
Expand Down Expand Up @@ -166,7 +170,7 @@ def __init__(

def _init_index(
self,
data: ArrayLike | SnowflakeQueryCompiler | None,
data: ArrayLike | native_pd.Index | Series | None,
ctor_defaults: dict,
query_compiler: SnowflakeQueryCompiler = None,
**kwargs: Any,
Expand All @@ -179,14 +183,23 @@ def _init_index(
), f"Non-default argument '{arg_name}={arg_value}' when constructing Index with query compiler"
self._query_compiler = query_compiler
elif isinstance(data, BasePandasDataset):
self._parent = data
self._query_compiler = data._query_compiler.drop(
columns=data._query_compiler.columns
if data.ndim != 1:
raise ValueError("Index data must be 1 - dimensional")
series_has_no_name = data.name is None
idx = (
data.to_frame().set_index(0 if series_has_no_name else data.name).index
)
if series_has_no_name:
idx.name = None
self._query_compiler = idx._query_compiler
else:
self._query_compiler = DataFrame(
index=self._NATIVE_INDEX_TYPE(data=data, **kwargs)
)._query_compiler
if len(self._query_compiler.columns):
self._query_compiler = self._query_compiler.drop(
columns=self._query_compiler.columns
)

def __getattr__(self, key: str) -> Any:
"""
Expand Down Expand Up @@ -410,7 +423,7 @@ def unique(self, level: Hashable | None = None) -> Index:
f"Too many levels: Index has only 1 level, {level} is not a valid level number."
)
return self.__constructor__(
data=self._query_compiler.groupby_agg(
query_compiler=self._query_compiler.groupby_agg(
by=self._query_compiler.get_index_names(axis=0),
agg_func={},
axis=0,
Expand Down Expand Up @@ -510,9 +523,9 @@ def astype(self, dtype: str | type | ExtensionDtype, copy: bool = True) -> Index
DatetimeIndex,
)

return DatetimeIndex(data=new_query_compiler)
return DatetimeIndex(query_compiler=new_query_compiler)

return Index(data=new_query_compiler)
return Index(query_compiler=new_query_compiler)

@property
def name(self) -> Hashable:
Expand Down Expand Up @@ -861,7 +874,9 @@ def copy(
False
"""
WarningMessage.ignored_argument(operation="copy", argument="deep", message="")
return self.__constructor__(self._query_compiler.copy(), name=name)
return self.__constructor__(
query_compiler=self._query_compiler.copy(), name=name
)

@index_not_implemented()
def delete(self) -> None:
Expand Down Expand Up @@ -1877,7 +1892,7 @@ def sort_values(
key=key,
include_indexer=return_indexer,
)
index = self.__constructor__(res)
index = self.__constructor__(query_compiler=res)
if return_indexer:
# When `return_indexer` is True, `res` is a query compiler with one index column
# and one data column.
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/modin/index/test_datetime_index_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_datetime_index_construction_negative():
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
msg = "DatetimeIndex can only be created from a query compiler with TimestampType"
with pytest.raises(ValueError, match=msg):
pd.DatetimeIndex(df._query_compiler)
pd.DatetimeIndex(query_compiler=df._query_compiler)


@sql_count_checker(query_count=0)
Expand All @@ -72,7 +72,7 @@ def test_non_default_args(kwargs):
value = list(kwargs.values())[0]
msg = f"Non-default argument '{name}={value}' when constructing Index with query compiler"
with pytest.raises(AssertionError, match=msg):
pd.DatetimeIndex(data=idx._query_compiler, **kwargs)
pd.DatetimeIndex(query_compiler=idx._query_compiler, **kwargs)


@sql_count_checker(query_count=6)
Expand Down
19 changes: 18 additions & 1 deletion tests/integ/modin/index/test_index_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,4 +382,21 @@ def test_non_default_args(kwargs):
value = list(kwargs.values())[0]
msg = f"Non-default argument '{name}={value}' when constructing Index with query compiler"
with pytest.raises(AssertionError, match=msg):
pd.Index(data=idx._query_compiler, **kwargs)
pd.Index(query_compiler=idx._query_compiler, **kwargs)


@sql_count_checker(query_count=2)
def test_create_index_from_series():
idx = pd.Index(pd.Series([5, 6]))
assert_index_equal(idx, native_pd.Index([5, 6]))

idx = pd.Index(pd.Series([5, 6], name="abc"))
assert_index_equal(idx, native_pd.Index([5, 6], name="abc"))


@sql_count_checker(query_count=0)
def test_create_index_from_df_negative():
with pytest.raises(ValueError):
pd.Index(pd.DataFrame([[1, 2], [3, 4]]))
with pytest.raises(ValueError):
pd.DatetimeIndex(pd.DataFrame([[1, 2], [3, 4]]))

0 comments on commit 3c915f2

Please sign in to comment.