Skip to content

Commit

Permalink
feat: add DuckDB: nw.nth, nw.sum_horizontal, nw.concat_str, group_by …
Browse files Browse the repository at this point in the history
…with drop_null_keys (#1832)
  • Loading branch information
MarcoGorelli authored Jan 20, 2025
1 parent 02544e4 commit 7e6c086
Show file tree
Hide file tree
Showing 14 changed files with 204 additions and 52 deletions.
14 changes: 7 additions & 7 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,19 +378,19 @@ def concat_str(
dtypes = import_dtypes_module(self._version)

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
series = (
s._native_series
for _expr in parsed_exprs
for s in _expr.cast(dtypes.String())(df)
)
compliant_series_list = [
s for _expr in parsed_exprs for s in _expr.cast(dtypes.String())(df)
]
null_handling = "skip" if ignore_nulls else "emit_null"
result_series = pc.binary_join_element_wise(
*series, separator, null_handling=null_handling
*(s._native_series for s in compliant_series_list),
separator,
null_handling=null_handling,
)
return [
ArrowSeries(
native_series=result_series,
name="",
name=compliant_series_list[0].name,
backend_version=self._backend_version,
version=self._version,
)
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def func(df: DaskLazyFrame) -> list[dx.Series]:
init_value,
)

return [result]
return [result.rename(null_mask[0].name)]

return DaskExpr(
call=func,
Expand Down
4 changes: 0 additions & 4 deletions narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,6 @@ def _from_native_frame(self: Self, df: Any) -> Self:
def group_by(self: Self, *keys: str, drop_null_keys: bool) -> DuckDBGroupBy:
from narwhals._duckdb.group_by import DuckDBGroupBy

if drop_null_keys:
msg = "todo"
raise NotImplementedError(msg)

return DuckDBGroupBy(
compliant_frame=self, keys=list(keys), drop_null_keys=drop_null_keys
)
Expand Down
26 changes: 26 additions & 0 deletions narwhals/_duckdb/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,32 @@ def func(_: DuckDBLazyFrame) -> list[duckdb.Expression]:
kwargs={},
)

@classmethod
def from_column_indices(
cls: type[Self],
*column_indices: int,
backend_version: tuple[int, ...],
version: Version,
) -> Self:
def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
from duckdb import ColumnExpression

columns = df.columns

return [ColumnExpression(columns[i]) for i in column_indices]

return cls(
func,
depth=0,
function_name="nth",
root_names=None,
output_names=None,
returns_scalar=False,
backend_version=backend_version,
version=version,
kwargs={},
)

def _from_call(
self,
call: Callable[..., duckdb.Expression],
Expand Down
7 changes: 5 additions & 2 deletions narwhals/_duckdb/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ def __init__(
keys: list[str],
drop_null_keys: bool, # noqa: FBT001
) -> None:
self._compliant_frame = compliant_frame
if drop_null_keys:
self._compliant_frame = compliant_frame.drop_nulls(subset=None)
else:
self._compliant_frame = compliant_frame
self._keys = keys

def agg(
Expand Down Expand Up @@ -46,7 +49,7 @@ def agg(
try:
return self._compliant_frame._from_native_frame(
self._compliant_frame._native_frame.aggregate(
agg_columns, group_expr=",".join(self._keys)
agg_columns, group_expr=",".join(f'"{key}"' for key in self._keys)
)
)
except ValueError as exc: # pragma: no cover
Expand Down
111 changes: 111 additions & 0 deletions narwhals/_duckdb/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from functools import reduce
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Literal
from typing import Sequence
from typing import cast
Expand Down Expand Up @@ -74,6 +75,83 @@ def concat(
)
return first._from_native_frame(res)

def concat_str(
self,
exprs: Iterable[IntoDuckDBExpr],
*more_exprs: IntoDuckDBExpr,
separator: str,
ignore_nulls: bool,
) -> DuckDBExpr:
parsed_exprs = [
*parse_into_exprs(*exprs, namespace=self),
*parse_into_exprs(*more_exprs, namespace=self),
]
from duckdb import CaseExpression
from duckdb import ConstantExpression
from duckdb import FunctionExpression

def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
cols = [s for _expr in parsed_exprs for s in _expr(df)]
null_mask = [s.isnull() for _expr in parsed_exprs for s in _expr(df)]
first_column_name = get_column_name(df, cols[0])

if not ignore_nulls:
null_mask_result = reduce(lambda x, y: x | y, null_mask)
cols_separated = [
y
for x in [
(col.cast("string"),)
if i == len(cols) - 1
else (col.cast("string"), ConstantExpression(separator))
for i, col in enumerate(cols)
]
for y in x
]
result = CaseExpression(
condition=~null_mask_result,
value=FunctionExpression("concat", *cols_separated),
)
else:
init_value, *values = [
CaseExpression(~nm, col.cast("string")).otherwise(
ConstantExpression("")
)
for col, nm in zip(cols, null_mask)
]
separators = (
CaseExpression(nm, ConstantExpression("")).otherwise(
ConstantExpression(separator)
)
for nm in null_mask[:-1]
)
result = reduce(
lambda x, y: FunctionExpression("concat", x, y),
(
FunctionExpression("concat", s, v)
for s, v in zip(separators, values)
),
init_value,
)

return [result.alias(first_column_name)]

return DuckDBExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="concat_str",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={
"exprs": exprs,
"more_exprs": more_exprs,
"separator": separator,
"ignore_nulls": ignore_nulls,
},
)

def all_horizontal(self, *exprs: IntoDuckDBExpr) -> DuckDBExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

Expand Down Expand Up @@ -158,6 +236,34 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
kwargs={"exprs": exprs},
)

def sum_horizontal(self, *exprs: IntoDuckDBExpr) -> DuckDBExpr:
from duckdb import CoalesceOperator
from duckdb import ConstantExpression

parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
cols = [c for _expr in parsed_exprs for c in _expr(df)]
col_name = get_column_name(df, cols[0])
return [
reduce(
operator.add,
(CoalesceOperator(col, ConstantExpression(0)) for col in cols),
).alias(col_name)
]

return DuckDBExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="sum_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"exprs": exprs},
)

def when(
self,
*predicates: IntoDuckDBExpr,
Expand All @@ -173,6 +279,11 @@ def col(self, *column_names: str) -> DuckDBExpr:
*column_names, backend_version=self._backend_version, version=self._version
)

def nth(self, *column_indices: int) -> DuckDBExpr:
return DuckDBExpr.from_column_indices(
*column_indices, backend_version=self._backend_version, version=self._version
)

def lit(self, value: Any, dtype: DType | None) -> DuckDBExpr:
from duckdb import ConstantExpression

Expand Down
25 changes: 25 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,31 @@ def func(_: SparkLikeLazyFrame) -> list[Column]:
kwargs={},
)

@classmethod
def from_column_indices(
cls: type[Self],
*column_indices: int,
backend_version: tuple[int, ...],
version: Version,
) -> Self:
def func(df: SparkLikeLazyFrame) -> list[Column]:
from pyspark.sql import functions as F # noqa: N812

columns = df.columns
return [F.col(columns[i]) for i in column_indices]

return cls(
func,
depth=0,
function_name="nth",
root_names=None,
output_names=None,
returns_scalar=False,
backend_version=backend_version,
version=version,
kwargs={},
)

def _from_call(
self,
call: Callable[..., Column],
Expand Down
18 changes: 14 additions & 4 deletions narwhals/_spark_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def col(self, *column_names: str) -> SparkLikeExpr:
*column_names, backend_version=self._backend_version, version=self._version
)

def nth(self, *column_indices: int) -> SparkLikeExpr:
return SparkLikeExpr.from_column_indices(
*column_indices, backend_version=self._backend_version, version=self._version
)

def lit(self, value: object, dtype: DType | None) -> SparkLikeExpr:
if dtype is not None:
msg = "todo"
Expand Down Expand Up @@ -293,19 +298,24 @@ def concat_str(
]

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = (s.cast(StringType()) for _expr in parsed_exprs for s in _expr(df))
cols = [s for _expr in parsed_exprs for s in _expr(df)]
cols_casted = [s.cast(StringType()) for s in cols]
null_mask = [F.isnull(s) for _expr in parsed_exprs for s in _expr(df)]
first_column_name = get_column_name(df, cols[0])

if not ignore_nulls:
null_mask_result = reduce(lambda x, y: x | y, null_mask)
result = F.when(
~null_mask_result,
reduce(lambda x, y: F.format_string(f"%s{separator}%s", x, y), cols),
reduce(
lambda x, y: F.format_string(f"%s{separator}%s", x, y),
cols_casted,
),
).otherwise(F.lit(None))
else:
init_value, *values = [
F.when(~nm, col).otherwise(F.lit(""))
for col, nm in zip(cols, null_mask)
for col, nm in zip(cols_casted, null_mask)
]

separators = (
Expand All @@ -318,7 +328,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
init_value,
)

return [result]
return [result.alias(first_column_name)]

return SparkLikeExpr(
call=func,
Expand Down
2 changes: 0 additions & 2 deletions tests/expr_and_series/all_horizontal_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def test_allh_nth(
) -> None:
if "polars" in str(constructor) and POLARS_VERSION < (1, 0):
request.applymarker(pytest.mark.xfail)
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {
"a": [False, False, True],
"b": [False, True, True],
Expand Down
16 changes: 9 additions & 7 deletions tests/expr_and_series/concat_str_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

import narwhals.stable.v1 as nw
from tests.utils import POLARS_VERSION
from tests.utils import Constructor
from tests.utils import assert_equal_data

Expand All @@ -27,7 +28,8 @@ def test_concat_str(
expected: list[str],
request: pytest.FixtureRequest,
) -> None:
if "duckdb" in str(constructor):
if "polars" in str(constructor) and POLARS_VERSION < (1, 0, 0):
# nth only available after 1.0
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = (
Expand All @@ -49,16 +51,16 @@ def test_concat_str(
assert_equal_data(result, {"full_sentence": expected})
result = (
df.select(
"a",
nw.col("a").alias("a_original"),
nw.concat_str(
nw.col("a") * 2,
nw.nth(0) * 2,
nw.col("b"),
nw.col("c"),
separator=" ",
ignore_nulls=ignore_nulls, # default behavior is False
).alias("full_sentence"),
),
)
.sort("a")
.select("full_sentence")
.sort("a_original")
.select("a")
)
assert_equal_data(result, {"full_sentence": expected})
assert_equal_data(result, {"a": expected})
2 changes: 0 additions & 2 deletions tests/expr_and_series/nth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ def test_nth(
expected: dict[str, list[int]],
request: pytest.FixtureRequest,
) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
if "polars" in str(constructor) and POLARS_VERSION < (1, 0, 0):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
Expand Down
Loading

0 comments on commit 7e6c086

Please sign in to comment.