Skip to content

Commit

Permalink
create new class for parent
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-vbudati committed Sep 12, 2024
1 parent 41e69e2 commit 6e00c73
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2262,7 +2262,7 @@ def _get_index(self):
return self._query_compiler.index

idx = Index(query_compiler=self._query_compiler)
idx._set_parent(self)
idx._parent.set_parent(self)
return idx


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
SnowflakeQueryCompiler,
)
from snowflake.snowpark.modin.plugin.extensions.index import Index
from snowflake.snowpark.modin.plugin.extensions.index import Index, IndexParent
from snowflake.snowpark.modin.plugin.utils.error_message import (
datetime_index_not_implemented,
)
Expand Down Expand Up @@ -162,8 +162,7 @@ def __new__(
query_compiler = query_compiler.series_to_datetime(include_index=True)
index._query_compiler = query_compiler
# `_parent` keeps track of any Series or DataFrame that this Index is a part of.
index._parent = None
index._parent_qc = None
index._parent = IndexParent()
return index

def __init__(
Expand Down
65 changes: 43 additions & 22 deletions src/snowflake/snowpark/modin/plugin/extensions/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,41 @@
}


class IndexParent:
def __init__(self) -> None:
"""
Initialize the IndexParent object.
IndexParent is used to keep track of the parent object that the Index is a part of.
It tracks the parent object and the parent object's query compiler at the time of creation.
"""
self._parent = None
self._parent_qc = None

def set_parent(self, parent: Series | DataFrame) -> None:
"""
Set the parent object and its query compiler.
Parameters
----------
parent : Series or DataFrame
The parent object that the Index is a part of.
"""
self._parent = parent
self._parent_qc = parent._query_compiler

def check_and_update_parent_qc_index_names(self, names: list) -> None:
"""
Update the Index and its parent's index names if the parent's current query compiler matches
the recorded query compiler (`_parent_qc`).
"""
if self._parent is not None and self._parent._query_compiler is self._parent_qc:
new_query_compiler = self._parent_qc.set_index_names(names)
self._parent._update_inplace(new_query_compiler=new_query_compiler)
# Update the query compiler after naming operation.
self._parent_qc = new_query_compiler


class Index(metaclass=TelemetryMeta):

# Equivalent index type in native pandas
Expand Down Expand Up @@ -135,11 +170,8 @@ def __new__(
index = object.__new__(cls)
# Initialize the Index
index._query_compiler = query_compiler
# `_parent` keeps track of any Series or DataFrame that this Index is a part of.
# `_parent_qc` keeps track of the original query compiler of the parent object.
# These fields are used with the name APIs.
index._parent = None
index._parent_qc = None
# `_parent` keeps track of the parent object that this Index is a part of.
index._parent = IndexParent()
return index

def __init__(
Expand Down Expand Up @@ -411,13 +443,6 @@ def __constructor__(self):
"""
return type(self)

def _set_parent(self, parent: Series | DataFrame):
"""
Set the parent object of the current Index to a given Series or DataFrame.
"""
self._parent = parent
self._parent_qc = parent._query_compiler

@property
def values(self) -> ArrayLike:
"""
Expand Down Expand Up @@ -731,11 +756,9 @@ def name(self, value: Hashable) -> None:
raise TypeError(f"{type(self).__name__}.name must be a hashable type")
self._query_compiler = self._query_compiler.set_index_names([value])
# Update the name of the parent's index only if the parent's current query compiler
# matches the recorded query compiler (_parent_qc).
if self._parent is not None and self._parent_qc is self._parent._query_compiler:
self._parent._update_inplace(
new_query_compiler=self._parent._query_compiler.set_index_names([value])
)
# matches the recorded query compiler.
if self._parent is not None:
self._parent.check_and_update_parent_qc_index_names([value])

def _get_names(self) -> list[Hashable]:
"""
Expand All @@ -762,11 +785,9 @@ def _set_names(self, values: list) -> None:
values = values.to_list()
self._query_compiler = self._query_compiler.set_index_names(values)
# Update the name of the parent's index only if the parent's current query compiler
# matches the recorded query compiler (_parent_qc).
if self._parent is not None and self._parent_qc is self._parent._query_compiler:
self._parent._update_inplace(
new_query_compiler=self._parent._query_compiler.set_index_names(values)
)
# matches the recorded query compiler.
if self._parent is not None:
self._parent.check_and_update_parent_qc_index_names(values)

names = property(fset=_set_names, fget=_get_names)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from snowflake.snowpark.modin.plugin.compiler.snowflake_query_compiler import (
SnowflakeQueryCompiler,
)
from snowflake.snowpark.modin.plugin.extensions.index import Index
from snowflake.snowpark.modin.plugin.extensions.index import Index, IndexParent
from snowflake.snowpark.modin.plugin.utils.error_message import (
timedelta_index_not_implemented,
)
Expand Down Expand Up @@ -118,8 +118,7 @@ def __new__(
data, _CONSTRUCTOR_DEFAULTS, query_compiler, **kwargs
)
# `_parent` keeps track of any Series or DataFrame that this Index is a part of.
tdi._parent = None
tdi._parent_qc = None
tdi._parent = IndexParent()
return tdi

def __init__(
Expand Down
40 changes: 34 additions & 6 deletions tests/integ/modin/index/test_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,12 +358,10 @@ def test_index_names_replace_behavior():
"""
Check that the index name of a DataFrame cannot be updated after the DataFrame has been modified.
"""
data = (
{
"A": [0, 1, 2, 3, 4, 4],
"B": ["a", "b", "c", "d", "e", "f"],
},
)
data = {
"A": [0, 1, 2, 3, 4, 4],
"B": ["a", "b", "c", "d", "e", "f"],
}
idx = [1, 2, 3, 4, 5, 6]
native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test"))
snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test"))
Expand All @@ -389,3 +387,33 @@ def test_index_names_replace_behavior():
# Compare the names. Changing the index name should not change the DataFrame's index name.
assert snow_index.name == native_index.name == "test3"
assert snow_df.index.name == native_df.index.name == "test2"


@sql_count_checker(query_count=1)
def test_index_names_multiple_renames():
"""
Check that the index name of a DataFrame can be renamed any number of times.
"""
data = {
"A": [0, 1, 2, 3, 4, 4],
"B": ["a", "b", "c", "d", "e", "f"],
}
idx = [1, 2, 3, 4, 5, 6]
native_df = native_pd.DataFrame(data, native_pd.Index(idx, name="test"))
snow_df = pd.DataFrame(data, index=pd.Index(idx, name="test"))

# Get a reference to the index of the DataFrames.
snow_index = snow_df.index
native_index = native_df.index

# Change and compare the names.
snow_index.name = "test2"
native_index.name = "test2"
assert snow_index.name == native_index.name == "test2"
assert snow_df.index.name == native_df.index.name == "test2"

# Change the names again and compare.
snow_index.name = "test3"
native_index.name = "test3"
assert snow_index.name == native_index.name == "test3"
assert snow_df.index.name == native_df.index.name == "test3"

0 comments on commit 6e00c73

Please sign in to comment.