diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e823a177fa..97059747450 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ #### Bug Fixes - Fixed a bug that causes output of GroupBy.aggregate's columns to be ordered incorrectly. +- Fixed a bug where `DataFrame.describe` on a frame with duplicate columns of differing dtypes could cause an error or incorrect results. #### Improvements diff --git a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py index b35a5b4cb63..2c8ebbc8797 100644 --- a/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py +++ b/src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py @@ -10002,7 +10002,61 @@ def describe( # max NaN 3.0 NaN sorted_percentiles = sorted(percentiles) dtypes = self.dtypes - query_compiler = self + # If we operate on the original frame's labels, then if two columns have the same name but + # different one is `object` and one is numeric,, the JOIN behavior of SnowflakeQueryCompiler.concat + # will produce incorrect results. For example, consider the following dataframe, where an + # `object` column and `int64` column both share the label "a": + # +---+-----+---+-----+ + # | a | a | b | c | + # +---+-----+---+-----+ + # | 1 | 'x' | 3 | 'i' | + # +---+-----+---+-----+ + # | 2 | 'y' | 4 | 'j' | + # +---+-----+---+-----+ + # | 3 | 'x' | 5 | 'j' | + # +---+-----+---+-----+ + # For all `object` columns in the frame, we will generate a query compiler with the computed + # `top`/`freq` statistics. Similarly, for the numeric columns we will generate a query compiler + # containing the `std`, `min`/`max`, and other numeric statistics: + # OBJECT QUERY COMPILER NUMERIC QUERY COMPILER + # +------+-----+-----+ +-----+-----+-----+ + # | | a | c | | | a | b | + # +------+-----+-----+ +-----+-----+-----+ + # | top | 'x' | 'j' | | min | 1 | 3 | + # +------+-----+-----+ +-----+-----+-----+ (additional aggregations omitted) + # | freq | 2 | 2 | | max | 3 | 5 | + # +------+-----+-----+ +-----+-----+-----+ + # We `concat` these two query compilers (+ an additional one for the `count` statistic computed + # for all columns). Numeric columns will have NULL values for the `top` and `freq` statistics, + # and object columns will have NULL values for `min`, `max`, etc. This is accomplished by + # the `join="outer"` parameter, but it will still erroneously try to combine the aggregations + # of the object and numeric columns that share a label. + # To circumvent this, we relabel all columns with a simple integer index, and restore the + # correct labels at the very end after `concat`. + # The end result (before restoring the original pandas labels) should look something like this + # (many rows omitted for brevity): + # Column mapping: {0: "a", 1: "a", 2: "b", 3: "c"} + # +------+-----+-----+ +-----+-----+-----+ + # | | 1 | 3 | | | 0 | 2 | + # +------+-----+-----+ +-----+-----+-----+ + # | top | 'x' | 'j' | -- CONCAT -- | min | 1 | 3 | + # +------+-----+-----+ +-----+-----+-----+ + # | freq | 2 | 2 | | max | 3 | 5 | + # +------+-----+-----+ +-----+-----+-----+ + # = + # +------+-----+------+-----+------+ + # | | 0 | 1 | 2 | 3 | + # +------+-----+------+-----+------+ + # | top | NaN | 'x' | NaN | 'j' | + # +------+-----+------+-----+------+ + # | freq | NaN | 2 | NaN | 2 | + # +------+-----+------+-----+------+ + # | min | 1 | None | 3 | None | + # +------+-----+------+-----+------+ + # | max | 3 | None | 5 | None | + # +------+-----+------+-----+------+ + original_columns = self.columns + query_compiler = self.set_columns(list(range(len(self.columns)))) internal_frame = query_compiler._modin_frame # Compute count for all columns regardless of dtype query_compilers_to_concat = [ @@ -10326,10 +10380,14 @@ def get_qcs_for_numeric_and_datetime_cols( assert ( len(query_compilers_to_concat) > 1 ), "must have more than one QC to concat" - return query_compilers_to_concat[0].concat( - other=query_compilers_to_concat[1:], - axis=0, - join="outer", + return ( + query_compilers_to_concat[0].concat( + other=query_compilers_to_concat[1:], + axis=0, + join="outer", + ) + # Restore the original pandas labels + .set_columns(original_columns) ) def sample( diff --git a/tests/integ/modin/frame/test_describe.py b/tests/integ/modin/frame/test_describe.py index 2d4bb437e0c..a9668c5794f 100644 --- a/tests/integ/modin/frame/test_describe.py +++ b/tests/integ/modin/frame/test_describe.py @@ -308,26 +308,13 @@ def test_describe_multiindex(index, columns, include, expected_union_count): ) -DUP_COL_FAIL_REASON = "SNOW-1019479: describe on frames with mixed object/number columns with the same name fails" - - @pytest.mark.parametrize( "include, exclude, expected_union_count", [ (None, None, 7), - pytest.param( - "all", - None, - 0, - marks=pytest.mark.xfail(strict=True, reason=DUP_COL_FAIL_REASON), - ), + ("all", None, 12), (np.number, None, 7), - pytest.param( - None, - float, - 0, - marks=pytest.mark.xfail(strict=True, reason=DUP_COL_FAIL_REASON), - ), + (None, float, 10), (object, None, 5), (None, object, 7), (int, float, 5), @@ -346,6 +333,21 @@ def test_describe_duplicate_columns(include, exclude, expected_union_count): ) +def test_describe_duplicate_columns_mixed(): + # Test that describing a frame where there are multiple columns (including ones with numeric data + # but `object` dtype) that share the same label is correct. + data = [[5, 0, 1.0], [6, 3, 4.0]] + + def helper(df): + # Convert first column to `object` dtype + df = df.astype({0: object}) + df.columns = ["a"] * 3 + return df.describe() + + with SqlCounter(query_count=1, union_count=7): + eval_snowpark_pandas_result(*create_test_dfs(data), lambda df: helper(df)) + + @sql_count_checker( query_count=3, union_count=21,