From a9921347e06dd4a215a5c89d9c9b168e5cd279fe Mon Sep 17 00:00:00 2001 From: Marcin Rudolf Date: Tue, 12 Mar 2024 15:51:58 +0100 Subject: [PATCH] validates class instances in typed dict --- dlt/common/validation.py | 23 +++++++++++++++++++---- tests/common/test_validation.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+), 4 deletions(-) diff --git a/dlt/common/validation.py b/dlt/common/validation.py index 6bf1356aeb..4b54d6a29e 100644 --- a/dlt/common/validation.py +++ b/dlt/common/validation.py @@ -1,5 +1,6 @@ import contextlib import functools +import inspect from typing import Callable, Any, Type from typing_extensions import get_type_hints, get_args @@ -38,11 +39,10 @@ def validate_dict( filter_f (TFilterFunc, optional): A function to filter keys in `doc`. It should return `True` for keys to be kept. Defaults to a function that keeps all keys. validator_f (TCustomValidator, optional): A function to perform additional validation - for types not covered by this function. It should return `True` if the validation passes. + for types not covered by this function. It should return `True` if the validation passes + or raise DictValidationException on validation error. For types it cannot validate, it + should return False to allow chaining. Defaults to a function that rejects all such types. - filter_required (TFilterFunc, optional): A function to filter out required fields, useful - for testing historic versions of dict that might now have certain fields yet. - Raises: DictValidationException: If there are missing required fields, unexpected fields, type mismatches or unvalidated types in `doc` compared to `spec`. @@ -162,8 +162,23 @@ def verify_prop(pk: str, pv: Any, t: Any) -> None: elif t is Any: # pass everything with any type pass + elif inspect.isclass(t) and isinstance(pv, t): + # allow instances of classes + pass else: + type_name = getattr(t, "__name__", str(t)) + pv_type_name = getattr(type(pv), "__name__", str(type(pv))) + # try to apply special validator if not validator_f(path, pk, pv, t): + # type `t` cannot be validated by validator_f + if inspect.isclass(t): + if not isinstance(pv, t): + raise DictValidationException( + f"In {path}: field {pk} expect class {type_name} but got instance of" + f" {pv_type_name}", + path, + pk, + ) # TODO: when Python 3.9 and earlier support is # dropped, just __name__ can be used type_name = getattr(t, "__name__", str(t)) diff --git a/tests/common/test_validation.py b/tests/common/test_validation.py index f7773fb89c..3297df1038 100644 --- a/tests/common/test_validation.py +++ b/tests/common/test_validation.py @@ -3,6 +3,7 @@ import yaml from typing import Callable, List, Literal, Mapping, Sequence, TypedDict, TypeVar, Optional, Union +from dlt.common import Decimal from dlt.common.exceptions import DictValidationException from dlt.common.schema.typing import TStoredSchema, TColumnSchema from dlt.common.schema.utils import simple_regex_validator @@ -18,6 +19,14 @@ TTableHintTemplate = Union[TDynHintType, TFunHintTemplate[TDynHintType]] +class ClassTest: + a: str + + +class SubClassTest(ClassTest): + b: str + + class TDict(TypedDict): field: TLiteral @@ -41,6 +50,7 @@ class TTestRecord(TypedDict): f_literal_optional: Optional[TLiteral] f_seq_literal: Sequence[Optional[TLiteral]] f_optional_union: Optional[Union[TLiteral, TDict]] + f_class: ClassTest TEST_COL: TColumnSchema = {"name": "col1", "data_type": "bigint", "nullable": False} @@ -70,6 +80,7 @@ class TTestRecord(TypedDict): "f_literal_optional": "dos", "f_seq_literal": ["uno", "dos", "tres"], "f_optional_union": {"field": "uno"}, + "f_class": SubClassTest(), } @@ -275,6 +286,26 @@ def f(item: Union[TDataItem, TDynHintType]) -> TDynHintType: ) +def test_class() -> None: + class TTestRecordInvalidClass(TypedDict): + prop: SubClassTest + + # prop must be SubClassTest or derive from it. not the case below + test_item_1 = {"prop": ClassTest()} + with pytest.raises(DictValidationException): + validate_dict(TTestRecordInvalidClass, test_item_1, path=".") + + # unions are accepted + class TTestRecordClassUnion(TypedDict): + prop: Union[SubClassTest, ClassTest] + + validate_dict(TTestRecordClassUnion, test_item_1, path=".") + + test_item_2 = {"prop": Decimal(1)} + with pytest.raises(DictValidationException): + validate_dict(TTestRecordClassUnion, test_item_2, path=".") + + # def test_union_merge() -> None: # """Overriding fields is simply illegal in TypedDict""" # class EndpointResource(TypedDict, total=False):