Skip to content

Commit

Permalink
fix: when/then/otherwise output name was not consistent across backen…
Browse files Browse the repository at this point in the history
…ds (#1833)
  • Loading branch information
MarcoGorelli authored Jan 20, 2025
1 parent 82da089 commit c676b68
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 22 deletions.
2 changes: 1 addition & 1 deletion narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:
except TypeError:
# `self._then_value` is a scalar and can't be converted to an expression
value_series = plx._create_series_from_scalar(
self._then_value, reference_series=condition
self._then_value, reference_series=condition.alias("literal")
)

condition_native, value_series_native = broadcast_series(
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,8 @@ def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:

if is_scalar:
_df = condition.to_frame("a")
_df["tmp"] = value_sequence[0]
value_series = _df["tmp"]
_df["literal"] = value_sequence[0]
value_series = _df["literal"]
else:
value_series = value_sequence

Expand Down
17 changes: 11 additions & 6 deletions narwhals/_duckdb/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,22 +248,27 @@ def __call__(self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]:
value = parse_into_expr(self._then_value, namespace=plx)(df)[0]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
value = ConstantExpression(self._then_value)
value = ConstantExpression(self._then_value).alias("literal")
value = cast("duckdb.Expression", value)
value_name = get_column_name(df, value)

if self._otherwise_value is None:
return [CaseExpression(condition=condition, value=value)]
return [CaseExpression(condition=condition, value=value).alias(value_name)]
try:
otherwise_expr = parse_into_expr(self._otherwise_value, namespace=plx)
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
return [
CaseExpression(condition=condition, value=value).otherwise(
ConstantExpression(self._otherwise_value)
)
CaseExpression(condition=condition, value=value)
.otherwise(ConstantExpression(self._otherwise_value))
.alias(value_name)
]
otherwise = otherwise_expr(df)[0]
return [CaseExpression(condition=condition, value=value).otherwise(otherwise)]
return [
CaseExpression(condition=condition, value=value)
.otherwise(otherwise)
.alias(value_name)
]

def then(self, value: DuckDBExpr | Any) -> DuckDBThen:
self._then_value = value
Expand Down
3 changes: 1 addition & 2 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,13 +467,12 @@ def __call__(self, df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
except TypeError:
# `self._then_value` is a scalar and can't be converted to an expression
value_series = plx._create_series_from_scalar(
self._then_value, reference_series=condition
self._then_value, reference_series=condition.alias("literal")
)

condition_native, value_series_native = broadcast_align_and_extract_native(
condition, value_series
)

if self._otherwise_value is None:
return [
value_series._from_native_series(
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def broadcast_align_and_extract_native(
s = rhs._native_series
return (
lhs._native_series,
s.__class__(s.iloc[0], index=lhs_index, dtype=s.dtype),
s.__class__(s.iloc[0], index=lhs_index, dtype=s.dtype, name=rhs.name),
)
if lhs.len() == 1:
# broadcast
Expand Down
17 changes: 7 additions & 10 deletions tests/expr_and_series/when_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ def test_when(constructor: Constructor) -> None:
"a_when": [3, None, None],
}
assert_equal_data(result, expected)
result = df.select(nw.when(nw.col("a") == 1).then(value=3))
expected = {
"literal": [3, None, None],
}
assert_equal_data(result, expected)


def test_when_otherwise(constructor: Constructor) -> None:
Expand Down Expand Up @@ -121,22 +126,14 @@ def test_otherwise_expression(constructor: Constructor) -> None:
assert_equal_data(result, expected)


def test_when_then_otherwise_into_expr(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_when_then_otherwise_into_expr(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") > 1).then("c").otherwise("e"))
expected = {"c": [7, 5, 6]}
assert_equal_data(result, expected)


def test_when_then_otherwise_lit_str(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_when_then_otherwise_lit_str(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") > 1).then(nw.col("b")).otherwise(nw.lit("z")))
expected = {"b": ["z", "b", "c"]}
Expand Down

0 comments on commit c676b68

Please sign in to comment.