diff --git a/CHANGELOG.md b/CHANGELOG.md index a987df71d33..82dd5f7895b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,8 +15,10 @@ - Added support for `by`, `left_by`, and `right_by` for `pd.merge_asof`. #### Bug Fixes +- Fixed a bug where an `Index` object created from a `Series`/`DataFrame` incorrectly updates the `Series`/`DataFrame`'s index name after an inplace update has been applied to the original `Series`/`DataFrame`. - Suppressed an unhelpful `SettingWithCopyWarning` that sometimes appeared when printing `Timedelta` columns. + ## 1.22.1 (2024-09-11) This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for detailed release content. diff --git a/src/snowflake/snowpark/modin/plugin/extensions/index.py b/src/snowflake/snowpark/modin/plugin/extensions/index.py index 643f6f5038e..b25bb481dc0 100644 --- a/src/snowflake/snowpark/modin/plugin/extensions/index.py +++ b/src/snowflake/snowpark/modin/plugin/extensions/index.py @@ -71,6 +71,35 @@ } +class IndexParent: + def __init__(self, parent: DataFrame | Series) -> None: + """ + Initialize the IndexParent object. + + IndexParent is used to keep track of the parent object that the Index is a part of. + It tracks the parent object and the parent object's query compiler at the time of creation. + + Parameters + ---------- + parent : DataFrame or Series + The parent object that the Index is a part of. + """ + assert isinstance(parent, (DataFrame, Series)) + self._parent = parent + self._parent_qc = parent._query_compiler + + def check_and_update_parent_qc_index_names(self, names: list) -> None: + """ + Update the Index and its parent's index names if the query compiler associated with the parent is + different from the original query compiler recorded, i.e., an inplace update has been applied to the parent. + """ + if self._parent._query_compiler is self._parent_qc: + new_query_compiler = self._parent_qc.set_index_names(names) + self._parent._update_inplace(new_query_compiler=new_query_compiler) + # Update the query compiler after naming operation. + self._parent_qc = new_query_compiler + + class Index(metaclass=TelemetryMeta): # Equivalent index type in native pandas @@ -135,7 +164,7 @@ def __new__( index = object.__new__(cls) # Initialize the Index index._query_compiler = query_compiler - # `_parent` keeps track of any Series or DataFrame that this Index is a part of. + # `_parent` keeps track of the parent object that this Index is a part of. index._parent = None return index @@ -252,6 +281,17 @@ def __getattr__(self, key: str) -> Any: ErrorMessage.not_implemented(f"Index.{key} is not yet implemented") raise err + def _set_parent(self, parent: Series | DataFrame) -> None: + """ + Set the parent object and its query compiler. + + Parameters + ---------- + parent : Series or DataFrame + The parent object that the Index is a part of. + """ + self._parent = IndexParent(parent) + def _binary_ops(self, method: str, other: Any) -> Index: if isinstance(other, Index): other = other.to_series().reset_index(drop=True) @@ -408,12 +448,6 @@ def __constructor__(self): """ return type(self) - def _set_parent(self, parent: Series | DataFrame): - """ - Set the parent object of the current Index to a given Series or DataFrame. - """ - self._parent = parent - @property def values(self) -> ArrayLike: """ @@ -726,10 +760,11 @@ def name(self, value: Hashable) -> None: if not is_hashable(value): raise TypeError(f"{type(self).__name__}.name must be a hashable type") self._query_compiler = self._query_compiler.set_index_names([value]) + # Update the name of the parent's index only if an inplace update is performed on + # the parent object, i.e., the parent's current query compiler matches the originally + # recorded query compiler. if self._parent is not None: - self._parent._update_inplace( - new_query_compiler=self._parent._query_compiler.set_index_names([value]) - ) + self._parent.check_and_update_parent_qc_index_names([value]) def _get_names(self) -> list[Hashable]: """ @@ -755,10 +790,10 @@ def _set_names(self, values: list) -> None: if isinstance(values, Index): values = values.to_list() self._query_compiler = self._query_compiler.set_index_names(values) + # Update the name of the parent's index only if the parent's current query compiler + # matches the recorded query compiler. if self._parent is not None: - self._parent._update_inplace( - new_query_compiler=self._parent._query_compiler.set_index_names(values) - ) + self._parent.check_and_update_parent_qc_index_names(values) names = property(fset=_set_names, fget=_get_names) diff --git a/tests/integ/modin/index/test_datetime_index_methods.py b/tests/integ/modin/index/test_datetime_index_methods.py index 793485f97d6..98d1a041c3b 100644 --- a/tests/integ/modin/index/test_datetime_index_methods.py +++ b/tests/integ/modin/index/test_datetime_index_methods.py @@ -142,13 +142,13 @@ def test_index_parent(): # DataFrame case. df = pd.DataFrame({"A": [1]}, index=native_idx1) snow_idx = df.index - assert_frame_equal(snow_idx._parent, df) + assert_frame_equal(snow_idx._parent._parent, df) assert_index_equal(snow_idx, native_idx1) # Series case. s = pd.Series([1, 2], index=native_idx2, name="zyx") snow_idx = s.index - assert_series_equal(snow_idx._parent, s) + assert_series_equal(snow_idx._parent._parent, s) assert_index_equal(snow_idx, native_idx2) diff --git a/tests/integ/modin/index/test_index_methods.py b/tests/integ/modin/index/test_index_methods.py index 8d0434915ac..6b33eb89889 100644 --- a/tests/integ/modin/index/test_index_methods.py +++ b/tests/integ/modin/index/test_index_methods.py @@ -393,13 +393,13 @@ def test_index_parent(): # DataFrame case. df = pd.DataFrame([[1, 2], [3, 4]], index=native_idx1) snow_idx = df.index - assert_frame_equal(snow_idx._parent, df) + assert_frame_equal(snow_idx._parent._parent, df) assert_index_equal(snow_idx, native_idx1) # Series case. s = pd.Series([1, 2, 4, 5, 6, 7], index=native_idx2, name="zyx") snow_idx = s.index - assert_series_equal(snow_idx._parent, s) + assert_series_equal(snow_idx._parent._parent, s) assert_index_equal(snow_idx, native_idx2) diff --git a/tests/integ/modin/index/test_name.py b/tests/integ/modin/index/test_name.py index b916110f386..f915598c5f6 100644 --- a/tests/integ/modin/index/test_name.py +++ b/tests/integ/modin/index/test_name.py @@ -351,3 +351,69 @@ def test_index_names_with_lazy_index(): ), inplace=True, ) + + +@sql_count_checker(query_count=1) +def test_index_names_replace_behavior(): + """ + Check that the index name of a DataFrame cannot be updated after the DataFrame has been modified. + """ + data = { + "A": [0, 1, 2, 3, 4, 4], + "B": ["a", "b", "c", "d", "e", "f"], + } + idx = [1, 2, 3, 4, 5, 6] + native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test")) + snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test")) + + # Get a reference to the index of the DataFrames. + snow_index = snow_df.index + native_index = native_df.index + + # Change the names. + snow_index.name = "test2" + native_index.name = "test2" + + # Compare the names. + assert snow_index.name == native_index.name == "test2" + assert snow_df.index.name == native_df.index.name == "test2" + + # Change the query compiler the DataFrame is referring to, change the names. + snow_df.dropna(inplace=True) + native_df.dropna(inplace=True) + snow_index.name = "test3" + native_index.name = "test3" + + # Compare the names. Changing the index name should not change the DataFrame's index name. + assert snow_index.name == native_index.name == "test3" + assert snow_df.index.name == native_df.index.name == "test2" + + +@sql_count_checker(query_count=1) +def test_index_names_multiple_renames(): + """ + Check that the index name of a DataFrame can be renamed any number of times. + """ + data = { + "A": [0, 1, 2, 3, 4, 4], + "B": ["a", "b", "c", "d", "e", "f"], + } + idx = [1, 2, 3, 4, 5, 6] + native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test")) + snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test")) + + # Get a reference to the index of the DataFrames. + snow_index = snow_df.index + native_index = native_df.index + + # Change and compare the names. + snow_index.name = "test2" + native_index.name = "test2" + assert snow_index.name == native_index.name == "test2" + assert snow_df.index.name == native_df.index.name == "test2" + + # Change the names again and compare. + snow_index.name = "test3" + native_index.name = "test3" + assert snow_index.name == native_index.name == "test3" + assert snow_df.index.name == native_df.index.name == "test3"