Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 49 remove exists forall #232

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions quinn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
approx_equal,
array_choice,
business_days_between,
exists,
forall,
is_false,
is_falsy,
is_not_in,
Expand Down
40 changes: 0 additions & 40 deletions quinn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Callable
from numbers import Number

from pyspark.sql import Column
Expand Down Expand Up @@ -84,45 +83,6 @@ def remove_non_word_characters(col: Column) -> Column:
return F.regexp_replace(col, "[^\\w\\s]+", "")


def exists(f: Callable[[Any], bool]) -> udf:
"""Create a user-defined function.

It takes a list expressed as a column of type ``ArrayType(AnyType)`` as an argument and returns a boolean value indicating
whether any element in the list is true according to the argument ``f`` of the ``exists()`` function.

:param f: Callable function - A callable function that takes an element of
type Any and returns a boolean value.
:return: A user-defined function that takes
a list expressed as a column of type ArrayType(AnyType) as an argument and
returns a boolean value indicating whether any element in the list is true
according to the argument ``f`` of the ``exists()`` function.
:rtype: UserDefinedFunction
"""

def temp_udf(list_: list) -> bool:
return any(map(f, list_))

return F.udf(temp_udf, BooleanType())


def forall(f: Callable[[Any], bool]) -> udf:
"""The **forall** function allows for mapping a given boolean function to a list of arguments and return a single boolean value.

It does this by creating a Spark UDF which takes in a list of arguments, applying the given boolean function to
each element of the list and returning a single boolean value if all the elements pass through the given boolean function.

:param f: A callable function ``f`` which takes in any type and returns a boolean
:return: A spark UDF which accepts a list of arguments and returns True if all
elements pass through the given boolean function, False otherwise.
:rtype: UserDefinedFunction
"""

def temp_udf(list_: list) -> bool:
return all(map(f, list_))

return F.udf(temp_udf, BooleanType())


def multi_equals(value: Any) -> udf: # noqa: ANN401
"""Create a user-defined function that checks if all the given columns have the designated value.

Expand Down
40 changes: 0 additions & 40 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,46 +97,6 @@ def test_anti_trim():
chispa.assert_column_equality(actual_df, "words_anti_trimmed", "expected")


def test_exists():
df = spark.createDataFrame(
[
([1, 2, 3], False),
([4, 5, 6], True),
([10, 11, 12], True),
],
StructType(
[
StructField("nums", ArrayType(IntegerType(), True), True),
StructField("expected", BooleanType(), True),
]
),
)
actual_df = df.withColumn(
"any_num_greater_than_5", quinn.exists(lambda n: n > 5)(F.col("nums"))
)
chispa.assert_column_equality(actual_df, "any_num_greater_than_5", "expected")


def test_forall():
df = spark.createDataFrame(
[
([1, 2, 3], False),
([4, 5, 6], True),
([10, 11, 12], True),
],
StructType(
[
StructField("nums", ArrayType(IntegerType(), True), True),
StructField("expected", BooleanType(), True),
]
),
)
actual_df = df.withColumn(
"all_nums_greater_than_3", quinn.forall(lambda n: n > 3)(F.col("nums"))
)
chispa.assert_column_equality(actual_df, "all_nums_greater_than_3", "expected")


def test_multi_equals():
df = quinn.create_df(
spark,
Expand Down
Loading