Skip to content

Commit

Permalink
feat: added new collection way for ec2, lambda and ecr
Browse files Browse the repository at this point in the history
  • Loading branch information
1101-1 committed Oct 18, 2024
1 parent 403a4c6 commit e50938e
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 36 deletions.
51 changes: 50 additions & 1 deletion plugins/aws/fix_plugin_aws/resource/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections import defaultdict
import logging
import re
from abc import ABC
Expand Down Expand Up @@ -27,6 +28,8 @@
BaseVolumeType,
Cloud,
EdgeType,
Finding,
Assessment,
ModelReference,
PhantomBaseResource,
BaseOrganizationalRoot,
Expand Down Expand Up @@ -169,6 +172,37 @@ def service_name(cls) -> Optional[str]:
"""
return cls.api_spec.service if cls.api_spec else None

def set_findings(self, builder: GraphBuilder, to_check: str = "id") -> None:
"""
Set the assessment findings for the resource based on its ID or ARN.
Args:
builder (GraphBuilder): The builder object that holds assessment findings.
to_check (str): A string indicating whether to use "id" or "arn" to check findings.
Default is "id".
"""
# Ensure this method is only applied to subclasses of AwsResource, not AwsResource itself
if isinstance(self, AwsResource) and self.__class__ == AwsResource:
return

id_or_arn = ""

if to_check == "arn":
if not self.arn:
return
id_or_arn = self.arn
elif to_check == "id":
id_or_arn = self.id
else:
return
for provider in ["inspector", "guard_duty"]:
provider_findings = builder._assessment_findings.get(
(provider, self.region().id, self.__class__.__name__), {}
).get(id_or_arn, [])
if provider_findings:
# Set the findings in the resource's _assessments dictionary
self._assessments.append(Assessment(provider, provider_findings))

def set_arn(
self,
builder: GraphBuilder,
Expand Down Expand Up @@ -469,7 +503,19 @@ def __init__(
self.last_run_started_at = last_run_started_at
self.created_at = utc()
self.__builder_cache = {region.safe_name: self}

self._assessment_findings: Dict[Tuple[str, str, str], Dict[str, List[Finding]]] = defaultdict(
lambda: defaultdict(list)
)
"""
AWS assessment findings that hold a list of AwsInspectorFinding or AwsGuardDutyFinding.
The outer dictionary's keys are tuples:
- The first element is the assessment provider (str).
- The second element is the region of the finding (str).
- The third element is the class name (str).
The values are dictionaries where:
- The keys are class IDs (str).
- The values are lists of Finding instances.
"""
if last_run_started_at:
now = utc()

Expand Down Expand Up @@ -499,6 +545,9 @@ def __init__(
def suppress(self, message: str) -> SuppressWithFeedback:
return SuppressWithFeedback(message, self.core_feedback, log)

def add_finding(self, provider: str, class_name: str, region: str, class_id: str, finding: Finding) -> None:
self._assessment_findings[(provider, region, class_name)][class_id].append(finding)

def submit_work(self, service: str, fn: Callable[..., T], *args: Any, **kwargs: Any) -> Future[T]:
"""
Use this method for work that can be done in parallel.
Expand Down
2 changes: 2 additions & 0 deletions plugins/aws/fix_plugin_aws/resource/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1523,6 +1523,8 @@ def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
if iam_profile := self.instance_iam_instance_profile:
builder.add_edge(self, reverse=True, clazz=AwsIamInstanceProfile, arn=iam_profile.arn)

self.set_findings(builder)

def delete_resource(self, client: AwsClient, graph: Graph) -> bool:
if self.instance_status == InstanceStatus.TERMINATED:
self.log("Instance is already terminated")
Expand Down
3 changes: 3 additions & 0 deletions plugins/aws/fix_plugin_aws/resource/ecr.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def called_collect_apis(cls) -> List[AwsApiSpec]:
AwsApiSpec(service_name, "get-repository-policy"),
]

def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
self.set_findings(builder)


# @define(eq=False, slots=False)
# class AwsEcrImageIdentifier:
Expand Down
89 changes: 54 additions & 35 deletions plugins/aws/fix_plugin_aws/resource/inspector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from datetime import datetime
from typing import ClassVar, Dict, Optional, List, Type, Any
from typing import ClassVar, Dict, Optional, List, Tuple, Type, Any

from attrs import define, field
from boto3.exceptions import Boto3Error
Expand All @@ -9,9 +9,9 @@
from fix_plugin_aws.resource.ec2 import AwsEc2Instance
from fix_plugin_aws.resource.ecr import AwsEcrRepository
from fix_plugin_aws.resource.lambda_ import AwsLambdaFunction
from fixlib.baseresources import ModelReference, PhantomBaseResource
from fixlib.json_bender import Bender, S, ForallBend, Bend, F
from fixlib.baseresources import ModelReference, PhantomBaseResource, Severity, Finding
from fixlib.types import Json
from fixlib.json_bender import Bender, S, ForallBend, Bend, F

