Skip to content

Commit

Permalink
Add document upload
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherSpelt authored and laurensWe committed Dec 17, 2024
1 parent 06e3964 commit 434fa23
Show file tree
Hide file tree
Showing 24 changed files with 2,461 additions and 1,445 deletions.
55 changes: 55 additions & 0 deletions amt/api/forms/measure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from gettext import NullTranslations

from amt.schema.webform import WebForm, WebFormField, WebFormFieldType, WebFormOption, WebFormTextCloneableField


async def get_measure_form(
id: str, current_values: dict[str, str | list[str] | list[tuple[str, str]]], translations: NullTranslations
) -> WebForm:
_ = translations.gettext

measure_form: WebForm = WebForm(id="", post_url="")

measure_form.fields = [
WebFormField(
type=WebFormFieldType.SELECT,
name="measure_state",
label=_("Status"),
options=[
WebFormOption(value="to do", display_value="to do"),
WebFormOption(value="in progress", display_value="in progress"),
WebFormOption(value="in review", display_value="in review"),
WebFormOption(value="done", display_value="done"),
WebFormOption(value="not implemented", display_value="not implemented"),
],
default_value=current_values.get("measure_state"),
group="1",
),
WebFormField(
type=WebFormFieldType.TEXTAREA,
name="measure_value",
default_value=current_values.get("measure_value"),
label=_("Information on how this measure is implemented"),
placeholder="",
group="1",
),
WebFormField(
type=WebFormFieldType.FILE,
name="measure_files",
description=_("Select one or more to upload. The files will be saved once you confirm changes by pressing the save button."),
default_value=current_values.get("measure_files"),
label=_("Add files"),
placeholder=_("No files selected."),
group="1",
),
WebFormTextCloneableField(
clone_button_name=_("Add URI"),
name="measure_links",
default_value=current_values.get("measure_links"),
label=_("Add links to documents"),
placeholder="",
group="1",
),
]

