Skip to content

Commit

Permalink
Add functions to measure which are changable in the modal (#426)
Browse files Browse the repository at this point in the history
  • Loading branch information
laurensWe authored Dec 18, 2024
2 parents 36398f2 + 908c638 commit 12b2868
Show file tree
Hide file tree
Showing 13 changed files with 381 additions and 98 deletions.
33 changes: 32 additions & 1 deletion amt/api/forms/measure.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,47 @@
from collections.abc import Sequence
from gettext import NullTranslations

from amt.models import User
from amt.schema.webform import WebForm, WebFormField, WebFormFieldType, WebFormOption, WebFormTextCloneableField


async def get_measure_form(
id: str, current_values: dict[str, str | list[str] | list[tuple[str, str]]], translations: NullTranslations
id: str,
current_values: dict[str, str | list[str] | list[tuple[str, str]]],
members: Sequence[User],
translations: NullTranslations,
) -> WebForm:
_ = translations.gettext

measure_form: WebForm = WebForm(id="", post_url="")

member_option_list = [WebFormOption(value=member.name, display_value=member.name) for member in members]
member_option_list.append(WebFormOption(value="", display_value=""))
measure_form.fields = [
WebFormField(
type=WebFormFieldType.SELECT,
name="measure_responsible",
label=_("Responsible"),
options=member_option_list,
default_value=current_values.get("measure_responsible"),
group="1",
),
WebFormField(
type=WebFormFieldType.SELECT,
name="measure_reviewer",
label=_("Reviewer"),
options=member_option_list,
default_value=current_values.get("measure_reviewer"),
group="1",
),
WebFormField(
type=WebFormFieldType.SELECT,
name="measure_accountable",
label=_("Accountable"),
options=member_option_list,
default_value=current_values.get("measure_accountable"),
group="1",
),
WebFormField(
type=WebFormFieldType.SELECT,
name="measure_state",
Expand Down
116 changes: 107 additions & 9 deletions amt/api/routes/algorithm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import asyncio
import datetime
import logging
from collections import defaultdict
from collections.abc import Sequence
from typing import Annotated, Any, cast

import yaml
from fastapi import APIRouter, Depends, File, Form, Request, Response, UploadFile
from fastapi import APIRouter, Depends, File, Form, Query, Request, Response, UploadFile
from fastapi.responses import FileResponse, HTMLResponse
from pydantic import BaseModel
from ulid import ULID
Expand All @@ -19,14 +20,16 @@
resolve_base_navigation_items,
resolve_navigation_items,
)
from amt.api.routes.shared import get_filters_and_sort_by
from amt.core.authorization import get_user
from amt.core.exceptions import AMTError, AMTNotFound, AMTRepositoryError
from amt.core.internationalization import get_current_translation
from amt.enums.status import Status
from amt.models import Algorithm
from amt.models.task import Task
from amt.repositories.organizations import OrganizationsRepository
from amt.schema.measure import ExtendedMeasureTask, MeasureTask
from amt.repositories.users import UsersRepository
from amt.schema.measure import ExtendedMeasureTask, MeasureTask, Person
from amt.schema.requirement import RequirementTask
from amt.schema.system_card import Owner, SystemCard
from amt.schema.task import MovedTask
Expand Down Expand Up @@ -414,6 +417,8 @@ async def get_algorithm_inference(
async def get_system_card_requirements(
request: Request,
algorithm_id: int,
organizations_repository: Annotated[OrganizationsRepository, Depends(OrganizationsRepository)],
users_repository: Annotated[UsersRepository, Depends(UsersRepository)],
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
requirements_service: Annotated[RequirementsService, Depends(create_requirements_service)],
measures_service: Annotated[MeasuresService, Depends(create_measures_service)],
Expand All @@ -422,6 +427,9 @@ async def get_system_card_requirements(
instrument_state = await get_instrument_state(algorithm.system_card)
requirements_state = await get_requirements_state(algorithm.system_card)
tab_items = get_algorithm_details_tabs(request)
filters, _, _, sort_by = get_filters_and_sort_by(request)
organization = await organizations_repository.find_by_id(algorithm.organization_id)
filters["organization-id"] = str(organization.id)

breadcrumbs = resolve_base_navigation_items(
[
Expand All @@ -436,14 +444,17 @@ async def get_system_card_requirements(
[requirement.urn for requirement in algorithm.system_card.requirements]
)

# Get measures that correspond to the requirements and merge them with the measuretasks
# Get measures that correspond to the requirements and merge them with the measure tasks
requirements_and_measures = []
measure_tasks: list[MeasureTask | None] = []
for requirement in requirements:
completed_measures_count = 0
linked_measures = await measures_service.fetch_measures(requirement.links)
extended_linked_measures: list[ExtendedMeasureTask] = []
for measure in linked_measures:
measure_task = find_measure_task(algorithm.system_card, measure.urn)
if measure_task not in measure_tasks:
measure_tasks.append(measure_task)
if measure_task:
ext_measure_task = ExtendedMeasureTask(
name=measure.name,
Expand All @@ -458,6 +469,8 @@ async def get_system_card_requirements(
extended_linked_measures.append(ext_measure_task)
requirements_and_measures.append((requirement, completed_measures_count, extended_linked_measures)) # pyright: ignore [reportUnknownMemberType]

measure_task_functions = await get_measure_task_functions(measure_tasks, users_repository, sort_by, filters)

context = {
"instrument_state": instrument_state,
"requirements_state": requirements_state,
Expand All @@ -466,11 +479,49 @@ async def get_system_card_requirements(
"tab_items": tab_items,
"breadcrumbs": breadcrumbs,
"requirements_and_measures": requirements_and_measures,
"measure_task_functions": measure_task_functions,
}

return templates.TemplateResponse(request, "algorithms/details_requirements.html.j2", context)


async def get_measure_task_functions(
measure_tasks: list[MeasureTask | None],
users_repository: Annotated[UsersRepository, Depends(UsersRepository)],
sort_by: dict[str, str],
filters: dict[str, str],
) -> dict[str, list[Any]]:
measure_task_functions: dict[str, list[Any]] = defaultdict(list)
for measure_task in measure_tasks:
if measure_task.accountable_persons: # pyright: ignore [reportOptionalMemberAccess]
members_accountable = await users_repository.find_all(
search=measure_task.accountable_persons[0].name, # pyright: ignore [reportOptionalMemberAccess]
sort=sort_by,
filters=filters,
)
if members_accountable:
measure_task_functions[measure_task.urn].append(members_accountable[0]) # pyright: ignore [reportOptionalMemberAccess]

if measure_task.reviewer_persons: # pyright: ignore [reportOptionalMemberAccess]
members_reviewer = await users_repository.find_all(
search=measure_task.reviewer_persons[0].name, # pyright: ignore [reportOptionalMemberAccess]
sort=sort_by,
filters=filters,
)
if members_reviewer:
measure_task_functions[measure_task.urn].append(members_reviewer[0]) # pyright: ignore [reportOptionalMemberAccess]

if measure_task.responsible_persons: # pyright: ignore [reportOptionalMemberAccess]
members_responsible = await users_repository.find_all(
search=measure_task.responsible_persons[0].name, # pyright: ignore [reportOptionalMemberAccess]
sort=sort_by,
filters=filters,
)
if members_responsible:
measure_task_functions[measure_task.urn].append(members_responsible[0]) # pyright: ignore [reportOptionalMemberAccess]
return measure_task_functions


def find_measure_task(system_card: SystemCard, urn: str) -> MeasureTask | None:
for measure in system_card.measures:
if measure.urn == urn:
Expand Down Expand Up @@ -518,12 +569,16 @@ async def delete_algorithm(
@router.get("/{algorithm_id}/measure/{measure_urn}")
async def get_measure(
request: Request,
organizations_repository: Annotated[OrganizationsRepository, Depends(OrganizationsRepository)],
users_repository: Annotated[UsersRepository, Depends(UsersRepository)],
algorithm_id: int,
measure_urn: str,
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
measures_service: Annotated[MeasuresService, Depends(create_measures_service)],
object_storage_service: Annotated[ObjectStorageService, Depends(create_object_storage_service)],
search: str = Query(""),
) -> HTMLResponse:
filters, _, _, sort_by = get_filters_and_sort_by(request)
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
measure = await measures_service.fetch_measures([measure_urn])
measure_task = get_measure_task_or_error(algorithm.system_card, measure_urn)
Expand All @@ -533,39 +588,74 @@ async def get_measure(
metadata = object_storage_service.get_file_metadata_from_object_name(file)
filenames.append((file.split("/")[-1], f"{metadata.filename}.{metadata.ext}"))

organization = await organizations_repository.find_by_id(algorithm.organization_id)
filters["organization-id"] = str(organization.id)
members = await users_repository.find_all(search=search, sort=sort_by, filters=filters)

measure_accountable = measure_task.accountable_persons[0].name if measure_task.accountable_persons else "" # pyright: ignore [reportOptionalMemberAccess]
measure_reviewer = measure_task.reviewer_persons[0].name if measure_task.reviewer_persons else "" # pyright: ignore [reportOptionalMemberAccess]
measure_responsible = measure_task.responsible_persons[0].name if measure_task.responsible_persons else "" # pyright: ignore [reportOptionalMemberAccess]

measure_form = await get_measure_form(
id="measure_state",
current_values={
"measure_state": measure_task.state,
"measure_value": measure_task.value,
"measure_links": measure_task.links,
"measure_files": filenames,
"measure_accountable": measure_accountable,
"measure_reviewer": measure_reviewer,
"measure_responsible": measure_responsible,
},
members=members,
translations=get_current_translation(request),
)

context = {
"measure": measure[0],
"algorithm_id": algorithm_id,
"form": measure_form,
}
context = {"measure": measure[0], "algorithm_id": algorithm_id, "form": measure_form}

return templates.TemplateResponse(request, "algorithms/details_measure_modal.html.j2", context)


async def get_users_from_function_name(
measure_accountable: Annotated[str | None, Form()],
measure_reviewer: Annotated[str | None, Form()],
measure_responsible: Annotated[str | None, Form()],
users_repository: Annotated[UsersRepository, Depends(UsersRepository)],
sort_by: dict[str, str],
filters: dict[str, str],
) -> tuple[list[Person], list[Person], list[Person]]:
accountable_persons, reviewer_persons, responsible_persons = [], [], []
if measure_accountable:
accountable_member = await users_repository.find_all(search=measure_accountable, sort=sort_by, filters=filters)
accountable_persons = [Person(name=accountable_member[0].name, uuid=str(accountable_member[0].id))] # pyright: ignore [reportOptionalMemberAccess]
if measure_reviewer:
reviewer_member = await users_repository.find_all(search=measure_reviewer, sort=sort_by, filters=filters)
reviewer_persons = [Person(name=reviewer_member[0].name, uuid=str(reviewer_member[0].id))] # pyright: ignore [reportOptionalMemberAccess]
if measure_responsible:
responsible_member = await users_repository.find_all(search=measure_responsible, sort=sort_by, filters=filters)
responsible_persons = [Person(name=responsible_member[0].name, uuid=str(responsible_member[0].id))] # pyright: ignore [reportOptionalMemberAccess]
return accountable_persons, reviewer_persons, responsible_persons


@router.post("/{algorithm_id}/measure/{measure_urn}")
async def update_measure_value(
request: Request,
algorithm_id: int,
measure_urn: str,
organizations_repository: Annotated[OrganizationsRepository, Depends(OrganizationsRepository)],
users_repository: Annotated[UsersRepository, Depends(UsersRepository)],
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
requirements_service: Annotated[RequirementsService, Depends(create_requirements_service)],
object_storage_service: Annotated[ObjectStorageService, Depends(create_object_storage_service)],
measure_state: Annotated[str, Form()],
measure_responsible: Annotated[str | None, Form()] = None,
measure_reviewer: Annotated[str | None, Form()] = None,
measure_accountable: Annotated[str | None, Form()] = None,
measure_value: Annotated[str | None, Form()] = None,
measure_links: Annotated[list[str] | None, Form()] = None,
measure_files: Annotated[list[UploadFile] | None, File()] = None,
) -> HTMLResponse:
filters, _, _, sort_by = get_filters_and_sort_by(request)
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
user_id = get_user_id_or_error(request)
measure_task = get_measure_task_or_error(algorithm.system_card, measure_urn)
Expand All @@ -577,7 +667,15 @@ async def update_measure_value(
if measure_files
else None
)
measure_task.update(measure_state, measure_value, measure_links, paths)
accountable_persons, reviewer_persons, responsible_persons = await get_users_from_function_name(
measure_accountable, measure_reviewer, measure_responsible, users_repository, sort_by, filters
)

measure_task.update(
measure_state, measure_value, measure_links, paths, responsible_persons, accountable_persons, reviewer_persons
)
organization = await organizations_repository.find_by_id(algorithm.organization_id)
filters["organization-id"] = str(organization.id)

# update for the linked requirements the state based on all it's measures
requirement_tasks = await find_requirement_tasks_by_measure_urn(algorithm.system_card, measure_urn)
Expand Down
40 changes: 26 additions & 14 deletions amt/locale/base.pot
Original file line number Diff line number Diff line change
Expand Up @@ -198,33 +198,45 @@ msgstr ""
msgid "Organization"
msgstr ""

#: amt/api/forms/measure.py:17
msgid "Status"
#: amt/api/forms/measure.py:24
msgid "Responsible"
msgstr ""

#: amt/api/forms/measure.py:32
msgid "Reviewer"
msgstr ""

#: amt/api/forms/measure.py:40
msgid "Accountable"
msgstr ""

#: amt/api/forms/measure.py:48
msgid "Status"
msgstr ""

#: amt/api/forms/measure.py:63
msgid "Information on how this measure is implemented"
msgstr ""

#: amt/api/forms/measure.py:39
#: amt/api/forms/measure.py:70
msgid ""
"Select one or more to upload. The files will be saved once you confirm "
"changes by pressing the save button."
msgstr ""

#: amt/api/forms/measure.py:44
#: amt/api/forms/measure.py:75
msgid "Add files"
msgstr ""

#: amt/api/forms/measure.py:45
#: amt/api/forms/measure.py:76
msgid "No files selected."
msgstr ""

#: amt/api/forms/measure.py:49
#: amt/api/forms/measure.py:80
msgid "Add URI"
msgstr ""

#: amt/api/forms/measure.py:52
#: amt/api/forms/measure.py:83
msgid "Add links to documents"
msgstr ""

Expand Down Expand Up @@ -458,7 +470,7 @@ msgid "Failed to estimate WOZ value: "
msgstr ""

#: amt/site/templates/algorithms/details_info.html.j2:16
#: amt/site/templates/algorithms/details_measure_modal.html.j2:26
#: amt/site/templates/algorithms/details_measure_modal.html.j2:27
msgid "Description"
msgstr ""

Expand Down Expand Up @@ -488,26 +500,26 @@ msgstr ""
msgid "References"
msgstr ""

#: amt/site/templates/algorithms/details_measure_modal.html.j2:36
#: amt/site/templates/algorithms/details_measure_modal.html.j2:37
msgid "Read more on the algoritmekader"
msgstr ""

#: amt/site/templates/algorithms/details_measure_modal.html.j2:47
#: amt/site/templates/algorithms/details_measure_modal.html.j2:63
#: amt/site/templates/macros/editable.html.j2:82
msgid "Save"
msgstr ""

#: amt/site/templates/algorithms/details_measure_modal.html.j2:51
#: amt/site/templates/algorithms/details_measure_modal.html.j2:67
#: amt/site/templates/macros/editable.html.j2:87
#: amt/site/templates/organizations/parts/add_members_modal.html.j2:26
msgid "Cancel"
msgstr ""

#: amt/site/templates/algorithms/details_requirements.html.j2:26
#: amt/site/templates/algorithms/details_requirements.html.j2:27
msgid "measures executed"
msgstr ""

#: amt/site/templates/algorithms/details_requirements.html.j2:59
#: amt/site/templates/algorithms/details_requirements.html.j2:60
#: amt/site/templates/macros/editable.html.j2:24
#: amt/site/templates/macros/editable.html.j2:27
msgid "Edit"
Expand Down Expand Up @@ -562,7 +574,7 @@ msgstr ""
#: amt/site/templates/algorithms/new.html.j2:172
msgid ""
"Overview of instruments for the responsible development, deployment, "
"assessment and monitoring of algorithms and AI-systems."
"assessment, and monitoring of algorithms and AI-systems."
msgstr ""

#: amt/site/templates/algorithms/new.html.j2:180
Expand Down
Loading

0 comments on commit 12b2868

Please sign in to comment.