diff --git a/crates/polars-plan/src/dsl/function_expr/strings.rs b/crates/polars-plan/src/dsl/function_expr/strings.rs index 205ecb6f0527..46d4b524c14a 100644 --- a/crates/polars-plan/src/dsl/function_expr/strings.rs +++ b/crates/polars-plan/src/dsl/function_expr/strings.rs @@ -466,6 +466,7 @@ fn extract_many( ascii_case_insensitive: bool, overlapping: bool, ) -> PolarsResult { + _check_same_length(s, "extract_many")?; let ca = s[0].str()?; let patterns = &s[1]; @@ -524,6 +525,7 @@ pub(super) fn len_bytes(s: &Column) -> PolarsResult { #[cfg(feature = "regex")] pub(super) fn contains(s: &[Column], literal: bool, strict: bool) -> PolarsResult { + _check_same_length(s, "contains")?; let ca = s[0].str()?; let pat = s[1].str()?; ca.contains_chunked(pat, literal, strict) @@ -532,6 +534,7 @@ pub(super) fn contains(s: &[Column], literal: bool, strict: bool) -> PolarsResul #[cfg(feature = "regex")] pub(super) fn find(s: &[Column], literal: bool, strict: bool) -> PolarsResult { + _check_same_length(s, "find")?; let ca = s[0].str()?; let pat = s[1].str()?; ca.find_chunked(pat, literal, strict) @@ -539,6 +542,7 @@ pub(super) fn find(s: &[Column], literal: bool, strict: bool) -> PolarsResult PolarsResult { + _check_same_length(s, "ends_with")?; let ca = &s[0].str()?.as_binary(); let suffix = &s[1].str()?.as_binary(); @@ -546,6 +550,7 @@ pub(super) fn ends_with(s: &[Column]) -> PolarsResult { } pub(super) fn starts_with(s: &[Column]) -> PolarsResult { + _check_same_length(s, "starts_with")?; let ca = s[0].str()?; let prefix = s[1].str()?; Ok(ca.starts_with_chunked(prefix).into_column()) @@ -579,6 +584,7 @@ pub(super) fn pad_end(s: &Column, length: usize, fill_char: char) -> PolarsResul #[cfg(feature = "string_pad")] pub(super) fn zfill(s: &[Column]) -> PolarsResult { + _check_same_length(s, "zfill")?; let ca = s[0].str()?; let length_s = s[1].strict_cast(&DataType::UInt64)?; let length = length_s.u64()?; @@ -586,30 +592,35 @@ pub(super) fn zfill(s: &[Column]) -> PolarsResult { } pub(super) fn strip_chars(s: &[Column]) -> PolarsResult { + _check_same_length(s, "strip_chars")?; let ca = s[0].str()?; let pat_s = &s[1]; ca.strip_chars(pat_s).map(|ok| ok.into_column()) } pub(super) fn strip_chars_start(s: &[Column]) -> PolarsResult { + _check_same_length(s, "strip_chars_start")?; let ca = s[0].str()?; let pat_s = &s[1]; ca.strip_chars_start(pat_s).map(|ok| ok.into_column()) } pub(super) fn strip_chars_end(s: &[Column]) -> PolarsResult { + _check_same_length(s, "strip_chars_end")?; let ca = s[0].str()?; let pat_s = &s[1]; ca.strip_chars_end(pat_s).map(|ok| ok.into_column()) } pub(super) fn strip_prefix(s: &[Column]) -> PolarsResult { + _check_same_length(s, "strip_prefix")?; let ca = s[0].str()?; let prefix = s[1].str()?; Ok(ca.strip_prefix(prefix).into_column()) } pub(super) fn strip_suffix(s: &[Column]) -> PolarsResult { + _check_same_length(s, "strip_suffix")?; let ca = s[0].str()?; let suffix = s[1].str()?; Ok(ca.strip_suffix(suffix).into_column()) @@ -1023,11 +1034,17 @@ fn _ensure_lengths(s: &[Column]) -> bool { .all(|series| series.len() == 1 || series.len() == len) } -pub(super) fn str_slice(s: &[Column]) -> PolarsResult { +fn _check_same_length(s: &[Column], fn_name: &str) -> Result<(), PolarsError> { polars_ensure!( _ensure_lengths(s), - ComputeError: "all series in `str_slice` should have equal or unit length", + ComputeError: "all series in `str.{}()` should have equal or unit length", + fn_name ); + Ok(()) +} + +pub(super) fn str_slice(s: &[Column]) -> PolarsResult { + _check_same_length(s, "slice")?; let ca = s[0].str()?; let offset = &s[1]; let length = &s[2]; @@ -1035,20 +1052,14 @@ pub(super) fn str_slice(s: &[Column]) -> PolarsResult { } pub(super) fn str_head(s: &[Column]) -> PolarsResult { - polars_ensure!( - _ensure_lengths(s), - ComputeError: "all series in `str_head` should have equal or unit length", - ); + _check_same_length(s, "head")?; let ca = s[0].str()?; let n = &s[1]; Ok(ca.str_head(n)?.into_column()) } pub(super) fn str_tail(s: &[Column]) -> PolarsResult { - polars_ensure!( - _ensure_lengths(s), - ComputeError: "all series in `str_tail` should have equal or unit length", - ); + _check_same_length(s, "tail")?; let ca = s[0].str()?; let n = &s[1]; Ok(ca.str_tail(n)?.into_column()) @@ -1092,6 +1103,7 @@ pub(super) fn json_decode( #[cfg(feature = "extract_jsonpath")] pub(super) fn json_path_match(s: &[Column]) -> PolarsResult { + _check_same_length(s, "json_path_match")?; let ca = s[0].str()?; let pat = s[1].str()?; Ok(ca.json_path_match(pat)?.into_column()) diff --git a/py-polars/tests/unit/operations/namespaces/string/test_pad.py b/py-polars/tests/unit/operations/namespaces/string/test_pad.py index 06df899e41ff..9e076f0521c9 100644 --- a/py-polars/tests/unit/operations/namespaces/string/test_pad.py +++ b/py-polars/tests/unit/operations/namespaces/string/test_pad.py @@ -1,6 +1,9 @@ from __future__ import annotations +import pytest + import polars as pl +from polars.exceptions import ComputeError from polars.testing import assert_frame_equal @@ -88,6 +91,12 @@ def test_str_zfill_expr() -> None: assert_frame_equal(out, expected) +def test_str_zfill_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ComputeError, match="should have equal or unit length"): + df.select(pl.col("num").str.zfill(pl.Series([1, 2]))) + + def test_pad_end_unicode() -> None: lf = pl.LazyFrame({"a": ["Café", "345", "東京", None]}) diff --git a/py-polars/tests/unit/operations/namespaces/string/test_string.py b/py-polars/tests/unit/operations/namespaces/string/test_string.py index 4649897644b4..f951629f15ad 100644 --- a/py-polars/tests/unit/operations/namespaces/string/test_string.py +++ b/py-polars/tests/unit/operations/namespaces/string/test_string.py @@ -54,6 +54,12 @@ def test_str_slice_expr() -> None: df.select(pl.col("a").str.slice(0, -1)) +def test_str_slice_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ComputeError, match="should have equal or unit length"): + df.select(pl.col("num").str.slice(pl.Series([1, 2]))) + + @pytest.mark.parametrize( ("input", "n", "output"), [ @@ -115,6 +121,12 @@ def test_str_head_expr() -> None: assert_frame_equal(out, expected) +def test_str_head_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ComputeError, match="should have equal or unit length"): + df.select(pl.col("num").str.head(pl.Series([1, 2]))) + + @pytest.mark.parametrize( ("input", "n", "output"), [ @@ -176,6 +188,12 @@ def test_str_tail_expr() -> None: assert_frame_equal(out, expected) +def test_str_tail_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ComputeError, match="should have equal or unit length"): + df.select(pl.col("num").str.tail(pl.Series([1, 2]))) + + def test_str_slice_multibyte() -> None: ref = "你好世界" s = pl.Series([ref]) @@ -212,6 +230,12 @@ def test_str_contains() -> None: assert_series_equal(s.str.contains("mes"), expected) +def test_str_contains_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ComputeError, match="should have equal or unit length"): + df.select(pl.col("num").str.contains(pl.Series(["a", "b"]))) # type: ignore [arg-type] + + def test_count_match_literal() -> None: s = pl.Series(["12 dbc 3xy", "cat\\w", "1zy3\\d\\d", None]) out = s.str.count_matches(r"\d", literal=True) @@ -338,6 +362,12 @@ def test_str_find_escaped_chars() -> None: ) +def test_str_find_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ComputeError, match="should have equal or unit length"): + df.select(pl.col("num").str.find(pl.Series(["a", "b"]))) # type: ignore [arg-type] + + def test_hex_decode_return_dtype() -> None: data = {"a": ["68656c6c6f", "776f726c64"]} expr = pl.col("a").str.decode("hex") @@ -515,6 +545,12 @@ def test_str_strip_chars() -> None: assert_series_equal(s.str.strip_chars(" hwo"), expected) +def test_str_strip_chars_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ComputeError, match="should have equal or unit length"): + df.select(pl.col("num").str.strip_chars(pl.Series(["a", "b"]))) + + def test_str_strip_chars_start() -> None: s = pl.Series([" hello ", "\t world"]) expected = pl.Series(["hello ", "world"]) @@ -527,6 +563,12 @@ def test_str_strip_chars_start() -> None: assert_series_equal(s.str.strip_chars_start("hw "), expected) +def test_str_strip_chars_start_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ComputeError, match="should have equal or unit length"): + df.select(pl.col("num").str.strip_chars_start(pl.Series(["a", "b"]))) + + def test_str_strip_chars_end() -> None: s = pl.Series([" hello ", "world\t "]) expected = pl.Series([" hello", "world"]) @@ -539,6 +581,12 @@ def test_str_strip_chars_end() -> None: assert_series_equal(s.str.strip_chars_end("odl \t"), expected) +def test_str_strip_chars_end_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ComputeError, match="should have equal or unit length"): + df.select(pl.col("num").str.strip_chars_end(pl.Series(["a", "b"]))) + + def test_str_strip_whitespace() -> None: s = pl.Series("a", ["trailing ", " leading", " both "]) @@ -579,6 +627,12 @@ def test_str_strip_prefix_suffix_expr() -> None: } +def test_str_strip_prefix_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ComputeError, match="should have equal or unit length"): + df.select(pl.col("num").str.strip_prefix(pl.Series(["a", "b"]))) + + def test_str_strip_suffix() -> None: s = pl.Series(["foo:bar", "foo:barbar", "foo:foo", "bar", "", None]) expected = pl.Series(["foo:", "foo:bar", "foo:foo", "", "", None]) @@ -588,6 +642,12 @@ def test_str_strip_suffix() -> None: assert_series_equal(s.str.strip_suffix(pl.lit(None, dtype=pl.String)), expected) +def test_str_strip_suffix_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ComputeError, match="should have equal or unit length"): + df.select(pl.col("num").str.strip_suffix(pl.Series(["a", "b"]))) + + def test_str_split() -> None: a = pl.Series("a", ["a, b", "a", "ab,c,de"]) for out in [a.str.split(","), pl.select(pl.lit(a).str.split(",")).to_series()]: @@ -730,6 +790,12 @@ def test_json_path_match() -> None: assert_frame_equal(out, expected) +def test_str_json_path_match_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ComputeError, match="should have equal or unit length"): + df.select(pl.col("num").str.json_path_match(pl.Series(["a", "b"]))) + + def test_extract_regex() -> None: s = pl.Series( [ @@ -1799,6 +1865,12 @@ def test_extract_many() -> None: assert f2.to_list() == [[0], [0, 5]] +def test_str_extract_many_wrong_length() -> None: + df = pl.DataFrame({"num": ["-10", "-1", "0"]}) + with pytest.raises(ComputeError, match="should have equal or unit length"): + df.select(pl.col("num").str.extract_many(pl.Series(["a", "b"]))) + + def test_json_decode_raise_on_data_type_mismatch_13061() -> None: assert_series_equal( pl.Series(["null", "null"]).str.json_decode(infer_schema_length=1),