From 405e3d51024cfe64dabc27d23b7f8bc07ccc6ae3 Mon Sep 17 00:00:00 2001 From: Paulo Octavio Date: Thu, 3 Oct 2024 17:52:21 -0300 Subject: [PATCH 1/3] Added new parameter return_bool to validate dataframe methods --- quinn/dataframe_validator.py | 55 +++++++++++++++++++++++++----------- 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/quinn/dataframe_validator.py b/quinn/dataframe_validator.py index 54004850..8867091a 100644 --- a/quinn/dataframe_validator.py +++ b/quinn/dataframe_validator.py @@ -14,7 +14,7 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union if TYPE_CHECKING: from pyspark.sql import DataFrame @@ -33,31 +33,37 @@ class DataFrameProhibitedColumnError(ValueError): """Raise this when a DataFrame includes prohibited columns.""" -def validate_presence_of_columns(df: DataFrame, required_col_names: list[str]) -> None: +def validate_presence_of_columns(df: DataFrame, required_col_names: list[str], return_bool: bool = False) -> Union[None, bool]: """Validate the presence of column names in a DataFrame. - :param df: A spark DataFrame. - :type df: DataFrame` + :type df: DataFrame :param required_col_names: List of the required column names for the DataFrame. - :type required_col_names: :py:class:`list` of :py:class:`str` - :return: None. + :type required_col_names: list[str] + :param return_bool: If True, return a boolean instead of raising an exception. + :type return_bool: bool + :return: None if return_bool is False, otherwise a boolean indicating if validation passed. :raises DataFrameMissingColumnError: if any of the requested column names are - not present in the DataFrame. + not present in the DataFrame and return_bool is False. """ all_col_names = df.columns missing_col_names = [x for x in required_col_names if x not in all_col_names] - error_message = f"The {missing_col_names} columns are not included in the DataFrame with the following columns {all_col_names}" + if missing_col_names: + error_message = f"The {missing_col_names} columns are not included in the DataFrame with the following columns {all_col_names}" + if return_bool: + return False raise DataFrameMissingColumnError(error_message) + return True if return_bool else None + def validate_schema( df: DataFrame, required_schema: StructType, ignore_nullable: bool = False, -) -> None: + return_bool: bool = False, +) -> Union[None, bool]: """Function that validate if a given DataFrame has a given StructType as its schema. - :param df: DataFrame to validate :type df: DataFrame :param required_schema: StructType required for the DataFrame @@ -65,9 +71,11 @@ def validate_schema( :param ignore_nullable: (Optional) A flag for if nullable fields should be ignored during validation :type ignore_nullable: bool, optional - + :param return_bool: If True, return a boolean instead of raising an exception. + :type return_bool: bool + :return: None if return_bool is False, otherwise a boolean indicating if validation passed. :raises DataFrameMissingStructFieldError: if any StructFields from the required - schema are not included in the DataFrame schema + schema are not included in the DataFrame schema and return_bool is False. """ _all_struct_fields = copy.deepcopy(df.schema) _required_schema = copy.deepcopy(required_schema) @@ -80,22 +88,35 @@ def validate_schema( x.nullable = None missing_struct_fields = [x for x in _required_schema if x not in _all_struct_fields] - error_message = f"The {missing_struct_fields} StructFields are not included in the DataFrame with the following StructFields {_all_struct_fields}" if missing_struct_fields: + error_message = ( + f"The {missing_struct_fields} StructFields are not included in the DataFrame with the following StructFields {_all_struct_fields}" + ) + if return_bool: + return False raise DataFrameMissingStructFieldError(error_message) + return True if return_bool else None -def validate_absence_of_columns(df: DataFrame, prohibited_col_names: list[str]) -> None: - """Validate that none of the prohibited column names are present among specified DataFrame columns. +def validate_absence_of_columns(df: DataFrame, prohibited_col_names: list[str], return_bool: bool = False) -> Union[None, bool]: + """Validate that none of the prohibited column names are present among specified DataFrame columns. :param df: DataFrame containing columns to be checked. :param prohibited_col_names: List of prohibited column names. + :param return_bool: If True, return a boolean instead of raising an exception. + :type return_bool: bool + :return: None if return_bool is False, otherwise a boolean indicating if validation passed. :raises DataFrameProhibitedColumnError: If the prohibited column names are - present among the specified DataFrame columns. + present among the specified DataFrame columns and return_bool is False. """ all_col_names = df.columns extra_col_names = [x for x in all_col_names if x in prohibited_col_names] - error_message = f"The {extra_col_names} columns are not allowed to be included in the DataFrame with the following columns {all_col_names}" + if extra_col_names: + error_message = f"The {extra_col_names} columns are not allowed to be included in the DataFrame with the following columns {all_col_names}" + if return_bool: + return False raise DataFrameProhibitedColumnError(error_message) + + return True if return_bool else None From 6b1758b08489befeba510538fd10889c109f83e5 Mon Sep 17 00:00:00 2001 From: Paulo Octavio Date: Thu, 3 Oct 2024 17:59:16 -0300 Subject: [PATCH 2/3] Update tests --- tests/test_dataframe_validator.py | 75 +++++++++++++++++++++++++------ 1 file changed, 61 insertions(+), 14 deletions(-) diff --git a/tests/test_dataframe_validator.py b/tests/test_dataframe_validator.py index 19ac7b06..56118d31 100644 --- a/tests/test_dataframe_validator.py +++ b/tests/test_dataframe_validator.py @@ -1,30 +1,41 @@ import pytest from pyspark.sql.types import StructType, StructField, StringType, LongType import semver - import quinn from .spark import spark def describe_validate_presence_of_columns(): - def it_raises_if_a_required_column_is_missing(): + def it_raises_if_a_required_column_is_missing_and_return_bool_is_false(): data = [("jose", 1), ("li", 2), ("luisa", 3)] source_df = spark.createDataFrame(data, ["name", "age"]) with pytest.raises(quinn.DataFrameMissingColumnError) as excinfo: - quinn.validate_presence_of_columns(source_df, ["name", "age", "fun"]) + quinn.validate_presence_of_columns(source_df, ["name", "age", "fun"], False) assert ( excinfo.value.args[0] == "The ['fun'] columns are not included in the DataFrame with the following columns ['name', 'age']" ) - def it_does_nothing_if_all_required_columns_are_present(): + def it_does_nothing_if_all_required_columns_are_present_and_return_bool_is_false(): + data = [("jose", 1), ("li", 2), ("luisa", 3)] + source_df = spark.createDataFrame(data, ["name", "age"]) + quinn.validate_presence_of_columns(source_df, ["name"], False) + + def it_returns_false_if_a_required_column_is_missing_and_return_bool_is_true(): data = [("jose", 1), ("li", 2), ("luisa", 3)] source_df = spark.createDataFrame(data, ["name", "age"]) - quinn.validate_presence_of_columns(source_df, ["name"]) + result = quinn.validate_presence_of_columns(source_df, ["name", "age", "fun"], True) + assert result is False + + def it_returns_true_if_all_required_columns_are_present_and_return_bool_is_true(): + data = [("jose", 1), ("li", 2), ("luisa", 3)] + source_df = spark.createDataFrame(data, ["name", "age"]) + result = quinn.validate_presence_of_columns(source_df, ["name"], True) + assert result is True def describe_validate_schema(): - def it_raises_when_struct_field_is_missing1(): + def it_raises_when_struct_field_is_missing_and_return_bool_is_false(): data = [("jose", 1), ("li", 2), ("luisa", 3)] source_df = spark.createDataFrame(data, ["name", "age"]) required_schema = StructType( @@ -34,7 +45,7 @@ def it_raises_when_struct_field_is_missing1(): ] ) with pytest.raises(quinn.DataFrameMissingStructFieldError) as excinfo: - quinn.validate_schema(source_df, required_schema) + quinn.validate_schema(source_df, required_schema, return_bool = False) current_spark_version = semver.Version.parse(spark.version) spark_330 = semver.Version.parse("3.3.0") @@ -44,7 +55,7 @@ def it_raises_when_struct_field_is_missing1(): expected_error_message = "The [StructField(city,StringType,true)] StructFields are not included in the DataFrame with the following StructFields StructType(List(StructField(name,StringType,true),StructField(age,LongType,true)))" # noqa assert excinfo.value.args[0] == expected_error_message - def it_does_nothing_when_the_schema_matches(): + def it_does_nothing_when_the_schema_matches_and_return_bool_is_false(): data = [("jose", 1), ("li", 2), ("luisa", 3)] source_df = spark.createDataFrame(data, ["name", "age"]) required_schema = StructType( @@ -53,7 +64,31 @@ def it_does_nothing_when_the_schema_matches(): StructField("age", LongType(), True), ] ) - quinn.validate_schema(source_df, required_schema) + quinn.validate_schema(source_df, required_schema, return_bool = False) + + def it_returns_false_when_struct_field_is_missing_and_return_bool_is_true(): + data = [("jose", 1), ("li", 2), ("luisa", 3)] + source_df = spark.createDataFrame(data, ["name", "age"]) + required_schema = StructType( + [ + StructField("name", StringType(), True), + StructField("city", StringType(), True), + ] + ) + result = quinn.validate_schema(source_df, required_schema, return_bool = True) + assert result is False + + def it_returns_true_when_the_schema_matches_and_return_bool_is_true(): + data = [("jose", 1), ("li", 2), ("luisa", 3)] + source_df = spark.createDataFrame(data, ["name", "age"]) + required_schema = StructType( + [ + StructField("name", StringType(), True), + StructField("age", LongType(), True), + ] + ) + result = quinn.validate_schema(source_df, required_schema, return_bool = True) + assert result is True def nullable_column_mismatches_are_ignored(): data = [("jose", 1), ("li", 2), ("luisa", 3)] @@ -64,21 +99,33 @@ def nullable_column_mismatches_are_ignored(): StructField("age", LongType(), False), ] ) - quinn.validate_schema(source_df, required_schema, ignore_nullable=True) + quinn.validate_schema(source_df, required_schema, ignore_nullable=True, return_bool = False) def describe_validate_absence_of_columns(): - def it_raises_when_a_unallowed_column_is_present(): + def it_raises_when_a_unallowed_column_is_present_and_return_bool_is_false(): data = [("jose", 1), ("li", 2), ("luisa", 3)] source_df = spark.createDataFrame(data, ["name", "age"]) with pytest.raises(quinn.DataFrameProhibitedColumnError) as excinfo: - quinn.validate_absence_of_columns(source_df, ["age", "cool"]) + quinn.validate_absence_of_columns(source_df, ["age", "cool"], False) assert ( excinfo.value.args[0] == "The ['age'] columns are not allowed to be included in the DataFrame with the following columns ['name', 'age']" # noqa ) - def it_does_nothing_when_no_unallowed_columns_are_present(): + def it_does_nothing_when_no_unallowed_columns_are_present_and_return_bool_is_false(): + data = [("jose", 1), ("li", 2), ("luisa", 3)] + source_df = spark.createDataFrame(data, ["name", "age"]) + quinn.validate_absence_of_columns(source_df, ["favorite_color"], False) + + def it_returns_false_when_a_unallowed_column_is_present_and_return_bool_is_true(): + data = [("jose", 1), ("li", 2), ("luisa", 3)] + source_df = spark.createDataFrame(data, ["name", "age"]) + result = quinn.validate_absence_of_columns(source_df, ["age", "cool"], True) + assert result is False + + def it_returns_true_when_no_unallowed_columns_are_present_and_return_bool_is_true(): data = [("jose", 1), ("li", 2), ("luisa", 3)] source_df = spark.createDataFrame(data, ["name", "age"]) - quinn.validate_absence_of_columns(source_df, ["favorite_color"]) + result = quinn.validate_absence_of_columns(source_df, ["favorite_color"], True) + assert result is True \ No newline at end of file From ab90811185a700a33e630c387d29f989a3f543fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Paulo=20Oct=C3=A1vio=20Ara=C3=BAjo=20de=20Paula?= Date: Fri, 4 Oct 2024 01:31:43 -0300 Subject: [PATCH 3/3] Fix whitespaces and repeated code --- quinn/dataframe_validator.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/quinn/dataframe_validator.py b/quinn/dataframe_validator.py index 0bae3ce4..8867091a 100644 --- a/quinn/dataframe_validator.py +++ b/quinn/dataframe_validator.py @@ -53,8 +53,6 @@ def validate_presence_of_columns(df: DataFrame, required_col_names: list[str], r if return_bool: return False raise DataFrameMissingColumnError(error_message) - - return True if return_bool else None return True if return_bool else None @@ -98,8 +96,6 @@ def validate_schema( if return_bool: return False raise DataFrameMissingStructFieldError(error_message) - - return True if return_bool else None return True if return_bool else None @@ -116,6 +112,7 @@ def validate_absence_of_columns(df: DataFrame, prohibited_col_names: list[str], """ all_col_names = df.columns extra_col_names = [x for x in all_col_names if x in prohibited_col_names] + if extra_col_names: error_message = f"The {extra_col_names} columns are not allowed to be included in the DataFrame with the following columns {all_col_names}" if return_bool: