From c834904516cce8c22390c76b905b8cfd03b6b026 Mon Sep 17 00:00:00 2001 From: Varnika Budati Date: Tue, 6 Aug 2024 16:41:39 -0700 Subject: [PATCH] SNOW-1569896 Refactor index to take in DataFrame/Series (#2020) --- src/snowflake/snowpark/modin/pandas/base.py | 6 ++++- .../snowpark/modin/plugin/extensions/index.py | 11 +++++++-- tests/integ/modin/index/test_index_methods.py | 23 +++++++++++++++++++ 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/src/snowflake/snowpark/modin/pandas/base.py b/src/snowflake/snowpark/modin/pandas/base.py index b1e4535a207..35039c695be 100644 --- a/src/snowflake/snowpark/modin/pandas/base.py +++ b/src/snowflake/snowpark/modin/pandas/base.py @@ -667,7 +667,11 @@ def _get_index(self): The union of all indexes across the partitions. """ # TODO: SNOW-1119855: Modin upgrade - modin.pandas.base.BasePandasDataset - return self._query_compiler.index + from snowflake.snowpark.modin.plugin.extensions.index import Index + + if self._query_compiler.is_multiindex(): + return self._query_compiler.index + return Index(data=self) index = property(_get_index, _set_index) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/index.py b/src/snowflake/snowpark/modin/plugin/extensions/index.py index 876fe470956..c0f793554dd 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/index.py @@ -35,6 +35,7 @@ from pandas.core.dtypes.common import pandas_dtype from snowflake.snowpark.modin.pandas import DataFrame, Series +from snowflake.snowpark.modin.pandas.base import BasePandasDataset from snowflake.snowpark.modin.pandas.utils import try_convert_index_to_native from snowflake.snowpark.modin.plugin._internal.telemetry import TelemetryMeta from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( @@ -50,7 +51,11 @@ class Index(metaclass=TelemetryMeta): def __init__( self, - data: ArrayLike | SnowflakeQueryCompiler | None = None, + data: ArrayLike + | modin.pandas.DataFrame + | Series + | SnowflakeQueryCompiler + | None = None, dtype: str | np.dtype | ExtensionDtype | None = None, copy: bool = False, name: object = None, @@ -63,7 +68,7 @@ def __init__( Parameters ---------- - data : array-like (1-dimensional) + data : array-like (1-dimensional), modin.pandas.Series, modin.pandas.DataFrame, SnowflakeQueryCompiler, optional dtype : str, numpy.dtype, or ExtensionDtype, optional Data type for the output Index. If not specified, this will be inferred from `data`. @@ -92,6 +97,8 @@ def __init__( >>> pd.Index([1, 2, 3], dtype="uint8") Index([1, 2, 3], dtype='int64') """ + self._parent = data if isinstance(data, BasePandasDataset) else None + data = data._query_compiler if isinstance(data, BasePandasDataset) else data if isinstance(data, SnowflakeQueryCompiler): qc = data else: diff --git a/tests/integ/modin/index/test_index_methods.py b/tests/integ/modin/index/test_index_methods.py index 37aa942680f..ce0fc05f80b 100644 --- a/tests/integ/modin/index/test_index_methods.py +++ b/tests/integ/modin/index/test_index_methods.py @@ -3,6 +3,7 @@ # import modin.pandas as pd +import pandas as native_pd import pytest from numpy.testing import assert_equal from pandas._libs import lib @@ -15,6 +16,7 @@ ) from tests.integ.modin.sql_counter import SqlCounter, sql_count_checker from tests.integ.modin.utils import ( + assert_frame_equal, assert_index_equal, assert_series_equal, assert_snowpark_pandas_equals_to_pandas_without_dtypecheck, @@ -341,3 +343,24 @@ def test_has_duplicates(index): with SqlCounter(query_count=1): snow_index = pd.Index(index) assert index.has_duplicates == snow_index.has_duplicates + + +@sql_count_checker(query_count=6) +def test_index_parent(): + """ + Check whether the parent field in Index is updated properly. + """ + native_idx1 = native_pd.Index(["A", "B"], name="xyz") + native_idx2 = native_pd.Index(["A", "B", "D", "E", "G", "H"], name="CFI") + + # DataFrame case. + df = pd.DataFrame([[1, 2], [3, 4]], index=native_idx1) + snow_idx = df.index + assert_frame_equal(snow_idx._parent, df) + assert_index_equal(snow_idx, native_idx1) + + # Series case. + s = pd.Series([1, 2, 4, 5, 6, 7], index=native_idx2, name="zyx") + snow_idx = s.index + assert_series_equal(snow_idx._parent, s) + assert_index_equal(snow_idx, native_idx2)