diff --git a/CHANGELOG.md b/CHANGELOG.md index 2abe83c37ab..590871fbda9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ - Flattened generated SQL when `DataFrame.filter()` or `DataFrame.order_by()` is followed by a projection statement (e.g. `DataFrame.select()`, `DataFrame.with_column()`). - Added support for creating Dynamic Tables `(in Private Preview)` using `Dataframe.create_or_replace_dynamic_table` - Added an optional argument `params` in `session.sql()` to support binding variables. Note that this is not supported in stored procedure yet. +- Added support for `functions.substring_index` ### Bug Fixes diff --git a/src/snowflake/snowpark/functions.py b/src/snowflake/snowpark/functions.py index 8bef281be15..5717f1e3df3 100644 --- a/src/snowflake/snowpark/functions.py +++ b/src/snowflake/snowpark/functions.py @@ -2289,6 +2289,53 @@ def substring( return builtin("substring")(s, p, length) +def substring_index( + text: ColumnOrName, delim: ColumnOrLiteralStr, count: int +) -> Column: + """ + Returns the substring from string ``text`` before ``count`` occurrences of the delimiter ``delim``. + If ``count`` is positive, everything to the left of the final delimiter (counting from left) is + returned. If ``count`` is negative, everything to the right of the final delimiter (counting from the + right) is returned. If ``count`` is zero, returns empty string. + + Example 1:: + >>> df = session.create_dataframe( + ... ["a.b.c.d"], + ... schema=["S"], + ... ).select(substring_index(col("S"), ".", 2).alias("result")) + >>> df.show() + ------------ + |"RESULT" | + ------------ + |a.b | + ------------ + + + Example 2:: + >>> df = session.create_dataframe( + ... [["a.b.c.d", "."]], + ... schema=["S", "delimiter"], + ... ).select(substring_index(col("S"), col("delimiter"), 2).alias("result")) + >>> df.show() + ------------ + |"RESULT" | + ------------ + |a.b | + ------------ + + """ + s = _to_col_if_str(text, "substring_index") + strtok_array = builtin("strtok_to_array")(s, delim) + return builtin("array_to_string")( + builtin("array_slice")( + strtok_array, + 0 if count >= 0 else builtin("array_size")(strtok_array) + count, + count if count >= 0 else builtin("array_size")(strtok_array), + ), + delim, + ) + + def regexp_count( subject: ColumnOrName, pattern: ColumnOrLiteralStr, diff --git a/tests/integ/test_function.py b/tests/integ/test_function.py index 2f95170f94f..6a1ea278bb8 100644 --- a/tests/integ/test_function.py +++ b/tests/integ/test_function.py @@ -121,6 +121,7 @@ strtok_to_array, struct, substring, + substring_index, to_array, to_binary, to_char, @@ -523,6 +524,37 @@ def test_basic_string_operations(session): assert "'REVERSE' expected Column or str, got: " in str(ex_info) +def test_substring_index(session): + """test calling substring_index with delimiter as string""" + df = session.create_dataframe([[0, "a.b.c.d"], [1, ""], [2, None]], ["id", "s"]) + # substring_index when count is positive + respos = df.select(substring_index("s", ".", 2), "id").order_by("id").collect() + assert respos[0][0] == "a.b" + assert respos[1][0] == "" + assert respos[2][0] is None + # substring_index when count is negative + resneg = df.select(substring_index("s", ".", -3), "id").order_by("id").collect() + assert resneg[0][0] == "b.c.d" + assert respos[1][0] == "" + assert respos[2][0] is None + # substring_index when count is 0, result should be empty string + reszero = df.select(substring_index("s", ".", 0), "id").order_by("id").collect() + assert reszero[0][0] == "" + assert respos[1][0] == "" + assert respos[2][0] is None + + +def test_substring_index_col(session): + """test calling substring_index with delimiter as column""" + df = session.create_dataframe([["a,b,c,d", ","]], ["s", "delimiter"]) + res = df.select(substring_index(col("s"), df["delimiter"], 2)).collect() + assert res[0][0] == "a,b" + res = df.select(substring_index(col("s"), col("delimiter"), 3)).collect() + assert res[0][0] == "a,b,c" + reslit = df.select(substring_index("s", lit(","), -3)).collect() + assert reslit[0][0] == "b,c,d" + + def test_bitshiftright(session): # Create a dataframe data = [(65504), (1), (4)]