Skip to content

Commit

Permalink
Merge branch 'main' into mvashishtha/SNOW-1664064/avoid-setting-with-…
Browse files Browse the repository at this point in the history
…copy-warning-for-timedelta
  • Loading branch information
sfc-gh-mvashishtha authored Sep 17, 2024
2 parents b1a0ea5 + e93cd68 commit a22403f
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 17 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
- Added support for `by`, `left_by`, and `right_by` for `pd.merge_asof`.

#### Bug Fixes
- Fixed a bug where an `Index` object created from a `Series`/`DataFrame` incorrectly updates the `Series`/`DataFrame`'s index name after an inplace update has been applied to the original `Series`/`DataFrame`.
- Suppressed an unhelpful `SettingWithCopyWarning` that sometimes appeared when printing `Timedelta` columns.


## 1.22.1 (2024-09-11)
This is a re-release of 1.22.0. Please refer to the 1.22.0 release notes for detailed release content.

Expand Down
61 changes: 48 additions & 13 deletions src/snowflake/snowpark/modin/plugin/extensions/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,35 @@
}


class IndexParent:
def __init__(self, parent: DataFrame | Series) -> None:
"""
Initialize the IndexParent object.
IndexParent is used to keep track of the parent object that the Index is a part of.
It tracks the parent object and the parent object's query compiler at the time of creation.
Parameters
----------
parent : DataFrame or Series
The parent object that the Index is a part of.
"""
assert isinstance(parent, (DataFrame, Series))
self._parent = parent
self._parent_qc = parent._query_compiler

def check_and_update_parent_qc_index_names(self, names: list) -> None:
"""
Update the Index and its parent's index names if the query compiler associated with the parent is
different from the original query compiler recorded, i.e., an inplace update has been applied to the parent.
"""
if self._parent._query_compiler is self._parent_qc:
new_query_compiler = self._parent_qc.set_index_names(names)
self._parent._update_inplace(new_query_compiler=new_query_compiler)
# Update the query compiler after naming operation.
self._parent_qc = new_query_compiler


class Index(metaclass=TelemetryMeta):

