From 997373e02b7e114c2637f087aeda61f91179aac6 Mon Sep 17 00:00:00 2001 From: sfc-gh-lninobrijaldo Date: Wed, 28 Aug 2024 11:42:50 -0500 Subject: [PATCH] Add method array_remove. (#2105) --- CHANGELOG.md | 12 ++++-- docs/source/snowpark/functions.rst | 1 + src/snowflake/snowpark/functions.py | 50 ++++++++++++++++++++++++ tests/integ/scala/test_function_suite.py | 29 ++++++++++++++ 4 files changed, 89 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 99ac288a7a9..8f5a3daf78e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,9 +4,14 @@ ### Snowpark Python API Updates +### New Features + +- Added following new functions in `snowflake.snowpark.functions`: + - `array_remove` + - `ln` + #### Improvements -- Added support for function `functions.ln` - Added support for specifying the following to `DataFrameWriter.save_as_table`: - `enable_schema_evolution` - `data_retention_time` @@ -47,9 +52,10 @@ - Added limited support for the `Timedelta` type, including the following features. Snowpark pandas will raise `NotImplementedError` for unsupported `Timedelta` use cases. - supporting tracking the Timedelta type through `copy`, `cache_result`, `shift`, `sort_index`. - - converting non-timedelta to timedelta via `astype`. + - converting non-timedelta to timedelta via `astype`. + - `NotImplementedError` will be raised for the rest of methods that do not support `Timedelta`. - support for subtracting two timestamps to get a Timedelta. - - support indexing with Timedelta data columns. + - support indexing with Timedelta data columns. - support for adding or subtracting timestamps and `Timedelta`. - support for binary arithmetic between two `Timedelta` values. - Added support for index's arithmetic and comparison operators. diff --git a/docs/source/snowpark/functions.rst b/docs/source/snowpark/functions.rst index 100cb5470fc..9a381e5046a 100644 --- a/docs/source/snowpark/functions.rst +++ b/docs/source/snowpark/functions.rst @@ -43,6 +43,7 @@ Functions array_min array_position array_prepend + array_remove array_size array_slice array_sort diff --git a/src/snowflake/snowpark/functions.py b/src/snowflake/snowpark/functions.py index 8f8b156132c..58c2ab8518c 100644 --- a/src/snowflake/snowpark/functions.py +++ b/src/snowflake/snowpark/functions.py @@ -5333,6 +5333,56 @@ def array_append(array: ColumnOrName, element: ColumnOrName) -> Column: return builtin("array_append")(a, e) +def array_remove(array: ColumnOrName, element: ColumnOrLiteral) -> Column: + """Given a source ARRAY, returns an ARRAY with elements of the specified value removed. + + Args: + array: name of column containing array. + element: element to be removed from the array. If the element is a VARCHAR, it needs + to be casted into VARIANT data type. + + Examples:: + >>> from snowflake.snowpark.types import VariantType + >>> df = session.create_dataframe([([1, '2', 3.1, 1, 1],)], ['data']) + >>> df.select(array_remove(df.data, 1).alias("objects")).show() + ------------- + |"OBJECTS" | + ------------- + |[ | + | "2", | + | 3.1 | + |] | + ------------- + + + >>> df.select(array_remove(df.data, lit('2').cast(VariantType())).alias("objects")).show() + ------------- + |"OBJECTS" | + ------------- + |[ | + | 1, | + | 3.1, | + | 1, | + | 1 | + |] | + ------------- + + + >>> df.select(array_remove(df.data, None).alias("objects")).show() + ------------- + |"OBJECTS" | + ------------- + |NULL | + ------------- + + + See Also: + - `ARRAY `_ for more details on semi-structured arrays. + """ + a = _to_col_if_str(array, "array_remove") + return builtin("array_remove")(a, element) + + def array_cat(array1: ColumnOrName, array2: ColumnOrName) -> Column: """Returns the concatenation of two ARRAYs. diff --git a/tests/integ/scala/test_function_suite.py b/tests/integ/scala/test_function_suite.py index a322c7d34b8..98b2bdbfeef 100644 --- a/tests/integ/scala/test_function_suite.py +++ b/tests/integ/scala/test_function_suite.py @@ -36,6 +36,7 @@ array_intersection, array_position, array_prepend, + array_remove, array_size, array_slice, array_to_string, @@ -2823,6 +2824,34 @@ def test_array_append(session): ) +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="array_remove is not yet supported in local testing mode.", +) +def test_array_remove(session): + Utils.check_answer( + [ + Row("[\n 2,\n 3\n]"), + Row("[\n 6,\n 7\n]"), + ], + TestData.array1(session).select( + array_remove(array_remove(col("arr1"), lit(1)), lit(8)) + ), + sort=False, + ) + + Utils.check_answer( + [ + Row("[\n 2,\n 3\n]"), + Row("[\n 6,\n 7\n]"), + ], + TestData.array1(session).select( + array_remove(array_remove(col("arr1"), 1), lit(8)) + ), + sort=False, + ) + + @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", reason="array_cat is not yet supported in local testing mode.",