Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1661142 Fix index name behavior #2274

Merged
merged 13 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

- Added support for `TimedeltaIndex.mean` method.

#### Bug Fixes

- Fixed a bug where an `Index` object created from a `Series`/`DataFrame` incorrectly updates the `Series`/`DataFrame`'s index name when it is not supposed to.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"index name when it is not supposed to." -> index name after an inplace updates have been applied to the original series/dataframe

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!



## 1.22.1 (2024-09-11)
This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for detailed release content.
Expand Down
60 changes: 47 additions & 13 deletions src/snowflake/snowpark/modin/plugin/extensions/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,35 @@
}


class IndexParent:
def __init__(self, parent: DataFrame | Series) -> None:
"""
Initialize the IndexParent object.

IndexParent is used to keep track of the parent object that the Index is a part of.
It tracks the parent object and the parent object's query compiler at the time of creation.

Parameters
----------
parent : DataFrame or Series
The parent object that the Index is a part of.
"""
assert isinstance(parent, (DataFrame, Series))
self._parent = parent
self._parent_qc = parent._query_compiler

def check_and_update_parent_qc_index_names(self, names: list) -> None:
"""
Update the Index and its parent's index names if the parent's current query compiler matches
the recorded query compiler (`_parent_qc`).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment here that "if the query compiler associated with parent is not the same as the original recorded one that means an inplace updates have been applied to the parent"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

"""
if self._parent._query_compiler is self._parent_qc:
new_query_compiler = self._parent_qc.set_index_names(names)
self._parent._update_inplace(new_query_compiler=new_query_compiler)
# Update the query compiler after naming operation.
self._parent_qc = new_query_compiler


class Index(metaclass=TelemetryMeta):

