Skip to content

Commit

Permalink
[MAINTENANCE] Have Expectation inherit from pydantic.BaseModel (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
cdkini authored Nov 20, 2023
1 parent 6bd8a63 commit 0d0338a
Show file tree
Hide file tree
Showing 20 changed files with 142 additions and 200 deletions.
4 changes: 2 additions & 2 deletions great_expectations/core/expectation_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1477,7 +1477,7 @@ def validate(
"""
expectation_impl: Type[Expectation] = self._get_expectation_impl()
# noinspection PyCallingNonCallable
return expectation_impl(self).validate(
return expectation_impl(meta=self.meta, **self.kwargs).validate_(
validator=validator,
runtime_configuration=runtime_configuration,
)
Expand All @@ -1491,7 +1491,7 @@ def metrics_validate(
):
expectation_impl: Type[Expectation] = self._get_expectation_impl()
# noinspection PyCallingNonCallable
return expectation_impl(self).metrics_validate(
return expectation_impl(meta=self.meta, **self.kwargs).metrics_validate(
metrics=metrics,
runtime_configuration=runtime_configuration,
execution_engine=execution_engine,
Expand Down
4 changes: 3 additions & 1 deletion great_expectations/core/expectation_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,9 @@ def _validate_expectation_configuration_before_adding(
):
try:
class_ = get_expectation_impl(expectation_configuration.expectation_type)
_ = class_(expectation_configuration) # Implicitly validates in constructor
_ = class_(
meta=expectation_configuration.meta, **expectation_configuration.kwargs
) # Implicitly validates in constructor
except (
gx_exceptions.ExpectationNotFoundError,
gx_exceptions.InvalidExpectationConfigurationError,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from numbers import Number
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import TYPE_CHECKING, ClassVar, Dict, List, Optional, Union

import numpy as np

Expand Down Expand Up @@ -182,7 +182,9 @@ class ExpectColumnQuantileValuesToBeBetween(ColumnAggregateExpectation):
validation_parameter_builder_configs: List[ParameterBuilderConfig] = [
quantile_value_ranges_estimator_parameter_builder_config,
]
default_profiler_config = RuleBasedProfilerConfig(
default_profiler_config: ClassVar[
RuleBasedProfilerConfig
] = RuleBasedProfilerConfig(
name="expect_column_quantile_values_to_be_between", # Convention: use "expectation_type" as profiler name.
config_version=1.0,
variables={},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, ClassVar, List, Optional

import great_expectations.exceptions as gx_exceptions
from great_expectations.core import (
Expand Down Expand Up @@ -160,7 +160,9 @@ class ExpectColumnValuesToBeBetween(ColumnMapExpectation):
column_min_range_estimator_parameter_builder_config,
column_max_range_estimator_parameter_builder_config,
]
default_profiler_config = RuleBasedProfilerConfig(
default_profiler_config: ClassVar[
RuleBasedProfilerConfig
] = RuleBasedProfilerConfig(
name="expect_column_values_to_be_between", # Convention: use "expectation_type" as profiler name.
config_version=1.0,
variables={},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar, Dict, Optional
from typing import TYPE_CHECKING, ClassVar, Dict, Optional, Tuple

from great_expectations.compatibility.typing_extensions import override
from great_expectations.core.expectation_configuration import parse_result_format
Expand Down Expand Up @@ -90,7 +90,7 @@ class ExpectColumnValuesToNotBeNull(ColumnMapExpectation):
}

map_metric: ClassVar[str] = "column_values.nonnull"
args_keys: ClassVar[tuple[str, ...]] = ("column",)
args_keys: ClassVar[Tuple[str, ...]] = ("column",)

@override
def validate_configuration(
Expand Down
52 changes: 39 additions & 13 deletions great_expectations/expectations/expectation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
import traceback
import warnings
from abc import ABC, ABCMeta, abstractmethod
from abc import ABC, abstractmethod
from collections import Counter, defaultdict
from copy import deepcopy
from inspect import isabstract
Expand All @@ -37,6 +37,8 @@
from typing_extensions import ParamSpec

from great_expectations import __version__ as ge_version
from great_expectations.compatibility import pydantic
from great_expectations.compatibility.pydantic import ModelMetaclass
from great_expectations.compatibility.typing_extensions import override
from great_expectations.core._docs_decorators import (
deprecated_method_or_class,
Expand Down Expand Up @@ -265,7 +267,7 @@ def wrapper(


# noinspection PyMethodParameters
class MetaExpectation(ABCMeta):
class MetaExpectation(ModelMetaclass):
"""MetaExpectation registers Expectations as they are defined, adding them to the Expectation registry.
Any class inheriting from Expectation will be registered based on the value of the "expectation_type" class
Expand Down Expand Up @@ -297,7 +299,7 @@ def __new__(cls, clsname, bases, attrs):


@public_api
class Expectation(metaclass=MetaExpectation):
class Expectation(pydantic.BaseModel, metaclass=MetaExpectation):
"""Base class for all Expectations.
Expectation classes *must* have the following attributes set:
Expand Down Expand Up @@ -329,7 +331,13 @@ class Expectation(metaclass=MetaExpectation):
2. Data Docs rendering methods decorated with the @renderer decorator. See the
"""

version: ClassVar = ge_version
class Config:
arbitrary_types_allowed = True
extra = pydantic.Extra.allow

meta: Union[dict, None] = None

version: ClassVar[str] = ge_version
domain_keys: ClassVar[Tuple[str, ...]] = ()
success_keys: ClassVar[Tuple[str, ...]] = ()
runtime_keys: ClassVar[Tuple[str, ...]] = (
Expand All @@ -338,7 +346,7 @@ class Expectation(metaclass=MetaExpectation):
"result_format",
)
default_kwarg_values: ClassVar[
dict[str, bool | str | float | RuleBasedProfilerConfig | None]
dict[str, Union[bool, str, float, RuleBasedProfilerConfig, None]]
] = {
"include_config": True,
"catch_exceptions": False,
Expand All @@ -349,10 +357,27 @@ class Expectation(metaclass=MetaExpectation):
expectation_type: ClassVar[str]
examples: ClassVar[List[dict]] = []

def __init__(self, configuration: ExpectationConfiguration) -> None:
self._configuration = configuration
def __init__(self, meta: dict | None = None, **kwargs) -> None:
# Safety precaution to prevent old-style instantiation
if "configuration" in kwargs:
raise ValueError(
"Cannot directly pass configuration into Expectation constructor; please pass in individual success keys and domain kwargs."
)

super().__init__(**kwargs)

# Everything below is purely to maintain current validation logic but should be migrated to Pydantic validators
configuration = ExpectationConfiguration(
expectation_type=camel_to_snake(self.__class__.__name__),
kwargs=kwargs,
meta=meta,
)
self.validate_configuration(configuration)

# Currently only used in Validator.validate_expectation
# Once the V1 Validator is live, we can remove this and its related property
self._configuration = configuration

@classmethod
def is_abstract(cls) -> bool:
return isabstract(cls)
Expand Down Expand Up @@ -1239,8 +1264,9 @@ def validate_configuration(
except AssertionError as e:
raise InvalidExpectationConfigurationError(str(e))

# Renamed from validate due to collision with Pydantic method of the same name
@public_api
def validate( # noqa: PLR0913
def validate_( # noqa: PLR0913
self,
validator: Validator,
configuration: Optional[ExpectationConfiguration] = None,
Expand Down Expand Up @@ -2344,7 +2370,7 @@ class BatchExpectation(Expectation, ABC):
"condition_parser",
)
metric_dependencies: ClassVar[Tuple[str, ...]] = ()
domain_type: ClassVar = MetricDomainTypes.TABLE
domain_type: ClassVar[MetricDomainTypes] = MetricDomainTypes.TABLE
args_keys: ClassVar[Tuple[str, ...]] = ()

@override
Expand Down Expand Up @@ -2742,14 +2768,14 @@ class ColumnMapExpectation(BatchExpectation, ABC):
"""

map_metric: ClassVar[Optional[str]] = None
domain_keys: ClassVar[tuple[str, ...]] = (
domain_keys: ClassVar[Tuple[str, ...]] = (
"batch_id",
"table",
"column",
"row_condition",
"condition_parser",
)
domain_type: ClassVar = MetricDomainTypes.COLUMN
domain_type: ClassVar[MetricDomainTypes] = MetricDomainTypes.COLUMN
success_keys: ClassVar[Tuple[str, ...]] = ("mostly",)
default_kwarg_values = {
"row_condition": None,
Expand Down Expand Up @@ -3023,7 +3049,7 @@ class ColumnPairMapExpectation(BatchExpectation, ABC):
kwargs from the Expectation Configuration.
"""

map_metric = None
map_metric: ClassVar[Optional[str]] = None
domain_keys = (
"batch_id",
"table",
Expand Down Expand Up @@ -3295,7 +3321,7 @@ class MulticolumnMapExpectation(BatchExpectation, ABC):
kwargs from the Expectation Configuration.
"""

map_metric = None
map_metric: ClassVar[Optional[str]] = None
domain_keys = (
"batch_id",
"table",
Expand Down
4 changes: 3 additions & 1 deletion great_expectations/expectations/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,9 @@ def get_metric_kwargs(
}
if configuration:
expectation_impl = get_expectation_impl(configuration.expectation_type)
configuration_kwargs = expectation_impl(configuration).get_runtime_kwargs(
configuration_kwargs = expectation_impl(
**configuration.kwargs
).get_runtime_kwargs(
configuration=configuration, runtime_configuration=runtime_configuration
)
if len(metric_kwargs["metric_domain_keys"]) > 0:
Expand Down
8 changes: 5 additions & 3 deletions great_expectations/validator/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,15 +543,17 @@ def inst_expectation(*args: dict, **kwargs): # noqa: PLR0912
)

try:
expectation = expectation_impl(configuration)
expectation = expectation_impl(
meta=configuration.meta, **configuration.kwargs
)
"""Given an implementation and a configuration for any Expectation, returns its validation result"""

if not self.interactive_evaluation and not self._active_validation:
validation_result = ExpectationValidationResult(
expectation_config=copy.deepcopy(expectation.configuration)
)
else:
validation_result = expectation.validate(
validation_result = expectation.validate_(
validator=self,
evaluation_parameters=self._expectation_suite.evaluation_parameters,
data_context=self._data_context,
Expand Down Expand Up @@ -1124,7 +1126,7 @@ def _generate_metric_dependency_subgraphs_for_each_expectation_configuration(

expectation_impl = get_expectation_impl(evaluated_config.expectation_type)
validation_dependencies: ValidationDependencies = expectation_impl(
evaluated_config
**evaluated_config.kwargs
).get_validation_dependencies(
configuration=evaluated_config,
execution_engine=self._execution_engine,
Expand Down
28 changes: 14 additions & 14 deletions tests/expectations/metrics/test_map_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _expecation_configuration_to_validation_result_pandas(
expectation_configuration (ExpectationConfiguration): configuration that is being tested
"""
expectation = ExpectColumnValuesToBeInSet(expectation_configuration)
expectation = ExpectColumnValuesToBeInSet(**expectation_configuration.kwargs)
batch_definition = BatchDefinition(
datasource_name="pandas_datasource",
data_connector_name="runtime_data_connector",
Expand All @@ -160,7 +160,7 @@ def _expecation_configuration_to_validation_result_pandas(
batch,
],
)
result = expectation.validate(validator)
result = expectation.validate_(validator)
return result


Expand All @@ -176,7 +176,7 @@ def _expecation_configuration_to_validation_result_sql(
expectation_configuration (ExpectationConfiguration): configuration that is being tested
"""
expectation = ExpectColumnValuesToBeInSet(expectation_configuration)
expectation = ExpectColumnValuesToBeInSet(**expectation_configuration.kwargs)
sqlite_path = file_relative_path(__file__, "../../test_sets/metrics_test.db")
connection_string = f"sqlite:///{sqlite_path}"
engine = SqlAlchemyExecutionEngine(
Expand Down Expand Up @@ -236,7 +236,7 @@ def _expecation_configuration_to_validation_result_sql(
batch,
],
)
result = expectation.validate(validator)
result = expectation.validate_(validator)
return result


Expand Down Expand Up @@ -774,7 +774,7 @@ def test_include_unexpected_rows_without_explicit_result_format_raises_error(
},
)

expectation = ExpectColumnValuesToBeInSet(expectation_configuration)
expectation = ExpectColumnValuesToBeInSet(**expectation_configuration.kwargs)
batch_definition = BatchDefinition(
datasource_name="pandas_datasource",
data_connector_name="runtime_data_connector",
Expand All @@ -795,7 +795,7 @@ def test_include_unexpected_rows_without_explicit_result_format_raises_error(
],
)
with pytest.raises(ValueError):
expectation.validate(validator)
expectation.validate_(validator)


# Spark
Expand All @@ -814,7 +814,7 @@ def test_spark_single_column_complete_result_format(
},
},
)
expectation = ExpectColumnValuesToBeInSet(expectation_configuration)
expectation = ExpectColumnValuesToBeInSet(**expectation_configuration.kwargs)
batch_definition = BatchDefinition(
datasource_name="spark_datasource",
data_connector_name="runtime_data_connector",
Expand All @@ -834,7 +834,7 @@ def test_spark_single_column_complete_result_format(
batch,
],
)
result = expectation.validate(validator)
result = expectation.validate_(validator)
assert convert_to_json_serializable(result.result) == {
"element_count": 6,
"missing_count": 0,
Expand Down Expand Up @@ -871,7 +871,7 @@ def test_spark_single_column_complete_result_format_with_id_pk(
},
},
)
expectation = ExpectColumnValuesToBeInSet(expectation_configuration)
expectation = ExpectColumnValuesToBeInSet(**expectation_configuration.kwargs)
batch_definition = BatchDefinition(
datasource_name="spark_datasource",
data_connector_name="runtime_data_connector",
Expand All @@ -894,7 +894,7 @@ def test_spark_single_column_complete_result_format_with_id_pk(

# result_format configuration at ExpectationConfiguration-level will emit warning
with pytest.warns(UserWarning):
result = expectation.validate(validator)
result = expectation.validate_(validator)

assert convert_to_json_serializable(result.result) == {
"element_count": 6,
Expand Down Expand Up @@ -942,7 +942,7 @@ def test_spark_single_column_summary_result_format(
},
},
)
expectation = ExpectColumnValuesToBeInSet(expectation_configuration)
expectation = ExpectColumnValuesToBeInSet(**expectation_configuration.kwargs)
batch_definition = BatchDefinition(
datasource_name="spark_datasource",
data_connector_name="runtime_data_connector",
Expand All @@ -962,7 +962,7 @@ def test_spark_single_column_summary_result_format(
batch,
],
)
result = expectation.validate(validator)
result = expectation.validate_(validator)
assert convert_to_json_serializable(result.result) == {
"element_count": 6,
"missing_count": 0,
Expand Down Expand Up @@ -995,7 +995,7 @@ def test_spark_single_column_basic_result_format(
},
},
)
expectation = ExpectColumnValuesToBeInSet(expectation_configuration)
expectation = ExpectColumnValuesToBeInSet(**expectation_configuration.kwargs)
batch_definition = BatchDefinition(
datasource_name="spark_datasource",
data_connector_name="runtime_data_connector",
Expand All @@ -1015,7 +1015,7 @@ def test_spark_single_column_basic_result_format(
batch,
],
)
result = expectation.validate(validator)
result = expectation.validate_(validator)
assert convert_to_json_serializable(result.result) == {
"element_count": 6,
"missing_count": 0,
Expand Down
Loading

0 comments on commit 0d0338a

Please sign in to comment.