diff --git a/plugins/azure/resoto_plugin_azure/azure_client.py b/plugins/azure/resoto_plugin_azure/azure_client.py index 394df20945..36335e9263 100644 --- a/plugins/azure/resoto_plugin_azure/azure_client.py +++ b/plugins/azure/resoto_plugin_azure/azure_client.py @@ -148,16 +148,12 @@ def _call(self, spec: AzureApiSpec, **kwargs: Any) -> List[Json]: params = case_insensitive_dict(kwargs.pop("params", {}) or {}) params["api-version"] = _SERIALIZER.query("api_version", spec.version, "str") # type: ignore + # Define url map + format_map_paths = {"subscriptionId": self.subscription_id, "location": self.location, **params} + format_map_paths.update({param: kwargs.pop(param, "") for param in spec.path_parameters if param in kwargs}) + # Construct url - path = spec.path.format_map( - { - "subscriptionId": self.subscription_id, - "location": self.location, - "resourceGroupName": kwargs.pop("resourceGroupName", ""), - "virtualNetworkName": kwargs.pop("virtualNetworkName", ""), - **params, - } - ) + path = spec.path.format_map(format_map_paths) url = self.client._client.format_url(path) # pylint: disable=protected-access # Construct and send request diff --git a/plugins/azure/resoto_plugin_azure/resource/base.py b/plugins/azure/resoto_plugin_azure/resource/base.py index cacbe387f6..d07ec44d76 100644 --- a/plugins/azure/resoto_plugin_azure/resource/base.py +++ b/plugins/azure/resoto_plugin_azure/resource/base.py @@ -44,6 +44,12 @@ class AzureResource(BaseResource): # Which API to call and what to expect in the result. api_spec: ClassVar[Optional[AzureApiSpec]] = None + def resource_subscription_id(self) -> str: + return self.extract_part("subscriptionId") + + def resource_resource_group_name(self) -> str: + return self.extract_part("resourceGroupName") + def extract_part(self, part: str) -> str: """ Extracts a specific part from a resource ID. @@ -66,18 +72,18 @@ def extract_part(self, part: str) -> str: id_parts = self.id.split("/") if part == "subscriptionId": - if "subscriptions" in id_parts: - index = id_parts.index("subscriptions") - return id_parts[index + 1] - else: + if "subscriptions" not in id_parts: raise ValueError(f"Id {self.id} does not have any subscriptionId info") + if index := id_parts.index("subscriptions"): + return id_parts[index + 1] + return "" elif part == "resourceGroupName": - if "resourceGroups" in id_parts: - index = id_parts.index("resourceGroups") - return id_parts[index + 1] - else: + if "resourceGroups" not in id_parts: raise ValueError(f"Id {self.id} does not have any resourceGroupName info") + if index := id_parts.index("resourceGroups"): + return id_parts[index + 1] + return "" else: raise ValueError(f"Value {part} does not have any cases to match") @@ -89,7 +95,7 @@ def delete(self, graph: Graph) -> bool: Returns: bool: True if the resource was successfully deleted; False otherwise. """ - subscription_id = self.extract_part("subscriptionId") + subscription_id = self.resource_subscription_id() return get_client(subscription_id).delete(self.id) def delete_tag(self, key: str) -> bool: @@ -98,7 +104,7 @@ def delete_tag(self, key: str) -> bool: This method removes a specific value from a tag associated with a subscription, while keeping the tag itself intact. The tag remains on the account, but the specified value will be deleted. """ - subscription_id = self.extract_part("subscriptionId") + subscription_id = self.resource_subscription_id() return get_client(subscription_id).delete_resource_tag(tag_name=key, resource_id=self.id) def update_tag(self, key: str, value: str) -> bool: @@ -107,7 +113,7 @@ def update_tag(self, key: str, value: str) -> bool: This method allows for the creation or update of a tag value associated with the specified tag name. The tag name must already exist for the operation to be successful. """ - subscription_id = self.extract_part("subscriptionId") + subscription_id = self.resource_subscription_id() return get_client(subscription_id).update_resource_tag(tag_name=key, tag_value=value, resource_id=self.id) def pre_process(self, graph_builder: GraphBuilder, source: Json) -> None: diff --git a/plugins/azure/resoto_plugin_azure/resource/network.py b/plugins/azure/resoto_plugin_azure/resource/network.py index 7f721ea5d4..176bfd2ee5 100644 --- a/plugins/azure/resoto_plugin_azure/resource/network.py +++ b/plugins/azure/resoto_plugin_azure/resource/network.py @@ -26,7 +26,7 @@ from resotolib.types import Json -def extract_vn_id(subnet_id: str) -> str: +def extract_virtual_network_id(subnet_id: str) -> str: """ Extracts {virtualNetworkName} value from a subnet ID and create virtual network ID @@ -1122,7 +1122,7 @@ def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None: builder.add_edge( self, edge_type=EdgeType.default, reverse=True, clazz=AzureSubnet, id=subnet_id ) - vn_id = extract_vn_id(subnet_id) + vn_id = extract_virtual_network_id(subnet_id) builder.add_edge( self, edge_type=EdgeType.default, reverse=True, clazz=AzureVirtualNetwork, id=vn_id ) @@ -1501,7 +1501,7 @@ def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None: if ip_confs := self.firewall_ip_configurations: for ip_conf in ip_confs: if subnet := ip_conf.subnet: - vn_id = extract_vn_id(subnet) + vn_id = extract_virtual_network_id(subnet) builder.add_edge( self, edge_type=EdgeType.default, reverse=True, clazz=AzureVirtualNetwork, id=vn_id ) @@ -2893,7 +2893,7 @@ def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None: if ip_confs := network_interface.interface_ip_configurations: for ip_conf in ip_confs: if subnet := ip_conf._subnet_id: - vn_id = extract_vn_id(subnet) + vn_id = extract_virtual_network_id(subnet) builder.add_edge( self, edge_type=EdgeType.default, reverse=True, clazz=AzureVirtualNetwork, id=vn_id ) @@ -4300,7 +4300,7 @@ def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None: if ip_configurations := container.ip_configurations: for ip_configuration in ip_configurations: if subnet := ip_configuration._subnet_id: - vn_id = extract_vn_id(subnet) + vn_id = extract_virtual_network_id(subnet) builder.add_edge( self, edge_type=EdgeType.default, reverse=True, clazz=AzureVirtualNetwork, id=vn_id ) @@ -4541,14 +4541,14 @@ class AzureNetworkWatcher(AzureResource): def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None: locations_and_ids_in_vn = self.fetch_resources( - builder, - "network", - "2023-05-01", - "/subscriptions/{subscriptionId}/providers/Microsoft.Network/virtualNetworks", - ["subscriptionId"], - ["api-version"], - lambda r: r["location"], - lambda r: r["id"], + builder=builder, + service="network", + api_version="2023-05-01", + path="/subscriptions/{subscriptionId}/providers/Microsoft.Network/virtualNetworks", + path_parameters=["subscriptionId"], + query_parameters=["api-version"], + compared_property=lambda r: r["location"], + binding_property=lambda r: r["id"], ) if (nw_location := self.location) and (vns_info := locations_and_ids_in_vn): @@ -5192,7 +5192,7 @@ def collect_subnets() -> None: access_path="value", expect_array=True, ) - resource_group_name = self.extract_part("resourceGroupName") + resource_group_name = self.resource_resource_group_name() virtual_network_name = self.name if self.name else "" items = graph_builder.client.list(