# Equivalent index type in native pandas
Expand Down Expand Up @@ -135,7 +164,7 @@ def __new__(
index = object.__new__(cls)
# Initialize the Index
index._query_compiler = query_compiler
# `_parent` keeps track of any Series or DataFrame that this Index is a part of.
# `_parent` keeps track of the parent object that this Index is a part of.
index._parent = None
return index

Expand Down Expand Up @@ -252,6 +281,17 @@ def __getattr__(self, key: str) -> Any:
ErrorMessage.not_implemented(f"Index.{key} is not yet implemented")
raise err

def _set_parent(self, parent: Series | DataFrame) -> None:
"""
Set the parent object and its query compiler.
Parameters
----------
parent : Series or DataFrame
The parent object that the Index is a part of.
"""
self._parent = IndexParent(parent)

def _binary_ops(self, method: str, other: Any) -> Index:
if isinstance(other, Index):
other = other.to_series().reset_index(drop=True)
Expand Down Expand Up @@ -408,12 +448,6 @@ def __constructor__(self):
"""
return type(self)

def _set_parent(self, parent: Series | DataFrame):
"""
Set the parent object of the current Index to a given Series or DataFrame.
"""
self._parent = parent

@property
def values(self) -> ArrayLike:
"""
Expand Down Expand Up @@ -726,10 +760,11 @@ def name(self, value: Hashable) -> None:
if not is_hashable(value):
raise TypeError(f"{type(self).__name__}.name must be a hashable type")
self._query_compiler = self._query_compiler.set_index_names([value])
# Update the name of the parent's index only if an inplace update is performed on
# the parent object, i.e., the parent's current query compiler matches the originally
# recorded query compiler.
if self._parent is not None:
self._parent._update_inplace(
new_query_compiler=self._parent._query_compiler.set_index_names([value])
)
self._parent.check_and_update_parent_qc_index_names([value])

def _get_names(self) -> list[Hashable]:
"""
Expand All @@ -755,10 +790,10 @@ def _set_names(self, values: list) -> None:
if isinstance(values, Index):
values = values.to_list()
self._query_compiler = self._query_compiler.set_index_names(values)
# Update the name of the parent's index only if the parent's current query compiler
# matches the recorded query compiler.
if self._parent is not None:
self._parent._update_inplace(
new_query_compiler=self._parent._query_compiler.set_index_names(values)
)
self._parent.check_and_update_parent_qc_index_names(values)

names = property(fset=_set_names, fget=_get_names)

Expand Down
4 changes: 2 additions & 2 deletions tests/integ/modin/index/test_datetime_index_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,13 @@ def test_index_parent():
# DataFrame case.
df = pd.DataFrame({"A": [1]}, index=native_idx1)
snow_idx = df.index
assert_frame_equal(snow_idx._parent, df)
assert_frame_equal(snow_idx._parent._parent, df)
assert_index_equal(snow_idx, native_idx1)

# Series case.
s = pd.Series([1, 2], index=native_idx2, name="zyx")
snow_idx = s.index
assert_series_equal(snow_idx._parent, s)
assert_series_equal(snow_idx._parent._parent, s)
assert_index_equal(snow_idx, native_idx2)


Expand Down
4 changes: 2 additions & 2 deletions tests/integ/modin/index/test_index_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,13 +393,13 @@ def test_index_parent():
# DataFrame case.
df = pd.DataFrame([[1, 2], [3, 4]], index=native_idx1)
snow_idx = df.index
assert_frame_equal(snow_idx._parent, df)
assert_frame_equal(snow_idx._parent._parent, df)
assert_index_equal(snow_idx, native_idx1)

# Series case.
s = pd.Series([1, 2, 4, 5, 6, 7], index=native_idx2, name="zyx")
snow_idx = s.index
assert_series_equal(snow_idx._parent, s)
assert_series_equal(snow_idx._parent._parent, s)
assert_index_equal(snow_idx, native_idx2)


Expand Down
66 changes: 66 additions & 0 deletions tests/integ/modin/index/test_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,69 @@ def test_index_names_with_lazy_index():
),
inplace=True,
)


@sql_count_checker(query_count=1)
def test_index_names_replace_behavior():
"""
Check that the index name of a DataFrame cannot be updated after the DataFrame has been modified.
"""
data = {
"A": [0, 1, 2, 3, 4, 4],
"B": ["a", "b", "c", "d", "e", "f"],
}
idx = [1, 2, 3, 4, 5, 6]
native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test"))
snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test"))

# Get a reference to the index of the DataFrames.
snow_index = snow_df.index
native_index = native_df.index

# Change the names.
snow_index.name = "test2"
native_index.name = "test2"

# Compare the names.
assert snow_index.name == native_index.name == "test2"
assert snow_df.index.name == native_df.index.name == "test2"

# Change the query compiler the DataFrame is referring to, change the names.
snow_df.dropna(inplace=True)
native_df.dropna(inplace=True)
snow_index.name = "test3"
native_index.name = "test3"

# Compare the names. Changing the index name should not change the DataFrame's index name.
assert snow_index.name == native_index.name == "test3"
assert snow_df.index.name == native_df.index.name == "test2"


@sql_count_checker(query_count=1)
def test_index_names_multiple_renames():
"""
Check that the index name of a DataFrame can be renamed any number of times.
"""
data = {
"A": [0, 1, 2, 3, 4, 4],
"B": ["a", "b", "c", "d", "e", "f"],
}
idx = [1, 2, 3, 4, 5, 6]
native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test"))
snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test"))

# Get a reference to the index of the DataFrames.
snow_index = snow_df.index
native_index = native_df.index

# Change and compare the names.
snow_index.name = "test2"
native_index.name = "test2"
assert snow_index.name == native_index.name == "test2"
assert snow_df.index.name == native_df.index.name == "test2"

# Change the names again and compare.
snow_index.name = "test3"
native_index.name = "test3"
assert snow_index.name == native_index.name == "test3"
assert snow_df.index.name == native_df.index.name == "test3"

0 comments on commit a22403f

Please sign in to comment.