Skip to content

Commit

Permalink
SNOW-1105953 Fix Local Testing implementation of count_distinct (#1304)
Browse files Browse the repository at this point in the history
* Fix bug

* Add changelog entry

* Update tests/integ/scala/test_dataframe_aggregate_suite.py

Co-authored-by: Jamison Rose <[email protected]>

---------

Co-authored-by: Jamison Rose <[email protected]>
  • Loading branch information
sfc-gh-stan and sfc-gh-jrose authored Mar 13, 2024
1 parent 2e1f67a commit e695d3b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
### Bug Fixes

- Fixed a bug in Local Testing's implementation of LEFT ANTI and LEFT SEMI joins where rows with null values are dropped.
- Fixed a bug in Local Testing's implementation of `count_distinct`.
- Fixed a bug in Local Testing's implementation where VARIANT columns raise errors at `DataFrame.collect`.

### Deprecations:
Expand Down
25 changes: 6 additions & 19 deletions src/snowflake/snowpark/mock/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,26 +218,13 @@ def mock_count_distinct(*cols: ColumnEmulator) -> ColumnEmulator:
we iterate over each row and then each col to check if there exists NULL value, if the col is NULL,
we do not count that row.
"""
dict_data = {}
df = TableEmulator()
for i in range(len(cols)):
dict_data[f"temp_col_{i}"] = cols[i]
rows = len(cols[0])
temp_table = TableEmulator(dict_data, index=[i for i in range(len(cols[0]))])
temp_table = temp_table.reset_index()
to_drop_index = set()
for col in cols:
for i in range(rows):
if col[col.index[i]] is None:
to_drop_index.add(i)
break
temp_table = temp_table.drop(index=list(to_drop_index))
temp_table = temp_table.drop_duplicates(subset=list(dict_data.keys()))
count_column = temp_table.count()
if isinstance(count_column, ColumnEmulator):
count_column.sf_type = ColumnType(LongType(), False)
return ColumnEmulator(
data=round(count_column, 5), sf_type=ColumnType(LongType(), False)
)
df[cols[i].name] = cols[i]
df = df.dropna()
combined = df[df.columns].apply(lambda row: tuple(row), axis=1).dropna()
res = combined.nunique()
return ColumnEmulator(data=res, sf_type=ColumnType(LongType(), False))


@patch("median")
Expand Down
32 changes: 32 additions & 0 deletions tests/integ/scala/test_dataframe_aggregate_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,38 @@ def test_multiple_column_distinct_count(session):
res.sort(key=lambda x: x[0])
assert res == [Row("a", 2), Row("x", 1)]

aa = [
"a1",
"a1",
"a1",
"a1",
"a2",
"a2",
"a2",
"a3",
"a3",
"a4",
]

bb = [
"b1",
"b2",
"b2",
"b3",
"b1",
"b2",
"b5",
"b1",
"b4",
"b1",
]

df = session.create_dataframe(list(zip(aa, bb)), ["a", "b"])

assert df.group_by("a").agg(count_distinct("b").alias("C")).select(
avg("C").alias("C")
).collect() == [Row(2.25)]


@pytest.mark.localtest
def test_zero_count(session):
Expand Down

0 comments on commit e695d3b

Please sign in to comment.