Skip to content

Commit

Permalink
[SNOW-743111] Add substring_index support under sqlfunction (#792)
Browse files Browse the repository at this point in the history
* add substring_index support

* add change log

* fix format

* add support for col as delimiter

* fix doc

* add change

* Empty-Commit

---------

Co-authored-by: Afroz Alam <[email protected]>
  • Loading branch information
sfc-gh-yzou and sfc-gh-aalam authored Apr 22, 2023
1 parent 7023363 commit c252967
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
- Flattened generated SQL when `DataFrame.filter()` or `DataFrame.order_by()` is followed by a projection statement (e.g. `DataFrame.select()`, `DataFrame.with_column()`).
- Added support for creating Dynamic Tables `(in Private Preview)` using `Dataframe.create_or_replace_dynamic_table`
- Added an optional argument `params` in `session.sql()` to support binding variables. Note that this is not supported in stored procedure yet.
- Added support for `functions.substring_index`

### Bug Fixes

Expand Down
47 changes: 47 additions & 0 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2289,6 +2289,53 @@ def substring(
return builtin("substring")(s, p, length)


def substring_index(
text: ColumnOrName, delim: ColumnOrLiteralStr, count: int
) -> Column:
"""
Returns the substring from string ``text`` before ``count`` occurrences of the delimiter ``delim``.
If ``count`` is positive, everything to the left of the final delimiter (counting from left) is
returned. If ``count`` is negative, everything to the right of the final delimiter (counting from the
right) is returned. If ``count`` is zero, returns empty string.
Example 1::
>>> df = session.create_dataframe(
... ["a.b.c.d"],
... schema=["S"],
... ).select(substring_index(col("S"), ".", 2).alias("result"))
>>> df.show()
------------
|"RESULT" |
------------
|a.b |
------------
<BLANKLINE>
Example 2::
>>> df = session.create_dataframe(
... [["a.b.c.d", "."]],
... schema=["S", "delimiter"],
... ).select(substring_index(col("S"), col("delimiter"), 2).alias("result"))
>>> df.show()
------------
|"RESULT" |
------------
|a.b |
------------
<BLANKLINE>
"""
s = _to_col_if_str(text, "substring_index")
strtok_array = builtin("strtok_to_array")(s, delim)
return builtin("array_to_string")(
builtin("array_slice")(
strtok_array,
0 if count >= 0 else builtin("array_size")(strtok_array) + count,
count if count >= 0 else builtin("array_size")(strtok_array),
),
delim,
)


def regexp_count(
subject: ColumnOrName,
pattern: ColumnOrLiteralStr,
Expand Down
32 changes: 32 additions & 0 deletions tests/integ/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
strtok_to_array,
struct,
substring,
substring_index,
to_array,
to_binary,
to_char,
Expand Down Expand Up @@ -523,6 +524,37 @@ def test_basic_string_operations(session):
assert "'REVERSE' expected Column or str, got: <class 'list'>" in str(ex_info)


def test_substring_index(session):
"""test calling substring_index with delimiter as string"""
df = session.create_dataframe([[0, "a.b.c.d"], [1, ""], [2, None]], ["id", "s"])
# substring_index when count is positive
respos = df.select(substring_index("s", ".", 2), "id").order_by("id").collect()
assert respos[0][0] == "a.b"
assert respos[1][0] == ""
assert respos[2][0] is None
# substring_index when count is negative
resneg = df.select(substring_index("s", ".", -3), "id").order_by("id").collect()
assert resneg[0][0] == "b.c.d"
assert respos[1][0] == ""
assert respos[2][0] is None
# substring_index when count is 0, result should be empty string
reszero = df.select(substring_index("s", ".", 0), "id").order_by("id").collect()
assert reszero[0][0] == ""
assert respos[1][0] == ""
assert respos[2][0] is None


def test_substring_index_col(session):
"""test calling substring_index with delimiter as column"""
df = session.create_dataframe([["a,b,c,d", ","]], ["s", "delimiter"])
res = df.select(substring_index(col("s"), df["delimiter"], 2)).collect()
assert res[0][0] == "a,b"
res = df.select(substring_index(col("s"), col("delimiter"), 3)).collect()
assert res[0][0] == "a,b,c"
reslit = df.select(substring_index("s", lit(","), -3)).collect()
assert reslit[0][0] == "b,c,d"


def test_bitshiftright(session):
# Create a dataframe
data = [(65504), (1), (4)]
Expand Down

0 comments on commit c252967

Please sign in to comment.