diff --git a/catalystwan/api/templates/feature_template.py b/catalystwan/api/templates/feature_template.py index 5bae8d12..c6d851c5 100644 --- a/catalystwan/api/templates/feature_template.py +++ b/catalystwan/api/templates/feature_template.py @@ -12,19 +12,18 @@ from catalystwan.api.templates.device_variable import DeviceVariable from catalystwan.utils.device_model import DeviceModel +from catalystwan.utils.dict import FlattenedDictValue, flatten_dict +from catalystwan.utils.feature_template.find_template_values import find_template_values from catalystwan.utils.pydantic_field import get_extra_field if TYPE_CHECKING: from catalystwan.session import ManagerSession - from catalystwan.utils.feature_template import FlattenedTemplateValue class FeatureTemplateValidator(BaseModel, ABC): @model_validator(mode="before") @classmethod - def map_fields(cls, values: Union[Any, Dict[str, Union[List[FlattenedTemplateValue], Any]]]): - from catalystwan.utils.feature_template import FlattenedTemplateValue - + def map_fields(cls, values: Union[Any, Dict[str, Union[List[FlattenedDictValue], Any]]]): if not isinstance(values, dict): return values for field_name, field_info in cls.model_fields.items(): @@ -39,7 +38,7 @@ def map_fields(cls, values: Union[Any, Dict[str, Union[List[FlattenedTemplateVal continue data_path = get_extra_field(field_info, "data_path", []) value = values.pop(payload_name) - if value and isinstance(value, list) and all([isinstance(v, FlattenedTemplateValue) for v in value]): + if value and isinstance(value, list) and all([isinstance(v, FlattenedDictValue) for v in value]): for template_value in value: if template_value.data_path == data_path: values[field_name] = template_value.value @@ -114,7 +113,7 @@ def get(cls, session: ManagerSession, name: str) -> FeatureTemplate: Returns: FeatureTemplate: filed out feature template model """ - from catalystwan.utils.feature_template import choose_model, find_template_values, flatten_template_definition + from catalystwan.utils.feature_template.choose_model import choose_model template_info = ( session.api.templates._get_feature_templates(summary=False).filter(name=name).single_or_default() @@ -127,7 +126,7 @@ def get(cls, session: ManagerSession, name: str) -> FeatureTemplate: values_from_template_definition = find_template_values( template_definition_as_dict, device_specific_variables=device_specific_variables ) - flattened_values = flatten_template_definition(values_from_template_definition) + flattened_values = flatten_dict(values_from_template_definition) return feature_template_model( template_name=template_info.name, diff --git a/catalystwan/api/templates/models/cisco_omp_model.py b/catalystwan/api/templates/models/cisco_omp_model.py index 2a65412a..baebfc8f 100644 --- a/catalystwan/api/templates/models/cisco_omp_model.py +++ b/catalystwan/api/templates/models/cisco_omp_model.py @@ -34,7 +34,7 @@ class Route(str, Enum): class IPv4Advertise(FeatureTemplateValidator): protocol: IPv4AdvertiseProtocol - route: Route + route: Optional[Route] = None class IPv6AdvertiseProtocol(str, Enum): diff --git a/catalystwan/api/templates/models/supported.py b/catalystwan/api/templates/models/supported.py index 43985bed..fc0ad1f2 100644 --- a/catalystwan/api/templates/models/supported.py +++ b/catalystwan/api/templates/models/supported.py @@ -5,6 +5,7 @@ from catalystwan.api.templates.models.cisco_bfd_model import CiscoBFDModel from catalystwan.api.templates.models.cisco_logging_model import CiscoLoggingModel from catalystwan.api.templates.models.cisco_ntp_model import CiscoNTPModel +from catalystwan.api.templates.models.cisco_omp_model import CiscoOMPModel from catalystwan.api.templates.models.cisco_ospf import CiscoOSPFModel from catalystwan.api.templates.models.cisco_secure_internet_gateway import CiscoSecureInternetGatewayModel from catalystwan.api.templates.models.cisco_snmp_model import CiscoSNMPModel @@ -31,4 +32,5 @@ "cisco_snmp": CiscoSNMPModel, "cisco_system": CiscoSystemModel, "cisco_secure_internet_gateway": CiscoSecureInternetGatewayModel, + "cisco_omp": CiscoOMPModel, } diff --git a/catalystwan/tests/templates/test_chose_model.py b/catalystwan/tests/templates/test_chose_model.py index af24a47c..fdf220a7 100644 --- a/catalystwan/tests/templates/test_chose_model.py +++ b/catalystwan/tests/templates/test_chose_model.py @@ -9,7 +9,7 @@ from catalystwan.api.templates.models.omp_vsmart_model import OMPvSmart from catalystwan.api.templates.models.security_vsmart_model import SecurityvSmart from catalystwan.api.templates.models.system_vsmart_model import SystemVsmart -from catalystwan.utils.feature_template import choose_model +from catalystwan.utils.feature_template.choose_model import choose_model class TestChooseModel(unittest.TestCase): diff --git a/catalystwan/utils/dict.py b/catalystwan/utils/dict.py index 393f7431..4bcb3e9f 100644 --- a/catalystwan/utils/dict.py +++ b/catalystwan/utils/dict.py @@ -1,5 +1,9 @@ # Copyright 2023 Cisco Systems, Inc. and its affiliates +from typing import Any, Dict, List + +from pydantic import BaseModel + def merge(a, b, path=None): if path is None: @@ -15,3 +19,39 @@ def merge(a, b, path=None): else: a[key] = b[key] return a + + +class FlattenedDictValue(BaseModel): + value: Any + data_path: List[str] + + +def flatten_dict(original_dict: Dict[str, Any]) -> Dict[str, List[FlattenedDictValue]]: + """ + Flattens a dictionary. + Each key corresponds to a list of FlattenedDictValue, allowing us to handle repeated keys in nesting. + """ + + def get_flattened_dict( + original_dict: Dict[str, Any], + flattened_dict: Dict[str, List[FlattenedDictValue]] = {}, + path: List[str] = [], + ): + for key, value in original_dict.items(): + if isinstance(value, dict): + get_flattened_dict(value, flattened_dict, path=path + [key]) + else: + if key not in flattened_dict: + flattened_dict[key] = [] + if isinstance(value, list) and all([isinstance(v, dict) for v in value]): + flattened_value = FlattenedDictValue( + value=[get_flattened_dict(v, {}) for v in value], data_path=path + ) + flattened_dict[key].append(flattened_value) + else: + flattened_dict[key].append(FlattenedDictValue(value=value, data_path=path)) + return flattened_dict + + flattened_dict: Dict[str, List[FlattenedDictValue]] = {} + get_flattened_dict(original_dict, flattened_dict) + return flattened_dict diff --git a/catalystwan/utils/feature_template/choose_model.py b/catalystwan/utils/feature_template/choose_model.py new file mode 100644 index 00000000..05bb8648 --- /dev/null +++ b/catalystwan/utils/feature_template/choose_model.py @@ -0,0 +1,28 @@ +from typing import Any + +from catalystwan.api.templates.models.supported import available_models +from catalystwan.exceptions import TemplateTypeError + + +def choose_model(type_value: str) -> Any: + """Chooses correct model based on provided type + + With provided type of feature template searches supported by catalystwan models + and returns correct for given type of feature template class. + + Args: + type_value: type of feature template + + Returns: + model + + Raises: + TemplateTypeError: Raises when the model is not supported by catalystwan. + """ + if type_value not in available_models: + for model in available_models.values(): + if model.type == type_value: # type: ignore + return model + raise TemplateTypeError(f"Feature template type '{type_value}' is not supported.") + + return available_models[type_value] diff --git a/catalystwan/utils/feature_template/find_template_values.py b/catalystwan/utils/feature_template/find_template_values.py new file mode 100644 index 00000000..b88667ca --- /dev/null +++ b/catalystwan/utils/feature_template/find_template_values.py @@ -0,0 +1,77 @@ +from typing import Dict, List, Optional, Union + +from catalystwan.api.templates.device_variable import DeviceVariable + + +def find_template_values( + template_definition: dict, + templated_values: dict = {}, + target_key: str = "vipType", + target_key_value_to_ignore: str = "ignore", + target_key_for_template_value: str = "vipValue", + device_specific_variables: Optional[Dict[str, DeviceVariable]] = None, + path: Optional[List[str]] = None, +) -> Dict[str, Union[str, list, dict]]: + """Based on provided template definition generates a dictionary with template fields and values + + Args: + template_definition: template definition provided as dict + templated_values: dictionary, empty at the beginning and filed out with names of fields as keys + and values of those fields as values + target_key: name of the key specifying if field is used in template, defaults to 'vipType' + target_key_value_to_ignore: value of the target key indicating + that field is not used in template, defaults to 'ignore' + target_key_for_template_value: name of the key specifying value of field used in template, + defaults to 'vipValue' + path: a list of keys indicating current path, defaults to None + Returns: + templated_values: dictionary containing template fields as key and values assigned to those fields as values + """ + if path is None: + path = [] + + # if value object is reached, try to extract the value + if target_key in template_definition: + if template_definition[target_key] == target_key_value_to_ignore: + return templated_values + + value = template_definition[target_key] + template_value = template_definition[target_key_for_template_value] + + field_key = path[-1] + # TODO: Handle nested DeviceVariable + if value == "variableName" and (device_specific_variables is not None): + device_specific_variables[field_key] = DeviceVariable(name=template_definition["vipVariableName"]) + elif template_definition["vipObjectType"] != "tree": + current_nesting = get_nested_dict(templated_values, path[:-1]) + current_nesting[field_key] = template_value + elif isinstance(template_value, dict): + find_template_values( + value, templated_values, device_specific_variables=device_specific_variables, path=path + ) + elif isinstance(template_value, list): + current_nesting = get_nested_dict(templated_values, path[:-1]) + current_nesting[field_key] = [] + for item in template_value: + current_nesting[field_key].append( + find_template_values(item, {}, device_specific_variables=device_specific_variables) + ) + + return templated_values + + # iterate the dict to extract values and assign them to their fields + for key, value in template_definition.items(): + if isinstance(value, dict) and value != target_key_value_to_ignore: + find_template_values( + value, templated_values, device_specific_variables=device_specific_variables, path=path + [key] + ) + return templated_values + + +def get_nested_dict(d: dict, path: List[str], populate: bool = True): + current_dict = d + for path_key in path: + if path_key not in current_dict and populate: + current_dict[path_key] = {} + current_dict = current_dict[path_key] + return current_dict