Skip to content

Commit

Permalink
validates class instances in typed dict
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolfix committed Mar 12, 2024
1 parent 6604289 commit a992134
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
23 changes: 19 additions & 4 deletions dlt/common/validation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import functools
import inspect
from typing import Callable, Any, Type
from typing_extensions import get_type_hints, get_args

Expand Down Expand Up @@ -38,11 +39,10 @@ def validate_dict(
filter_f (TFilterFunc, optional): A function to filter keys in `doc`. It should
return `True` for keys to be kept. Defaults to a function that keeps all keys.
validator_f (TCustomValidator, optional): A function to perform additional validation
for types not covered by this function. It should return `True` if the validation passes.
for types not covered by this function. It should return `True` if the validation passes
or raise DictValidationException on validation error. For types it cannot validate, it
should return False to allow chaining.
Defaults to a function that rejects all such types.
filter_required (TFilterFunc, optional): A function to filter out required fields, useful
for testing historic versions of dict that might now have certain fields yet.
Raises:
DictValidationException: If there are missing required fields, unexpected fields,
type mismatches or unvalidated types in `doc` compared to `spec`.
Expand Down Expand Up @@ -162,8 +162,23 @@ def verify_prop(pk: str, pv: Any, t: Any) -> None:
elif t is Any:
# pass everything with any type
pass
elif inspect.isclass(t) and isinstance(pv, t):
# allow instances of classes
pass
else:
type_name = getattr(t, "__name__", str(t))
pv_type_name = getattr(type(pv), "__name__", str(type(pv)))
# try to apply special validator
if not validator_f(path, pk, pv, t):
# type `t` cannot be validated by validator_f
if inspect.isclass(t):
if not isinstance(pv, t):
raise DictValidationException(
f"In {path}: field {pk} expect class {type_name} but got instance of"
f" {pv_type_name}",
path,
pk,
)
# TODO: when Python 3.9 and earlier support is
# dropped, just __name__ can be used
type_name = getattr(t, "__name__", str(t))
Expand Down
31 changes: 31 additions & 0 deletions tests/common/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import yaml
from typing import Callable, List, Literal, Mapping, Sequence, TypedDict, TypeVar, Optional, Union

from dlt.common import Decimal
from dlt.common.exceptions import DictValidationException
from dlt.common.schema.typing import TStoredSchema, TColumnSchema
from dlt.common.schema.utils import simple_regex_validator
Expand All @@ -18,6 +19,14 @@
TTableHintTemplate = Union[TDynHintType, TFunHintTemplate[TDynHintType]]


class ClassTest:
a: str


class SubClassTest(ClassTest):
b: str


class TDict(TypedDict):
field: TLiteral

Expand All @@ -41,6 +50,7 @@ class TTestRecord(TypedDict):
f_literal_optional: Optional[TLiteral]
f_seq_literal: Sequence[Optional[TLiteral]]
f_optional_union: Optional[Union[TLiteral, TDict]]
f_class: ClassTest


TEST_COL: TColumnSchema = {"name": "col1", "data_type": "bigint", "nullable": False}
Expand Down Expand Up @@ -70,6 +80,7 @@ class TTestRecord(TypedDict):
"f_literal_optional": "dos",
"f_seq_literal": ["uno", "dos", "tres"],
"f_optional_union": {"field": "uno"},
"f_class": SubClassTest(),
}


Expand Down Expand Up @@ -275,6 +286,26 @@ def f(item: Union[TDataItem, TDynHintType]) -> TDynHintType:
)


def test_class() -> None:
class TTestRecordInvalidClass(TypedDict):
prop: SubClassTest

# prop must be SubClassTest or derive from it. not the case below
test_item_1 = {"prop": ClassTest()}
with pytest.raises(DictValidationException):
validate_dict(TTestRecordInvalidClass, test_item_1, path=".")

# unions are accepted
class TTestRecordClassUnion(TypedDict):
prop: Union[SubClassTest, ClassTest]

validate_dict(TTestRecordClassUnion, test_item_1, path=".")

test_item_2 = {"prop": Decimal(1)}
with pytest.raises(DictValidationException):
validate_dict(TTestRecordClassUnion, test_item_2, path=".")


# def test_union_merge() -> None:
# """Overriding fields is simply illegal in TypedDict"""
# class EndpointResource(TypedDict, total=False):
Expand Down

0 comments on commit a992134

Please sign in to comment.