Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
sbasan committed Feb 22, 2024
1 parent 42a1be0 commit bf4bb0e
Show file tree
Hide file tree
Showing 30 changed files with 160 additions and 75 deletions.
27 changes: 22 additions & 5 deletions catalystwan/api/configuration_groups/parcel.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
from enum import Enum
from typing import Any, Dict, Generic, Literal, Optional, TypeVar, get_origin

from pydantic import AliasPath, BaseModel, ConfigDict, Field, PrivateAttr, model_serializer
from pydantic import (
AliasPath,
BaseModel,
ConfigDict,
Field,
PrivateAttr,
SerializerFunctionWrapHandler,
model_serializer,
)

T = TypeVar("T")


class _ParcelBase(BaseModel):
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True, populate_by_name=True)
model_config = ConfigDict(
extra="allow", arbitrary_types_allowed=True, populate_by_name=True, json_schema_mode_override="validation"
)
parcel_name: str = Field(
min_length=1,
max_length=128,
Expand All @@ -21,11 +31,18 @@ class _ParcelBase(BaseModel):
validation_alias="description",
description="Set the parcel description",
)
# data: Optional[Any] = None
_parcel_data_key: str = PrivateAttr(default="data")

@model_serializer(mode="wrap", when_used="json")
def envelope_parcel_data(self, handler) -> Dict[str, Any]:
@model_serializer(mode="wrap")
def envelope_parcel_data(self, handler: SerializerFunctionWrapHandler) -> Dict[str, Any]:
"""
serializes model fields with respect to field validation_alias,
sub-classing parcel fields can be defined like following:
>>> entries: List[SecurityZoneListEntry] = Field(default=[], validation_alias=AliasPath("data", "entries"))
"data" is default _parcel_data_key which must match validation_alias prefix,
this attribute can be overriden in sub-class when needed
"""
model_dict = handler(self)
model_dict[self._parcel_data_key] = {}
remove_keys = []
Expand Down
67 changes: 46 additions & 21 deletions catalystwan/endpoints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
runtime_checkable,
Expand Down Expand Up @@ -114,6 +115,47 @@ def json(cls) -> TypeSpecifier:
def model_union(cls, models: Sequence[type]) -> TypeSpecifier:
return TypeSpecifier(present=True, payload_union_model_types=models)

@classmethod
def resolve_nested_base_model_unions(
cls, annotation: Any, models_types: List[Union[Type[BaseModelV1], Type[BaseModelV2]]]
) -> List[Union[Type[BaseModelV1], Type[BaseModelV2]]]:
type_origin = get_origin(annotation)
if isclass(annotation):
try:
if issubclass(annotation, (BaseModelV1, BaseModelV2)):
return [annotation]
raise APIEndpointError(f"Expected: {PayloadType}")
except TypeError:
raise APIEndpointError(f"Expected: {PayloadType}")
# Check if Annnotated[Union[PayloadModelType, ...]], only unions of pydantic models allowed
elif type_origin == Annotated:
if annotated_origin := get_args(annotation):
if (len(annotated_origin) >= 1) and get_origin(annotated_origin[0]) == Union:
type_args = get_args(annotated_origin[0])
if all(isclass(t) for t in type_args) and all(
issubclass(t, (BaseModelV1, BaseModelV2)) for t in type_args
):
models_types.extend(list(type_args))
return models_types
else:
non_models = [t for t in type_args if not isclass(t)]
for non_model in non_models:
models_types.extend(cls.resolve_nested_base_model_unions(non_model, models_types))
return models_types

# Check if Union[PayloadModelType, ...], only unions of pydantic models allowed
elif type_origin == Union:
type_args = get_args(annotation)
if all(isclass(t) for t in type_args) and all(issubclass(t, (BaseModelV1, BaseModelV2)) for t in type_args):
models_types.extend(list(type_args))
return models_types
else:
non_models = [t for t in type_args if not isclass(t)]
for non_model in non_models:
models_types.extend(cls.resolve_nested_base_model_unions(non_model, models_types))
return models_types
raise APIEndpointError(f"Expected: {PayloadType}")


@dataclass
class APIEndpointRequestMeta:
Expand Down Expand Up @@ -451,27 +493,10 @@ def specify_payload_type(self) -> TypeSpecifier:
and issubclass(type_args[0], (BaseModelV1, BaseModelV2))
):
return TypeSpecifier(True, type_origin, type_args[0], None, False, is_optional)
# Check if Annnotated[Union[PayloadModelType, ...]], only unions of pydantic models allowed
elif type_origin == Annotated:
if annotated_origin := get_args(annotation):
if (len(annotated_origin) >= 1) and get_origin(annotated_origin[0]) == Union:
if (
(type_args := get_args(annotated_origin[0]))
and all(isclass(t) for t in type_args)
and all(issubclass(t, (BaseModelV1, BaseModelV2)) for t in type_args)
):
return TypeSpecifier.model_union(models=list(type_args))
# Check if Union[PayloadModelType, ...], only unions of pydantic models allowed
elif type_origin == Union:
if (
(type_args := get_args(annotation))
and all(isclass(t) for t in type_args)
and all(issubclass(t, (BaseModelV1, BaseModelV2)) for t in type_args)
):
return TypeSpecifier.model_union(models=list(type_args))
raise APIEndpointError(f"Expected: {PayloadType} but found payload {annotation}")
else:
raise APIEndpointError(f"Expected: {PayloadType} but found payload {annotation}")
else:
models = TypeSpecifier.resolve_nested_base_model_unions(annotation, [])
return TypeSpecifier.model_union(models)
raise APIEndpointError(f"'payload' param must be annotated with supported type: {PayloadType}")

