Skip to content

Commit

Permalink
Merge branch 'main' into feat_switch_to_ruff_linter_and_formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
benyissa authored Jan 23, 2024
2 parents c817b9c + dcea6dc commit 0244762
Show file tree
Hide file tree
Showing 12 changed files with 580 additions and 79 deletions.
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ license_files = LICENSE
long_description = file: README.md
long_description_content_type = text/markdown
url = https://github.com/Ostorlab/ostorlab
version=0.18.15
version=0.18.18
project_urls =
Documentation = https://docs.ostorlab.co/
Source = https://github.com/Ostorlab/ostorlab
Expand All @@ -39,7 +39,7 @@ package_dir =
=src

# Require a min/specific Python version (comma-separated conditions)
# python_requires = >=3.8
python_requires = >=3.9

# Add here dependencies of your project (line-separated), e.g. requests>=2.2,<3.0.
# Version specifiers like >=2.2,<3.0 avoid problems due to API changes in
Expand Down
9 changes: 8 additions & 1 deletion src/ostorlab/agent/kb/kb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import pathlib

from typing import Dict
from typing import Dict, Optional, Union

KB_FOLDER = "KB"
META_JSON = "meta.json"
Expand All @@ -28,6 +28,7 @@ class Entry:
targeted_by_ransomware: bool = False
targeted_by_nation_state: bool = False
cvss_v3_vector: str = ""
category_groups: Optional[list[dict[str, Union[str, list[str]]]]] = None


class MetaKB(type):
Expand All @@ -47,6 +48,11 @@ def __getattr__(cls, item: str) -> Entry:
encoding="utf-8"
) as r:
meta = json.loads(f.read())
categories = meta.get("categories", {})
category_groups = [
{"key": k, "categories": v} for k, v in categories.items()
]

return Entry(
title=meta.get("title"),
risk_rating=meta.get("risk_rating"),
Expand All @@ -61,6 +67,7 @@ def __getattr__(cls, item: str) -> Entry:
targeted_by_ransomware=meta.get("targeted_by_ransomware", False),
targeted_by_nation_state=meta.get("targeted_by_nation_state", False),
cvss_v3_vector=meta.get("cvss_v3_vector", ""),
category_groups=category_groups,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ message VulnerabilityLocation {
}
repeated Metadata metadata = 100 ;
}
message CategoryGroup {
optional string key = 1;
repeated string categories = 2;
}

