Skip to content

Commit

Permalink
SNOW-1019479: Fix describe with duplicate column names (#1742)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-joshi authored Jun 7, 2024
1 parent c0341ec commit a259aa9
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#### Bug Fixes

- Fixed a bug that causes output of GroupBy.aggregate's columns to be ordered incorrectly.
- Fixed a bug where `DataFrame.describe` on a frame with duplicate columns of differing dtypes could cause an error or incorrect results.

#### Improvements

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10002,7 +10002,61 @@ def describe(
# max NaN 3.0 NaN
sorted_percentiles = sorted(percentiles)
dtypes = self.dtypes
query_compiler = self
# If we operate on the original frame's labels, then if two columns have the same name but
# different one is `object` and one is numeric,, the JOIN behavior of SnowflakeQueryCompiler.concat
# will produce incorrect results. For example, consider the following dataframe, where an
# `object` column and `int64` column both share the label "a":
# +---+-----+---+-----+
# | a | a | b | c |
# +---+-----+---+-----+
# | 1 | 'x' | 3 | 'i' |
# +---+-----+---+-----+
# | 2 | 'y' | 4 | 'j' |
# +---+-----+---+-----+
# | 3 | 'x' | 5 | 'j' |
# +---+-----+---+-----+
# For all `object` columns in the frame, we will generate a query compiler with the computed
# `top`/`freq` statistics. Similarly, for the numeric columns we will generate a query compiler
# containing the `std`, `min`/`max`, and other numeric statistics:
# OBJECT QUERY COMPILER NUMERIC QUERY COMPILER
# +------+-----+-----+ +-----+-----+-----+
# | | a | c | | | a | b |
# +------+-----+-----+ +-----+-----+-----+
# | top | 'x' | 'j' | | min | 1 | 3 |
# +------+-----+-----+ +-----+-----+-----+ (additional aggregations omitted)
# | freq | 2 | 2 | | max | 3 | 5 |
# +------+-----+-----+ +-----+-----+-----+
# We `concat` these two query compilers (+ an additional one for the `count` statistic computed
# for all columns). Numeric columns will have NULL values for the `top` and `freq` statistics,
# and object columns will have NULL values for `min`, `max`, etc. This is accomplished by
# the `join="outer"` parameter, but it will still erroneously try to combine the aggregations
# of the object and numeric columns that share a label.
# To circumvent this, we relabel all columns with a simple integer index, and restore the
# correct labels at the very end after `concat`.
# The end result (before restoring the original pandas labels) should look something like this
# (many rows omitted for brevity):
# Column mapping: {0: "a", 1: "a", 2: "b", 3: "c"}
# +------+-----+-----+ +-----+-----+-----+
# | | 1 | 3 | | | 0 | 2 |
# +------+-----+-----+ +-----+-----+-----+
# | top | 'x' | 'j' | -- CONCAT -- | min | 1 | 3 |
# +------+-----+-----+ +-----+-----+-----+
# | freq | 2 | 2 | | max | 3 | 5 |
# +------+-----+-----+ +-----+-----+-----+
# =
# +------+-----+------+-----+------+
# | | 0 | 1 | 2 | 3 |
# +------+-----+------+-----+------+
# | top | NaN | 'x' | NaN | 'j' |
# +------+-----+------+-----+------+
# | freq | NaN | 2 | NaN | 2 |
# +------+-----+------+-----+------+
# | min | 1 | None | 3 | None |
# +------+-----+------+-----+------+
# | max | 3 | None | 5 | None |
# +------+-----+------+-----+------+
original_columns = self.columns
query_compiler = self.set_columns(list(range(len(self.columns))))
internal_frame = query_compiler._modin_frame
# Compute count for all columns regardless of dtype
query_compilers_to_concat = [
Expand Down Expand Up @@ -10326,10 +10380,14 @@ def get_qcs_for_numeric_and_datetime_cols(
assert (
len(query_compilers_to_concat) > 1
), "must have more than one QC to concat"
return query_compilers_to_concat[0].concat(
other=query_compilers_to_concat[1:],
axis=0,
join="outer",
return (
query_compilers_to_concat[0].concat(
other=query_compilers_to_concat[1:],
axis=0,
join="outer",
)
# Restore the original pandas labels
.set_columns(original_columns)
)

def sample(
Expand Down
32 changes: 17 additions & 15 deletions tests/integ/modin/frame/test_describe.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,26 +308,13 @@ def test_describe_multiindex(index, columns, include, expected_union_count):
)


DUP_COL_FAIL_REASON = "SNOW-1019479: describe on frames with mixed object/number columns with the same name fails"


@pytest.mark.parametrize(
"include, exclude, expected_union_count",
[
(None, None, 7),
pytest.param(
"all",
None,
0,
marks=pytest.mark.xfail(strict=True, reason=DUP_COL_FAIL_REASON),
),
("all", None, 12),
(np.number, None, 7),
pytest.param(
None,
float,
0,
marks=pytest.mark.xfail(strict=True, reason=DUP_COL_FAIL_REASON),
),
(None, float, 10),
(object, None, 5),
(None, object, 7),
(int, float, 5),
Expand All @@ -346,6 +333,21 @@ def test_describe_duplicate_columns(include, exclude, expected_union_count):
)


def test_describe_duplicate_columns_mixed():
# Test that describing a frame where there are multiple columns (including ones with numeric data
# but `object` dtype) that share the same label is correct.
data = [[5, 0, 1.0], [6, 3, 4.0]]

def helper(df):
# Convert first column to `object` dtype
df = df.astype({0: object})
df.columns = ["a"] * 3
return df.describe()

with SqlCounter(query_count=1, union_count=7):
eval_snowpark_pandas_result(*create_test_dfs(data), lambda df: helper(df))


@sql_count_checker(
query_count=3,
union_count=21,
Expand Down

0 comments on commit a259aa9

Please sign in to comment.