diff --git a/catalystwan/api/templates/device_template/device_template.py b/catalystwan/api/templates/device_template/device_template.py index 85103afc..4cc6d95b 100644 --- a/catalystwan/api/templates/device_template/device_template.py +++ b/catalystwan/api/templates/device_template/device_template.py @@ -50,6 +50,26 @@ class DeviceTemplate(BaseModel): security_policy_id: str = Field(default="", alias="securityPolicyId") policy_id: str = Field(default="", alias="policyId") + def get_flattened_general_templates(self) -> List[GeneralTemplate]: + """ + Recursively flattens the general templates by removing the sub-templates + and returning a list of flattened templates. + + Returns: + A list of GeneralTemplate objects representing the flattened templates. + """ + + def flatten_general_templates(general_templates: List[GeneralTemplate]) -> List[GeneralTemplate]: + result = [] + for gt in general_templates: + sub_templates = gt.subTemplates + gt.subTemplates = [] + result.append(gt) + result.extend(flatten_general_templates(sub_templates)) + return result + + return flatten_general_templates(self.general_templates) + def generate_payload(self) -> str: env = Environment( loader=FileSystemLoader(self.payload_path.parent), diff --git a/catalystwan/tests/templates/test_device_template.py b/catalystwan/tests/templates/test_device_template.py new file mode 100644 index 00000000..937b117a --- /dev/null +++ b/catalystwan/tests/templates/test_device_template.py @@ -0,0 +1,48 @@ +import unittest + +from catalystwan.api.templates.device_template.device_template import DeviceTemplate, GeneralTemplate + + +class TestDeviceTemplate(unittest.TestCase): + def setUp(self): + self.device_template = DeviceTemplate( + template_name="DT-example", + template_description="DT-example", + device_role="None", + device_type="None", + security_policy_id="None", + policy_id="None", + generalTemplates=[ + GeneralTemplate( + name="1level", + templateId="1", + templateType="1", + subTemplates=[ + GeneralTemplate( + name="2level", + templateId="2", + templateType="2", + subTemplates=[GeneralTemplate(name="3level", templateId="3", templateType="3")], + ) + ], + ) + ], + ) + + def test_flatten_general_templates(self): + self.assertEqual( + self.device_template.get_flattened_general_templates(), + [ + GeneralTemplate( + name="1level", + templateId="1", + templateType="1", + ), + GeneralTemplate( + name="2level", + templateId="2", + templateType="2", + ), + GeneralTemplate(name="3level", templateId="3", templateType="3"), + ], + )