diff --git a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py index aeca9d6e305..d364c5f57fa 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/base_overrides.py @@ -2262,7 +2262,7 @@ def _get_index(self): return self._query_compiler.index idx = Index(query_compiler=self._query_compiler) - idx._set_parent(self) + idx._parent.set_parent(self) return idx diff --git a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py index 7795d35c746..8f62e6e0dc6 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/datetime_index.py @@ -46,7 +46,7 @@ from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( SnowflakeQueryCompiler, ) -from snowflake.snowpark.modin.plugin.extensions.index import Index +from snowflake.snowpark.modin.plugin.extensions.index import Index, IndexParent from snowflake.snowpark.modin.plugin.utils.error_message import ( datetime_index_not_implemented, ) @@ -162,8 +162,7 @@ def __new__( query_compiler = query_compiler.series_to_datetime(include_index=True) index._query_compiler = query_compiler # `_parent` keeps track of any Series or DataFrame that this Index is a part of. - index._parent = None - index._parent_qc = None + index._parent = IndexParent() return index def __init__( diff --git a/src/snowflake/snowpark/modin/plugin/extensions/index.py b/src/snowflake/snowpark/modin/plugin/extensions/index.py index 63290803480..a5e541c2229 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/index.py @@ -71,6 +71,41 @@ } +class IndexParent: + def __init__(self) -> None: + """ + Initialize the IndexParent object. + + IndexParent is used to keep track of the parent object that the Index is a part of. + It tracks the parent object and the parent object's query compiler at the time of creation. + """ + self._parent = None + self._parent_qc = None + + def set_parent(self, parent: Series | DataFrame) -> None: + """ + Set the parent object and its query compiler. + + Parameters + ---------- + parent : Series or DataFrame + The parent object that the Index is a part of. + """ + self._parent = parent + self._parent_qc = parent._query_compiler + + def check_and_update_parent_qc_index_names(self, names: list) -> None: + """ + Update the Index and its parent's index names if the parent's current query compiler matches + the recorded query compiler (`_parent_qc`). + """ + if self._parent is not None and self._parent._query_compiler is self._parent_qc: + new_query_compiler = self._parent_qc.set_index_names(names) + self._parent._update_inplace(new_query_compiler=new_query_compiler) + # Update the query compiler after naming operation. + self._parent_qc = new_query_compiler + + class Index(metaclass=TelemetryMeta): # Equivalent index type in native pandas @@ -135,11 +170,8 @@ def __new__( index = object.__new__(cls) # Initialize the Index index._query_compiler = query_compiler - # `_parent` keeps track of any Series or DataFrame that this Index is a part of. - # `_parent_qc` keeps track of the original query compiler of the parent object. - # These fields are used with the name APIs. - index._parent = None - index._parent_qc = None + # `_parent` keeps track of the parent object that this Index is a part of. + index._parent = IndexParent() return index def __init__( @@ -411,13 +443,6 @@ def __constructor__(self): """ return type(self) - def _set_parent(self, parent: Series | DataFrame): - """ - Set the parent object of the current Index to a given Series or DataFrame. - """ - self._parent = parent - self._parent_qc = parent._query_compiler - @property def values(self) -> ArrayLike: """ @@ -731,11 +756,9 @@ def name(self, value: Hashable) -> None: raise TypeError(f"{type(self).__name__}.name must be a hashable type") self._query_compiler = self._query_compiler.set_index_names([value]) # Update the name of the parent's index only if the parent's current query compiler - # matches the recorded query compiler (_parent_qc). - if self._parent is not None and self._parent_qc is self._parent._query_compiler: - self._parent._update_inplace( - new_query_compiler=self._parent._query_compiler.set_index_names([value]) - ) + # matches the recorded query compiler. + if self._parent is not None: + self._parent.check_and_update_parent_qc_index_names([value]) def _get_names(self) -> list[Hashable]: """ @@ -762,11 +785,9 @@ def _set_names(self, values: list) -> None: values = values.to_list() self._query_compiler = self._query_compiler.set_index_names(values) # Update the name of the parent's index only if the parent's current query compiler - # matches the recorded query compiler (_parent_qc). - if self._parent is not None and self._parent_qc is self._parent._query_compiler: - self._parent._update_inplace( - new_query_compiler=self._parent._query_compiler.set_index_names(values) - ) + # matches the recorded query compiler. + if self._parent is not None: + self._parent.check_and_update_parent_qc_index_names(values) names = property(fset=_set_names, fget=_get_names) diff --git a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py index 558e3bcee76..93efe292437 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/timedelta_index.py @@ -41,7 +41,7 @@ from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import ( SnowflakeQueryCompiler, ) -from snowflake.snowpark.modin.plugin.extensions.index import Index +from snowflake.snowpark.modin.plugin.extensions.index import Index, IndexParent from snowflake.snowpark.modin.plugin.utils.error_message import ( timedelta_index_not_implemented, ) @@ -118,8 +118,7 @@ def __new__( data, _CONSTRUCTOR_DEFAULTS, query_compiler, **kwargs ) # `_parent` keeps track of any Series or DataFrame that this Index is a part of. - tdi._parent = None - tdi._parent_qc = None + tdi._parent = IndexParent() return tdi def __init__( diff --git a/tests/integ/modin/index/test_name.py b/tests/integ/modin/index/test_name.py index 33bc59739ff..f915598c5f6 100644 --- a/tests/integ/modin/index/test_name.py +++ b/tests/integ/modin/index/test_name.py @@ -358,12 +358,10 @@ def test_index_names_replace_behavior(): """ Check that the index name of a DataFrame cannot be updated after the DataFrame has been modified. """ - data = ( - { - "A": [0, 1, 2, 3, 4, 4], - "B": ["a", "b", "c", "d", "e", "f"], - }, - ) + data = { + "A": [0, 1, 2, 3, 4, 4], + "B": ["a", "b", "c", "d", "e", "f"], + } idx = [1, 2, 3, 4, 5, 6] native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test")) snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test")) @@ -389,3 +387,33 @@ def test_index_names_replace_behavior(): # Compare the names. Changing the index name should not change the DataFrame's index name. assert snow_index.name == native_index.name == "test3" assert snow_df.index.name == native_df.index.name == "test2" + + +@sql_count_checker(query_count=1) +def test_index_names_multiple_renames(): + """ + Check that the index name of a DataFrame can be renamed any number of times. + """ + data = { + "A": [0, 1, 2, 3, 4, 4], + "B": ["a", "b", "c", "d", "e", "f"], + } + idx = [1, 2, 3, 4, 5, 6] + native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test")) + snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test")) + + # Get a reference to the index of the DataFrames. + snow_index = snow_df.index + native_index = native_df.index + + # Change and compare the names. + snow_index.name = "test2" + native_index.name = "test2" + assert snow_index.name == native_index.name == "test2" + assert snow_df.index.name == native_df.index.name == "test2" + + # Change the names again and compare. + snow_index.name = "test3" + native_index.name = "test3" + assert snow_index.name == native_index.name == "test3" + assert snow_df.index.name == native_df.index.name == "test3"