Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
fix validation errors during collecting policies
Browse files Browse the repository at this point in the history
  • Loading branch information
sbasan committed Mar 5, 2024
1 parent 03ab25c commit 7dbabe5
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 56 deletions.
30 changes: 29 additions & 1 deletion catalystwan/models/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from typing import Dict, List, Literal, Set, Tuple
from typing import Dict, List, Literal, Sequence, Set, Tuple, Union
from uuid import UUID

from pydantic import PlainSerializer
from pydantic.functional_validators import BeforeValidator
from typing_extensions import Annotated


def check_fields_exclusive(values: Dict, field_names: Set[str], at_least_one: bool = False) -> bool:
Expand Down Expand Up @@ -46,6 +51,25 @@ def check_any_of_exclusive_field_sets(values: Dict, field_sets: List[Tuple[Set[s
raise ValueError(f"One of {all_sets_field_names} must be assigned")


IntStr = Annotated[
int,
PlainSerializer(lambda x: str(x), return_type=str, when_used="json-unless-none"),
BeforeValidator(lambda x: int(x)),
]


def str_as_uuid_list(val: Union[str, Sequence[UUID]]) -> Sequence[UUID]:
if isinstance(val, str):
return [UUID(uuid_) for uuid_ in val.split()]
return val


def str_as_str_list(val: Union[str, Sequence[str]]) -> Sequence[str]:
if isinstance(val, str):
return [s for s in val.split()]
return val


InterfaceType = Literal[
"Ethernet",
"FastEthernet",
Expand Down Expand Up @@ -114,3 +138,7 @@ def check_any_of_exclusive_field_sets(values: Dict, field_sets: List[Tuple[Set[s
"SC15",
"SC16",
]

ICMPMessageType = Literal[
"echo", "echo-reply", "unreachable", "net-unreachable", "host-unreachable", "protocol-unreachable"
]
4 changes: 2 additions & 2 deletions catalystwan/models/policy/definitions/access_control_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def match_high_plp(self) -> None:
def match_protocols(self, protocols: Set[int]) -> None:
self._insert_match(ProtocolEntry.from_protocol_set(protocols))

def match_source_data_prefix_list(self, data_prefix_list_id: UUID) -> None:
self._insert_match(SourceDataPrefixListEntry(ref=data_prefix_list_id))
def match_source_data_prefix_list(self, data_prefix_lists: List[UUID]) -> None:
self._insert_match(SourceDataPrefixListEntry(ref=data_prefix_lists))

def match_source_ip(self, networks: List[IPv4Network]) -> None:
self._insert_match(SourceIPEntry.from_ipv4_networks(networks))
Expand Down
4 changes: 2 additions & 2 deletions catalystwan/models/policy/definitions/device_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ class DeviceAccessPolicySequence(PolicyDefinitionSequenceBase):
def match_device_access_protocol(self, port: DeviceAccessProtocol) -> None:
self._insert_match(DestinationPortEntry.from_port_set_and_ranges(ports={port}))

def match_source_data_prefix_list(self, data_prefix_list_id: UUID) -> None:
self._insert_match(SourceDataPrefixListEntry(ref=data_prefix_list_id))
def match_source_data_prefix_list(self, data_prefix_lists: List[UUID]) -> None:
self._insert_match(SourceDataPrefixListEntry(ref=data_prefix_lists))

def match_source_ip(self, networks: List[IPv4Network]) -> None:
self._insert_match(SourceIPEntry.from_ipv4_networks(networks))
Expand Down
57 changes: 26 additions & 31 deletions catalystwan/models/policy/definitions/qos_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator

from catalystwan.models.common import IntStr
from catalystwan.models.policy.policy_definition import PolicyDefinitionBase

QoSScheduling = Literal[
Expand All @@ -17,11 +18,17 @@


class QoSScheduler(BaseModel):
queue: str
class_map_ref: Union[UUID, Literal[""]] = Field(serialization_alias="classMapRef", validation_alias="classMapRef")
bandwidth_percent: str = Field("1", serialization_alias="bandwidthPercent", validation_alias="bandwidthPercent")
buffer_percent: str = Field("1", serialization_alias="bufferPercent", validation_alias="bufferPercent")
burst: Optional[str] = None
queue: IntStr = Field(ge=0, le=8)
class_map_ref: Optional[UUID] = Field(
default=None, serialization_alias="classMapRef", validation_alias="classMapRef"
)
bandwidth_percent: IntStr = Field(
default=1, ge=1, le=100, serialization_alias="bandwidthPercent", validation_alias="bandwidthPercent"
)
buffer_percent: IntStr = Field(
default=1, ge=1, le=100, serialization_alias="bufferPercent", validation_alias="bufferPercent"
)
burst: Optional[IntStr] = Field(default=None, ge=5000, le=10_000_000)
scheduling: QoSScheduling = "wrr"
drops: QoSDropType = "tail-drop"
temp_key_values: Optional[str] = Field(
Expand All @@ -31,35 +38,23 @@ class QoSScheduler(BaseModel):
@staticmethod
def get_default_control_scheduler() -> "QoSScheduler":
return QoSScheduler(
queue="0",
class_map_ref="",
bandwidth_percent="100",
buffer_percent="100",
burst="15000",
queue=0,
bandwidth_percent=100,
buffer_percent=100,
burst=15000,
scheduling="llq",
drops="tail-drop",
)

model_config = ConfigDict(populate_by_name=True)

@field_validator("queue")
@classmethod
def check_queue(cls, queue_str: str):
assert 0 <= int(queue_str) <= 7
return queue_str

@field_validator("bandwidth_percent", "buffer_percent")
@classmethod
def check_bandwidth_and_buffer_percent(cls, percent_str: str):
assert 1 <= int(percent_str) <= 100
return percent_str

@field_validator("burst")
@field_validator("class_map_ref", mode="before")
@classmethod
def check_burst(cls, burst_val: Union[str, None]):
if burst_val is not None:
assert 5000 <= int(burst_val) <= 10_000_000
return burst_val
def check_optional_class_map_ref(cls, class_map_ref: Union[str, None]):
# None and "" indicates missing value, both can be found in server responses
if not class_map_ref:
return None
return class_map_ref


class QoSMapDefinition(BaseModel):
Expand All @@ -84,11 +79,11 @@ def add_scheduler(
) -> None:
self.definition.qos_schedulers.append(
QoSScheduler(
queue=str(queue),
queue=queue,
class_map_ref=class_map_ref,
bandwidth_percent=str(bandwidth),
buffer_percent=str(buffer),
burst=str(burst) if burst is not None else None,
bandwidth_percent=bandwidth,
buffer_percent=buffer,
burst=burst,
scheduling=scheduling,
drops=drops,
)
Expand Down
35 changes: 20 additions & 15 deletions catalystwan/models/policy/definitions/traffic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import ConfigDict, Field
from typing_extensions import Annotated

from catalystwan.models.common import ServiceChainNumber, TLOCColor
from catalystwan.models.common import ICMPMessageType, ServiceChainNumber, TLOCColor
from catalystwan.models.policy.lists_entries import EncapType
from catalystwan.models.policy.policy_definition import (
AppListEntry,
Expand All @@ -24,6 +24,7 @@
DSCPEntry,
FallBackToRoutingAction,
ForwardingClassEntry,
ICMPMessageEntry,
LocalTLOCListEntry,
LocalTLOCListEntryValue,
LogAction,
Expand Down Expand Up @@ -64,24 +65,25 @@

TrafficDataPolicySequenceEntry = Annotated[
Union[
AppListEntry,
DestinationDataIPv6PrefixListEntry,
DestinationDataPrefixListEntry,
DestinationIPEntry,
DestinationPortEntry,
DestinationRegionEntry,
DNSAppListEntry,
DNSEntry,
DSCPEntry,
ICMPMessageEntry,
PacketLengthEntry,
PLPEntry,
ProtocolEntry,
DSCPEntry,
SourceDataIPv6PrefixListEntry,
SourceDataPrefixListEntry,
SourceIPEntry,
SourcePortEntry,
DestinationIPEntry,
DestinationPortEntry,
TCPEntry,
DNSEntry,
TrafficToEntry,
SourceDataPrefixListEntry,
DestinationDataPrefixListEntry,
SourceDataIPv6PrefixListEntry,
DestinationDataIPv6PrefixListEntry,
DestinationRegionEntry,
DNSAppListEntry,
AppListEntry,
],
Field(discriminator="field"),
]
Expand All @@ -98,7 +100,7 @@ class TrafficDataPolicySequenceMatch(Match):


class TrafficDataPolicySequence(PolicyDefinitionSequenceBase):
sequence_type: Literal["data"] = Field(
sequence_type: Literal["applicationFirewall", "qos", "serviceChaining", "trafficEngineering", "data"] = Field(
default="data", serialization_alias="sequenceType", validation_alias="sequenceType"
)
match: TrafficDataPolicySequenceMatch = TrafficDataPolicySequenceMatch()
Expand All @@ -120,6 +122,9 @@ def match_dns_response(self) -> None:
def match_dscp(self, dscp: int) -> None:
self._insert_match(DSCPEntry(value=str(dscp)))

def match_icmp(self, icmp_message_types: List[ICMPMessageType]) -> None:
self._insert_match(ICMPMessageEntry(value=icmp_message_types))

def match_packet_length(self, packet_lengths: Tuple[int, int]) -> None:
self._insert_match(PacketLengthEntry.from_range(packet_lengths))

Expand All @@ -132,8 +137,8 @@ def match_high_plp(self) -> None:
def match_protocols(self, protocols: Set[int]) -> None:
self._insert_match(ProtocolEntry.from_protocol_set(protocols))

def match_source_data_prefix_list(self, data_prefix_list_id: UUID) -> None:
self._insert_match(SourceDataPrefixListEntry(ref=data_prefix_list_id))
def match_source_data_prefix_list(self, data_prefix_lists: List[UUID]) -> None:
self._insert_match(SourceDataPrefixListEntry(ref=data_prefix_lists))

def match_source_ip(self, networks: List[IPv4Network]) -> None:
self._insert_match(SourceIPEntry.from_ipv4_networks(networks))
Expand Down
4 changes: 2 additions & 2 deletions catalystwan/models/policy/definitions/zone_based_firewall.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ def match_protocol_names(self, names: Set[str], protocol_map: Dict[str, Applicat
def match_protocol_name_list(self, protocol_name_list_id: UUID) -> None:
self._insert_match(ProtocolNameListEntry(ref=protocol_name_list_id))

def match_source_data_prefix_list(self, data_prefix_list_id: UUID) -> None:
self._insert_match(SourceDataPrefixListEntry(ref=data_prefix_list_id))
def match_source_data_prefix_list(self, data_prefix_lists: List[UUID]) -> None:
self._insert_match(SourceDataPrefixListEntry(ref=data_prefix_lists))

def match_source_fqdn(self, fqdn: str) -> None:
self._insert_match(SourceFQDNEntry(value=fqdn))
Expand Down
23 changes: 20 additions & 3 deletions catalystwan/models/policy/policy_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,17 @@
from typing import Any, Dict, List, MutableSequence, Optional, Protocol, Sequence, Set, Tuple, Union
from uuid import UUID

from pydantic import BaseModel, ConfigDict, Field, RootModel, model_validator
from pydantic import BaseModel, ConfigDict, Field, RootModel, field_validator, model_validator
from typing_extensions import Annotated, Literal

from catalystwan.models.common import ServiceChainNumber, TLOCColor, check_fields_exclusive
from catalystwan.models.common import (
ICMPMessageType,
ServiceChainNumber,
TLOCColor,
check_fields_exclusive,
str_as_str_list,
str_as_uuid_list,
)
from catalystwan.models.misc.application_protocols import ApplicationProtocol
from catalystwan.models.policy.lists_entries import EncapType
from catalystwan.typed_list import DataSequence
Expand Down Expand Up @@ -485,9 +492,18 @@ def from_nat_vpn(fallback: bool, vpn: int = 0) -> "NATVPNEntry":
return NATVPNEntry(root=[UseVPNEntry(value=str(vpn))])


class ICMPMessageEntry(BaseModel):
field: Literal["icmpMessage"] = "icmpMessage"
value: List[ICMPMessageType]

_value = field_validator("value", mode="before")(str_as_str_list)


class SourceDataPrefixListEntry(BaseModel):
field: Literal["sourceDataPrefixList"] = "sourceDataPrefixList"
ref: UUID
ref: List[UUID]

_ref = field_validator("ref", mode="before")(str_as_uuid_list)


class SourceDataIPv6PrefixListEntry(BaseModel):
Expand Down Expand Up @@ -847,6 +863,7 @@ class ActionSet(BaseModel):
DSCPEntry,
ExpandedCommunityListEntry,
GroupIDEntry,
ICMPMessageEntry,
NextHeaderEntry,
OMPTagEntry,
OriginatorEntry,
Expand Down

0 comments on commit 7dbabe5

Please sign in to comment.