diff --git a/apischema/dataclasses/__init__.py b/apischema/dataclasses/__init__.py index 9b6ce1cb..34020e34 100644 --- a/apischema/dataclasses/__init__.py +++ b/apischema/dataclasses/__init__.py @@ -1,5 +1,12 @@ import sys -from dataclasses import is_dataclass, replace as replace_ +from dataclasses import ( # type: ignore + Field, + is_dataclass, + replace as replace_, + _FIELDS, + _FIELD_CLASSVAR, +) +from typing import Mapping, Type if sys.version_info <= (3, 7): is_dataclass_ = is_dataclass @@ -16,3 +23,12 @@ def replace(*args, **changes): if hasattr(obj, FIELDS_SET_ATTR): set_fields(result, *fields_set(obj), *changes, overwrite=True) return result + + +def fields_items(cls: Type) -> Mapping[str, Field]: + assert is_dataclass(cls) + return { + name: field + for name, field in getattr(cls, _FIELDS).items() + if field._field_type != _FIELD_CLASSVAR + } diff --git a/apischema/dataclasses/cache.py b/apischema/dataclasses/cache.py index 20bda965..abd3ced8 100644 --- a/apischema/dataclasses/cache.py +++ b/apischema/dataclasses/cache.py @@ -33,6 +33,7 @@ Deserialization, Serialization, ) +from apischema.dataclasses import fields_items from apischema.dependencies import DependentRequired from apischema.metadata.keys import ( ALIAS_METADATA, @@ -135,7 +136,7 @@ def _from_aggregate(aggregate_cache: AggregateFieldCache) -> FieldCache: for field in aggregate_fields: metadata = field.base_field.metadata if MERGED_METADATA in metadata: - merged_fields.append((frozenset(_merged_aliases(field.type)), field)) + merged_fields.append((_deserialization_merged_aliases(field.type), field)) else: pattern = metadata[PROPERTIES_METADATA] if pattern is not None: @@ -343,24 +344,23 @@ def _serialization( ) -def _merged_aliases(cls: Type) -> Iterable[str]: +def _deserialization_merged_aliases(cls: Type) -> AbstractSet[str]: + """Return all aliases used in cls deserialization.""" assert dataclasses.is_dataclass(cls) types = get_type_hints(cls, include_extras=True) - for field in dataclasses.fields(cls): + result: Set[str] = set() + for field in fields_items(cls).values(): + if not field.init: + continue if MERGED_METADATA in field.metadata: # No need to check overlapping here because it will be checked # when merged dataclass will be cached - yield from _merged_aliases(types[field.name]) + result |= _deserialization_merged_aliases(types[field.name]) elif PROPERTIES_METADATA in field.metadata: raise TypeError("Merged dataclass cannot have properties field") else: - yield field.metadata.get(ALIAS_METADATA, field.name) - - -def _check_fields_overlap(present: Set[str], other: AbstractSet[str]): - if present & other: - raise TypeError(f"Merged fields {present & other} overlap") - present.update(other) + result.add(field.metadata.get(ALIAS_METADATA, field.name)) + return result def _update_dependencies(cls: AnyType, all_fields: Mapping[str, Field]): @@ -390,8 +390,8 @@ def _update_dependencies(cls: AnyType, all_fields: Mapping[str, Field]): def _filter_by_kind(field_list: Iterable[F], kind: FieldKind) -> Sequence[F]: - fields = [(elt, elt[1] if isinstance(elt, tuple) else elt) for elt in field_list] - return [elt for elt, field in fields if field.kind != kind] + fields = [elt[1] if isinstance(elt, tuple) else elt for elt in field_list] + return [elt for elt, field in zip(field_list, fields) if field.kind != kind] @dataclasses.dataclass @@ -421,10 +421,7 @@ def cache_fields(cls: Type): types = get_type_hints(cls, include_extras=True) lists = FieldLists(cls) all_fields: Dict[str, Field] = {} - all_merged_aliases: Set[str] = set() - for field in getattr(cls, dataclasses._FIELDS).values(): # type: ignore - if field._field_type == dataclasses._FIELD_CLASSVAR: # type: ignore - continue + for field in fields_items(cls).values(): metadata = field.metadata if SKIP_METADATA in metadata: continue @@ -492,8 +489,7 @@ def cache_fields(cls: Type): raise TypeError( f"{error_prefix}Merged field must have a dataclass type" ) - merged_aliases: AbstractSet[str] = frozenset(_merged_aliases(type_)) - _check_fields_overlap(all_merged_aliases, merged_aliases) + merged_aliases = _deserialization_merged_aliases(type_) lists.merged.append((merged_aliases, new_field)) elif PROPERTIES_METADATA in metadata: if any(key in metadata for key in INCOMPATIBLE_WITH_PROPERTIES): @@ -505,7 +501,6 @@ def cache_fields(cls: Type): lists.pattern.append((pattern, new_field)) else: lists.normal.append(new_field) - _check_fields_overlap(all_merged_aliases, all_fields.keys()) _update_dependencies(cls, all_fields) _deserialization_fields[cls] = lists.remove_kind(FieldKind.NO_INIT) _aggregate_serialization_fields[cls] = _to_aggregate( diff --git a/apischema/validation/mock.py b/apischema/validation/mock.py index 06738e6a..04e62d36 100644 --- a/apischema/validation/mock.py +++ b/apischema/validation/mock.py @@ -1,14 +1,9 @@ -from dataclasses import ( # type: ignore - Field, - MISSING, - _FIELDS, - _FIELD_CLASSVAR, - dataclass, -) +from dataclasses import MISSING, dataclass from functools import partial from types import FunctionType, MethodType from typing import Any, Mapping, Optional, TYPE_CHECKING, Type, TypeVar +from apischema.dataclasses import fields_items from apischema.fields import FIELDS_SET_ATTR from apischema.utils import get_default @@ -37,15 +32,12 @@ def __getattribute__(self, name: str) -> Any: if name in fields: return fields[name] cls = super().__getattribute__("cls") - cls_fields: Mapping[str, Field] = getattr(cls, _FIELDS) + cls_fields = fields_items(cls) if name in cls_fields: - if cls_fields[name]._field_type == _FIELD_CLASSVAR: # type: ignore - return getattr(cls, name) - else: - try: - return get_default(cls_fields[name]) - except NotImplementedError: - raise NonTrivialDependency(name) + try: + return get_default(cls_fields[name]) + except NotImplementedError: + raise NonTrivialDependency(name) from None if name == "__class__": return cls if name == "__dict__": @@ -54,11 +46,8 @@ def __getattribute__(self, name: str) -> Any: **{ name: get_default(field) for name, field in cls_fields.items() - if field._field_type != _FIELD_CLASSVAR # type: ignore - and ( - field.default is not MISSING - or field.default_factory is not MISSING # type: ignore - ) + if field.default is not MISSING + or field.default_factory is not MISSING # type: ignore }, FIELDS_SET_ATTR: set(fields), } diff --git a/docs/changelog.md b/docs/changelog.md index 90dcac73..0f08026b 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,5 +1,9 @@ # Changelog +## 0.7.4 + +- Add `InitVar` support for merged dataclasses. + ## 0.7.3 - Fix bugs in settings global default conversion and coercer assignation. diff --git a/setup.py b/setup.py index f20b5d33..e1595d0c 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="apischema", - version="0.7.3", + version="0.7.4", url="https://github.com/wyfo/apischema", author="Joseph Perez", author_email="joperez@hotmail.fr",