Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherSpelt committed Nov 20, 2024
1 parent de93a32 commit 071bd9f
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 16 deletions.
35 changes: 25 additions & 10 deletions amt/services/task_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,20 @@


def is_requirement_applicable(requirement: Requirement, ai_act_profile: AiActProfile) -> bool:
"""
Determine if a specific requirement is applicable to a given AI Act profile.
Evaluation Criteria:
- Always applicable requirements automatically return True.
- For the 'role' attribute, handles compound values like
"gebruiksverantwoordelijke + aanbieder".
- A requirement is applicable if all specified attributes match or have no
specific restrictions.
"""
if requirement.always_applicable == 1:
return True

# We can assume the ai_act_profile field always is of length 1.
# We can assume the ai_act_profile field always contains exactly 1 element.
requirement_profile = requirement.ai_act_profile[0]
comparison_attrs = ["type", "risk_category", "type", "open_source", "systemic_risk", "transparency_obligations"]

Expand All @@ -23,11 +33,7 @@ def is_requirement_applicable(requirement: Requirement, ai_act_profile: AiActPro
if not requirement_attr_values:
continue

# In the system card the field role has values "gebruiksverantwoordelijke", "aanbieder" and
# "gebruiksverantwoordelijke + aanbieder", so we need to split the latter into a list of two
# strings.
raw_input_value = getattr(ai_act_profile, attr)
input_value = {raw_input_value} if attr != "role" else {s.strip() for s in raw_input_value.split("+")}
input_value = _parse_attribute_values(attr, getattr(ai_act_profile, attr))

if not input_value & {attr_value.value for attr_value in requirement_attr_values}:
return False
Expand All @@ -43,15 +49,24 @@ async def get_requirements_and_measures(
all_requirements = await requirements_service.fetch_requirements()

applicable_requirements: list[RequirementTask] = []
applicable_measures: dict[str, MeasureTask] = {}
applicable_measures: list[MeasureTask] = []
measure_urns: set[str] = set()

for requirement in all_requirements:
if is_requirement_applicable(requirement, ai_act_profile):
applicable_requirements.append(RequirementTask(urn=requirement.urn, version=requirement.schema_version))

for measure_urn in requirement.links:
if measure_urn not in applicable_measures:
if measure_urn not in measure_urns:
measure = await measure_service.fetch_measures(measure_urn)
applicable_measures[measure_urn] = MeasureTask(urn=measure_urn, version=measure[0].schema_version)
applicable_measures.append(MeasureTask(urn=measure_urn, version=measure[0].schema_version))
measure_urns.add(measure_urn)

return applicable_requirements, [*applicable_measures.values()]
return applicable_requirements, applicable_measures


def _parse_attribute_values(attr: str, raw_input_value: str) -> set[str]:
"""
Helper function needed in `is_requirement_applicable`, handling special case for `role`.
"""
return {raw_input_value} if attr != "role" else {s.strip() for s in raw_input_value.split("+")}
160 changes: 154 additions & 6 deletions tests/services/test_task_registry_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from amt.schema.ai_act_profile import AiActProfile
from amt.schema.measure import Measure
from amt.schema.requirement import (
OpenSourceEnum,
Requirement,
Expand All @@ -10,7 +11,8 @@
TransparencyObligation,
TypeEnum,
)
from amt.services.task_registry import is_requirement_applicable
from amt.services.task_registry import get_requirements_and_measures, is_requirement_applicable
from pytest_mock import MockerFixture


@pytest.mark.asyncio
Expand Down Expand Up @@ -65,11 +67,6 @@ async def test_is_requirement_applicable_empty_profile():

@pytest.mark.asyncio
async def test_is_requirement_applicable_with_profile():
# setup
# mock_requirements_service = mocker.AsyncMock()
# mocker.patch("amt.services.requirements.RequirementsService", return_value=mock_requirements_service)
# mock_requirements_service.fetch_requirements = mocker.AsyncMock(return_value=sample_data)

# given
ai_act_profile = AiActProfile(
type="AI-systeem",
Expand Down Expand Up @@ -230,3 +227,154 @@ async def test_is_requirement_applicable_with_matching_profile():

# then
assert result is True


@pytest.mark.asyncio
async def test_get_requirements_and_measures_no_applicable_requirements(mocker: MockerFixture):
mock_requirements_service = mocker.AsyncMock()
mock_measures_service = mocker.AsyncMock()

mocker.patch("amt.services.task_registry.create_requirements_service", return_value=mock_requirements_service)
mocker.patch("amt.services.task_registry.create_measures_service", return_value=mock_measures_service)
mocker.patch("amt.services.task_registry.is_requirement_applicable", return_value=False)

mock_requirements_service.fetch_requirements.return_value = [
Requirement(
name="requirement",
urn="urn:requirement:1",
description="description",
schema_version="1.1.0",
links=[],
always_applicable=0,
ai_act_profile=[
RequirementAiActProfile(
type=[],
risk_category=[],
role=[],
open_source=[],
systemic_risk=[],
transparency_obligations=[],
),
],
),
]

ai_act_profile = AiActProfile()

requirements, measures = await get_requirements_and_measures(ai_act_profile)

assert requirements == []
assert measures == []
mock_requirements_service.fetch_requirements.assert_awaited_once()
mock_measures_service.fetch_measures.assert_not_called()


@pytest.mark.asyncio
async def test_get_requirements_and_measures_single_requirement_with_measures(mocker: MockerFixture):
mock_requirements_service = mocker.AsyncMock()
mock_measures_service = mocker.AsyncMock()

mocker.patch("amt.services.task_registry.create_requirements_service", return_value=mock_requirements_service)
mocker.patch("amt.services.task_registry.create_measures_service", return_value=mock_measures_service)
mocker.patch("amt.services.task_registry.is_requirement_applicable", return_value=True)

mock_requirements_service.fetch_requirements.return_value = [
Requirement(
name="requirement",
urn="urn:requirement:1",
description="description",
schema_version="1.1.0",
links=["urn:measure:1", "urn:measure:2"],
always_applicable=0,
ai_act_profile=[
RequirementAiActProfile(
type=[],
risk_category=[],
role=[],
open_source=[],
systemic_risk=[],
transparency_obligations=[],
),
],
),
]

mock_measures_service.fetch_measures.side_effect = [
[Measure(name="name 1", urn="measure:urn:1", description="", url="", schema_version="1.1.0")],
[Measure(name="name 2", urn="measure:urn:2", description="", url="", schema_version="1.1.0")],
]

ai_act_profile = AiActProfile()

requirements, measures = await get_requirements_and_measures(ai_act_profile)

assert len(requirements) == 1
assert requirements[0].urn == "urn:requirement:1"

assert len(measures) == 2
assert {measure.urn for measure in measures} == {"urn:measure:1", "urn:measure:2"}


@pytest.mark.asyncio
async def test_get_requirements_and_measures_duplicate_measure_urns(mocker: MockerFixture):
mock_requirements_service = mocker.AsyncMock()
mock_measures_service = mocker.AsyncMock()

mocker.patch("amt.services.task_registry.create_requirements_service", return_value=mock_requirements_service)
mocker.patch("amt.services.task_registry.create_measures_service", return_value=mock_measures_service)
mocker.patch("amt.services.task_registry.is_requirement_applicable", return_value=True)

mock_requirements_service.fetch_requirements.return_value = [
Requirement(
name="requirement",
urn="urn:requirement:1",
description="description",
schema_version="1.1.0",
links=["urn:measure:1", "urn:measure:2"],
always_applicable=0,
ai_act_profile=[
RequirementAiActProfile(
type=[],
risk_category=[],
role=[],
open_source=[],
systemic_risk=[],
transparency_obligations=[],
),
],
),
Requirement(
name="requirement",
urn="urn:requirement:2",
description="description",
schema_version="1.1.0",
links=["urn:measure:1", "urn:measure:3"],
always_applicable=0,
ai_act_profile=[
RequirementAiActProfile(
type=[],
risk_category=[],
role=[],
open_source=[],
systemic_risk=[],
transparency_obligations=[],
),
],
),
]

mock_measures_service.fetch_measures.side_effect = [
[Measure(name="name 1", urn="measure:urn:1", description="", url="", schema_version="1.1.0")],
[Measure(name="name 2", urn="measure:urn:2", description="", url="", schema_version="1.1.0")],
[Measure(name="name 3", urn="measure:urn:3", description="", url="", schema_version="1.1.0")],
]

ai_act_profile = AiActProfile()

requirements, measures = await get_requirements_and_measures(ai_act_profile)

assert len(requirements) == 2
assert {requirement.urn for requirement in requirements} == {"urn:requirement:1", "urn:requirement:2"}

assert len(measures) == 3
assert {measure.urn for measure in measures} == {"urn:measure:1", "urn:measure:2", "urn:measure:3"}

0 comments on commit 071bd9f

Please sign in to comment.