def check_params(self):
"""Checks params in decorated method definition
Expand Down
4 changes: 2 additions & 2 deletions catalystwan/models/configuration/config_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from pydantic import BaseModel, ConfigDict, Field

from catalystwan.api.configuration_groups.parcel import _ParcelBase
from catalystwan.endpoints.configuration_group import ConfigGroup
from catalystwan.models.configuration.feature_profile.common import FeatureProfileCreationPayload
from catalystwan.models.configuration.feature_profile.sdwan.policy_object import AnyPolicyObjectParcel
from catalystwan.models.policy import (
AnyPolicyDefinition,
AnyPolicyList,
Expand Down Expand Up @@ -56,6 +56,6 @@ class UX2Config(BaseModel):
feature_profiles: List[FeatureProfileCreationPayload] = Field(
default=[], serialization_alias="featureProfiles", validation_alias="featureProfiles"
)
profile_parcels: List[_ParcelBase] = Field(
profile_parcels: List[AnyPolicyObjectParcel] = Field(
default=[], serialization_alias="profileParcels", validation_alias="profileParcels"
)
Original file line number Diff line number Diff line change
Expand Up @@ -32,35 +32,42 @@
from .security.url import BaseURLListEntry, URLAllowParcel, URLBlockParcel
from .security.zone import SecurityZoneListEntry, SecurityZoneListParcel

AnyURLParcel = Annotated[
Union[
URLAllowParcel,
URLBlockParcel,
],
Field(discriminator="parcel_type"),
]

AnyPolicyObjectParcel = Annotated[
Union[
AppProbeParcel,
AnyURLParcel,
ApplicationListParcel,
AppProbeParcel,
ColorParcel,
DataPrefixParcel,
ExpandedCommunityParcel,
FowardingClassParcel,
FQDNDomainParcel,
GeoLocationListParcel,
IPSSignatureParcel,
IPv6DataPrefixParcel,
IPv6PrefixListParcel,
PrefixListParcel,
LocalDomainParcel,
PolicierParcel,
PreferredColorGroupParcel,
SLAClassParcel,
TlocParcel,
StandardCommunityParcel,
LocalDomainParcel,
FQDNDomainParcel,
IPSSignatureParcel,
URLAllowParcel,
URLBlockParcel,
SecurityPortParcel,
PrefixListParcel,
ProtocolListParcel,
GeoLocationListParcel,
SecurityZoneListParcel,
SecurityApplicationListParcel,
SecurityDataPrefixParcel,
SecurityPortParcel,
SecurityZoneListParcel,
SLAClassParcel,
StandardCommunityParcel,
TlocParcel,
],
Field(discriminator="type"),
Field(discriminator="type_"),
]

POLICY_OBJECT_PAYLOAD_ENDPOINT_MAPPING: Mapping[type, str] = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Literal

from pydantic import AliasPath, BaseModel, ConfigDict, Field, field_validator

Expand Down Expand Up @@ -26,6 +26,7 @@ class AppProbeEntry(BaseModel):


class AppProbeParcel(_ParcelBase):
type_: Literal["app-probe"] = Field(default="app-probe", exclude=True)
entries: List[AppProbeEntry] = Field(default=[], validation_alias=AliasPath("data", "entries"))

def add_fowarding_class(self, forwarding_class_name: str):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Union
from typing import List, Literal, Union

from pydantic import AliasPath, BaseModel, ConfigDict, Field

Expand All @@ -16,6 +16,7 @@ class ApplicationFamilyListEntry(BaseModel):


class ApplicationListParcel(_ParcelBase):
type_: Literal["app-list"] = Field(default="app-list", exclude=True)
entries: List[Union[ApplicationListEntry, ApplicationFamilyListEntry]] = Field(
default=[], validation_alias=AliasPath("data", "entries")
)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Literal

from pydantic import AliasPath, BaseModel, Field

Expand All @@ -11,6 +11,7 @@ class ColorEntry(BaseModel):


class ColorParcel(_ParcelBase):
type_: Literal["color"] = Field(default="color", exclude=True)
entries: List[ColorEntry] = Field(default=[], validation_alias=AliasPath("data", "entries"))

def add_color(self, color: TLOCColor):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ipaddress import IPv4Address, IPv4Network
from typing import List
from typing import List, Literal

from pydantic import AliasPath, BaseModel, ConfigDict, Field

Expand All @@ -20,6 +20,7 @@ def from_ipv4_network(ipv4_network: IPv4Network) -> "DataPrefixEntry":


class DataPrefixParcel(_ParcelBase):
type_: Literal["data-prefix"] = Field(default="data-prefix", exclude=True)
entries: List[DataPrefixEntry] = Field(default_factory=list, validation_alias=AliasPath("data", "entries"))

def add_data_prefix(self, ipv4_network: IPv4Network):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Literal

from pydantic import AliasPath, ConfigDict, Field, field_validator

from catalystwan.api.configuration_groups.parcel import Global, _ParcelBase, as_global


class ExpandedCommunityParcel(_ParcelBase):
type_: Literal["expanded-community"] = Field(default="expanded-community", exclude=True)
model_config = ConfigDict(populate_by_name=True)
expandedCommunityList: Global[list] = Field(
default=as_global([]),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Literal

from pydantic import AliasPath, BaseModel, Field, field_validator

Expand All @@ -16,6 +16,7 @@ def check_burst(cls, queue: Global):


class FowardingClassParcel(_ParcelBase):
type_: Literal["class"] = Field(default="class", exclude=True)
entries: List[FowardingClassQueueEntry] = Field(default=[], validation_alias=AliasPath("data", "entries"))

def add_queue(self, queue: int):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ipaddress import IPv6Address, IPv6Network
from typing import List
from typing import List, Literal

from pydantic import AliasPath, BaseModel, ConfigDict, Field

Expand All @@ -13,6 +13,7 @@ class IPv6DataPrefixEntry(BaseModel):


class IPv6DataPrefixParcel(_ParcelBase):
type_: Literal["data-ipv6-prefix"] = Field(default="data-ipv6-prefix", exclude=True)
entries: List[IPv6DataPrefixEntry] = Field(default=[], validation_alias=AliasPath("data", "entries"))

def add_prefix(self, ipv6_network: IPv6Network):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ipaddress import IPv6Address, IPv6Network
from typing import List
from typing import List, Literal

from pydantic import AliasPath, BaseModel, ConfigDict, Field

Expand All @@ -13,6 +13,7 @@ class IPv6PrefixListEntry(BaseModel):


class IPv6PrefixListParcel(_ParcelBase):
type_: Literal["ipv6-prefix"] = Field(default="ipv6-prefix", exclude=True)
entries: List[IPv6PrefixListEntry] = Field(default=[], validation_alias=AliasPath("data", "entries"))

def add_prefix(self, ipv6_network: IPv6Network):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Literal

from pydantic import AliasPath, BaseModel, ConfigDict, Field, field_validator

Expand Down Expand Up @@ -26,6 +26,7 @@ def check_rate(cls, rate_str: Global):


class PolicierParcel(_ParcelBase):
type_: Literal["policer"] = Field(default="policer", exclude=True)
entries: List[PolicierEntry] = Field(default=[], validation_alias=AliasPath("data", "entries"))

def add_entry(self, burst: int, exceed: PolicerExceedAction, rate: int):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List, Literal, Optional

from pydantic import AliasPath, BaseModel, ConfigDict, Field, model_validator

Expand Down Expand Up @@ -35,6 +35,7 @@ def check_passwords_match(self) -> "PreferredColorGroupEntry":


class PreferredColorGroupParcel(_ParcelBase):
type_: Literal["preferred-color-group"] = Field(default="preferred-color-group", exclude=True)
entries: List[PreferredColorGroupEntry] = Field(default=[], validation_alias=AliasPath("data", "entries"))

def add_primary(self, color_preference: List[TLOCColor], path_preference: PathPreference):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ipaddress import IPv4Address, IPv4Network
from typing import List
from typing import List, Literal

from pydantic import AliasPath, BaseModel, ConfigDict, Field

Expand All @@ -13,6 +13,7 @@ class PrefixListEntry(BaseModel):


class PrefixListParcel(_ParcelBase):
type_: Literal["prefix"] = Field(default="prefix", exclude=True)
entries: List[PrefixListEntry] = Field(default_factory=list, validation_alias=AliasPath("data", "entries"))

def add_prefix(self, ipv4_network: IPv4Network):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class SLAClassListEntry(BaseModel):


class SLAClassParcel(_ParcelBase):
type_: Literal["sla-class"] = Field(default="sla-class", exclude=True)
entries: List[SLAClassListEntry] = Field(default=[], validation_alias=AliasPath("data", "entries"))

def add_entry(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Literal

from pydantic import AliasPath, BaseModel, ConfigDict, Field

Expand All @@ -14,6 +14,7 @@ class StandardCommunityEntry(BaseModel):


class StandardCommunityParcel(_ParcelBase):
type_: Literal["standard-community"] = Field(default="standard-community", exclude=True)
entries: List[StandardCommunityEntry] = Field(default=[], validation_alias=AliasPath("data", "entries"))

def add_community(self, standard_community: WellKnownBGPCommunities):
Expand Down
Loading

0 comments on commit bf4bb0e

Please sign in to comment.