Skip to content

Commit

Permalink
SNOW-1458137 Tests to verify set_index and reset_index work as ex…
Browse files Browse the repository at this point in the history
…pected (#2138)

<!---
Please answer these questions before creating your pull request. Thanks!
--->

1. Which Jira issue is this PR addressing? Make sure that there is an
accompanying issue to your PR.

   <!---
   In this section, please add a Snowflake Jira issue number.
   
Note that if a corresponding GitHub issue exists, you should still
include
   the Snowflake Jira issue number. For example, for GitHub issue
#1400, you should
   add "SNOW-1335071" here.
    --->

   Fixes SNOW-1458137

2. Fill out the following pre-review checklist:

- [x] I am adding a new automated test(s) to verify correctness of my
new code
- [ ] If this test skips Local Testing mode, I'm requesting review from
@snowflakedb/local-testing
   - [ ] I am adding new logging messages
   - [ ] I am adding a new telemetry message
   - [ ] I am adding new credentials
   - [ ] I am adding a new dependency
- [ ] If this is a new feature/behavior, I'm adding the Local Testing
parity changes.

3. Please describe how your code solves the related issue.

Verifying that `df.index = new_index` is implemented correctly.
  • Loading branch information
sfc-gh-vbudati authored Aug 27, 2024
1 parent 7a156a4 commit 0a9bbc7
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 248 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1705,17 +1705,9 @@ def set_index(
if not any(isinstance(k, SnowflakeQueryCompiler) for k in keys):
return self.set_index_from_columns(keys, drop=drop, append=append)

self_num_rows = self.get_axis_len(axis=0)
new_qc = self
for key in keys:
if isinstance(key, SnowflakeQueryCompiler):
assert (
len(key._modin_frame.data_column_pandas_labels) == 1
), "need to be a series"
if key.get_axis_len(0) != self_num_rows:
raise ValueError(
f"Length mismatch: Expected {self_num_rows} rows, received array of length {key.get_axis_len(0)}"
)
new_qc = new_qc.set_index_from_series(key, append)
else:
new_qc = new_qc.set_index_from_columns([key], drop, append)
Expand Down
133 changes: 10 additions & 123 deletions tests/integ/modin/frame/test_axis.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_index(test_df):


@pytest.mark.parametrize("test_df", test_dfs)
@sql_count_checker(query_count=8, join_count=3)
@sql_count_checker(query_count=3, join_count=3)
def test_set_and_assign_index(test_df):
def assign_index(df, keys):
df.index = keys
Expand Down Expand Up @@ -289,7 +289,7 @@ def test_duplicate_labels_assignment():
native_pd.DataFrame({"A": [3.14, 1.414, 1.732], "B": [9.8, 1.0, 0]}),
"rows",
[None] * 3,
5,
3,
2,
],
[ # Labels is a MultiIndex from tuples.
Expand All @@ -299,14 +299,14 @@ def test_duplicate_labels_assignment():
[("r0", "rA", "rR"), ("r1", "rB", "rS"), ("r2", "rC", "rT")],
names=["Courses", "Fee", "Random"],
),
7,
3,
6,
],
[
native_pd.DataFrame({"A": ["foo", "bar", 3], "B": [4, "baz", 6]}),
0,
{1: "c", 2: "b", 3: "a"},
5,
3,
2,
],
[
Expand All @@ -326,7 +326,7 @@ def test_duplicate_labels_assignment():
),
0,
['"row 1"', "row 2"],
5,
3,
2,
],
[
Expand All @@ -339,7 +339,7 @@ def test_duplicate_labels_assignment():
),
"rows",
list(range(10)),
5,
3,
2,
],
[
Expand Down Expand Up @@ -367,7 +367,7 @@ def test_duplicate_labels_assignment():
native_pd.MultiIndex.from_product(
[["NJ", "CA"], ["temp", "precip"]], names=["number", "color"]
),
6,
3,
4,
],
# Set columns.
Expand Down Expand Up @@ -720,105 +720,6 @@ def test_duplicate_labels_assignment():
]


