Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added new parameter return_bool to validate dataframe methods (fix linting) #267

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions quinn/dataframe_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class DataFrameProhibitedColumnError(ValueError):

def validate_presence_of_columns(df: DataFrame, required_col_names: list[str], return_bool: bool = False) -> Union[None, bool]:
"""Validate the presence of column names in a DataFrame.

:param df: A spark DataFrame.
:type df: DataFrame
:param required_col_names: List of the required column names for the DataFrame.
Expand All @@ -48,13 +47,13 @@ def validate_presence_of_columns(df: DataFrame, required_col_names: list[str], r
"""
all_col_names = df.columns
missing_col_names = [x for x in required_col_names if x not in all_col_names]

if missing_col_names:
error_message = f"The {missing_col_names} columns are not included in the DataFrame with the following columns {all_col_names}"
if return_bool:
return False
raise DataFrameMissingColumnError(error_message)

return True if return_bool else None


Expand All @@ -65,7 +64,6 @@ def validate_schema(
return_bool: bool = False,
) -> Union[None, bool]:
"""Function that validate if a given DataFrame has a given StructType as its schema.

:param df: DataFrame to validate
:type df: DataFrame
:param required_schema: StructType required for the DataFrame
Expand All @@ -90,19 +88,20 @@ def validate_schema(
x.nullable = None

missing_struct_fields = [x for x in _required_schema if x not in _all_struct_fields]

if missing_struct_fields:
error_message = f"The {missing_struct_fields} StructFields are not included in the DataFrame with the following StructFields {_all_struct_fields}"
error_message = (
f"The {missing_struct_fields} StructFields are not included in the DataFrame with the following StructFields {_all_struct_fields}"
)
if return_bool:
return False
raise DataFrameMissingStructFieldError(error_message)

return True if return_bool else None


def validate_absence_of_columns(df: DataFrame, prohibited_col_names: list[str], return_bool: bool = False) -> Union[None, bool]:
"""Validate that none of the prohibited column names are present among specified DataFrame columns.

:param df: DataFrame containing columns to be checked.
:param prohibited_col_names: List of prohibited column names.
:param return_bool: If True, return a boolean instead of raising an exception.
Expand All @@ -113,11 +112,11 @@ def validate_absence_of_columns(df: DataFrame, prohibited_col_names: list[str],
"""
all_col_names = df.columns
extra_col_names = [x for x in all_col_names if x in prohibited_col_names]

if extra_col_names:
error_message = f"The {extra_col_names} columns are not allowed to be included in the DataFrame with the following columns {all_col_names}"
if return_bool:
return False
raise DataFrameProhibitedColumnError(error_message)
return True if return_bool else None

return True if return_bool else None
7 changes: 3 additions & 4 deletions tests/test_dataframe_validator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
from pyspark.sql.types import StructType, StructField, StringType, LongType
import semver

import quinn
from .spark import spark

Expand All @@ -21,7 +20,7 @@ def it_does_nothing_if_all_required_columns_are_present_and_return_bool_is_false
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
quinn.validate_presence_of_columns(source_df, ["name"], False)

def it_returns_false_if_a_required_column_is_missing_and_return_bool_is_true():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
Expand Down Expand Up @@ -66,7 +65,7 @@ def it_does_nothing_when_the_schema_matches_and_return_bool_is_false():
]
)
quinn.validate_schema(source_df, required_schema, return_bool = False)

def it_returns_false_when_struct_field_is_missing_and_return_bool_is_true():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
Expand Down Expand Up @@ -118,7 +117,7 @@ def it_does_nothing_when_no_unallowed_columns_are_present_and_return_bool_is_fal
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
quinn.validate_absence_of_columns(source_df, ["favorite_color"], False)

def it_returns_false_when_a_unallowed_column_is_present_and_return_bool_is_true():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
Expand Down
Loading