return measure_form
2 changes: 2 additions & 0 deletions amt/api/forms/organization.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def get_organization_form(id: str, translations: NullTranslations, user: User |
placeholder=_("Name of the organization"),
attributes={"onkeyup": "amt.generate_slug('" + id + "name', '" + id + "slug')"},
group="1",
required=True,
),
WebFormField(
type=WebFormFieldType.TEXT,
Expand All @@ -32,6 +33,7 @@ def get_organization_form(id: str, translations: NullTranslations, user: User |
label=_("Slug"),
placeholder=_("The slug for this organization"),
group="1",
required=True,
),
WebFormSearchField(
name="user_ids",
Expand Down
129 changes: 97 additions & 32 deletions amt/api/routes/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from typing import Annotated, Any, cast

import yaml
from fastapi import APIRouter, Depends, Request
from fastapi import APIRouter, Depends, File, Form, Request, Response, UploadFile
from fastapi.responses import FileResponse, HTMLResponse
from pydantic import BaseModel, Field
from pydantic import BaseModel
from ulid import ULID

from amt.api.deps import templates
from amt.api.forms.measure import get_measure_form
from amt.api.navigation import (
BaseNavigationItem,
Navigation,
Expand All @@ -18,7 +20,8 @@
resolve_navigation_items,
)
from amt.core.authorization import get_user
from amt.core.exceptions import AMTNotFound, AMTRepositoryError
from amt.core.exceptions import AMTError, AMTNotFound, AMTRepositoryError
from amt.core.internationalization import get_current_translation
from amt.enums.status import Status
from amt.models import Algorithm
from amt.models.task import Task
Expand All @@ -31,6 +34,7 @@
from amt.services.algorithms import AlgorithmsService
from amt.services.instruments_and_requirements_state import InstrumentStateService, RequirementsStateService
from amt.services.measures import MeasuresService, create_measures_service
from amt.services.object_storage import ObjectStorageService, create_object_storage_service
from amt.services.organizations import OrganizationsService
from amt.services.requirements import RequirementsService, create_requirements_service
from amt.services.tasks import TasksService
Expand Down Expand Up @@ -76,6 +80,20 @@ async def get_algorithm_or_error(
return algorithm


def get_user_id_or_error(request: Request) -> str:
user = get_user(request)
if user is None or user["sub"] is None:
raise AMTError
return user["sub"]


def get_measure_task_or_error(system_card: SystemCard, measure_urn: str) -> MeasureTask:
measure_task = find_measure_task(system_card, measure_urn)
if not measure_task:
raise AMTNotFound
return measure_task


def get_algorithm_details_tabs(request: Request) -> list[NavigationItem]:
return resolve_navigation_items(
[
Expand Down Expand Up @@ -503,41 +521,63 @@ async def get_measure(
algorithm_id: int,
measure_urn: str,
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
measures_service: Annotated[MeasuresService, Depends(create_measures_service)],
object_storage_service: Annotated[ObjectStorageService, Depends(create_object_storage_service)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
measures_service = create_measures_service()
measure = await measures_service.fetch_measures([measure_urn])
measure_task = find_measure_task(algorithm.system_card, measure_urn)
measure_task = get_measure_task_or_error(algorithm.system_card, measure_urn)

filenames: list[tuple[str, str]] = []
for file in measure_task.files:
metadata = object_storage_service.get_file_metadata_from_object_name(file)
filenames.append((file.split("/")[-1], f"{metadata.filename}.{metadata.ext}"))

measure_form = await get_measure_form(
id="measure_state",
current_values={
"measure_state": measure_task.state,
"measure_value": measure_task.value,
"measure_links": measure_task.links,
"measure_files": filenames,
},
translations=get_current_translation(request),
)

context = {
"measure": measure[0],
"measure_state": measure_task.state, # pyright: ignore [reportOptionalMemberAccess]
"measure_value": measure_task.value, # pyright: ignore [reportOptionalMemberAccess]
"algorithm_id": algorithm_id,
"form": measure_form,
}

return templates.TemplateResponse(request, "algorithms/details_measure_modal.html.j2", context)


class MeasureUpdate(BaseModel):
measure_state: str = Field(default=None)
measure_value: str = Field(default=None)


@router.post("/{algorithm_id}/measure/{measure_urn}")
async def update_measure_value(
request: Request,
algorithm_id: int,
measure_urn: str,
measure_update: MeasureUpdate,
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
requirements_service: Annotated[RequirementsService, Depends(create_requirements_service)],
object_storage_service: Annotated[ObjectStorageService, Depends(create_object_storage_service)],
measure_state: Annotated[str, Form()],
measure_value: Annotated[str | None, Form()] = None,
measure_links: Annotated[list[str] | None, Form()] = None,
measure_files: Annotated[list[UploadFile] | None, File()] = None,
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)

measure_task = find_measure_task(algorithm.system_card, measure_urn)
measure_task.state = measure_update.measure_state # pyright: ignore [reportOptionalMemberAccess]
measure_task.value = measure_update.measure_value # pyright: ignore [reportOptionalMemberAccess]
user_id = get_user_id_or_error(request)
measure_task = get_measure_task_or_error(algorithm.system_card, measure_urn)

paths = (
object_storage_service.upload_files(
algorithm.organization_id, algorithm.id, measure_urn, user_id, measure_files
)
if measure_files
else None
)
measure_task.update(measure_state, measure_value, measure_links, paths)

# update for the linked requirements the state based on all it's measures
requirement_tasks = await find_requirement_tasks_by_measure_urn(algorithm.system_card, measure_urn)
Expand All @@ -564,11 +604,6 @@ async def update_measure_value(
return templates.Redirect(request, f"/algorithm/{algorithm_id}/details/system_card/requirements")


# !!!
# Implementation of this endpoint is for now independent of the algorithm ID, meaning
# that the same system card is rendered for all algorithm ID's. This is due to the fact
# that the logical process flow of a system card is not complete.
# !!!
@router.get("/{algorithm_id}/details/system_card/data")
async def get_system_card_data_page(
request: Request,
Expand Down Expand Up @@ -602,11 +637,6 @@ async def get_system_card_data_page(
return templates.TemplateResponse(request, "algorithms/details_data.html.j2", context)


# !!!
# Implementation of this endpoint is for now independent of the algorithm ID, meaning
# that the same system card is rendered for all algorithm ID's. This is due to the fact
# that the logical process flow of a system card is not complete.
# !!!
@router.get("/{algorithm_id}/details/system_card/instruments")
async def get_system_card_instruments(
request: Request,
Expand Down Expand Up @@ -685,11 +715,6 @@ async def get_assessment_card(
return templates.TemplateResponse(request, "pages/assessment_card.html.j2", context)


# !!!
# Implementation of this endpoint is for now independent of the algorithm ID, meaning
# that the same system card is rendered for all algorithm ID's. This is due to the fact
# that the logical process flow of a system card is not complete.
# !!!
@router.get("/{algorithm_id}/details/system_card/models/{model_card}")
async def get_model_card(
request: Request,
Expand Down Expand Up @@ -752,3 +777,43 @@ async def download_algorithm_system_card_as_yaml(
return FileResponse(filename, filename=filename)
except AMTRepositoryError as e:
raise AMTNotFound from e


@router.get("/{algorithm_id}/file/{ulid}")
async def get_file(
request: Request,
algorithm_id: int,
ulid: ULID,
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
object_storage_service: Annotated[ObjectStorageService, Depends(create_object_storage_service)],
) -> Response:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
file = object_storage_service.get_file(algorithm.organization_id, algorithm_id, ulid)
file_metadata = object_storage_service.get_file_metadata(algorithm.organization_id, algorithm_id, ulid)

return Response(
content=file.read(decode_content=True),
headers={
"Content-Disposition": f"attachment;filename={file_metadata.filename}.{file_metadata.ext}",
"Content-Type": "application/octet-stream",
},
)


@router.delete("/{algorithm_id}/file/{ulid}")
async def delete_file(
request: Request,
algorithm_id: int,
ulid: ULID,
algorithms_service: Annotated[AlgorithmsService, Depends(AlgorithmsService)],
object_storage_service: Annotated[ObjectStorageService, Depends(create_object_storage_service)],
) -> HTMLResponse:
algorithm = await get_algorithm_or_error(algorithm_id, algorithms_service, request)
metadata = object_storage_service.get_file_metadata(algorithm.organization_id, algorithm_id, ulid)
measure_task = get_measure_task_or_error(algorithm.system_card, metadata.measure_urn)

entry_to_delete = object_storage_service.delete_file(algorithm.organization_id, algorithm_id, ulid)
measure_task.files.remove(entry_to_delete)
await algorithms_service.update(algorithm)

return HTMLResponse(content="", status_code=200)
5 changes: 5 additions & 0 deletions amt/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ class Settings(BaseSettings):

TASK_REGISTRY_URL: str = "https://task-registry.apps.digilab.network"

OBJECT_STORE_URL: str = "localhost:9000"
OBJECT_STORE_USER: str = "amt"
OBJECT_STORE_PASSWORD: str = "changeme"
OBJECT_STORE_BUCKET_NAME: str = "amt"

@computed_field
def SQLALCHEMY_ECHO(self) -> bool:
return self.DEBUG
Expand Down
6 changes: 6 additions & 0 deletions amt/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,9 @@ class AMTAuthorizationFlowError(AMTHTTPException):
def __init__(self) -> None:
self.detail: str = _("Something went wrong during the authorization flow. Please try again later.")
super().__init__(status.HTTP_401_UNAUTHORIZED, self.detail)


class AMTStorageError(AMTHTTPException):
def __init__(self) -> None:
self.detail: str = _("Something went wrong storing your file. PLease try again later.")
super().__init__(status.HTTP_500_INTERNAL_SERVER_ERROR, self.detail)
Loading

0 comments on commit 434fa23

Please sign in to comment.