Skip to content

Commit

Permalink
feat: optimized implementation of ec2 instance types
Browse files Browse the repository at this point in the history
  • Loading branch information
1101-1 committed Oct 29, 2024
1 parent 49b77b8 commit 279cd86
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 47 deletions.
8 changes: 3 additions & 5 deletions plugins/aws/fix_plugin_aws/resource/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def __init__(
client: AwsClient,
executor: ExecutorQueue,
core_feedback: CoreFeedback,
global_instance_types: Optional[Dict[Tuple[str, str], Any]] = None,
global_instance_types: Optional[Dict[str, Any]] = None,
graph_nodes_access: Optional[RWLock] = None,
graph_edges_access: Optional[RWLock] = None,
last_run_started_at: Optional[datetime] = None,
Expand All @@ -462,9 +462,7 @@ def __init__(
self.client = client
self.executor = executor
self.name = f"AWS:{account.name}:{region.name}"
self.global_instance_types: Dict[Tuple[str, str], Any] = (
global_instance_types if global_instance_types is not None else {}
)
self.global_instance_types: Dict[str, Any] = global_instance_types if global_instance_types is not None else {}
self.core_feedback = core_feedback
self.graph_nodes_access = graph_nodes_access or RWLock()
self.graph_edges_access = graph_edges_access or RWLock()
Expand Down Expand Up @@ -659,7 +657,7 @@ def edges_of(

@lru_cache(maxsize=None)
def instance_type(self, region: AwsRegion, instance_type: str) -> Optional[Any]:
if (it := self.global_instance_types.get((region.id, instance_type))) is None:
if (it := self.global_instance_types.get(instance_type)) is None:
return None # instance type not found

price = value_in_path(cloud_instance_data, ["aws", instance_type, "pricing", region.id, "linux", "ondemand"])
Expand Down
50 changes: 12 additions & 38 deletions plugins/aws/fix_plugin_aws/resource/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ class AwsEc2InstanceType(AwsResource, BaseInstanceType):
_kind_service: ClassVar[Optional[str]] = service_name
_metadata: ClassVar[Dict[str, Any]] = {"icon": "type", "group": "compute"}
_aws_metadata: ClassVar[Dict[str, Any]] = {"arn_tpl": "arn:{partition}:ec2:{region}:{account}:instance/{id}"} # fmt: skip
# api_spec defined in `collect_resource_types` method
# api_spec defined in `collect_resource_types` method and collected by AwsEc2Instance
_reference_kinds: ClassVar[ModelReference] = {
"successors": {
"default": ["aws_ec2_instance"],
Expand Down Expand Up @@ -491,7 +491,7 @@ def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) ->
# Only "used" instance type will be stored in the graph
# note: not all instance types are returned in any region.
# we collect instance types in all regions and make the data unique in the builder
builder.global_instance_types[(builder.region.id, it.safe_name)] = it
builder.global_instance_types[it.safe_name] = it

@classmethod
def service_name(cls) -> Optional[str]:
Expand Down Expand Up @@ -1407,41 +1407,15 @@ class AwsEc2Instance(EC2Taggable, AwsResource, BaseInstance):

@classmethod
def collect_resources(cls: Type[AwsResource], builder: GraphBuilder) -> None:
# Default behavior: in case the class has an ApiSpec, call the api and call collect.
log.debug(f"Collecting {cls.__name__} in region {builder.region.name}")
if spec := cls.api_spec:
try:
kwargs = spec.parameter or {}
items = builder.client.list(
aws_service=spec.service,
action=spec.api_action,
result_name=spec.result_property,
expected_errors=spec.expected_errors,
**kwargs,
)
if not items:
return
ec2_instance_types = []
checked_types = set()
for item in items:
for instance in item.get("Instances", []):
if instance_type := instance.get("InstanceType"):
if instance_type not in checked_types:
ec2_instance_types.append(instance_type)
checked_types.add(instance_type)
if ec2_instance_types:
builder.submit_work(
service_name, AwsEc2InstanceType.collect_resource_types, builder, ec2_instance_types
)
cls.collect(items, builder)
except Boto3Error as e:
msg = f"Error while collecting {cls.__name__} in region {builder.region.name}: {e}"
builder.core_feedback.error(msg, log)
raise
except Exception as e:
msg = f"Error while collecting {cls.__name__} in region {builder.region.name}: {e}"
builder.core_feedback.info(msg, log)
raise
super().collect_resources(builder) # type: ignore # mypy bug: https://github.com/python/mypy/issues/12885
ec2_instance_types = []
checked_types = set()
for instance in builder.nodes(clazz=AwsEc2Instance):
if (instance_type := instance.instance_type) and instance_type not in checked_types:
ec2_instance_types.append(instance_type)
checked_types.add(instance_type)
if ec2_instance_types:
builder.submit_work(service_name, AwsEc2InstanceType.collect_resource_types, builder, ec2_instance_types)

@classmethod
def collect(cls: Type[AwsResource], json: List[Json], builder: GraphBuilder) -> None:
Expand Down Expand Up @@ -4020,7 +3994,7 @@ class AwsEc2LaunchTemplate(EC2Taggable, AwsResource):
# endregion

resources: List[Type[AwsResource]] = [
# AwsEc2InstanceType, Collected via AwsEc2Instance
AwsEc2InstanceType,
AwsEc2ElasticIp,
AwsEc2FlowLog,
AwsEc2Host,
Expand Down
4 changes: 2 additions & 2 deletions plugins/aws/test/graphbuilder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@

def test_instance_type(builder: GraphBuilder) -> None:
instance_type = "m4.large"
builder.global_instance_types[(builder.region.id, instance_type)] = AwsEc2InstanceType(id=instance_type)
builder.global_instance_types[instance_type] = AwsEc2InstanceType(id=instance_type)
m4l: AwsEc2InstanceType = builder.instance_type(builder.region, instance_type) # type: ignore
assert m4l == builder.instance_type(builder.region, instance_type)
assert m4l.ondemand_cost == value_in_path(
cloud_instance_data, ["aws", instance_type, "pricing", builder.region.id, "linux", "ondemand"]
)
eu_builder = builder.for_region(AwsRegion(id="eu-central-1"))
builder.global_instance_types[(eu_builder.region.id, instance_type)] = AwsEc2InstanceType(id=instance_type)
builder.global_instance_types[instance_type] = AwsEc2InstanceType(id=instance_type)
m4l_eu: AwsEc2InstanceType = eu_builder.instance_type(eu_builder.region, instance_type) # type: ignore
assert m4l != m4l_eu
assert m4l_eu == eu_builder.instance_type(eu_builder.region, instance_type)
Expand Down
3 changes: 1 addition & 2 deletions plugins/aws/test/resources/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,7 @@ def test_instance_type_handling(builder: GraphBuilder) -> None:
region1 = AwsRegion(id="us-east-1")
region2 = AwsRegion(id="us-east-2")
it = AwsEc2InstanceType(id="t3.micro")
builder.global_instance_types[(region1.id, it.safe_name)] = it
builder.global_instance_types[(region2.id, it.safe_name)] = it
builder.global_instance_types[it.safe_name] = it
it1: AwsEc2InstanceType = builder.instance_type(region1, it.safe_name) # type: ignore
assert it1.region() == region1
it2: AwsEc2InstanceType = builder.instance_type(region2, it.safe_name) # type: ignore
Expand Down

0 comments on commit 279cd86

Please sign in to comment.