Skip to content

Commit

Permalink
feat: Deleted post_process methods and added new DRY method
Browse files Browse the repository at this point in the history
  • Loading branch information
1101-1 committed Nov 28, 2023
1 parent acb7471 commit 711759a
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 126 deletions.
31 changes: 30 additions & 1 deletion plugins/azure/resoto_plugin_azure/resource/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from concurrent.futures import Future
from datetime import datetime
from threading import Lock
from typing import Any, ClassVar, Dict, Optional, TypeVar, List, Type, Callable, cast
from typing import Any, ClassVar, Dict, Optional, TypeVar, List, Tuple, Type, Callable, Union, cast

from attr import define, field
from azure.core.utils import CaseInsensitiveDict
Expand Down Expand Up @@ -101,6 +101,35 @@ def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
# Default behavior: add resource to the namespace
pass

def fetch_resources(
self,
builder: GraphBuilder,
service: str,
api_version: str,
path: str,
path_parameters: List[str],
query_parameters: List[str],
compared_property: Callable[[Json], Union[List[str], str]],
binding_property: Callable[[Json], str],
) -> List[Tuple[Union[str, List[str]], str]]:
"""
Fetch additional resources from the Azure API for further connection using the connect_in_graph method.
Returns:
List[Tuple[Union[str, List[str]], str]]: A list of tuples containing information to compare and connect the retrieved resources.
"""
resources_api_spec = AzureApiSpec(
service=service,
version=api_version,
path=path,
path_parameters=path_parameters,
query_parameters=query_parameters,
access_path="value",
expect_array=True,
)

return [(compared_property(r), binding_property(r)) for r in builder.client.list(resources_api_spec)]

@classmethod
def collect_resources(
cls: Type[AzureResourceType], builder: GraphBuilder, **kwargs: Any
Expand Down
205 changes: 80 additions & 125 deletions plugins/azure/resoto_plugin_azure/resource/network.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import ClassVar, Dict, Optional, List, Type, Tuple
from typing import Callable, ClassVar, Dict, Optional, List, Type

from attr import define, field

Expand Down Expand Up @@ -3145,33 +3145,25 @@ class AzureExpressRouteCircuit(AzureResource):
service_provider_provisioning_state: Optional[str] = field(default=None, metadata={'description': 'The ServiceProviderProvisioningState state of the resource.'}) # fmt: skip
azure_sku: Optional[AzureSku] = field(default=None, metadata={'description': 'Contains SKU in an ExpressRouteCircuit.'}) # fmt: skip
stag: Optional[int] = field(default=None, metadata={'description': 'The identifier of the circuit traffic. Outer tag for QinQ encapsulation.'}) # fmt: skip
_ids_and_names_in_resource: Optional[List[Tuple[str, str]]] = None

def post_process(self, graph_builder: GraphBuilder, source: Json) -> None:
def collect_ids_and_names() -> None:
resources_api_spec = AzureApiSpec(
service="network",
version="2023-05-01",
path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/ExpressRoutePortsLocations",
path_parameters=["subscriptionId"],
query_parameters=["api-version"],
access_path="value",
expect_array=True,
)

self._ids_and_names_in_resource = [
(r["name"], r["id"]) for r in graph_builder.client.list(resources_api_spec)
]

graph_builder.submit_work(collect_ids_and_names)

def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
ids_and_names_in_resource = self.fetch_resources(
builder,
service="network",
api_version="2023-05-01",
path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/networkInterfaces",
path_parameters=["subscriptionId"],
query_parameters=["api-version"],
compared_property=lambda r: r["name"],
binding_property=lambda r: r["id"],
)

if route_port_id := self.express_route_port:
builder.add_edge(self, edge_type=EdgeType.default, clazz=AzureExpressRoutePort, id=route_port_id)
if (
(provider_properties := self.service_provider_properties)
and (location_name := provider_properties.peering_location)
and (names_and_ids := self._ids_and_names_in_resource)
and (names_and_ids := ids_and_names_in_resource)
):
for info in names_and_ids:
erplocation, erplocation_id = info
Expand Down Expand Up @@ -3849,29 +3841,20 @@ class AzureIpGroup(AzureResource):
firewalls: Optional[List[str]] = field(default=None, metadata={'description': 'List of references to Firewall resources that this IpGroups is associated with.'}) # fmt: skip
ip_addresses: Optional[List[str]] = field(default=None, metadata={'description': 'IpAddresses/IpAddressPrefixes in the IpGroups resource.'}) # fmt: skip
provisioning_state: Optional[str] = field(default=None, metadata={'description': 'The current provisioning state.'}) # fmt: skip
_virtual_networks: Optional[List[Tuple[List[str], str]]] = None

def post_process(self, graph_builder: GraphBuilder, source: Json) -> None:
def collect_vns() -> None:
resources_api_spec = AzureApiSpec(
service="network",
version="2023-05-01",
path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/virtualNetworks",
path_parameters=["subscriptionId"],
query_parameters=["api-version"],
access_path="value",
expect_array=True,
)

self._virtual_networks = [
(r["properties"]["addressSpace"]["addressPrefixes"], r["id"])
for r in graph_builder.client.list(resources_api_spec)
]

graph_builder.submit_work(collect_vns)

def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
if (ip_addresses := self.ip_addresses) and (vns := self._virtual_networks):
virtual_networks = self.fetch_resources(
builder,
service="network",
api_version="2023-05-01",
path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/virtualNetworks",
path_parameters=["subscriptionId"],
query_parameters=["api-version"],
compared_property=lambda r: r["properties"]["addressSpace"]["addressPrefixes"],
binding_property=lambda r: r["id"],
)

if (ip_addresses := self.ip_addresses) and (vns := virtual_networks):
for ip_address in ip_addresses:
for info in vns:
vn_ips, vn_id = info
Expand Down Expand Up @@ -4274,32 +4257,24 @@ class AzureNetworkProfile(AzureResource):
etag: Optional[str] = field(default=None, metadata={'description': 'A unique read-only string that changes whenever the resource is updated.'}) # fmt: skip
provisioning_state: Optional[str] = field(default=None, metadata={'description': 'The current provisioning state.'}) # fmt: skip
resource_guid: Optional[str] = field(default=None, metadata={'description': 'The resource GUID property of the network profile resource.'}) # fmt: skip
_network_interfaces_and_vm_ids: Optional[List[Tuple[str, str]]] = None

def post_process(self, graph_builder: GraphBuilder, source: Json) -> None:
def collect_nis_and_vms() -> None:
resources_api_spec = AzureApiSpec(
service="network",
version="2023-05-01",
path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/networkInterfaces",
path_parameters=["subscriptionId"],
query_parameters=["api-version"],
access_path="value",
expect_array=True,
)

self._network_interfaces_and_vm_ids = [
(r["id"], r["properties"]["virtualMachine"]["id"])
for r in graph_builder.client.list(resources_api_spec)
if "properties" in r and "virtualMachine" in r["properties"]
]

graph_builder.submit_work(collect_nis_and_vms)

def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
# Import placed inside the method due to circular import error resolution
from resoto_plugin_azure.resource.compute import AzureVirtualMachine # pylint: disable=import-outside-toplevel

network_interfaces_and_vm_ids = self.fetch_resources(
builder,
service="network",
api_version="2023-05-01",
path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/networkInterfaces",
path_parameters=["subscriptionId"],
query_parameters=["api-version"],
compared_property=lambda r: r["id"],
binding_property=lambda r: r["properties"]["virtualMachine"]["id"]
if "properties" in r and "virtualMachine" in r["properties"]
else "",
)

if container_nic := self.container_network_interface_configurations:
for container in container_nic:
if ip_configurations := container.ip_configurations:
Expand All @@ -4311,7 +4286,7 @@ def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
)

