Skip to content

Commit

Permalink
Add method array_remove. (#2105)
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-lninobrijaldo authored Aug 28, 2024
1 parent a7c6820 commit 997373e
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 3 deletions.
12 changes: 9 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@

### Snowpark Python API Updates

### New Features

- Added following new functions in `snowflake.snowpark.functions`:
- `array_remove`
- `ln`

#### Improvements

- Added support for function `functions.ln`
- Added support for specifying the following to `DataFrameWriter.save_as_table`:
- `enable_schema_evolution`
- `data_retention_time`
Expand Down Expand Up @@ -47,9 +52,10 @@

- Added limited support for the `Timedelta` type, including the following features. Snowpark pandas will raise `NotImplementedError` for unsupported `Timedelta` use cases.
- supporting tracking the Timedelta type through `copy`, `cache_result`, `shift`, `sort_index`.
- converting non-timedelta to timedelta via `astype`.
- converting non-timedelta to timedelta via `astype`.
- `NotImplementedError` will be raised for the rest of methods that do not support `Timedelta`.
- support for subtracting two timestamps to get a Timedelta.
- support indexing with Timedelta data columns.
- support indexing with Timedelta data columns.
- support for adding or subtracting timestamps and `Timedelta`.
- support for binary arithmetic between two `Timedelta` values.
- Added support for index's arithmetic and comparison operators.
Expand Down
1 change: 1 addition & 0 deletions docs/source/snowpark/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Functions
array_min
array_position
array_prepend
array_remove
array_size
array_slice
array_sort
Expand Down
50 changes: 50 additions & 0 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5333,6 +5333,56 @@ def array_append(array: ColumnOrName, element: ColumnOrName) -> Column:
return builtin("array_append")(a, e)


def array_remove(array: ColumnOrName, element: ColumnOrLiteral) -> Column:
"""Given a source ARRAY, returns an ARRAY with elements of the specified value removed.
Args:
array: name of column containing array.
element: element to be removed from the array. If the element is a VARCHAR, it needs
to be casted into VARIANT data type.
Examples::
>>> from snowflake.snowpark.types import VariantType
>>> df = session.create_dataframe([([1, '2', 3.1, 1, 1],)], ['data'])
>>> df.select(array_remove(df.data, 1).alias("objects")).show()
-------------
|"OBJECTS" |
-------------
|[ |
| "2", |
| 3.1 |
|] |
-------------
<BLANKLINE>
>>> df.select(array_remove(df.data, lit('2').cast(VariantType())).alias("objects")).show()
-------------
|"OBJECTS" |
-------------
|[ |
| 1, |
| 3.1, |
| 1, |
| 1 |
|] |
-------------
<BLANKLINE>
>>> df.select(array_remove(df.data, None).alias("objects")).show()
-------------
|"OBJECTS" |
-------------
|NULL |
-------------
<BLANKLINE>
See Also:
- `ARRAY <https://docs.snowflake.com/en/sql-reference/data-types-semistructured#label-data-type-array>`_ for more details on semi-structured arrays.
"""
a = _to_col_if_str(array, "array_remove")
return builtin("array_remove")(a, element)


def array_cat(array1: ColumnOrName, array2: ColumnOrName) -> Column:
"""Returns the concatenation of two ARRAYs.
Expand Down
29 changes: 29 additions & 0 deletions tests/integ/scala/test_function_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
array_intersection,
array_position,
array_prepend,
array_remove,
array_size,
array_slice,
array_to_string,
Expand Down Expand Up @@ -2823,6 +2824,34 @@ def test_array_append(session):
)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="array_remove is not yet supported in local testing mode.",
)
def test_array_remove(session):
Utils.check_answer(
[
Row("[\n 2,\n 3\n]"),
Row("[\n 6,\n 7\n]"),
],
TestData.array1(session).select(
array_remove(array_remove(col("arr1"), lit(1)), lit(8))
),
sort=False,
)

Utils.check_answer(
[
Row("[\n 2,\n 3\n]"),
Row("[\n 6,\n 7\n]"),
],
TestData.array1(session).select(
array_remove(array_remove(col("arr1"), 1), lit(8))
),
sort=False,
)


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="array_cat is not yet supported in local testing mode.",
Expand Down

0 comments on commit 997373e

Please sign in to comment.