# Invalid data which raises ValueError different from native pandas
# -----------------------------------------------------------------
# Format: df, axis, and invalid labels.
# - This data cover the negative case for DataFrame.set_axis() with invalid labels.
# - Invalid labels here consist of: passing None, too many values, too few values, empty list,
# Index, and MultiIndex objects as invalid labels for row-like axis.
TEST_DATA_FOR_DF_SET_AXIS_RAISES_VALUE_ERROR_DIFF_ERROR_MSG = [
# invalid row labels
[
native_pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]}),
"index",
["a", "b", "c", "d"],
"Length mismatch: Expected 3 rows, received array of length 4",
],
[
native_pd.DataFrame(
{
"A": [None] * 3,
"B": [4, 5, 6],
"C": [4, 5, 6],
"D": [7, 8, 9],
"E": [-1, -2, -3],
}
),
0,
[99], # too few labels,
"Length mismatch: Expected 3 rows, received array of length 1",
],
[
native_pd.DataFrame(
{
"Artists": ["Monet", "Manet", "Gogh"],
"Museums": ["MoMA", "SoMA", "The High"],
}
),
"rows",
[None], # too few labels
"Length mismatch: Expected 3 rows, received array of length 1",
],
[
native_pd.DataFrame(
{"A": ["a", "b", "c", "d", "e"], "B": ["e", "d", "c", "b", "a"]}
),
"index",
native_pd.Index([11, 111, 1111, 11111, 111111, 11111111]), # too many labels
"Length mismatch: Expected 5 rows, received array of length 6",
],
[
native_pd.DataFrame({"foo": [None], "bar": [None], "baz": [None]}),
0,
native_pd.Index([None] * 10), # too many labels
"Length mismatch: Expected 1 rows, received array of length 10",
],
[
native_pd.DataFrame(
{"echo": ["echo", "echo"], "not an echo": [". . .", ". . ."]}
),
"rows",
native_pd.Index([]), # too few labels
"Length mismatch: Expected 2 rows, received array of length 0",
],
[ # Labels is a MultiIndex from tuples.
native_pd.DataFrame({"A": [1], -2515 / 135: [4]}),
"index",
native_pd.MultiIndex.from_tuples(
[("r0", "rA", "rR"), ("r1", "rB", "rS"), ("r2", "rC", "rT")],
names=["Courses", "Fee", "Random"],
), # too many labels
"Length mismatch: Expected 1 rows, received array of length 3",
],
[ # Labels is a MultiIndex from a DataFrame.
native_pd.DataFrame(
{
"A": [1, 2, 3, 4],
"B": [4, 5, 6, 7],
"C": [7, 8, 9, 10],
"D": [10, 11, 12, 13],
}
),
"rows",
native_pd.MultiIndex.from_frame(
native_pd.DataFrame(
[],
columns=["a", "b"],
),
), # too few labels
"Length mismatch: Expected 4 rows, received array of length 0",
],
[ # Labels is a MultiIndex from a product.
native_pd.DataFrame({1: [1], 2: [2], 3: [3], 4: [4], 5: [5], 6: [6]}),
0,
native_pd.MultiIndex.from_product(
[[0], ["green", "purple"]], names=["number", "color"]
), # too many labels
"Length mismatch: Expected 1 rows, received array of length 2",
],
]


@pytest.mark.parametrize(
"native_df, axis, labels, num_queries, num_joins", TEST_DATA_FOR_DF_SET_AXIS
)
Expand Down Expand Up @@ -865,20 +766,6 @@ def test_set_axis_df_raises_value_error(native_df, axis, labels):
)


@pytest.mark.parametrize(
"native_df, axis, labels, error_msg",
TEST_DATA_FOR_DF_SET_AXIS_RAISES_VALUE_ERROR_DIFF_ERROR_MSG,
)
def test_set_axis_df_raises_value_error_diff_error_msg(
native_df, axis, labels, error_msg
):
# Should raise a ValueError if the labels for row-like axis are invalid.
# The error messages do not match native pandas.
with SqlCounter(query_count=2):
with pytest.raises(ValueError, match=error_msg):
pd.DataFrame(native_df).set_axis(labels, axis=axis)


@pytest.mark.parametrize(
"native_df, axis, labels, error_msg", TEST_DATA_FOR_DF_SET_AXIS_RAISES_TYPE_ERROR
)
Expand All @@ -892,7 +779,7 @@ def test_set_axis_df_raises_type_error_diff_error_msg(
pd.DataFrame(native_df).set_axis(labels, axis=axis)


@sql_count_checker(query_count=3, join_count=1)
@sql_count_checker(query_count=1, join_count=1)
def test_df_set_axis_copy_true(caplog):
# Test that warning is raised when copy argument is used.
native_df = native_pd.DataFrame({"A": [1.25], "B": [3]})
Expand Down Expand Up @@ -933,11 +820,11 @@ def test_df_set_axis_with_quoted_index():
# check first that operation result is the same
snow_df = pd.DataFrame(data)
native_df = native_pd.DataFrame(data)
with SqlCounter(query_count=3):
with SqlCounter(query_count=1):
eval_snowpark_pandas_result(snow_df, native_df, helper)

# then, explicitly compare axes
with SqlCounter(query_count=1):
with SqlCounter(query_count=0):
ans = helper(snow_df)

native_ans = helper(native_df)
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/modin/frame/test_drop.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_drop_duplicate_columns(native_df, labels):

@pytest.mark.parametrize(
"labels, expected_query_count, expected_join_count",
[([], 3, 1), (1, 4, 2), (2, 4, 2), ([1, 2], 5, 3)],
[([], 1, 1), (1, 2, 2), (2, 2, 2), ([1, 2], 3, 3)],
)
def test_drop_duplicate_row_labels(
native_df, labels, expected_query_count, expected_join_count
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/modin/frame/test_drop_duplicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_drop_duplicates(subset, keep, ignore_index):
query_count = 1
join_count = 2
if ignore_index is True:
query_count += 2
query_count += 1
join_count += 3
with SqlCounter(query_count=query_count, join_count=join_count):
assert_frame_equal(
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/modin/frame/test_nlargest_nsmallest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_nlargest_nsmallest_large_n(snow_df, native_df, method):
)


@sql_count_checker(query_count=3)
@sql_count_checker(query_count=1)
def test_nlargest_nsmallest_overlapping_index_name(snow_df, native_df, method):
snow_df = snow_df.rename_axis("A")
native_df = native_df.rename_axis("A")
Expand Down
Loading

0 comments on commit 0a9bbc7

Please sign in to comment.