Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[azure][feat] Update security assessments collection #2266

Merged
merged 15 commits into from
Nov 4, 2024
Merged
4 changes: 4 additions & 0 deletions plugins/azure/fix_plugin_azure/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ def get_last_run() -> Optional[datetime]:
self.collect_with(builder, locations)
queue.wait_for_submitted_work()

# call all registered after collect hooks
for after_collect in builder.after_collect_actions:
after_collect()

# connect nodes
log.info(f"[Azure:{self.account.safe_name}] Connect resources and create edges.")
for node, data in list(self.graph.nodes(data=True)):
Expand Down
3 changes: 3 additions & 0 deletions plugins/azure/fix_plugin_azure/resource/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,7 @@ def __init__(
location: Optional[BaseRegion] = None,
graph_access_lock: Optional[RWLock] = None,
last_run_started_at: Optional[datetime] = None,
after_collect_actions: Optional[List[Callable[[], Any]]] = None,
) -> None:
self.graph = graph
self.cloud = cloud
Expand All @@ -796,6 +797,7 @@ def __init__(
self.config = config
self.last_run_started_at = last_run_started_at
self.created_at = utc()
self.after_collect_actions = after_collect_actions if after_collect_actions is not None else []

if last_run_started_at:
now = utc()
Expand Down Expand Up @@ -1002,6 +1004,7 @@ def with_location(self, location: BaseRegion) -> GraphBuilder:
graph_access_lock=self.graph_access_lock,
config=self.config,
last_run_started_at=self.last_run_started_at,
after_collect_actions=self.after_collect_actions,
)


Expand Down
2 changes: 2 additions & 0 deletions plugins/azure/fix_plugin_azure/resource/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -5274,6 +5274,7 @@ class AzureNetworkVirtualNetwork(MicrosoftResource, BaseNetwork):
"virtual_network_peerings": S("properties", "virtualNetworkPeerings")
>> ForallBend(AzureVirtualNetworkPeering.mapping),
"location": S("location"),
"type": S("type"),
}
address_space: Optional[AzureAddressSpace] = field(default=None, metadata={'description': 'AddressSpace contains an array of IP address ranges that can be used by subnets of the virtual network.'}) # fmt: skip
bgp_communities: Optional[AzureVirtualNetworkBgpCommunities] = field(default=None, metadata={'description': 'Bgp Communities sent over ExpressRoute with each route corresponding to a prefix in this VNET.'}) # fmt: skip
Expand All @@ -5290,6 +5291,7 @@ class AzureNetworkVirtualNetwork(MicrosoftResource, BaseNetwork):
_subnet_ids: Optional[List[str]] = field(default=None, metadata={'description': 'A list of subnets in a Virtual Network.'}) # fmt: skip
virtual_network_peerings: Optional[List[AzureVirtualNetworkPeering]] = field(default=None, metadata={'description': 'A list of peerings in a Virtual Network.'}) # fmt: skip
location: Optional[str] = field(default=None, metadata={"description": "Resource location."})
type: Optional[str] = field(default=None, metadata={"description": "Type of the resource."})

def post_process(self, graph_builder: GraphBuilder, source: Json) -> None:
def collect_subnets() -> None:
Expand Down
68 changes: 51 additions & 17 deletions plugins/azure/fix_plugin_azure/resource/security.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from datetime import datetime
from functools import partial
import logging
from typing import ClassVar, Dict, Optional, List, Any, Type

from attr import define, field

from fix_plugin_azure.azure_client import AzureResourceSpec
from fix_plugin_azure.resource.base import MicrosoftResource, AzureSystemData, GraphBuilder
from fixlib.baseresources import ModelReference, PhantomBaseResource
from fixlib.baseresources import SEVERITY_MAPPING, Finding, PhantomBaseResource, Severity
from fixlib.json_bender import Bender, S, Bend, ForallBend, F
from fixlib.types import Json

service_name = "security"
log = logging.getLogger("fix.plugins.azure")


