From bdbdc036b63fdacceb894465b5549b99284f46f0 Mon Sep 17 00:00:00 2001 From: Sophie Tan Date: Tue, 12 Mar 2024 10:14:41 -0700 Subject: [PATCH 1/3] Fix bug --- src/snowflake/snowpark/mock/_functions.py | 25 ++++----------- .../scala/test_dataframe_aggregate_suite.py | 32 +++++++++++++++++++ 2 files changed, 38 insertions(+), 19 deletions(-) diff --git a/src/snowflake/snowpark/mock/_functions.py b/src/snowflake/snowpark/mock/_functions.py index ea0c6bc5225..3c8af86329c 100644 --- a/src/snowflake/snowpark/mock/_functions.py +++ b/src/snowflake/snowpark/mock/_functions.py @@ -217,26 +217,13 @@ def mock_count_distinct(*cols: ColumnEmulator) -> ColumnEmulator: we iterate over each row and then each col to check if there exists NULL value, if the col is NULL, we do not count that row. """ - dict_data = {} + df = TableEmulator() for i in range(len(cols)): - dict_data[f"temp_col_{i}"] = cols[i] - rows = len(cols[0]) - temp_table = TableEmulator(dict_data, index=[i for i in range(len(cols[0]))]) - temp_table = temp_table.reset_index() - to_drop_index = set() - for col in cols: - for i in range(rows): - if col[col.index[i]] is None: - to_drop_index.add(i) - break - temp_table = temp_table.drop(index=list(to_drop_index)) - temp_table = temp_table.drop_duplicates(subset=list(dict_data.keys())) - count_column = temp_table.count() - if isinstance(count_column, ColumnEmulator): - count_column.sf_type = ColumnType(LongType(), False) - return ColumnEmulator( - data=round(count_column, 5), sf_type=ColumnType(LongType(), False) - ) + df[cols[i].name] = cols[i] + df = df.dropna() + combined = df[df.columns].apply(lambda row: tuple(row), axis=1).dropna() + res = combined.nunique() + return ColumnEmulator(data=res, sf_type=ColumnType(LongType(), False)) @patch("median") diff --git a/tests/integ/scala/test_dataframe_aggregate_suite.py b/tests/integ/scala/test_dataframe_aggregate_suite.py index 70258865e36..a6e6d5bb333 100644 --- a/tests/integ/scala/test_dataframe_aggregate_suite.py +++ b/tests/integ/scala/test_dataframe_aggregate_suite.py @@ -910,6 +910,38 @@ def test_multiple_column_distinct_count(session): res.sort(key=lambda x: x[0]) assert res == [Row("a", 2), Row("x", 1)] + aa = [ + "a1", + "a1", + "a1", + "a1", + "a2", + "a2", + "a2", + "a3", + "a3", + "a4", + ] + + bb = [ + "b1", + "b2", + "b2", + "b3", + "b1", + "b2", + "b5", + "b1", + "b4", + "b1", + ] + + df = session.create_dataframe([[a, b] for a, b in zip(aa, bb)], ["a", "b"]) + + assert df.group_by("a").agg(count_distinct("b").alias("C")).select( + avg("C").alias("C") + ).collect() == [Row(2.25)] + @pytest.mark.localtest def test_zero_count(session): From cadc177973178c334cd89e43ab8bed309fa775e8 Mon Sep 17 00:00:00 2001 From: Sophie Tan Date: Tue, 12 Mar 2024 10:19:42 -0700 Subject: [PATCH 2/3] Add changelog entry --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b8adbfeb0a9..87747e69137 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ ### Bug Fixes - Fixed a bug in Local Testing's implementation of LEFT ANTI and LEFT SEMI joins where rows with null values are dropped. +- Fixed a bug in Local Testing's implementation of `count_distinct`. ### Deprecations: From a93160ec36987ee4d16a9c39bd6410614d6af1ab Mon Sep 17 00:00:00 2001 From: Sophie Tan Date: Tue, 12 Mar 2024 16:08:10 -0700 Subject: [PATCH 3/3] Update tests/integ/scala/test_dataframe_aggregate_suite.py Co-authored-by: Jamison Rose --- tests/integ/scala/test_dataframe_aggregate_suite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integ/scala/test_dataframe_aggregate_suite.py b/tests/integ/scala/test_dataframe_aggregate_suite.py index a6e6d5bb333..4b2e99f4999 100644 --- a/tests/integ/scala/test_dataframe_aggregate_suite.py +++ b/tests/integ/scala/test_dataframe_aggregate_suite.py @@ -936,7 +936,7 @@ def test_multiple_column_distinct_count(session): "b1", ] - df = session.create_dataframe([[a, b] for a, b in zip(aa, bb)], ["a", "b"]) + df = session.create_dataframe(list(zip(aa, bb)), ["a", "b"]) assert df.group_by("a").agg(count_distinct("b").alias("C")).select( avg("C").alias("C")