diff --git a/CHANGELOG.md b/CHANGELOG.md index 200e6318eb4..11dffa17a2a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ ### Bug Fixes - Fixed a bug in Local Testing's implementation of LEFT ANTI and LEFT SEMI joins where rows with null values are dropped. +- Fixed a bug in Local Testing's implementation of `count_distinct`. - Fixed a bug in Local Testing's implementation where VARIANT columns raise errors at `DataFrame.collect`. ### Deprecations: diff --git a/src/snowflake/snowpark/mock/_functions.py b/src/snowflake/snowpark/mock/_functions.py index 709925b05e9..6a00fa86096 100644 --- a/src/snowflake/snowpark/mock/_functions.py +++ b/src/snowflake/snowpark/mock/_functions.py @@ -218,26 +218,13 @@ def mock_count_distinct(*cols: ColumnEmulator) -> ColumnEmulator: we iterate over each row and then each col to check if there exists NULL value, if the col is NULL, we do not count that row. """ - dict_data = {} + df = TableEmulator() for i in range(len(cols)): - dict_data[f"temp_col_{i}"] = cols[i] - rows = len(cols[0]) - temp_table = TableEmulator(dict_data, index=[i for i in range(len(cols[0]))]) - temp_table = temp_table.reset_index() - to_drop_index = set() - for col in cols: - for i in range(rows): - if col[col.index[i]] is None: - to_drop_index.add(i) - break - temp_table = temp_table.drop(index=list(to_drop_index)) - temp_table = temp_table.drop_duplicates(subset=list(dict_data.keys())) - count_column = temp_table.count() - if isinstance(count_column, ColumnEmulator): - count_column.sf_type = ColumnType(LongType(), False) - return ColumnEmulator( - data=round(count_column, 5), sf_type=ColumnType(LongType(), False) - ) + df[cols[i].name] = cols[i] + df = df.dropna() + combined = df[df.columns].apply(lambda row: tuple(row), axis=1).dropna() + res = combined.nunique() + return ColumnEmulator(data=res, sf_type=ColumnType(LongType(), False)) @patch("median") diff --git a/tests/integ/scala/test_dataframe_aggregate_suite.py b/tests/integ/scala/test_dataframe_aggregate_suite.py index 70258865e36..4b2e99f4999 100644 --- a/tests/integ/scala/test_dataframe_aggregate_suite.py +++ b/tests/integ/scala/test_dataframe_aggregate_suite.py @@ -910,6 +910,38 @@ def test_multiple_column_distinct_count(session): res.sort(key=lambda x: x[0]) assert res == [Row("a", 2), Row("x", 1)] + aa = [ + "a1", + "a1", + "a1", + "a1", + "a2", + "a2", + "a2", + "a3", + "a3", + "a4", + ] + + bb = [ + "b1", + "b2", + "b2", + "b3", + "b1", + "b2", + "b5", + "b1", + "b4", + "b1", + ] + + df = session.create_dataframe(list(zip(aa, bb)), ["a", "b"]) + + assert df.group_by("a").agg(count_distinct("b").alias("C")).select( + avg("C").alias("C") + ).collect() == [Row(2.25)] + @pytest.mark.localtest def test_zero_count(session):