From 711759a90c945e7b96a8de4ab9743650554c1236 Mon Sep 17 00:00:00 2001 From: Kirill Date: Tue, 28 Nov 2023 16:50:21 +0400 Subject: [PATCH] feat: Deleted post_process methods and added new DRY method --- .../resoto_plugin_azure/resource/base.py | 31 ++- .../resoto_plugin_azure/resource/network.py | 205 +++++++----------- 2 files changed, 110 insertions(+), 126 deletions(-) diff --git a/plugins/azure/resoto_plugin_azure/resource/base.py b/plugins/azure/resoto_plugin_azure/resource/base.py index b7f267bf25..f8cb53a773 100644 --- a/plugins/azure/resoto_plugin_azure/resource/base.py +++ b/plugins/azure/resoto_plugin_azure/resource/base.py @@ -4,7 +4,7 @@ from concurrent.futures import Future from datetime import datetime from threading import Lock -from typing import Any, ClassVar, Dict, Optional, TypeVar, List, Type, Callable, cast +from typing import Any, ClassVar, Dict, Optional, TypeVar, List, Tuple, Type, Callable, Union, cast from attr import define, field from azure.core.utils import CaseInsensitiveDict @@ -101,6 +101,35 @@ def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None: # Default behavior: add resource to the namespace pass + def fetch_resources( + self, + builder: GraphBuilder, + service: str, + api_version: str, + path: str, + path_parameters: List[str], + query_parameters: List[str], + compared_property: Callable[[Json], Union[List[str], str]], + binding_property: Callable[[Json], str], + ) -> List[Tuple[Union[str, List[str]], str]]: + """ + Fetch additional resources from the Azure API for further connection using the connect_in_graph method. + + Returns: + List[Tuple[Union[str, List[str]], str]]: A list of tuples containing information to compare and connect the retrieved resources. + """ + resources_api_spec = AzureApiSpec( + service=service, + version=api_version, + path=path, + path_parameters=path_parameters, + query_parameters=query_parameters, + access_path="value", + expect_array=True, + ) + + return [(compared_property(r), binding_property(r)) for r in builder.client.list(resources_api_spec)] + @classmethod def collect_resources( cls: Type[AzureResourceType], builder: GraphBuilder, **kwargs: Any diff --git a/plugins/azure/resoto_plugin_azure/resource/network.py b/plugins/azure/resoto_plugin_azure/resource/network.py index bc03707a2c..a59ed4b1d7 100644 --- a/plugins/azure/resoto_plugin_azure/resource/network.py +++ b/plugins/azure/resoto_plugin_azure/resource/network.py @@ -1,4 +1,4 @@ -from typing import ClassVar, Dict, Optional, List, Type, Tuple +from typing import Callable, ClassVar, Dict, Optional, List, Type from attr import define, field @@ -3145,33 +3145,25 @@ class AzureExpressRouteCircuit(AzureResource): service_provider_provisioning_state: Optional[str] = field(default=None, metadata={'description': 'The ServiceProviderProvisioningState state of the resource.'}) # fmt: skip azure_sku: Optional[AzureSku] = field(default=None, metadata={'description': 'Contains SKU in an ExpressRouteCircuit.'}) # fmt: skip stag: Optional[int] = field(default=None, metadata={'description': 'The identifier of the circuit traffic. Outer tag for QinQ encapsulation.'}) # fmt: skip - _ids_and_names_in_resource: Optional[List[Tuple[str, str]]] = None - - def post_process(self, graph_builder: GraphBuilder, source: Json) -> None: - def collect_ids_and_names() -> None: - resources_api_spec = AzureApiSpec( - service="network", - version="2023-05-01", - path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/ExpressRoutePortsLocations", - path_parameters=["subscriptionId"], - query_parameters=["api-version"], - access_path="value", - expect_array=True, - ) - - self._ids_and_names_in_resource = [ - (r["name"], r["id"]) for r in graph_builder.client.list(resources_api_spec) - ] - - graph_builder.submit_work(collect_ids_and_names) def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None: + ids_and_names_in_resource = self.fetch_resources( + builder, + service="network", + api_version="2023-05-01", + path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/networkInterfaces", + path_parameters=["subscriptionId"], + query_parameters=["api-version"], + compared_property=lambda r: r["name"], + binding_property=lambda r: r["id"], + ) + if route_port_id := self.express_route_port: builder.add_edge(self, edge_type=EdgeType.default, clazz=AzureExpressRoutePort, id=route_port_id) if ( (provider_properties := self.service_provider_properties) and (location_name := provider_properties.peering_location) - and (names_and_ids := self._ids_and_names_in_resource) + and (names_and_ids := ids_and_names_in_resource) ): for info in names_and_ids: erplocation, erplocation_id = info @@ -3849,29 +3841,20 @@ class AzureIpGroup(AzureResource): firewalls: Optional[List[str]] = field(default=None, metadata={'description': 'List of references to Firewall resources that this IpGroups is associated with.'}) # fmt: skip ip_addresses: Optional[List[str]] = field(default=None, metadata={'description': 'IpAddresses/IpAddressPrefixes in the IpGroups resource.'}) # fmt: skip provisioning_state: Optional[str] = field(default=None, metadata={'description': 'The current provisioning state.'}) # fmt: skip - _virtual_networks: Optional[List[Tuple[List[str], str]]] = None - - def post_process(self, graph_builder: GraphBuilder, source: Json) -> None: - def collect_vns() -> None: - resources_api_spec = AzureApiSpec( - service="network", - version="2023-05-01", - path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/virtualNetworks", - path_parameters=["subscriptionId"], - query_parameters=["api-version"], - access_path="value", - expect_array=True, - ) - - self._virtual_networks = [ - (r["properties"]["addressSpace"]["addressPrefixes"], r["id"]) - for r in graph_builder.client.list(resources_api_spec) - ] - - graph_builder.submit_work(collect_vns) def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None: - if (ip_addresses := self.ip_addresses) and (vns := self._virtual_networks): + virtual_networks = self.fetch_resources( + builder, + service="network", + api_version="2023-05-01", + path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/virtualNetworks", + path_parameters=["subscriptionId"], + query_parameters=["api-version"], + compared_property=lambda r: r["properties"]["addressSpace"]["addressPrefixes"], + binding_property=lambda r: r["id"], + ) + + if (ip_addresses := self.ip_addresses) and (vns := virtual_networks): for ip_address in ip_addresses: for info in vns: vn_ips, vn_id = info @@ -4274,32 +4257,24 @@ class AzureNetworkProfile(AzureResource): etag: Optional[str] = field(default=None, metadata={'description': 'A unique read-only string that changes whenever the resource is updated.'}) # fmt: skip provisioning_state: Optional[str] = field(default=None, metadata={'description': 'The current provisioning state.'}) # fmt: skip resource_guid: Optional[str] = field(default=None, metadata={'description': 'The resource GUID property of the network profile resource.'}) # fmt: skip - _network_interfaces_and_vm_ids: Optional[List[Tuple[str, str]]] = None - - def post_process(self, graph_builder: GraphBuilder, source: Json) -> None: - def collect_nis_and_vms() -> None: - resources_api_spec = AzureApiSpec( - service="network", - version="2023-05-01", - path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/networkInterfaces", - path_parameters=["subscriptionId"], - query_parameters=["api-version"], - access_path="value", - expect_array=True, - ) - - self._network_interfaces_and_vm_ids = [ - (r["id"], r["properties"]["virtualMachine"]["id"]) - for r in graph_builder.client.list(resources_api_spec) - if "properties" in r and "virtualMachine" in r["properties"] - ] - - graph_builder.submit_work(collect_nis_and_vms) def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None: # Import placed inside the method due to circular import error resolution from resoto_plugin_azure.resource.compute import AzureVirtualMachine # pylint: disable=import-outside-toplevel + network_interfaces_and_vm_ids = self.fetch_resources( + builder, + service="network", + api_version="2023-05-01", + path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/networkInterfaces", + path_parameters=["subscriptionId"], + query_parameters=["api-version"], + compared_property=lambda r: r["id"], + binding_property=lambda r: r["properties"]["virtualMachine"]["id"] + if "properties" in r and "virtualMachine" in r["properties"] + else "", + ) + if container_nic := self.container_network_interface_configurations: for container in container_nic: if ip_configurations := container.ip_configurations: @@ -4311,7 +4286,7 @@ def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None: ) if (container_ni_ids := container.container_network_interfaces) and ( - ni_ids_and_vm_ids := self._network_interfaces_and_vm_ids + ni_ids_and_vm_ids := network_interfaces_and_vm_ids ): for ni_id in container_ni_ids: for info in ni_ids_and_vm_ids: @@ -4447,27 +4422,20 @@ class AzureNetworkVirtualAppliance(AzureResource): virtual_appliance_nics: Optional[List[AzureVirtualApplianceNicProperties]] = field(default=None, metadata={'description': 'List of Virtual Appliance Network Interfaces.'}) # fmt: skip virtual_appliance_sites: Optional[List[str]] = field(default=None, metadata={'description': 'List of references to VirtualApplianceSite.'}) # fmt: skip virtual_hub: Optional[str] = field(default=None, metadata={"description": "Reference to another subresource."}) - _vendors_in_resource: Optional[List[Tuple[str, str]]] = None def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None: - def collect_vendors() -> None: - resources_api_spec = AzureApiSpec( - service="network", - version="2023-05-01", - path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/networkVirtualApplianceSkus", - path_parameters=["subscriptionId"], - query_parameters=["api-version"], - access_path="value", - expect_array=True, - ) - - self._vendors_in_resource = [ - (r["properties"]["vendor"], r["id"]) for r in builder.client.list(resources_api_spec) - ] - - builder.submit_work(collect_vendors) + vendors_in_resource = self.fetch_resources( + builder, + service="network", + api_version="2023-05-01", + path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/networkVirtualApplianceSkus", + path_parameters=["subscriptionId"], + query_parameters=["api-version"], + compared_property=lambda r: r["properties"]["vendor"], + binding_property=lambda r: r["id"], + ) - if (nva := self.nva_sku) and (nva_vendor := nva.vendor) and (vendors := self._vendors_in_resource): + if (nva := self.nva_sku) and (nva_vendor := nva.vendor) and (vendors := vendors_in_resource): for info in vendors: vendor_name, nvasku_id = info if vendor_name == nva_vendor: @@ -4550,28 +4518,20 @@ class AzureNetworkWatcher(AzureResource): etag: Optional[str] = field(default=None, metadata={'description': 'A unique read-only string that changes whenever the resource is updated.'}) # fmt: skip properties: Optional[str] = field(default=None, metadata={"description": "The network watcher properties."}) location: Optional[str] = field(default=None, metadata={"description": "Resource location."}) - _locations_and_ids_in_vn: Optional[List[Tuple[str, str]]] = None - - def post_process(self, graph_builder: GraphBuilder, source: Json) -> None: - def collect_vns_location() -> None: - resources_api_spec = AzureApiSpec( - service="network", - version="2023-05-01", - path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/virtualNetworks", - path_parameters=["subscriptionId"], - query_parameters=["api-version"], - access_path="value", - expect_array=True, - ) - - self._locations_and_ids_in_vn = [ - (r["location"], r["id"]) for r in graph_builder.client.list(resources_api_spec) - ] - - graph_builder.submit_work(collect_vns_location) def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None: - if (nw_location := self.location) and (vns_info := self._locations_and_ids_in_vn): + locations_and_ids_in_vn = self.fetch_resources( + builder, + "network", + "2023-05-01", + "/subscriptions/{subscriptionId}/providers/Microsoft.Network/virtualNetworks", + ["subscriptionId"], + ["api-version"], + lambda r: r["location"], + lambda r: r["id"], + ) + + if (nw_location := self.location) and (vns_info := locations_and_ids_in_vn): for info in vns_info: vn_location, vn_id = info if vn_location == nw_location: @@ -5035,30 +4995,25 @@ class AzureVirtualHub(AzureResource): virtual_router_ips: Optional[List[str]] = field(default=None, metadata={"description": "VirtualRouter IPs."}) virtual_wan: Optional[str] = field(default=None, metadata={"description": "Reference to another subresource."}) vpn_gateway: Optional[str] = field(default=None, metadata={"description": "Reference to another subresource."}) - _p_ip_addresses_ip_c_ids: Optional[List[Tuple[str, str]]] = None - - def post_process(self, graph_builder: GraphBuilder, source: Json) -> None: - def collect_ip_info() -> None: - resources_api_spec = AzureApiSpec( - service="network", - version="2023-07-01", - path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/networkInterfaces", - path_parameters=["subscriptionId"], - query_parameters=["api-version"], - access_path="value", - expect_array=True, - ) - - self._p_ip_addresses_ip_c_ids = [ - (ip_config["properties"]["publicIPAddress"]["id"], nic["id"]) - for nic in graph_builder.client.list(resources_api_spec) - for ip_config in nic.get("properties", {}).get("ipConfigurations", []) - if "publicIPAddress" in ip_config.get("properties", {}) - ] - - graph_builder.submit_work(collect_ip_info) def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None: + compared_property: Callable[[Json], List[str]] = lambda r: [ # pylint: disable=unnecessary-lambda-assignment + ip_config["properties"]["publicIPAddress"]["id"] + for ip_config in r.get("properties", {}).get("ipConfigurations", []) + if "publicIPAddress" in ip_config.get("properties", {}) + ] + + p_ip_addresses_ip_c_ids = self.fetch_resources( + builder, + service="network", + api_version="2023-05-01", + path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/networkInterfaces", + path_parameters=["subscriptionId"], + query_parameters=["api-version"], + compared_property=compared_property, + binding_property=lambda r: r["id"], + ) + if er_gateway_id := self.express_route_gateway: builder.add_edge( self, edge_type=EdgeType.default, reverse=True, clazz=AzureExpressRouteGateway, id=er_gateway_id @@ -5067,7 +5022,7 @@ def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None: builder.add_edge(self, edge_type=EdgeType.default, reverse=True, clazz=AzureVpnGateway, id=vpn_gateway_id) if vw_id := self.virtual_wan: builder.add_edge(self, edge_type=EdgeType.default, reverse=True, clazz=AzureVirtualWAN, id=vw_id) - if (ip_config_ids := self.ip_configuration_ids) and (p_ip_a_and_ip_conf_ids := self._p_ip_addresses_ip_c_ids): + if (ip_config_ids := self.ip_configuration_ids) and (p_ip_a_and_ip_conf_ids := p_ip_addresses_ip_c_ids): for ip_config_id in ip_config_ids: for info in p_ip_a_and_ip_conf_ids: p_ip_address_id, collected_ip_conf_id = info