Skip to content

Commit

Permalink
improve coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Mar 22, 2024
1 parent d96269c commit f334589
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
11 changes: 8 additions & 3 deletions narwhals/pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,12 @@ def agg_pandas( # noqa: PLR0913,PLR0915
for output_name, named_agg in simple_aggregations.items():
aggs[named_agg[0]].append(named_agg[1])
name_mapping[f"{named_agg[0]}_{named_agg[1]}"] = output_name
result_simple = grouped.agg(aggs)
try:
result_simple = grouped.agg(aggs)
except AttributeError as exc:
raise RuntimeError(
"Failed to aggregated - does your aggregation function return a scalar?"
) from exc
result_simple.columns = [f"{a}_{b}" for a, b in result_simple.columns]
result_simple = result_simple.rename(columns=name_mapping).reset_index()
else:
Expand Down Expand Up @@ -149,11 +154,11 @@ def func(df: Any) -> Any:
if implementation == "pandas":
import pandas as pd

if parse_version(pd.__version__) < parse_version("2.2.0"):
if parse_version(pd.__version__) < parse_version("2.2.0"): # pragma: no cover
result_complex = grouped.apply(func)
else:
result_complex = grouped.apply(func, include_groups=False)
else:
else: # pragma: no cover
result_complex = grouped.apply(func)

if result_simple is not None and not complex_aggs:
Expand Down
8 changes: 8 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,11 @@ def test_unique(df_raw: Any) -> None:
result = nw.to_native(df.unique("b").sort("b"))
expected = {"a": [1, 2], "b": [4, 6], "z": [7.0, 9.0]}
compare_dicts(result, expected)


@pytest.mark.parametrize("df_raw", [df_pandas_na, df_lazy_na])
def test_drop_nulls(df_raw: Any) -> None:
df = nw.LazyFrame(df_raw)
result = nw.to_native(df.select(nw.col("a").drop_nulls()))
expected = {"a": [3, 2]}
compare_dicts(result, expected)
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ def compare_dicts(result: dict[str, Any], expected: dict[str, Any]) -> None:
for key in expected:
for lhs, rhs in zip(result[key], expected[key]):
if isinstance(lhs, float):
assert abs(lhs - rhs) < 1e-6
assert abs(lhs - rhs) < 1e-6, (lhs, rhs)
else:
assert lhs == rhs
assert lhs == rhs, (lhs, rhs)

0 comments on commit f334589

Please sign in to comment.