Skip to content

Commit

Permalink
Refine fetch_for_packages and related unit tests #102
Browse files Browse the repository at this point in the history
Signed-off-by: tdruez <[email protected]>
  • Loading branch information
tdruez committed Dec 27, 2024
1 parent 5846561 commit a1c6e8e
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 34 deletions.
68 changes: 41 additions & 27 deletions vulnerabilities/fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def fetch_from_vulnerablecode(dataspace, batch_size, update, timeout, log_func=N
if log_func:
log_func(f"{package_count} Packages in the queue.")

created = fetch_for_packages(
results = fetch_for_packages(
queryset=package_qs,
dataspace=dataspace,
batch_size=batch_size,
Expand All @@ -48,7 +48,8 @@ def fetch_from_vulnerablecode(dataspace, batch_size, update, timeout, log_func=N
)
run_time = timer() - start_time
if log_func:
log_func(f"+ Created {intcomma(created)} vulnerabilities")
log_func(f"+ Created {intcomma(results['created'])} vulnerabilities")
log_func(f"+ Updated {intcomma(results['updated'])} vulnerabilities")
log_func(f"Completed in {humanize_time(run_time)}")

dataspace.vulnerabilities_updated_at = timezone.now()
Expand All @@ -65,9 +66,7 @@ def fetch_for_packages(
return

vulnerablecode = VulnerableCode(dataspace)
vulnerability_qs = Vulnerability.objects.scope(dataspace)
created_vulnerabilities = 0
updated_vulnerabilities = 0
results = {"created": 0, "updated": 0}

for index, batch in enumerate(chunked_queryset(queryset, batch_size), start=1):
if log_func:
Expand All @@ -76,6 +75,7 @@ def fetch_for_packages(
progress_count = object_count
log_func(f"Progress: {intcomma(progress_count)}/{intcomma(object_count)}")

batch_affected_packages = []
vc_entries = vulnerablecode.get_vulnerable_purls(batch, purl_only=False, timeout=timeout)
for vc_entry in vc_entries:
affected_by_vulnerabilities = vc_entry.get("affected_by_vulnerabilities")
Expand All @@ -91,30 +91,44 @@ def fetch_for_packages(
if not affected_packages:
raise CommandError("Could not find package!")

# Store all packages of that batch to then trigger the update_weighted_risk_score
batch_affected_packages.extend(affected_packages)

for vulnerability_data in affected_by_vulnerabilities:
vulnerability_id = vulnerability_data["vulnerability_id"]
vulnerability = vulnerability_qs.get_or_none(vulnerability_id=vulnerability_id)
if not vulnerability:
vulnerability = Vulnerability.create_from_data(
dataspace=dataspace,
data=vulnerability_data,
)
created_vulnerabilities += 1
elif update:
updated_fields = vulnerability.update_from_data(
user=None,
data=vulnerability_data,
override=True,
)
if updated_fields:
updated_vulnerabilities += 1

vulnerability.add_affected_packages(affected_packages)
create_or_update_vulnerability(
vulnerability_data, dataspace, affected_packages, update, results
)

if package_risk_score := vc_entry.get("risk_score"):
affected_packages.update(risk_score=package_risk_score)
# TODO: This should be improved.
product_package_qs = ProductPackage.objects.filter(package__in=affected_packages)
product_package_qs.update_weighted_risk_score()

return created_vulnerabilities
product_package_qs = ProductPackage.objects.filter(package__in=batch_affected_packages)
product_package_qs.update_weighted_risk_score()

return results


def create_or_update_vulnerability(
vulnerability_data, dataspace, affected_packages, update, results
):
vulnerability_id = vulnerability_data["vulnerability_id"]
vulnerability_qs = Vulnerability.objects.scope(dataspace)
vulnerability = vulnerability_qs.get_or_none(vulnerability_id=vulnerability_id)

if not vulnerability:
vulnerability = Vulnerability.create_from_data(
dataspace=dataspace,
data=vulnerability_data,
)
results["created"] += 1
elif update:
updated_fields = vulnerability.update_from_data(
user=None,
data=vulnerability_data,
override=True,
)
if updated_fields:
results["updated"] += 1

vulnerability.add_affected_packages(affected_packages)
return vulnerability
26 changes: 19 additions & 7 deletions vulnerabilities/tests/test_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from component_catalog.models import Package
from component_catalog.tests import make_package
from dje.models import Dataspace
from product_portfolio.tests import make_product
from product_portfolio.tests import make_product_item_purpose
from product_portfolio.tests import make_product_package
from vulnerabilities.fetch import fetch_for_packages
from vulnerabilities.fetch import fetch_from_vulnerablecode

Expand Down Expand Up @@ -49,33 +52,42 @@ def test_vulnerabilities_fetch_from_vulnerablecode(
def test_vulnerabilities_fetch_for_packages(self, mock_bulk_search_by_purl):
buffer = io.StringIO()
package1 = make_package(self.dataspace, package_url="pkg:pypi/[email protected]")
make_package(self.dataspace, package_url="pkg:pypi/[email protected]")
product1 = make_product(self.dataspace)
pp1 = make_product_package(product1, package=package1)
queryset = Package.objects.scope(self.dataspace)
response_file = self.data / "vulnerabilities" / "idna_3.6_response.json"
response_json = json.loads(response_file.read_text())
mock_bulk_search_by_purl.return_value = response_json["results"]

created_vulnerabilities = fetch_for_packages(
results = fetch_for_packages(
queryset, self.dataspace, batch_size=1, update=True, log_func=buffer.write
)
self.assertEqual(1, created_vulnerabilities)
self.assertEqual("Progress: 1/2Progress: 2/2", buffer.getvalue())
self.assertEqual(results, {"created": 1, "updated": 0})

self.assertEqual("Progress: 1/1", buffer.getvalue())
self.assertEqual(1, package1.affected_by_vulnerabilities.count())
vulnerability = package1.affected_by_vulnerabilities.get()
self.assertEqual("VCID-j3au-usaz-aaag", vulnerability.vulnerability_id)
self.assertEqual(Decimal("2.0"), vulnerability.exploitability)
self.assertEqual(Decimal("4.2"), vulnerability.weighted_severity)
self.assertEqual(Decimal("8.4"), vulnerability.risk_score)

package1.refresh_from_db()
pp1.refresh_from_db()
self.assertEqual(Decimal("8.4"), package1.risk_score)
self.assertEqual(Decimal("8.4"), pp1.weighted_risk_score)

# Update
purpose1 = make_product_item_purpose(self.dataspace, exposure_factor=0.5)
pp1.raw_update(purpose=purpose1)
response_json["results"][0]["affected_by_vulnerabilities"][0]["risk_score"] = 10.0
mock_bulk_search_by_purl.return_value = response_json["results"]
created_vulnerabilities = fetch_for_packages(
results = fetch_for_packages(
queryset, self.dataspace, batch_size=1, update=True, log_func=buffer.write
)
self.assertEqual(0, created_vulnerabilities)
self.assertEqual(results, {"created": 0, "updated": 1})
vulnerability = package1.affected_by_vulnerabilities.get()
self.assertEqual(Decimal("10.0"), vulnerability.risk_score)
package1.refresh_from_db()
pp1.refresh_from_db()
self.assertEqual(Decimal("8.4"), package1.risk_score)
self.assertEqual(Decimal("4.2"), pp1.weighted_risk_score)

0 comments on commit a1c6e8e

Please sign in to comment.