@define(eq=False, slots=False)
Expand Down Expand Up @@ -94,14 +97,7 @@ class AzureAssessmentStatus:
@define(eq=False, slots=False)
class AzureSecurityAssessment(MicrosoftResource, PhantomBaseResource):
kind: ClassVar[str] = "azure_security_assessment"
_kind_display: ClassVar[str] = "Azure Security Assessment"
_kind_service: ClassVar[Optional[str]] = service_name
_kind_description: ClassVar[str] = "Azure Security Assessment is a service that evaluates Azure resources for potential security vulnerabilities and compliance issues. It scans configurations, identifies risks, and provides recommendations to improve security posture. The assessment covers various aspects including network security, data protection, and access control, offering insights to help organizations strengthen their Azure environment's security." # fmt: skip
_docs_url: ClassVar[str] = (
"https://learn.microsoft.com/en-us/azure/defender-for-cloud/secure-score-security-controls"
)
_metadata: ClassVar[Dict[str, Any]] = {"icon": "log", "group": "management"}
_reference_kinds: ClassVar[ModelReference] = {"successors": {"default": [MicrosoftResource.kind]}}
_model_export: ClassVar[bool] = False
api_spec: ClassVar[AzureResourceSpec] = AzureResourceSpec(
service=service_name,
version="2021-06-01",
Expand All @@ -118,24 +114,62 @@ class AzureSecurityAssessment(MicrosoftResource, PhantomBaseResource):
"assessment_status": S("properties", "status") >> Bend(AzureAssessmentStatus.mapping),
"resource_source": S("properties", "resourceDetails", "Source"),
"resource_id": S("properties", "resourceDetails", "ResourceId"),
"resource_type": S("properties", "resourceDetails", "ResourceType"),
"additional_date": S("properties", "additionalData"),
"azurePortalUri": S("properties", "links", "azurePortalUri"),
}
assessment_status: Optional[AzureAssessmentStatus] = field(default=None, metadata={'description': 'The result of the assessment'}) # fmt: skip
resource_source: Optional[str] = field(default=None, metadata={'description': 'The source of the resource that the assessment is performed on'}) # fmt: skip
resource_id: Optional[str] = field(default=None, metadata={'description': 'The id of the resource that the assessment is performed on'}) # fmt: skip
resource_type: Optional[str] = field(default=None, metadata={'description': 'The resource type'}) # fmt: skip
additional_data: Optional[Dict[str, Any]] = field(default=None, metadata={'description': 'Additional data for the assessment'}) # fmt: skip
subscription_issue: Optional[bool] = field(default=False, metadata={'description': 'Indicates if the assessment is a subscription issue'}) # fmt: skip

def post_process(self, builder: GraphBuilder, source: Json) -> None:
# mark as subscription issue, when the resource id is the same as the account id
if (rid := self.resource_id) and (sub := self._account):
self.subscription_issue = rid.split("/")[-1] == sub.id
def parse_finding(self, source: Json) -> Finding:
finding_title = self.safe_name
properties = source.get("properties") or {}
if metadata := properties.get("metadata", {}):
finding_severity = SEVERITY_MAPPING.get(metadata.get("severity", "").upper(), Severity.medium)
else:
finding_severity = Severity.medium
if status := self.assessment_status:
description = status.description
updated_at = status.status_change_date
else:
description = None
updated_at = None
details = self.additional_data or {} | properties.get("metadata", {})
return Finding(finding_title, finding_severity, description, None, updated_at, details)

@classmethod
def collect_resources(cls, builder: GraphBuilder, **kwargs: Any) -> List["AzureSecurityAssessment"]:
def add_finding(provider: str, finding: Finding, resource_id: str) -> None:
if resource := builder.node(clazz=MicrosoftResource, id=resource_id):
resource.add_finding(provider, finding)

# Default behavior: in case the class has an ApiSpec, call the api and call collect.
log.debug(f"[Azure:{builder.account.id}] Collecting {cls.__name__} with ({kwargs})")
if spec := cls.api_spec:
try:
for item in builder.client.list(spec, **kwargs):
if finding := AzureSecurityAssessment.from_api(item, builder):
if finding.resource_source == "Azure" and (rid := finding.resource_id):
if finding.resource_type == "subscription":
rid = "/subscriptions/" + rid
builder.after_collect_actions.append(
partial(
add_finding,
"azure_security_assessment",
finding.parse_finding(item),
rid,
)
)
except Exception as e:
msg = f"Error while collecting {cls.__name__} with service {spec.service} and location: {builder.location}: {e}"
builder.core_feedback.info(msg, log)
raise