if (container_ni_ids := container.container_network_interfaces) and (
ni_ids_and_vm_ids := self._network_interfaces_and_vm_ids
ni_ids_and_vm_ids := network_interfaces_and_vm_ids
):
for ni_id in container_ni_ids:
for info in ni_ids_and_vm_ids:
Expand Down Expand Up @@ -4447,27 +4422,20 @@ class AzureNetworkVirtualAppliance(AzureResource):
virtual_appliance_nics: Optional[List[AzureVirtualApplianceNicProperties]] = field(default=None, metadata={'description': 'List of Virtual Appliance Network Interfaces.'}) # fmt: skip
virtual_appliance_sites: Optional[List[str]] = field(default=None, metadata={'description': 'List of references to VirtualApplianceSite.'}) # fmt: skip
virtual_hub: Optional[str] = field(default=None, metadata={"description": "Reference to another subresource."})
_vendors_in_resource: Optional[List[Tuple[str, str]]] = None

def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
def collect_vendors() -> None:
resources_api_spec = AzureApiSpec(
service="network",
version="2023-05-01",
path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/networkVirtualApplianceSkus",
path_parameters=["subscriptionId"],
query_parameters=["api-version"],
access_path="value",
expect_array=True,
)

self._vendors_in_resource = [
(r["properties"]["vendor"], r["id"]) for r in builder.client.list(resources_api_spec)
]

builder.submit_work(collect_vendors)
vendors_in_resource = self.fetch_resources(
builder,
service="network",
api_version="2023-05-01",
path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/networkVirtualApplianceSkus",
path_parameters=["subscriptionId"],
query_parameters=["api-version"],
compared_property=lambda r: r["properties"]["vendor"],
binding_property=lambda r: r["id"],
)

