diff --git a/catalystwan/api/configuration_groups/parcel.py b/catalystwan/api/configuration_groups/parcel.py index ce53e3ea..a156f8b2 100644 --- a/catalystwan/api/configuration_groups/parcel.py +++ b/catalystwan/api/configuration_groups/parcel.py @@ -1,7 +1,7 @@ # Copyright 2023 Cisco Systems, Inc. and its affiliates from enum import Enum -from typing import Any, Dict, Generic, Literal, Optional, TypeVar, get_origin +from typing import Any, Dict, Generic, Literal, Optional, Tuple, TypeVar, get_origin from pydantic import ( AliasPath, @@ -9,11 +9,13 @@ ConfigDict, Field, PrivateAttr, + SerializationInfo, SerializerFunctionWrapHandler, model_serializer, ) from catalystwan.exceptions import CatalystwanException +from catalystwan.models.common import VersionedField T = TypeVar("T") @@ -38,7 +40,7 @@ class _ParcelBase(BaseModel): _parcel_data_key: str = PrivateAttr(default="data") @model_serializer(mode="wrap") - def envelope_parcel_data(self, handler: SerializerFunctionWrapHandler) -> Dict[str, Any]: + def envelope_parcel_data(self, handler: SerializerFunctionWrapHandler, info: SerializationInfo) -> Dict[str, Any]: """ serializes model fields with respect to field validation_alias, sub-classing parcel fields can be defined like following: @@ -50,16 +52,23 @@ def envelope_parcel_data(self, handler: SerializerFunctionWrapHandler) -> Dict[s model_dict = handler(self) model_dict[self._parcel_data_key] = {} remove_keys = [] + replaced_keys: Dict[str, Tuple[str, str]] = {} + # enveloping for key in model_dict.keys(): field_info = self.model_fields.get(key) if field_info and isinstance(field_info.validation_alias, AliasPath): aliases = field_info.validation_alias.convert_to_aliases() if aliases and aliases[0] == self._parcel_data_key and len(aliases) == 2: model_dict[self._parcel_data_key][aliases[1]] = model_dict[key] + replaced_keys[key] = (self._parcel_data_key, str(aliases[1])) remove_keys.append(key) for key in remove_keys: del model_dict[key] + + # versioned field update + model_dict = VersionedField.dump(self.model_fields, model_dict, info, replaced_keys) + return model_dict @classmethod diff --git a/catalystwan/models/common.py b/catalystwan/models/common.py index 6a1565ff..db272abf 100644 --- a/catalystwan/models/common.py +++ b/catalystwan/models/common.py @@ -1,12 +1,12 @@ # Copyright 2023 Cisco Systems, Inc. and its affiliates from dataclasses import InitVar, dataclass, field -from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Union +from typing import Any, Dict, Iterator, List, Literal, Mapping, Optional, Sequence, Set, Tuple, Union from uuid import UUID from packaging.specifiers import SpecifierSet # type: ignore from packaging.version import Version # type: ignore -from pydantic import PlainSerializer, SerializationInfo +from pydantic import PlainSerializer, SerializationInfo, ValidationInfo from pydantic.fields import FieldInfo from pydantic.functional_validators import BeforeValidator from typing_extensions import Annotated @@ -41,29 +41,54 @@ def __post_init__(self, versions): self.versions_set = SpecifierSet(versions) @staticmethod - def update_model_fields( - model_fields: Dict[str, FieldInfo], model_dict: Dict[str, Any], serialization_info: SerializationInfo + def model_iterate( + model_fields: Dict[str, FieldInfo], info: Union[SerializationInfo, ValidationInfo] + ) -> Iterator[Tuple[str, FieldInfo, "VersionedField"]]: + """Itrerates over model fields that matches a version given in context (Serialization info or ValidationInfo) + + Yields: + Tuple[str, FieldInfo, VersionedField]: a tuple containing field name, FieldInfo and VersionedField + """ + if info.context is not None: + api_version: Optional[Version] = info.context.get("api_version") + if api_version is not None: + for field_name, field_info in model_fields.items(): + versioned_fields = [meta for meta in field_info.metadata if isinstance(meta, VersionedField)] + for versioned_field in versioned_fields: + if api_version in versioned_field.versions_set: + yield (field_name, field_info, versioned_field) + + @staticmethod + def dump( + model_fields: Dict[str, FieldInfo], + model_dict: Dict[str, Any], + info: SerializationInfo, + replaced_keys: Optional[Mapping[str, Tuple[Optional[str], str]]] = None, ) -> Dict[str, Any]: """To be reused in methods decorated with pydantic.model_serializer Args: model_fields (Dict[str, FieldInfo]): obtained from BaseModel class model_dict (Dict[str, Any]): obtained from serialized BaseModel instance serialization_info (SerializationInfo): passed from serializer + replaced_keys (Dict[str, Tuple(Optional[str], str)]): field names that were replaced + previously during serialization + (Tuple represent path and new field name - this currently supports up to 1 level deep alias path only) Returns: Dict[str, Any]: model_dict with updated field names according to matching runtime version """ - if serialization_info.context is not None: - api_version: Optional[Version] = serialization_info.context.get("api_version") - if api_version is not None: - for field_name, field_info in model_fields.items(): - versioned_fields = [meta for meta in field_info.metadata if isinstance(meta, VersionedField)] - for versioned_field in versioned_fields: - if api_version in versioned_field.versions_set: - current_field_name = field_info.serialization_alias or field_info.alias or field_name - if model_dict.get(current_field_name) is not None: - model_dict[versioned_field.serialization_alias] = model_dict[current_field_name] - del model_dict[current_field_name] + for field_name, field_info, versioned_field in VersionedField.model_iterate(model_fields, info): + current_field_name = field_info.serialization_alias or field_info.alias or field_name + new_field_name = versioned_field.serialization_alias + if current_field_name in model_dict: + model_dict[new_field_name] = model_dict[current_field_name] + del model_dict[current_field_name] + elif replaced_keys is not None: + if current_field_path := replaced_keys.get(current_field_name): + path, name = current_field_path + dict_ = model_dict[path] if path is not None else model_dict + dict_[new_field_name] = dict_[name] + del dict_[name] return model_dict diff --git a/catalystwan/tests/test_endpoints.py b/catalystwan/tests/test_endpoints.py index b40511f8..74638bec 100644 --- a/catalystwan/tests/test_endpoints.py +++ b/catalystwan/tests/test_endpoints.py @@ -890,7 +890,7 @@ class Payload(BaseModel): @model_serializer(mode="wrap") def serialize(self, handler, info): - return VersionedField.update_model_fields(self.model_fields, handler(self), info) + return VersionedField.dump(self.model_fields, handler(self), info) class ExampleAPI(APIEndpoints): @request("POST", "/v1/data") diff --git a/catalystwan/tests/test_models_common.py b/catalystwan/tests/test_models_common.py index 8139e302..bc0c7ccb 100644 --- a/catalystwan/tests/test_models_common.py +++ b/catalystwan/tests/test_models_common.py @@ -28,7 +28,7 @@ class VersionedFieldsModel(BaseModel): @model_serializer(mode="wrap") def dump(self, handler: SerializerFunctionWrapHandler, info: SerializationInfo) -> Dict[str, Any]: - return VersionedField.update_model_fields(self.model_fields, handler(self), info) + return VersionedField.dump(self.model_fields, handler(self), info) class Payload(BaseModel):