Skip to content

Commit

Permalink
feat: add mean_horizontal for DuckDB (#1846)
Browse files Browse the repository at this point in the history
  • Loading branch information
raisadz authored Jan 21, 2025
1 parent 8a1b663 commit 38fadec
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
28 changes: 28 additions & 0 deletions narwhals/_duckdb/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,34 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
kwargs={"exprs": exprs},
)

def mean_horizontal(self, *exprs: IntoDuckDBExpr) -> DuckDBExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
cols = [c for _expr in parsed_exprs for c in _expr(df)]
col_name = get_column_name(df, cols[0])
return [
(
reduce(
operator.add,
(CoalesceOperator(col, ConstantExpression(0)) for col in cols),
)
/ reduce(operator.add, (col.isnotnull().cast("int") for col in cols))
).alias(col_name)
]

return DuckDBExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="mean_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"exprs": exprs},
)

def when(
self,
*predicates: IntoDuckDBExpr,
Expand Down
10 changes: 2 additions & 8 deletions tests/expr_and_series/mean_horizontal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,15 @@


@pytest.mark.parametrize("col_expr", [nw.col("a"), "a"])
def test_meanh(
constructor: Constructor, col_expr: Any, request: pytest.FixtureRequest
) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_meanh(constructor: Constructor, col_expr: Any) -> None:
data = {"a": [1, 3, None, None], "b": [4, None, 6, None]}
df = nw.from_native(constructor(data))
result = df.select(horizontal_mean=nw.mean_horizontal(col_expr, nw.col("b")))
expected = {"horizontal_mean": [2.5, 3.0, 6.0, None]}
assert_equal_data(result, expected)


def test_meanh_all(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_meanh_all(constructor: Constructor) -> None:
data = {"a": [2, 4, 6], "b": [10, 20, 30]}
df = nw.from_native(constructor(data))
result = df.select(nw.mean_horizontal(nw.all()))
Expand Down

0 comments on commit 38fadec

Please sign in to comment.