if (nva := self.nva_sku) and (nva_vendor := nva.vendor) and (vendors := self._vendors_in_resource):
if (nva := self.nva_sku) and (nva_vendor := nva.vendor) and (vendors := vendors_in_resource):
for info in vendors:
vendor_name, nvasku_id = info
if vendor_name == nva_vendor:
Expand Down Expand Up @@ -4550,28 +4518,20 @@ class AzureNetworkWatcher(AzureResource):
etag: Optional[str] = field(default=None, metadata={'description': 'A unique read-only string that changes whenever the resource is updated.'}) # fmt: skip
properties: Optional[str] = field(default=None, metadata={"description": "The network watcher properties."})
location: Optional[str] = field(default=None, metadata={"description": "Resource location."})
_locations_and_ids_in_vn: Optional[List[Tuple[str, str]]] = None

def post_process(self, graph_builder: GraphBuilder, source: Json) -> None:
def collect_vns_location() -> None:
resources_api_spec = AzureApiSpec(
service="network",
version="2023-05-01",
path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/virtualNetworks",
path_parameters=["subscriptionId"],
query_parameters=["api-version"],
access_path="value",
expect_array=True,
)

self._locations_and_ids_in_vn = [
(r["location"], r["id"]) for r in graph_builder.client.list(resources_api_spec)
]

graph_builder.submit_work(collect_vns_location)

def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
if (nw_location := self.location) and (vns_info := self._locations_and_ids_in_vn):
locations_and_ids_in_vn = self.fetch_resources(
builder,
"network",
"2023-05-01",
"/subscriptions/{subscriptionId}/providers/Microsoft.Network/virtualNetworks",
["subscriptionId"],
["api-version"],
lambda r: r["location"],
lambda r: r["id"],
)

if (nw_location := self.location) and (vns_info := locations_and_ids_in_vn):
for info in vns_info:
vn_location, vn_id = info
if vn_location == nw_location:
Expand Down Expand Up @@ -5035,30 +4995,25 @@ class AzureVirtualHub(AzureResource):
virtual_router_ips: Optional[List[str]] = field(default=None, metadata={"description": "VirtualRouter IPs."})
virtual_wan: Optional[str] = field(default=None, metadata={"description": "Reference to another subresource."})
vpn_gateway: Optional[str] = field(default=None, metadata={"description": "Reference to another subresource."})
_p_ip_addresses_ip_c_ids: Optional[List[Tuple[str, str]]] = None

def post_process(self, graph_builder: GraphBuilder, source: Json) -> None:
def collect_ip_info() -> None:
resources_api_spec = AzureApiSpec(
service="network",
version="2023-07-01",
path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/networkInterfaces",
path_parameters=["subscriptionId"],
query_parameters=["api-version"],
access_path="value",
expect_array=True,
)

self._p_ip_addresses_ip_c_ids = [
(ip_config["properties"]["publicIPAddress"]["id"], nic["id"])
for nic in graph_builder.client.list(resources_api_spec)
for ip_config in nic.get("properties", {}).get("ipConfigurations", [])
if "publicIPAddress" in ip_config.get("properties", {})
]

graph_builder.submit_work(collect_ip_info)

def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
compared_property: Callable[[Json], List[str]] = lambda r: [ # pylint: disable=unnecessary-lambda-assignment
ip_config["properties"]["publicIPAddress"]["id"]
for ip_config in r.get("properties", {}).get("ipConfigurations", [])
if "publicIPAddress" in ip_config.get("properties", {})
]

p_ip_addresses_ip_c_ids = self.fetch_resources(
builder,
service="network",
api_version="2023-05-01",
path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/networkInterfaces",
path_parameters=["subscriptionId"],
query_parameters=["api-version"],
compared_property=compared_property,
binding_property=lambda r: r["id"],
)

if er_gateway_id := self.express_route_gateway:
builder.add_edge(
self, edge_type=EdgeType.default, reverse=True, clazz=AzureExpressRouteGateway, id=er_gateway_id
Expand All @@ -5067,7 +5022,7 @@ def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
builder.add_edge(self, edge_type=EdgeType.default, reverse=True, clazz=AzureVpnGateway, id=vpn_gateway_id)
if vw_id := self.virtual_wan:
builder.add_edge(self, edge_type=EdgeType.default, reverse=True, clazz=AzureVirtualWAN, id=vw_id)
if (ip_config_ids := self.ip_configuration_ids) and (p_ip_a_and_ip_conf_ids := self._p_ip_addresses_ip_c_ids):
if (ip_config_ids := self.ip_configuration_ids) and (p_ip_a_and_ip_conf_ids := p_ip_addresses_ip_c_ids):
for ip_config_id in ip_config_ids:
for info in p_ip_a_and_ip_conf_ids:
p_ip_address_id, collected_ip_conf_id = info
Expand Down

0 comments on commit 711759a

Please sign in to comment.