message Message{
optional string title = 1;
Expand All @@ -72,4 +76,5 @@ message Message{
optional string dna = 14;
optional string cvss_v3_vector = 15;
optional VulnerabilityLocation vulnerability_location = 16;
repeated CategoryGroup category_groups = 17;
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 11 additions & 4 deletions src/ostorlab/agent/mixins/agent_mq_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import tenacity

logger = logging.getLogger(__name__)
NUMBER_RETRIES = 3
WAIT_FIXED_TIME = 1


class AgentMQMixin:
Expand Down Expand Up @@ -168,15 +170,20 @@ async def async_mq_send_message(
await exchange.publish(routing_key=key, message=pika_message)

@tenacity.retry(
retry=tenacity.retry_if_exception_type(aio_pika.exceptions.ConnectionClosed),
retry=tenacity.retry_if_exception_type(
(aio_pika.exceptions.ConnectionClosed, ConnectionResetError)
),
stop=tenacity.stop_after_attempt(NUMBER_RETRIES),
wait=tenacity.wait_fixed(WAIT_FIXED_TIME),
reraise=True,
)
def mq_send_message(
self, key: str, message: bytes, message_priority: Optional[int] = None
) -> None:
"""the method sends the message to the selected key with the defined priority in async mode .
"""The method sends the message to the selected key with the defined priority in async mode .
Args:
keys: Selectors that the queue listens to.
message: Message to send .
key: Selector that the queue listens to.
message: Message to send.
message_priority: the priority to use for the message default is 0.
"""
logger.debug("sending %s to %s", message, key)
Expand Down
2 changes: 2 additions & 0 deletions src/ostorlab/agent/mixins/agent_report_vulnerability_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def report_vulnerability(
targeted_by_ransomware = entry.targeted_by_ransomware
targeted_by_nation_state = entry.targeted_by_nation_state
cvss_v3_vector = entry.cvss_v3_vector
category_groups = entry.category_groups

references = []
for key, value in entry.references.items():
Expand All @@ -121,6 +122,7 @@ def report_vulnerability(
"targeted_by_nation_state": targeted_by_nation_state,
"cvss_v3_vector": cvss_v3_vector,
"dna": dna,
"category_groups": category_groups,
}
if vulnerability_location is not None:
data["vulnerability_location"] = vulnerability_location.to_dict()
Expand Down
90 changes: 26 additions & 64 deletions src/ostorlab/scanner/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Defines call back to trigger a scan after receiving a startAgentScan messages in the NATS."""
import logging
import ipaddress
from typing import List, Any, Optional
from typing import Any, Optional

import docker

Expand All @@ -23,6 +23,7 @@
from ostorlab.assets import agent as agent_asset
from ostorlab.utils import scanner_state_reporter
from ostorlab.scanner import scanner_conf
from ostorlab.agent.message import proto_dict


logger = logging.getLogger(__name__)
Expand All @@ -44,97 +45,58 @@ def _install_agents(
logger.warning("agent %s not found on the store", agent.key)


def _prepare_ip_asset(ip_request) -> asset.Asset:
"""Return IP assets from a NATs received message."""
ip_network = ipaddress.ip_network(ip_request.host, strict=False)
def _prepare_ip_asset(ip_asset_value: dict[str, Any]) -> asset.Asset:
"""Return IP assets from ip_asset_value dict."""
host = ip_asset_value.get("host")
ip_network = ipaddress.ip_network(host, strict=False)
if ip_network.version == 4:
return ipv4.IPv4(
host=ip_network.network_address.exploded,
mask=ip_request.mask or str(ip_network.prefixlen),
mask=ip_asset_value.get("mask") or str(ip_network.prefixlen),
)
elif ip_network.version == 6:
return ipv6.IPv6(
host=ip_network.network_address.exploded,
mask=ip_request.mask or str(ip_network.prefixlen),
mask=ip_asset_value.get("mask") or str(ip_network.prefixlen),
)
else:
raise ValueError(f"Invalid Ip address {ip_request.host}")
raise ValueError(f"Invalid Ip address {host}")


def _extract_assets(request: Any) -> List[asset.Asset]:
def _extract_assets(request: Any) -> list[asset.Asset]:
"""Returns list of specific Ostorlab-injectable assets, from a message received from NATs."""
logger.debug("Extracting assets.")
assets = []
asset_type = request.WhichOneof("asset")
if asset_type in ("ip", "ip4v", "ipv6"):
ip_request = request.ip or request.ipv4 or request.ipv6
assets.append(_prepare_ip_asset(ip_request))

asset_value = proto_dict.protobuf_to_dict(getattr(request, asset_type))
if asset_type in ("ip", "ipv4", "ipv6"):
assets.append(_prepare_ip_asset(asset_value))
elif asset_type == "android_store":
assets.append(
android_store.AndroidStore(package_name=request.android_store.package_name)
)

assets.append(android_store.AndroidStore(**asset_value))
elif asset_type == "ios_store":
assets.append(ios_store.IOSStore(bundle_id=request.ios_store.bundle_id))

assets.append(ios_store.IOSStore(**asset_value))
elif asset_type == "ipa":
assets.append(
ios_ipa.IOSIpa(
content=request.ipa.content,
path=request.ipa.path,
content_url=request.ipa.content_url,
)
)
assets.append(ios_ipa.IOSIpa(**asset_value))
elif asset_type == "apk":
assets.append(
android_apk.AndroidApk(
content=request.apk.content,
path=request.apk.path,
content_url=request.apk.content_url,
)
)
assets.append(android_apk.AndroidApk(**asset_value))
elif asset_type == "aab":
assets.append(
android_aab.AndroidAab(
content=request.aab.content,
path=request.aab.path,
content_url=request.aab.content_url,
)
)
elif asset_type == "domain":
assets.append(domain_name.DomainName(name=request.domain_name.name))

assets.append(android_aab.AndroidAab(**asset_value))
elif asset_type == "domain_name":
assets.append(domain_name.DomainName(**asset_value))
elif asset_type == "agent":
assets.append(
agent_asset.Agent(
key=request.agent.key,
version=request.agent.version,
git_location=request.agent.git_location,
docker_location=request.agent.docker_location,
yaml_file_location=request.agent.yaml_file_location,
)
)

assets.append(agent_asset.Agent(**asset_value))
elif asset_type == "file":
assets.append(
file.File(
content=request.file.content,
path=request.file.path,
content_url=request.file.content_url,
)
)
assets.append(file.File(**asset_value))
elif asset_type == "network":
for ip in request.network.ips:
for ip in asset_value.get("ips"):
ip_asset = _prepare_ip_asset(ip)
assets.append(ip_asset)

elif asset_type == "links":
for link in request.links.links:
for link in asset_value.get("links"):
assets.append(
link_asset.Link(
url=link.url,
method=link.url or "GET",
url=link.get("url"),
method=link.get("method") or "GET",
)
)

Expand Down
19 changes: 19 additions & 0 deletions tests/agent/mixins/agent_mq_mixin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest import mock

import pytest
from requests_mock import mocker as req_mocker

from ostorlab.agent.mixins import agent_mq_mixin
from ostorlab.utils import strings
Expand Down Expand Up @@ -126,3 +127,21 @@ async def testClient_whenClientDisconnects_messageIsNotLost(mocker, mq_service):
# make sure the message is received and was not deleted
stub.assert_called_with(word)
assert stub.call_count == 1


def testMqSendMessage_onConnectionResetError_shouldRetriesAndReraise(
mocker: req_mocker.Mocker,
):
mock_send_message = mocker.patch.object(agent_mq_mixin.AgentMQMixin, "_get_channel")
mock_send_message.side_effect = ConnectionResetError
agent = agent_mq_mixin.AgentMQMixin(
name="test",
keys=["a.#"],
url="amqp://guest:guest@localhost:5672/",
topic="test_topic",
)

with pytest.raises(ConnectionResetError):
agent.mq_send_message(key="a.1.2", message=b"test message")

assert mock_send_message.call_count == 3
Loading

0 comments on commit 0244762

Please sign in to comment.