def connect_in_graph(self, builder: GraphBuilder, source: Json) -> None:
# this will not connect subscription issues.
if self.resource_source == "Azure" and (rid := self.resource_id):
builder.add_edge(self, clazz=MicrosoftResource, id=rid)
return []


@define(eq=False, slots=False)
Expand Down
6 changes: 4 additions & 2 deletions plugins/azure/test/collector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def test_collect(
config, Cloud(id="azure"), azure_subscription, credentials, core_feedback, filter_unused_resources=False
)
subscription_collector.collect()
assert len(subscription_collector.graph.nodes) == 889
assert len(subscription_collector.graph.edges) == 1284
assert len(subscription_collector.graph.nodes) == 887
assert len(subscription_collector.graph.edges) == 1282

graph_collector = MicrosoftGraphOrganizationCollector(
config, Cloud(id="azure"), MicrosoftGraphOrganization(id="test", name="test"), credentials, core_feedback
Expand Down Expand Up @@ -113,6 +113,8 @@ def all_base_classes(cls: Type[Any]) -> Set[Type[Any]]:
expected_declared_properties = ["kind", "_kind_display"]
expected_props_in_hierarchy = ["_kind_service", "_metadata"]
for rc in all_resources:
if not rc._model_export:
continue
for prop in expected_declared_properties:
assert prop in rc.__dict__, f"{rc.__name__} missing {prop}"
with_bases = (all_base_classes(rc) | {rc}) - {MicrosoftResource, BaseResource}
Expand Down
6 changes: 3 additions & 3 deletions plugins/azure/test/files/security/assessments.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
"type": "Microsoft.Security/assessments",
"properties": {
"resourceDetails": {
"source": "Azure",
"id": "/subscriptions/20ff7fc3-e762-44dd-bd96-b71116dcdc23/resourceGroups/myRg/providers/Microsoft.Compute/virtualMachineScaleSets/vmss1"
"Source": "Azure",
"ResourceId": "/subscriptions/{subscription-id}/resourceGroups/{resourceGroupName}/providers/Microsoft.Compute/virtualMachineScaleSets/{virtualMachineScaleSetName}"
},
"displayName": "Install endpoint protection solution on virtual machine scale sets",
"status": {
Expand Down Expand Up @@ -40,4 +40,4 @@
}
}
]
}
}
34 changes: 28 additions & 6 deletions plugins/azure/test/security_test.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,39 @@
from conftest import roundtrip_check
from fix_plugin_azure.resource.base import GraphBuilder

from fix_plugin_azure.azure_client import MicrosoftClient
from fix_plugin_azure.collector import AzureSubscriptionCollector
from fix_plugin_azure.config import AzureConfig, AzureCredentials
from fix_plugin_azure.resource.base import AzureSubscription, GraphBuilder
from fix_plugin_azure.resource.compute import AzureComputeVirtualMachineScaleSet
from fix_plugin_azure.resource.security import (
AzureSecurityAssessment,
AzureSecurityPricing,
AzureSecurityServerVulnerabilityAssessmentsSetting,
AzureSecuritySetting,
AzureSecurityAutoProvisioningSetting,
)


def test_security_assessment(builder: GraphBuilder) -> None:
collected = roundtrip_check(AzureSecurityAssessment, builder)
assert len(collected) == 2
from fixlib.baseresources import Cloud, Severity
from fixlib.core.actions import CoreFeedback


def test_security_assessment(
config: AzureConfig,
azure_subscription: AzureSubscription,
credentials: AzureCredentials,
core_feedback: CoreFeedback,
azure_client: MicrosoftClient,
) -> None:
subscription_collector = AzureSubscriptionCollector(
config, Cloud(id="azure"), azure_subscription, credentials, core_feedback, filter_unused_resources=False
)
subscription_collector.collect()
instances = list(subscription_collector.graph.search("kind", AzureComputeVirtualMachineScaleSet.kind))
assert instances[0]._assessments[0].provider == "azure_security_assessment"
assert (
instances[0]._assessments[0].findings[0].title
== "Install endpoint protection solution on virtual machine scale sets"
)
assert instances[0]._assessments[0].findings[0].severity == Severity.medium


def test_security_pricing(builder: GraphBuilder) -> None:
Expand Down
Loading