diff --git a/plugins/aws/fix_plugin_aws/resource/base.py b/plugins/aws/fix_plugin_aws/resource/base.py index a270912bc4..2be419f72c 100644 --- a/plugins/aws/fix_plugin_aws/resource/base.py +++ b/plugins/aws/fix_plugin_aws/resource/base.py @@ -449,7 +449,7 @@ def __init__( client: AwsClient, executor: ExecutorQueue, core_feedback: CoreFeedback, - global_instance_types: Optional[Dict[Tuple[str, str], Any]] = None, + global_instance_types: Optional[Dict[str, Any]] = None, graph_nodes_access: Optional[RWLock] = None, graph_edges_access: Optional[RWLock] = None, last_run_started_at: Optional[datetime] = None, @@ -462,9 +462,7 @@ def __init__( self.client = client self.executor = executor self.name = f"AWS:{account.name}:{region.name}" - self.global_instance_types: Dict[Tuple[str, str], Any] = ( - global_instance_types if global_instance_types is not None else {} - ) + self.global_instance_types: Dict[str, Any] = global_instance_types if global_instance_types is not None else {} self.core_feedback = core_feedback self.graph_nodes_access = graph_nodes_access or RWLock() self.graph_edges_access = graph_edges_access or RWLock() @@ -659,7 +657,7 @@ def edges_of( @lru_cache(maxsize=None) def instance_type(self, region: AwsRegion, instance_type: str) -> Optional[Any]: - if (it := self.global_instance_types.get((region.id, instance_type))) is None: + if (it := self.global_instance_types.get(instance_type)) is None: return None # instance type not found price = value_in_path(cloud_instance_data, ["aws", instance_type, "pricing", region.id, "linux", "ondemand"]) diff --git a/plugins/aws/fix_plugin_aws/resource/ec2.py b/plugins/aws/fix_plugin_aws/resource/ec2.py index 064ee2591a..c0374f01fe 100644 --- a/plugins/aws/fix_plugin_aws/resource/ec2.py +++ b/plugins/aws/fix_plugin_aws/resource/ec2.py @@ -393,7 +393,7 @@ class AwsEc2InstanceType(AwsResource, BaseInstanceType): _kind_service: ClassVar[Optional[str]] = service_name _metadata: ClassVar[Dict[str, Any]] = {"icon": "type", "group": "compute"} _aws_metadata: ClassVar[Dict[str, Any]] = {"arn_tpl": "arn:{partition}:ec2:{region}:{account}:instance/{id}"} # fmt: skip - # api_spec defined in `collect_resource_types` method + # api_spec defined in `collect_resource_types` method and collected by AwsEc2Instance _reference_kinds: ClassVar[ModelReference] = { "successors": { "default": ["aws_ec2_instance"], @@ -491,7 +491,7 @@ def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> # Only "used" instance type will be stored in the graph # note: not all instance types are returned in any region. # we collect instance types in all regions and make the data unique in the builder - builder.global_instance_types[(builder.region.id, it.safe_name)] = it + builder.global_instance_types[it.safe_name] = it @classmethod def service_name(cls) -> Optional[str]: @@ -1407,41 +1407,15 @@ class AwsEc2Instance(EC2Taggable, AwsResource, BaseInstance): @classmethod def collect_resources(cls: Type[AwsResource], builder: GraphBuilder) -> None: - # Default behavior: in case the class has an ApiSpec, call the api and call collect. - log.debug(f"Collecting {cls.__name__} in region {builder.region.name}") - if spec := cls.api_spec: - try: - kwargs = spec.parameter or {} - items = builder.client.list( - aws_service=spec.service, - action=spec.api_action, - result_name=spec.result_property, - expected_errors=spec.expected_errors, - **kwargs, - ) - if not items: - return - ec2_instance_types = [] - checked_types = set() - for item in items: - for instance in item.get("Instances", []): - if instance_type := instance.get("InstanceType"): - if instance_type not in checked_types: - ec2_instance_types.append(instance_type) - checked_types.add(instance_type) - if ec2_instance_types: - builder.submit_work( - service_name, AwsEc2InstanceType.collect_resource_types, builder, ec2_instance_types - ) - cls.collect(items, builder) - except Boto3Error as e: - msg = f"Error while collecting {cls.__name__} in region {builder.region.name}: {e}" - builder.core_feedback.error(msg, log) - raise - except Exception as e: - msg = f"Error while collecting {cls.__name__} in region {builder.region.name}: {e}" - builder.core_feedback.info(msg, log) - raise + super().collect_resources(builder) # type: ignore # mypy bug: https://github.com/python/mypy/issues/12885 + ec2_instance_types = [] + checked_types = set() + for instance in builder.nodes(clazz=AwsEc2Instance): + if (instance_type := instance.instance_type) and instance_type not in checked_types: + ec2_instance_types.append(instance_type) + checked_types.add(instance_type) + if ec2_instance_types: + builder.submit_work(service_name, AwsEc2InstanceType.collect_resource_types, builder, ec2_instance_types) @classmethod def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> None: @@ -4020,7 +3994,7 @@ class AwsEc2LaunchTemplate(EC2Taggable, AwsResource): # endregion resources: List[Type[AwsResource]] = [ - # AwsEc2InstanceType, Collected via AwsEc2Instance + AwsEc2InstanceType, AwsEc2ElasticIp, AwsEc2FlowLog, AwsEc2Host, diff --git a/plugins/aws/test/graphbuilder_test.py b/plugins/aws/test/graphbuilder_test.py index 61810251f3..f5a769c3bd 100644 --- a/plugins/aws/test/graphbuilder_test.py +++ b/plugins/aws/test/graphbuilder_test.py @@ -10,14 +10,14 @@ def test_instance_type(builder: GraphBuilder) -> None: instance_type = "m4.large" - builder.global_instance_types[(builder.region.id, instance_type)] = AwsEc2InstanceType(id=instance_type) + builder.global_instance_types[instance_type] = AwsEc2InstanceType(id=instance_type) m4l: AwsEc2InstanceType = builder.instance_type(builder.region, instance_type) # type: ignore assert m4l == builder.instance_type(builder.region, instance_type) assert m4l.ondemand_cost == value_in_path( cloud_instance_data, ["aws", instance_type, "pricing", builder.region.id, "linux", "ondemand"] ) eu_builder = builder.for_region(AwsRegion(id="eu-central-1")) - builder.global_instance_types[(eu_builder.region.id, instance_type)] = AwsEc2InstanceType(id=instance_type) + builder.global_instance_types[instance_type] = AwsEc2InstanceType(id=instance_type) m4l_eu: AwsEc2InstanceType = eu_builder.instance_type(eu_builder.region, instance_type) # type: ignore assert m4l != m4l_eu assert m4l_eu == eu_builder.instance_type(eu_builder.region, instance_type) diff --git a/plugins/aws/test/resources/base_test.py b/plugins/aws/test/resources/base_test.py index a8502d6833..163e54abb1 100644 --- a/plugins/aws/test/resources/base_test.py +++ b/plugins/aws/test/resources/base_test.py @@ -87,8 +87,7 @@ def test_instance_type_handling(builder: GraphBuilder) -> None: region1 = AwsRegion(id="us-east-1") region2 = AwsRegion(id="us-east-2") it = AwsEc2InstanceType(id="t3.micro") - builder.global_instance_types[(region1.id, it.safe_name)] = it - builder.global_instance_types[(region2.id, it.safe_name)] = it + builder.global_instance_types[it.safe_name] = it it1: AwsEc2InstanceType = builder.instance_type(region1, it.safe_name) # type: ignore assert it1.region() == region1 it2: AwsEc2InstanceType = builder.instance_type(region2, it.safe_name) # type: ignore