# Equivalent index type in native pandas
Expand Down Expand Up @@ -135,7 +164,7 @@ def __new__(
index = object.__new__(cls)
# Initialize the Index
index._query_compiler = query_compiler
# `_parent` keeps track of any Series or DataFrame that this Index is a part of.
# `_parent` keeps track of the parent object that this Index is a part of.
index._parent = None
return index

Expand Down Expand Up @@ -252,6 +281,17 @@ def __getattr__(self, key: str) -> Any:
ErrorMessage.not_implemented(f"Index.{key} is not yet implemented")
raise err

def _set_parent(self, parent: Series | DataFrame) -> None:
"""
Set the parent object and its query compiler.

Parameters
----------
parent : Series or DataFrame
The parent object that the Index is a part of.
"""
self._parent = IndexParent(parent)

def _binary_ops(self, method: str, other: Any) -> Index:
if isinstance(other, Index):
other = other.to_series().reset_index(drop=True)
Expand Down Expand Up @@ -408,12 +448,6 @@ def __constructor__(self):
"""
return type(self)

def _set_parent(self, parent: Series | DataFrame):
"""
Set the parent object of the current Index to a given Series or DataFrame.
"""
self._parent = parent

@property
def values(self) -> ArrayLike:
"""
Expand Down Expand Up @@ -726,10 +760,10 @@ def name(self, value: Hashable) -> None:
if not is_hashable(value):
raise TypeError(f"{type(self).__name__}.name must be a hashable type")
self._query_compiler = self._query_compiler.set_index_names([value])
# Update the name of the parent's index only if the parent's current query compiler
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you update the doc for the name function to make it clear that the inplace replacement only works when no inplace update has been applied to parent

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

# matches the recorded query compiler.
if self._parent is not None:
self._parent._update_inplace(
new_query_compiler=self._parent._query_compiler.set_index_names([value])
)
self._parent.check_and_update_parent_qc_index_names([value])

def _get_names(self) -> list[Hashable]:
"""
Expand All @@ -755,10 +789,10 @@ def _set_names(self, values: list) -> None:
if isinstance(values, Index):
values = values.to_list()
self._query_compiler = self._query_compiler.set_index_names(values)
# Update the name of the parent's index only if the parent's current query compiler
# matches the recorded query compiler.
if self._parent is not None:
self._parent._update_inplace(
new_query_compiler=self._parent._query_compiler.set_index_names(values)
)
self._parent.check_and_update_parent_qc_index_names(values)

names = property(fset=_set_names, fget=_get_names)

Expand Down
4 changes: 2 additions & 2 deletions tests/integ/modin/index/test_datetime_index_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,13 @@ def test_index_parent():
# DataFrame case.
df = pd.DataFrame({"A": [1]}, index=native_idx1)
snow_idx = df.index
assert_frame_equal(snow_idx._parent, df)
assert_frame_equal(snow_idx._parent._parent, df)
assert_index_equal(snow_idx, native_idx1)

# Series case.
s = pd.Series([1, 2], index=native_idx2, name="zyx")
snow_idx = s.index
assert_series_equal(snow_idx._parent, s)
assert_series_equal(snow_idx._parent._parent, s)
assert_index_equal(snow_idx, native_idx2)


Expand Down
4 changes: 2 additions & 2 deletions tests/integ/modin/index/test_index_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,13 +393,13 @@ def test_index_parent():
# DataFrame case.
df = pd.DataFrame([[1, 2], [3, 4]], index=native_idx1)
snow_idx = df.index
assert_frame_equal(snow_idx._parent, df)
assert_frame_equal(snow_idx._parent._parent, df)
assert_index_equal(snow_idx, native_idx1)

# Series case.
s = pd.Series([1, 2, 4, 5, 6, 7], index=native_idx2, name="zyx")
snow_idx = s.index
assert_series_equal(snow_idx._parent, s)
assert_series_equal(snow_idx._parent._parent, s)
assert_index_equal(snow_idx, native_idx2)


Expand Down
66 changes: 66 additions & 0 deletions tests/integ/modin/index/test_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,69 @@ def test_index_names_with_lazy_index():
),
inplace=True,
)


@sql_count_checker(query_count=1)
def test_index_names_replace_behavior():
"""
Check that the index name of a DataFrame cannot be updated after the DataFrame has been modified.
"""
data = {
"A": [0, 1, 2, 3, 4, 4],
"B": ["a", "b", "c", "d", "e", "f"],
}
idx = [1, 2, 3, 4, 5, 6]
native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test"))
snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test"))

# Get a reference to the index of the DataFrames.
snow_index = snow_df.index
native_index = native_df.index

# Change the names.
snow_index.name = "test2"
native_index.name = "test2"

# Compare the names.
assert snow_index.name == native_index.name == "test2"
assert snow_df.index.name == native_df.index.name == "test2"

# Change the query compiler the DataFrame is referring to, change the names.
snow_df.dropna(inplace=True)
native_df.dropna(inplace=True)
snow_index.name = "test3"
native_index.name = "test3"

# Compare the names. Changing the index name should not change the DataFrame's index name.
assert snow_index.name == native_index.name == "test3"
assert snow_df.index.name == native_df.index.name == "test2"


@sql_count_checker(query_count=1)
def test_index_names_multiple_renames():
"""
Check that the index name of a DataFrame can be renamed any number of times.
"""
data = {
"A": [0, 1, 2, 3, 4, 4],
"B": ["a", "b", "c", "d", "e", "f"],
}
idx = [1, 2, 3, 4, 5, 6]
native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test"))
snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test"))

# Get a reference to the index of the DataFrames.
snow_index = snow_df.index
native_index = native_df.index

# Change and compare the names.
snow_index.name = "test2"
native_index.name = "test2"
assert snow_index.name == native_index.name == "test2"
assert snow_df.index.name == native_df.index.name == "test2"

# Change the names again and compare.
snow_index.name = "test3"
native_index.name = "test3"
assert snow_index.name == native_index.name == "test3"
assert snow_df.index.name == native_df.index.name == "test3"
Loading