log = logging.getLogger("fix.plugins.aws")
service_name = "inspector2"
Expand Down Expand Up @@ -403,20 +403,67 @@ class AwsInspectorFinding(AwsResource, PhantomBaseResource):
type: Optional[str] = field(default=None, metadata={"description": "The type of the finding. The type value determines the valid values for resource in your request. For more information, see Finding types in the Amazon Inspector user guide."}) # fmt: skip
updated_at: Optional[datetime] = field(default=None, metadata={"description": "The date and time the finding was last updated at."}) # fmt: skip

def parse_finding(self, source: Json) -> Finding:
severity_mapping = {
"INFORMATIONAL": Severity.info,
"LOW": Severity.low,
"MEDIUM": Severity.medium,
"HIGH": Severity.high,
"CRITICAL": Severity.critical,
}
finding_title = self.safe_name
if not self.finding_severity:
finding_severity = Severity.unknown
else:
finding_severity = severity_mapping.get(self.finding_severity, Severity.unknown)
description = self.description
remidiation = ""
if self.remediation and self.remediation.recommendation:
remidiation = self.remediation.recommendation.text or ""
updated_at = self.updated_at
details = source.get("packageVulnerabilityDetails", {}) | source.get("codeVulnerabilityDetails", {})
return Finding(finding_title, finding_severity, description, remidiation, updated_at, details)

@classmethod
def collect_resources(cls: Type[AwsResource], builder: GraphBuilder) -> None:
def collect_resources(cls, builder: GraphBuilder) -> None:
def check_type_and_adjust_id(
class_type: Optional[str], class_id: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
if not class_id or not class_type:
return None, None
match class_type:
case "AWS_LAMBDA_FUNCTION":
# remove lambda's version from arn
lambda_arn = class_id.rsplit(":", 1)[0]
return AwsLambdaFunction.__name__, lambda_arn
case "AWS_EC2_INSTANCE":
return AwsEc2Instance.__name__, class_id
case "AWS_ECR_REPOSITORY":
return AwsEcrRepository.__name__, class_id
case _:
return None, None

# Default behavior: in case the class has an ApiSpec, call the api and call collect.
log.debug(f"Collecting {cls.__name__} in region {builder.region.name}")
if spec := cls.api_spec:
try:
items = builder.client.list(
for item in builder.client.list(
aws_service=spec.service,
action=spec.api_action,
result_name=spec.result_property,
expected_errors=spec.expected_errors,
filterCriteria={"awsAccountId": [{"comparison": "EQUALS", "value": f"{builder.account.id}"}]},
)
cls.collect(items, builder)
):
if finding := AwsInspectorFinding.from_api(item, builder):
if finding_resources := finding.finding_resources:
for fr in finding_resources:
class_name, class_id = check_type_and_adjust_id(fr.type, fr.id)
if class_name and class_id:
adjusted_finding = finding.parse_finding(item)
builder.add_finding(
"inspector", class_name, fr.region or "global", class_id, adjusted_finding
)

except Boto3Error as e:
msg = f"Error while collecting {cls.__name__} in region {builder.region.name}: {e}"
builder.core_feedback.error(msg, log)
Expand All @@ -426,33 +473,5 @@ def collect_resources(cls: Type[AwsResource], builder: GraphBuilder) -> None:
builder.core_feedback.info(msg, log)
raise

def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
if finding_resources := self.finding_resources:
for finding_resource in finding_resources:
if rid := finding_resource.id:
match finding_resource.type:
case "AWS_LAMBDA_FUNCTION":
# remove lambda's version from arn to connect by arn
lambda_arn = rid.rsplit(":", 1)[0]
builder.add_edge(
self,
clazz=AwsLambdaFunction,
arn=lambda_arn,
)
case "AWS_EC2_INSTANCE":
builder.add_edge(
self,
clazz=AwsEc2Instance,
id=rid,
)
case "AWS_ECR_REPOSITORY":
builder.add_edge(
self,
clazz=AwsEcrRepository,
id=rid,
)
case _:
continue


resources: List[Type[AwsResource]] = [AwsInspectorFinding]
1 change: 1 addition & 0 deletions plugins/aws/fix_plugin_aws/resource/lambda_.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
clazz=AwsKmsKey,
arn=self.function_kms_key_arn,
)
self.set_findings(builder, "arn")

def update_resource_tag(self, client: AwsClient, key: str, value: str) -> bool:
client.call(
Expand Down

0 comments on commit e50938e

Please sign in to comment.