From 19801757db7e398bdc40aa3b3da9fcd884a58b67 Mon Sep 17 00:00:00 2001 From: 1101-1 <70093559+1101-1@users.noreply.github.com> Date: Fri, 11 Oct 2024 13:58:43 +0500 Subject: [PATCH] [azure][fix] Reimplement resource type collection of compute, psql, mysql and ml services (#2234) --- plugins/azure/fix_plugin_azure/collector.py | 6 +- .../fix_plugin_azure/resource/compute.py | 144 +++++++++++------- .../resource/machinelearning.py | 93 ++++++++--- .../azure/fix_plugin_azure/resource/mysql.py | 89 ++++++++--- .../fix_plugin_azure/resource/postgresql.py | 96 ++++++++---- plugins/azure/test/collector_test.py | 4 +- 6 files changed, 299 insertions(+), 133 deletions(-) diff --git a/plugins/azure/fix_plugin_azure/collector.py b/plugins/azure/fix_plugin_azure/collector.py index fbd649b039..647c1e617f 100644 --- a/plugins/azure/fix_plugin_azure/collector.py +++ b/plugins/azure/fix_plugin_azure/collector.py @@ -269,11 +269,11 @@ def remove_usage_zero_value() -> None: rm_leaf_nodes(AzureComputeVirtualMachineSize, AzureLocation) rm_leaf_nodes(AzureNetworkExpressRoutePortsLocation, AzureSubscription) rm_leaf_nodes(AzureNetworkVirtualApplianceSku, AzureSubscription) - rm_leaf_nodes(AzureComputeDiskType, AzureSubscription) + rm_leaf_nodes(AzureComputeDiskType, (AzureSubscription, AzureLocation)) # type: ignore rm_leaf_nodes(AzureMachineLearningVirtualMachineSize, AzureLocation) rm_leaf_nodes(AzureStorageSku, AzureLocation) - rm_leaf_nodes(AzureMysqlServerType, AzureSubscription) - rm_leaf_nodes(AzurePostgresqlServerType, AzureSubscription) + rm_leaf_nodes(AzureMysqlServerType, AzureLocation) + rm_leaf_nodes(AzurePostgresqlServerType, AzureLocation) rm_leaf_nodes(AzureCosmosDBLocation, AzureLocation, check_pred=False) rm_leaf_nodes(AzureLocation, check_pred=False) rm_leaf_nodes(AzureComputeDiskTypePricing, AzureSubscription) diff --git a/plugins/azure/fix_plugin_azure/resource/compute.py b/plugins/azure/fix_plugin_azure/resource/compute.py index 9d5d6d6c87..7901bf3f1e 100644 --- a/plugins/azure/fix_plugin_azure/resource/compute.py +++ b/plugins/azure/fix_plugin_azure/resource/compute.py @@ -1,3 +1,4 @@ +from collections import defaultdict import logging from datetime import datetime from typing import ClassVar, Dict, Optional, List, Any, Type @@ -899,7 +900,9 @@ def build_custom_disk_size( return premium_ssd_v2_object @staticmethod - def create_unique_disk_sizes(collected_disks: List[MicrosoftResourceType], builder: GraphBuilder) -> None: + def create_unique_disk_sizes( + collected_disks: List[MicrosoftResourceType], builder: GraphBuilder, location: str + ) -> None: disk_sizes: List[Json] = [] seen_hashes = set() # Set to keep track of unique hashes for disk in collected_disks: @@ -907,7 +910,6 @@ def create_unique_disk_sizes(collected_disks: List[MicrosoftResourceType], build continue if ( (volume_type := disk.volume_type) - and (location := disk.location) and (size := disk.volume_size) and (iops := disk.volume_iops) and (throughput := disk.volume_throughput) @@ -1046,15 +1048,22 @@ class AzureComputeDisk(MicrosoftResource, BaseVolume): tier_name: Optional[str] = field(default=None, metadata={"description": "The sku tier."}) @classmethod - def collect_resources( - cls: Type[MicrosoftResourceType], builder: GraphBuilder, **kwargs: Any - ) -> List[MicrosoftResourceType]: + def collect_resources(cls, builder: GraphBuilder, **kwargs: Any) -> List["AzureComputeDisk"]: log.debug(f"[Azure:{builder.account.id}] Collecting {cls.__name__} with ({kwargs})") + if not issubclass(cls, MicrosoftResource): + return [] if spec := cls.api_spec: items = builder.client.list(spec, **kwargs) collected = cls.collect(items, builder) - # Create additional custom disk sizes for disks with Ultra SSD or Premium SSD v2 types - AzureComputeDiskType.create_unique_disk_sizes(collected, builder) + disks_by_location = defaultdict(list) + for disk in collected: + if disk_location := getattr(disk, "location", None): + disks_by_location[disk_location].append(disk) + for d_loc, disks in disks_by_location.items(): + # Collect disk types for the disks in this location + AzureComputeDisk._collect_disk_types(builder, d_loc) + # Create additional custom disk sizes for disks with Ultra SSD or Premium SSD v2 types + AzureComputeDiskType.create_unique_disk_sizes(disks, builder, d_loc) if builder.config.collect_usage_metrics: try: cls.collect_usage_metrics(builder, collected) @@ -1063,33 +1072,32 @@ def collect_resources( return collected return [] - def post_process(self, graph_builder: GraphBuilder, source: Json) -> None: - if location := self.location: - - def collect_disk_types() -> None: - log.debug(f"[Azure:{graph_builder.account.id}] Collecting AzureComputeDiskType") - product_names = { - "Standard SSD Managed Disks", - "Premium SSD Managed Disks", - "Standard HDD Managed Disks", - } - sku_items = [] - for product_name in product_names: - api_spec = AzureResourceSpec( - service="compute", - version="2023-01-01-preview", - path=f"https://prices.azure.com/api/retail/prices?$filter=productName eq '{product_name}' and armRegionName eq '{location}' and unitOfMeasure eq '1/Month' and serviceFamily eq 'Storage' and type eq 'Consumption' and isPrimaryMeterRegion eq true", - path_parameters=[], - query_parameters=["api-version"], - access_path="Items", - expect_array=True, - ) + @staticmethod + def _collect_disk_types(graph_builder: GraphBuilder, location: str) -> None: + def collect_disk_types() -> None: + log.debug(f"[Azure:{graph_builder.account.id}] Collecting AzureComputeDiskType") + product_names = { + "Standard SSD Managed Disks", + "Premium SSD Managed Disks", + "Standard HDD Managed Disks", + } + sku_items = [] + for product_name in product_names: + api_spec = AzureResourceSpec( + service="compute", + version="2023-01-01-preview", + path=f"https://prices.azure.com/api/retail/prices?$filter=productName eq '{product_name}' and armRegionName eq '{location}' and unitOfMeasure eq '1/Month' and serviceFamily eq 'Storage' and type eq 'Consumption' and isPrimaryMeterRegion eq true", + path_parameters=[], + query_parameters=["api-version"], + access_path="Items", + expect_array=True, + ) - items = graph_builder.client.list(api_spec) - sku_items.extend(items) - AzureComputeDiskType.collect(sku_items, graph_builder) + items = graph_builder.client.list(api_spec) + sku_items.extend(items) + AzureComputeDiskType.collect(sku_items, graph_builder) - graph_builder.submit_work(service_name, collect_disk_types) + graph_builder.submit_work(service_name, collect_disk_types) @classmethod def collect_usage_metrics( @@ -2954,30 +2962,60 @@ def collect_instance_status() -> None: if not instance_status_set: self.instance_status = InstanceStatus.UNKNOWN - if location := self.location: + graph_builder.submit_work(service_name, collect_instance_status) - def collect_vm_sizes() -> None: - api_spec = AzureResourceSpec( - service="compute", - version="2023-03-01", - path="/subscriptions/{subscriptionId}/providers/Microsoft.Compute/locations/" - + f"{location}/vmSizes", - path_parameters=["subscriptionId"], - query_parameters=["api-version"], - access_path="value", - expect_array=True, - ) - items = graph_builder.client.list(api_spec) - if not items: - return - # Set location for further connect_in_graph method - for item in items: - item["location"] = location - AzureComputeVirtualMachineSize.collect(items, graph_builder) + @classmethod + def collect_resources(cls, builder: GraphBuilder, **kwargs: Any) -> List["AzureComputeVirtualMachineBase"]: + log.debug(f"[Azure:{builder.account.id}] Collecting {cls.__name__} with ({kwargs})") - graph_builder.submit_work(service_name, collect_vm_sizes) + if not issubclass(cls, MicrosoftResource): + return [] - graph_builder.submit_work(service_name, collect_instance_status) + if spec := cls.api_spec: + items = builder.client.list(spec, **kwargs) + collected = cls.collect(items, builder) + + unique_locations = set(getattr(vm, "location") for vm in collected if getattr(vm, "location")) + + for location in unique_locations: + log.debug(f"Processing virtual machines in location: {location}") + + # Collect VM sizes for the VM in this location + AzureComputeVirtualMachineBase._collect_vm_sizes(builder, location) + + if builder.config.collect_usage_metrics: + try: + cls.collect_usage_metrics(builder, collected) + except Exception as e: + log.warning(f"Failed to collect usage metrics for {cls.__name__}: {e}") + + return collected + + return [] + + @staticmethod + def _collect_vm_sizes(graph_builder: GraphBuilder, location: str) -> None: + def collect_vm_sizes() -> None: + api_spec = AzureResourceSpec( + service="compute", + version="2023-03-01", + path=f"/subscriptions/{{subscriptionId}}/providers/Microsoft.Compute/locations/{location}/vmSizes", + path_parameters=["subscriptionId"], + query_parameters=["api-version"], + access_path="value", + expect_array=True, + ) + items = graph_builder.client.list(api_spec) + if not items: + return + + # Set location for further connect_in_graph method + for item in items: + item["location"] = location + + AzureComputeVirtualMachineSize.collect(items, graph_builder) + + graph_builder.submit_work(service_name, collect_vm_sizes) @classmethod def collect_usage_metrics( diff --git a/plugins/azure/fix_plugin_azure/resource/machinelearning.py b/plugins/azure/fix_plugin_azure/resource/machinelearning.py index a01a305148..528d18d203 100644 --- a/plugins/azure/fix_plugin_azure/resource/machinelearning.py +++ b/plugins/azure/fix_plugin_azure/resource/machinelearning.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import defaultdict import logging from datetime import datetime from typing import Any, ClassVar, Dict, Optional, List, Tuple, Type @@ -553,32 +554,78 @@ class AzureMachineLearningCompute(MicrosoftResource): system_data: Optional[AzureSystemData] = field(default=None, metadata={'description': 'Metadata pertaining to creation and last modification of the resource.'}) # fmt: skip location: Optional[str] = field(default=None, metadata={'description': 'The geo-location where the resource lives'}) # fmt: skip - def post_process(self, graph_builder: GraphBuilder, source: Json) -> None: - if location := self.location: + @classmethod + def collect_resources(cls, builder: GraphBuilder, **kwargs: Any) -> List["AzureMachineLearningCompute"]: + log.debug(f"[Azure:{builder.account.id}] Collecting {cls.__name__} with ({kwargs})") - def collect_vm_sizes() -> None: - api_spec = AzureResourceSpec( - service="machinelearningservices", - version="2024-04-01", - path="/subscriptions/{subscriptionId}/providers/Microsoft.MachineLearningServices/locations/" - + f"{location}/vmSizes", - path_parameters=["subscriptionId"], - query_parameters=["api-version"], - access_path="value", - expect_array=True, - ) - items = graph_builder.client.list(api_spec) - if not items: - return - for item in items: - item["location"] = location - collected = AzureMachineLearningVirtualMachineSize.collect(items, graph_builder) - for _ in collected: - if (properties := self.properties) and (vm_size := properties.get("vmSize")): - graph_builder.add_edge(self, clazz=AzureMachineLearningVirtualMachineSize, name=vm_size) + if not issubclass(cls, MicrosoftResource): + return [] + + if spec := cls.api_spec: + items = builder.client.list(spec, **kwargs) + collected = cls.collect(items, builder) + + resources_by_location = defaultdict(list) - graph_builder.submit_work(service_name, collect_vm_sizes) + for compute_resource in collected: + location = getattr(compute_resource, "location", None) + if location: + resources_by_location[location].append(compute_resource) + # Process each unique location + for location, compute_resources in resources_by_location.items(): + log.debug(f"Processing compute resources in location: {location}") + + # Collect VM sizes for the compute resources in this location + cls._collect_vm_sizes(builder, location, compute_resources) + + if builder.config.collect_usage_metrics: + try: + cls.collect_usage_metrics(builder, collected) + except Exception as e: + log.warning(f"Failed to collect usage metrics for {cls.__name__}: {e}") + + return collected + + return [] + + @staticmethod + def _collect_vm_sizes( + graph_builder: GraphBuilder, location: str, compute_resources: List["AzureMachineLearningCompute"] + ) -> None: + def collect_vm_sizes() -> None: + api_spec = AzureResourceSpec( + service="machinelearningservices", + version="2024-04-01", + path=f"/subscriptions/{{subscriptionId}}/providers/Microsoft.MachineLearningServices/locations/{location}/vmSizes", + path_parameters=["subscriptionId"], + query_parameters=["api-version"], + access_path="value", + expect_array=True, + ) + items = graph_builder.client.list(api_spec) + + if not items: + return + + # Set location for further connect_in_graph method + for item in items: + item["location"] = location + + # Collect the virtual machine sizes + collected_vm_sizes = AzureMachineLearningVirtualMachineSize.collect(items, graph_builder) + + for compute_resource in compute_resources: + vm_size = (compute_resource.properties or {}).get("vmSize") + if vm_size: + for size in collected_vm_sizes: + if size.name == vm_size: + graph_builder.add_edge(compute_resource, node=size) + break + + graph_builder.submit_work(service_name, collect_vm_sizes) + + def post_process(self, graph_builder: GraphBuilder, source: Json) -> None: if resource_id := self.id: def collect_nodes() -> None: diff --git a/plugins/azure/fix_plugin_azure/resource/mysql.py b/plugins/azure/fix_plugin_azure/resource/mysql.py index c786837030..d9e928ec34 100644 --- a/plugins/azure/fix_plugin_azure/resource/mysql.py +++ b/plugins/azure/fix_plugin_azure/resource/mysql.py @@ -723,32 +723,71 @@ def post_process(self, graph_builder: GraphBuilder, source: Json) -> None: else: self.volume_encrypted = False - if (location := self.location) and (sku := self.server_sku) and (version := self.version): - - def collect_capabilities() -> None: - api_spec = AzureResourceSpec( - service="mysql", - version="2023-12-30", - path="/subscriptions/{subscriptionId}/providers/Microsoft.DBforMySQL/locations/" - + f"{location}/capabilities", - path_parameters=["subscriptionId"], - query_parameters=["api-version"], - access_path="value", - expect_array=True, + @classmethod + def collect_resources(cls, builder: GraphBuilder, **kwargs: Any) -> List["AzureMysqlServer"]: + log.debug(f"[Azure:{builder.account.id}] Collecting {cls.__name__} with ({kwargs})") + + if not issubclass(cls, MicrosoftResource): + return [] + + if spec := cls.api_spec: + items = builder.client.list(spec, **kwargs) + collected = cls.collect(items, builder) + + # Group the collected resources by location, sku, and version + unique_servers = set() + for server in collected: + location = getattr(server, "location", None) + sku = getattr(server, "server_sku", None) + version = getattr(server, "version", None) + if location and sku and version: + sku_name = sku.name + sku_tier = sku.tier + unique_servers.add((location, sku_name, sku_tier, version)) + + for location, sku_name, sku_tier, version in unique_servers: + log.debug( + f"Processing servers in location: {location}, SKU: {sku_name}, Tier: {sku_tier}, Version: {version}" ) - items = graph_builder.client.list(api_spec) - if not items: - return - for item in items: - # Set location for further connect_in_graph method - item["location"] = location - # Set sku name and tier for SKUs filtering - item["expected_sku_name"] = sku.name - item["expected_sku_tier"] = sku.tier - item["expected_version"] = version - AzureMysqlServerType.collect(items, graph_builder) - - graph_builder.submit_work(service_name, collect_capabilities) + + # Collect MySQL server types for the servers in this group + AzureMysqlServer._collect_mysql_server_types(builder, location, sku_name, sku_tier, version) + + if builder.config.collect_usage_metrics: + try: + cls.collect_usage_metrics(builder, collected) + except Exception as e: + log.warning(f"Failed to collect usage metrics for {cls.__name__} in {location}: {e}") + + return collected + + return [] + + @staticmethod + def _collect_mysql_server_types( + builder: GraphBuilder, location: str, sku_name: str, sku_tier: str, version: str + ) -> None: + def collect_capabilities() -> None: + api_spec = AzureResourceSpec( + service="mysql", + version="2023-12-30", + path=f"/subscriptions/{{subscriptionId}}/providers/Microsoft.DBforMySQL/locations/{location}/capabilities", + path_parameters=["subscriptionId"], + query_parameters=["api-version"], + access_path="value", + expect_array=True, + ) + items = builder.client.list(api_spec) + if not items: + return + for item in items: + item["location"] = location + item["expected_sku_name"] = sku_name + item["expected_sku_tier"] = sku_tier + item["expected_version"] = version + AzureMysqlServerType.collect(items, builder) + + builder.submit_work(service_name, collect_capabilities) def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None: if ( diff --git a/plugins/azure/fix_plugin_azure/resource/postgresql.py b/plugins/azure/fix_plugin_azure/resource/postgresql.py index 289dd8f116..4f5df0f35e 100644 --- a/plugins/azure/fix_plugin_azure/resource/postgresql.py +++ b/plugins/azure/fix_plugin_azure/resource/postgresql.py @@ -605,34 +605,76 @@ def post_process(self, graph_builder: GraphBuilder, source: Json) -> None: else: self.volume_encrypted = False - if (server_location := self.location) and (sku := self.server_sku) and (version := self.version): - - def collect_capabilities() -> None: - - api_spec = AzureResourceSpec( - service="postgresql", - version="2022-12-01", - path="/subscriptions/{subscriptionId}/providers/Microsoft.DBforPostgreSQL/locations/" - + f"{server_location}/capabilities", - path_parameters=["subscriptionId"], - query_parameters=["api-version"], - access_path="value", - expect_array=True, - expected_error_codes={"InternalServerError": None}, + @classmethod + def collect_resources(cls, builder: GraphBuilder, **kwargs: Any) -> List["AzurePostgresqlServer"]: + log.debug(f"[Azure:{builder.account.id}] Collecting {cls.__name__} with ({kwargs})") + + if not issubclass(cls, MicrosoftResource): + return [] + + if spec := cls.api_spec: + items = builder.client.list(spec, **kwargs) + collected = cls.collect(items, builder) + + unique_servers = set() + + for server in collected: + location = getattr(server, "location", None) + sku = getattr(server, "server_sku", None) + version = getattr(server, "version", None) + + if location and sku and version: + sku_name = sku.name + sku_tier = sku.tier + unique_servers.add((location, sku_name, sku_tier, version)) + + for location, sku_name, sku_tier, version in unique_servers: + log.debug( + f"Processing PostgreSQL servers in location: {location}, SKU: {sku_name}, Tier: {sku_tier}, Version: {version}" ) - items = graph_builder.client.list(api_spec) - if not items: - return - for item in items: - # Set location for further connect_in_graph method - item["location"] = server_location - # Set sku name and tier for SKUs filtering - item["expected_sku_name"] = sku.name - item["expected_sku_tier"] = sku.tier - item["expected_version"] = version - AzurePostgresqlServerType.collect(items, graph_builder) - - graph_builder.submit_work(service_name, collect_capabilities) + + # Collect PostgreSQL server types for the servers in this group + AzurePostgresqlServer._collect_postgresql_server_types(builder, location, sku_name, sku_tier, version) + + if builder.config.collect_usage_metrics: + try: + cls.collect_usage_metrics(builder, collected) + except Exception as e: + log.warning(f"Failed to collect usage metrics for {cls.__name__} in {location}: {e}") + + return collected + + return [] + + @staticmethod + def _collect_postgresql_server_types( + graph_builder: GraphBuilder, server_location: str, sku_name: str, sku_tier: str, version: str + ) -> None: + def collect_capabilities() -> None: + api_spec = AzureResourceSpec( + service="postgresql", + version="2022-12-01", + path=f"/subscriptions/{{subscriptionId}}/providers/Microsoft.DBforPostgreSQL/locations/{server_location}/capabilities", + path_parameters=["subscriptionId"], + query_parameters=["api-version"], + access_path="value", + expect_array=True, + expected_error_codes={"InternalServerError": None}, + ) + + items = graph_builder.client.list(api_spec) + if not items: + return + + for item in items: + item["location"] = server_location + item["expected_sku_name"] = sku_name + item["expected_sku_tier"] = sku_tier + item["expected_version"] = version + + AzurePostgresqlServerType.collect(items, graph_builder) + + graph_builder.submit_work(service_name, collect_capabilities) def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None: if ( diff --git a/plugins/azure/test/collector_test.py b/plugins/azure/test/collector_test.py index 405b7f9c7d..f92231b6f6 100644 --- a/plugins/azure/test/collector_test.py +++ b/plugins/azure/test/collector_test.py @@ -48,8 +48,8 @@ def test_collect( config, Cloud(id="azure"), azure_subscription, credentials, core_feedback, filter_unused_resources=False ) subscription_collector.collect() - assert len(subscription_collector.graph.nodes) == 952 - assert len(subscription_collector.graph.edges) == 1341 + assert len(subscription_collector.graph.nodes) == 883 + assert len(subscription_collector.graph.edges) == 1272 graph_collector = MicrosoftGraphOrganizationCollector( config, Cloud(id="azure"), MicrosoftGraphOrganization(id="test", name="test"), credentials, core_feedback