Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
PrzeG committed Mar 5, 2024
1 parent b29afb8 commit b1f6982
Show file tree
Hide file tree
Showing 7 changed files with 155 additions and 9 deletions.
13 changes: 6 additions & 7 deletions catalystwan/api/templates/feature_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,18 @@

from catalystwan.api.templates.device_variable import DeviceVariable
from catalystwan.utils.device_model import DeviceModel
from catalystwan.utils.dict import FlattenedDictValue, flatten_dict
from catalystwan.utils.feature_template.find_template_values import find_template_values
from catalystwan.utils.pydantic_field import get_extra_field

if TYPE_CHECKING:
from catalystwan.session import ManagerSession
from catalystwan.utils.feature_template import FlattenedTemplateValue


class FeatureTemplateValidator(BaseModel, ABC):
@model_validator(mode="before")
@classmethod
def map_fields(cls, values: Union[Any, Dict[str, Union[List[FlattenedTemplateValue], Any]]]):
from catalystwan.utils.feature_template import FlattenedTemplateValue

def map_fields(cls, values: Union[Any, Dict[str, Union[List[FlattenedDictValue], Any]]]):
if not isinstance(values, dict):
return values
for field_name, field_info in cls.model_fields.items():
Expand All @@ -39,7 +38,7 @@ def map_fields(cls, values: Union[Any, Dict[str, Union[List[FlattenedTemplateVal
continue
data_path = get_extra_field(field_info, "data_path", [])
value = values.pop(payload_name)
if value and isinstance(value, list) and all([isinstance(v, FlattenedTemplateValue) for v in value]):
if value and isinstance(value, list) and all([isinstance(v, FlattenedDictValue) for v in value]):
for template_value in value:
if template_value.data_path == data_path:
values[field_name] = template_value.value
Expand Down Expand Up @@ -114,7 +113,7 @@ def get(cls, session: ManagerSession, name: str) -> FeatureTemplate:
Returns:
FeatureTemplate: filed out feature template model
"""
from catalystwan.utils.feature_template import choose_model, find_template_values, flatten_template_definition
from catalystwan.utils.feature_template.choose_model import choose_model

template_info = (
session.api.templates._get_feature_templates(summary=False).filter(name=name).single_or_default()
Expand All @@ -127,7 +126,7 @@ def get(cls, session: ManagerSession, name: str) -> FeatureTemplate:
values_from_template_definition = find_template_values(
template_definition_as_dict, device_specific_variables=device_specific_variables
)
flattened_values = flatten_template_definition(values_from_template_definition)
flattened_values = flatten_dict(values_from_template_definition)

return feature_template_model(
template_name=template_info.name,
Expand Down
2 changes: 1 addition & 1 deletion catalystwan/api/templates/models/cisco_omp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Route(str, Enum):

class IPv4Advertise(FeatureTemplateValidator):
protocol: IPv4AdvertiseProtocol
route: Route
route: Optional[Route] = None


class IPv6AdvertiseProtocol(str, Enum):
Expand Down
2 changes: 2 additions & 0 deletions catalystwan/api/templates/models/supported.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from catalystwan.api.templates.models.cisco_bfd_model import CiscoBFDModel
from catalystwan.api.templates.models.cisco_logging_model import CiscoLoggingModel
from catalystwan.api.templates.models.cisco_ntp_model import CiscoNTPModel
from catalystwan.api.templates.models.cisco_omp_model import CiscoOMPModel
from catalystwan.api.templates.models.cisco_ospf import CiscoOSPFModel
from catalystwan.api.templates.models.cisco_secure_internet_gateway import CiscoSecureInternetGatewayModel
from catalystwan.api.templates.models.cisco_snmp_model import CiscoSNMPModel
Expand All @@ -31,4 +32,5 @@
"cisco_snmp": CiscoSNMPModel,
"cisco_system": CiscoSystemModel,
"cisco_secure_internet_gateway": CiscoSecureInternetGatewayModel,
"cisco_omp": CiscoOMPModel,
}
2 changes: 1 addition & 1 deletion catalystwan/tests/templates/test_chose_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from catalystwan.api.templates.models.omp_vsmart_model import OMPvSmart
from catalystwan.api.templates.models.security_vsmart_model import SecurityvSmart
from catalystwan.api.templates.models.system_vsmart_model import SystemVsmart
from catalystwan.utils.feature_template import choose_model
from catalystwan.utils.feature_template.choose_model import choose_model


class TestChooseModel(unittest.TestCase):
Expand Down
40 changes: 40 additions & 0 deletions catalystwan/utils/dict.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Copyright 2023 Cisco Systems, Inc. and its affiliates

from typing import Any, Dict, List

from pydantic import BaseModel


def merge(a, b, path=None):
if path is None:
Expand All @@ -15,3 +19,39 @@ def merge(a, b, path=None):
else:
a[key] = b[key]
return a


class FlattenedDictValue(BaseModel):
value: Any
data_path: List[str]


def flatten_dict(original_dict: Dict[str, Any]) -> Dict[str, List[FlattenedDictValue]]:
"""
Flattens a dictionary.
Each key corresponds to a list of FlattenedDictValue, allowing us to handle repeated keys in nesting.
"""

def get_flattened_dict(
original_dict: Dict[str, Any],
flattened_dict: Dict[str, List[FlattenedDictValue]] = {},
path: List[str] = [],
):
for key, value in original_dict.items():
if isinstance(value, dict):
get_flattened_dict(value, flattened_dict, path=path + [key])
else:
if key not in flattened_dict:
flattened_dict[key] = []
if isinstance(value, list) and all([isinstance(v, dict) for v in value]):
flattened_value = FlattenedDictValue(
value=[get_flattened_dict(v, {}) for v in value], data_path=path
)
flattened_dict[key].append(flattened_value)
else:
flattened_dict[key].append(FlattenedDictValue(value=value, data_path=path))
return flattened_dict

flattened_dict: Dict[str, List[FlattenedDictValue]] = {}
get_flattened_dict(original_dict, flattened_dict)
return flattened_dict
28 changes: 28 additions & 0 deletions catalystwan/utils/feature_template/choose_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any

from catalystwan.api.templates.models.supported import available_models
from catalystwan.exceptions import TemplateTypeError


def choose_model(type_value: str) -> Any:
"""Chooses correct model based on provided type
With provided type of feature template searches supported by catalystwan models
and returns correct for given type of feature template class.
Args:
type_value: type of feature template
Returns:
model
Raises:
TemplateTypeError: Raises when the model is not supported by catalystwan.
"""
if type_value not in available_models:
for model in available_models.values():
if model.type == type_value: # type: ignore
return model
raise TemplateTypeError(f"Feature template type '{type_value}' is not supported.")

return available_models[type_value]
77 changes: 77 additions & 0 deletions catalystwan/utils/feature_template/find_template_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import Dict, List, Optional, Union

from catalystwan.api.templates.device_variable import DeviceVariable


def find_template_values(
template_definition: dict,
templated_values: dict = {},
target_key: str = "vipType",
target_key_value_to_ignore: str = "ignore",
target_key_for_template_value: str = "vipValue",
device_specific_variables: Optional[Dict[str, DeviceVariable]] = None,
path: Optional[List[str]] = None,
) -> Dict[str, Union[str, list, dict]]:
"""Based on provided template definition generates a dictionary with template fields and values
Args:
template_definition: template definition provided as dict
templated_values: dictionary, empty at the beginning and filed out with names of fields as keys
and values of those fields as values
target_key: name of the key specifying if field is used in template, defaults to 'vipType'
target_key_value_to_ignore: value of the target key indicating
that field is not used in template, defaults to 'ignore'
target_key_for_template_value: name of the key specifying value of field used in template,
defaults to 'vipValue'
path: a list of keys indicating current path, defaults to None
Returns:
templated_values: dictionary containing template fields as key and values assigned to those fields as values
"""
if path is None:
path = []

# if value object is reached, try to extract the value
if target_key in template_definition:
if template_definition[target_key] == target_key_value_to_ignore:
return templated_values

value = template_definition[target_key]
template_value = template_definition[target_key_for_template_value]

field_key = path[-1]
# TODO: Handle nested DeviceVariable
if value == "variableName" and (device_specific_variables is not None):
device_specific_variables[field_key] = DeviceVariable(name=template_definition["vipVariableName"])
elif template_definition["vipObjectType"] != "tree":
current_nesting = get_nested_dict(templated_values, path[:-1])
current_nesting[field_key] = template_value
elif isinstance(template_value, dict):
find_template_values(
value, templated_values, device_specific_variables=device_specific_variables, path=path
)
elif isinstance(template_value, list):
current_nesting = get_nested_dict(templated_values, path[:-1])
current_nesting[field_key] = []
for item in template_value:
current_nesting[field_key].append(
find_template_values(item, {}, device_specific_variables=device_specific_variables)
)

return templated_values

# iterate the dict to extract values and assign them to their fields
for key, value in template_definition.items():
if isinstance(value, dict) and value != target_key_value_to_ignore:
find_template_values(
value, templated_values, device_specific_variables=device_specific_variables, path=path + [key]
)
return templated_values


def get_nested_dict(d: dict, path: List[str], populate: bool = True):
current_dict = d
for path_key in path:
if path_key not in current_dict and populate:
current_dict[path_key] = {}
current_dict = current_dict[path_key]
return current_dict

0 comments on commit b1f6982

Please sign in to comment.