Skip to content

Commit

Permalink
SNOW-1569896 Refactor index to take in DataFrame/Series (#2020)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-vbudati authored Aug 6, 2024
1 parent e0f7d81 commit c834904
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 3 deletions.
6 changes: 5 additions & 1 deletion src/snowflake/snowpark/modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,11 @@ def _get_index(self):
The union of all indexes across the partitions.
"""
# TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset
return self._query_compiler.index
from snowflake.snowpark.modin.plugin.extensions.index import Index

if self._query_compiler.is_multiindex():
return self._query_compiler.index
return Index(data=self)

index = property(_get_index, _set_index)

Expand Down
11 changes: 9 additions & 2 deletions src/snowflake/snowpark/modin/plugin/extensions/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from pandas.core.dtypes.common import pandas_dtype

from snowflake.snowpark.modin.pandas import DataFrame, Series
from snowflake.snowpark.modin.pandas.base import BasePandasDataset
from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native
from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
Expand All @@ -50,7 +51,11 @@
class Index(metaclass=TelemetryMeta):
def __init__(
self,
data: ArrayLike | SnowflakeQueryCompiler | None = None,
data: ArrayLike
| modin.pandas.DataFrame
| Series
| SnowflakeQueryCompiler
| None = None,
dtype: str | np.dtype | ExtensionDtype | None = None,
copy: bool = False,
name: object = None,
Expand All @@ -63,7 +68,7 @@ def __init__(
Parameters
----------
data : array-like (1-dimensional)
data : array-like (1-dimensional), modin.pandas.Series, modin.pandas.DataFrame, SnowflakeQueryCompiler, optional
dtype : str, numpy.dtype, or ExtensionDtype, optional
Data type for the output Index. If not specified, this will be
inferred from `data`.
Expand Down Expand Up @@ -92,6 +97,8 @@ def __init__(
>>> pd.Index([1, 2, 3], dtype="uint8")
Index([1, 2, 3], dtype='int64')
"""
self._parent = data if isinstance(data, BasePandasDataset) else None
data = data._query_compiler if isinstance(data, BasePandasDataset) else data
if isinstance(data, SnowflakeQueryCompiler):
qc = data
else:
Expand Down
23 changes: 23 additions & 0 deletions tests/integ/modin/index/test_index_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#

import modin.pandas as pd
import pandas as native_pd
import pytest
from numpy.testing import assert_equal
from pandas._libs import lib
Expand All @@ -15,6 +16,7 @@
)
from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker
from tests.integ.modin.utils import (
assert_frame_equal,
assert_index_equal,
assert_series_equal,
assert_snowpark_pandas_equals_to_pandas_without_dtypecheck,
Expand Down Expand Up @@ -341,3 +343,24 @@ def test_has_duplicates(index):
with SqlCounter(query_count=1):
snow_index = pd.Index(index)
assert index.has_duplicates == snow_index.has_duplicates


@sql_count_checker(query_count=6)
def test_index_parent():
"""
Check whether the parent field in Index is updated properly.
"""
native_idx1 = native_pd.Index(["A", "B"], name="xyz")
native_idx2 = native_pd.Index(["A", "B", "D", "E", "G", "H"], name="CFI")

# DataFrame case.
df = pd.DataFrame([[1, 2], [3, 4]], index=native_idx1)
snow_idx = df.index
assert_frame_equal(snow_idx._parent, df)
assert_index_equal(snow_idx, native_idx1)

# Series case.
s = pd.Series([1, 2, 4, 5, 6, 7], index=native_idx2, name="zyx")
snow_idx = s.index
assert_series_equal(snow_idx._parent, s)
assert_index_equal(snow_idx, native_idx2)

0 comments on commit c834904

Please sign in to comment.