diff --git a/tad/api/routes/pages.py b/tad/api/routes/pages.py index 73c31fe3..9983d673 100644 --- a/tad/api/routes/pages.py +++ b/tad/api/routes/pages.py @@ -5,8 +5,6 @@ from tad.services.statuses import StatusesService from tad.services.tasks import TasksService -tasks_service = TasksService() -statuses_service = StatusesService() router = APIRouter() templates = Jinja2Templates(directory="tad/site/templates") @@ -15,7 +13,7 @@ async def default_layout(request: Request): context = { "page_title": "This is the page title", - "tasks_service": tasks_service, - "statuses_service": statuses_service, + "tasks_service": TasksService(), + "statuses_service": StatusesService(), } return templates.TemplateResponse(request=request, name="default_layout.jinja", context=context) diff --git a/tad/api/routes/tasks.py b/tad/api/routes/tasks.py index bc5d2b04..9b17ae7e 100644 --- a/tad/api/routes/tasks.py +++ b/tad/api/routes/tasks.py @@ -1,24 +1,35 @@ +from typing import Any + from fastapi import APIRouter, Request from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates +from tad.models.task import MoveTask from tad.services.tasks import TasksService router = APIRouter() - -tasks_service = TasksService() templates = Jinja2Templates(directory="tad/site/templates") -@router.get("/") -async def test(): - return [{"username": "Rick"}, {"username": "Morty"}] - - @router.post("/move", response_class=HTMLResponse) -async def move_task(request: Request): - json = await request.json() - task = tasks_service.move_task( - int(json["taskId"]), int(json["statusId"]), json["previousSiblingId"], json["nextSiblingId"] +async def move_task(request: Request, move_task: MoveTask) -> HTMLResponse: + """ + Move a task through an API call. + :param request: the request object + :param move_task: the move task object + :return: a HTMLResponse object, in this case the html code of the card that was moved + """ + task = TasksService.move_task( + move_task.id, + move_task.status_id, + convert_to_int_if_is_int(move_task.previous_sibling_id), + convert_to_int_if_is_int(move_task.next_sibling_id), ) return templates.TemplateResponse(request=request, name="task.jinja", context={"task": task}) + + +def convert_to_int_if_is_int(value: Any) -> int | Any: + # If the given value is of type integer, convert it to integer, otherwise return the given value + if isinstance(value, int): + return int(value) + return value diff --git a/tad/core/db.py b/tad/core/db.py index d5317814..798139c0 100644 --- a/tad/core/db.py +++ b/tad/core/db.py @@ -1,8 +1,9 @@ +from sqlalchemy.engine.base import Engine from sqlmodel import Session, create_engine, select from tad.core.config import settings -engine = create_engine(settings.SQLALCHEMY_DATABASE_URI) +engine: Engine = create_engine(settings.SQLALCHEMY_DATABASE_URI) async def check_db(): diff --git a/tad/core/singleton.py b/tad/core/singleton.py deleted file mode 100644 index 845ba727..00000000 --- a/tad/core/singleton.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import ClassVar - - -class Singleton(type): - """The Singleton metaclass can be used to mark classes as singleton. - Based on https://stackoverflow.com/questions/6760685/what-is-the-best-way-of-implementing-singleton-in-python - - Usage: class Classname(metaclass=Singleton): - """ - - _instances = ClassVar[{}] - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - cls._instances[cls] = super().__call__(*args, **kwargs) - return cls._instances[cls] diff --git a/tad/main.py b/tad/main.py index 681a88d4..9ac34687 100644 --- a/tad/main.py +++ b/tad/main.py @@ -71,7 +71,5 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE app.include_router(api_router) -tasks_repository = TasksRepository() -statuses_repository = StatusesRepository() - -logger.info("Hallo ik ben een logger") +TasksRepository().create_example_tasks() +StatusesRepository().create_example_statuses() diff --git a/tad/models/task.py b/tad/models/task.py index dc64ef08..0d484412 100644 --- a/tad/models/task.py +++ b/tad/models/task.py @@ -1,3 +1,5 @@ +from pydantic import BaseModel +from pydantic.fields import Field as PydanticField from sqlmodel import Field, SQLModel @@ -9,3 +11,12 @@ class Task(SQLModel, table=True): status_id: int | None = Field(default=None, foreign_key="status.id") user_id: int | None = Field(default=None, foreign_key="user.id") # todo(robbert) Tasks probably are grouped (and sub-grouped), so we probably need a reference to a group_id + + +class MoveTask(BaseModel): + # todo(robbert) values from htmx json are all strings, using type int does not work for + # sibling variables (they are optional) + id: int = PydanticField(None, alias="taskId", strict=False) + status_id: int = PydanticField(None, alias="statusId", strict=False) + previous_sibling_id: str | None = PydanticField(None, alias="previousSiblingId", strict=False) + next_sibling_id: str | None = PydanticField(None, alias="nextSiblingId", strict=False) diff --git a/tad/repositories/statuses.py b/tad/repositories/statuses.py index 25b4a380..616b41b6 100644 --- a/tad/repositories/statuses.py +++ b/tad/repositories/statuses.py @@ -4,42 +4,46 @@ from sqlmodel import Session, select from tad.core.db import engine -from tad.core.singleton import Singleton from tad.models import Status logger = logging.getLogger(__name__) -class StatusesRepository(metaclass=Singleton): +class StatusesRepository: # TODO find out how to reuse Session - def __init__(self): - logger.info("Hello world from statuses repo") - statuses = self.find_all() + @staticmethod + def create_example_statuses(): + statuses = StatusesRepository.find_all() if len(statuses) == 0: - self.__add_test_statuses() - - def __add_test_statuses(self): - with Session(engine) as session: - session.add(Status(id=1, name="todo", sort_order=1)) - session.add(Status(id=2, name="in_progress", sort_order=2)) - session.add(Status(id=3, name="review", sort_order=3)) - session.add(Status(id=4, name="done", sort_order=4)) - session.commit() - - def find_all(self) -> Sequence[Status]: + with Session(engine) as session: + session.add(Status(id=1, name="todo", sort_order=1)) + session.add(Status(id=2, name="in_progress", sort_order=2)) + session.add(Status(id=3, name="review", sort_order=3)) + session.add(Status(id=4, name="done", sort_order=4)) + session.commit() + + @staticmethod + def find_all() -> Sequence[Status]: with Session(engine) as session: statement = select(Status) return session.exec(statement).all() - def save(self, status) -> Status: + @staticmethod + def save(status) -> Status: with Session(engine) as session: session.add(status) session.commit() session.refresh(status) return status - def find_by_id(self, status_id) -> Status: + @staticmethod + def find_by_id(status_id) -> Status: + """ + Returns the status with the given id or an exception if the id does not exist. + :param status_id: the id of the status + :return: the status with the given id or an exception + """ with Session(engine) as session: statement = select(Status).where(Status.id == status_id) return session.exec(statement).one() diff --git a/tad/repositories/tasks.py b/tad/repositories/tasks.py index 3901682c..36907f53 100644 --- a/tad/repositories/tasks.py +++ b/tad/repositories/tasks.py @@ -3,49 +3,54 @@ from sqlmodel import Session, select from tad.core.db import engine -from tad.core.singleton import Singleton from tad.models import Task +# todo(robbert) sessionmanagement should be done better, using a pool or maybe fastAPI dependencies -class TasksRepository(metaclass=Singleton): - def __init__(self): - tasks = self.find_all() - if len(tasks) == 0: - self.__add_test_tasks() - def __add_test_tasks(self): - with Session(engine) as session: - session.add( - Task( - status_id=1, - title="IAMA", - description="Impact Assessment Mensenrechten en Algoritmes", - sort_order=10, +class TasksRepository: + @staticmethod + def create_example_tasks(): + tasks = TasksRepository.find_all() + if len(tasks) == 0: + with Session(engine) as session: + session.add( + Task( + status_id=1, + title="IAMA", + description="Impact Assessment Mensenrechten en Algoritmes", + sort_order=10, + ) ) - ) - session.add(Task(status_id=1, title="SHAP", description="SHAP", sort_order=20)) - session.add(Task(status_id=1, title="This is title 3", description="This is description 3", sort_order=30)) - session.commit() + session.add(Task(status_id=1, title="SHAP", description="SHAP", sort_order=20)) + session.add( + Task(status_id=1, title="This is title 3", description="This is description 3", sort_order=30) + ) + session.commit() - def find_all(self) -> Sequence[Task]: + @staticmethod + def find_all() -> Sequence[Task]: """Returns all the tasks from the repository.""" with Session(engine) as session: statement = select(Task) return session.exec(statement).all() - def find_by_status_id(self, status_id) -> Sequence[Task]: + @staticmethod + def find_by_status_id(status_id) -> Sequence[Task]: with Session(engine) as session: statement = select(Task).where(Task.status_id == status_id).order_by(Task.sort_order) return session.exec(statement).all() - def save(self, task) -> Task: + @staticmethod + def save(task) -> Task: with Session(engine) as session: session.add(task) session.commit() session.refresh(task) return task - def find_by_id(self, task_id) -> Task: + @staticmethod + def find_by_id(task_id) -> Task: with Session(engine) as session: statement = select(Task).where(Task.id == task_id) return session.exec(statement).one() diff --git a/tad/services/statuses.py b/tad/services/statuses.py index fe96ac17..833cec7b 100644 --- a/tad/services/statuses.py +++ b/tad/services/statuses.py @@ -1,20 +1,15 @@ import logging -from tad.core.singleton import Singleton from tad.repositories.statuses import StatusesRepository logger = logging.getLogger(__name__) -class StatusesService(metaclass=Singleton): - __statuses_repository = StatusesRepository() +class StatusesService: + @staticmethod + def get_status(status_id): + return StatusesRepository.find_by_id(status_id) - def __init__(self): - logger.info("Statuses service initialized") - # TODO find out why logging is not visible - - def get_status(self, status_id): - return self.__statuses_repository.find_by_id(status_id) - - def get_statuses(self) -> []: - return self.__statuses_repository.find_all() + @staticmethod + def get_statuses() -> []: + return StatusesRepository.find_all() diff --git a/tad/services/tasks.py b/tad/services/tasks.py index 9292d268..86172ec5 100644 --- a/tad/services/tasks.py +++ b/tad/services/tasks.py @@ -1,6 +1,5 @@ import logging -from tad.core.singleton import Singleton from tad.models.task import Task from tad.models.user import User from tad.repositories.tasks import TasksRepository @@ -9,23 +8,28 @@ logger = logging.getLogger(__name__) -class TasksService(metaclass=Singleton): - __tasks_repository = TasksRepository() - __statuses_service = StatusesService() +class TasksService: + @staticmethod + def get_tasks(status_id): + return TasksRepository.find_by_status_id(status_id) - def __init__(self): - pass - - def get_tasks(self, status_id): - return self.__tasks_repository.find_by_status_id(status_id) - - def assign_task(self, task: Task, user: User): + @staticmethod + def assign_task(task: Task, user: User) -> Task: task.user_id = user.id - self.__tasks_repository.save(task) - - def move_task(self, task_id, status_id, previous_sibling_id, next_sibling_id) -> Task: - status = self.__statuses_service.get_status(status_id) - task = self.__tasks_repository.find_by_id(task_id) + return TasksRepository.save(task) + + @staticmethod + def move_task(task_id: int, status_id: int, previous_sibling_id: int, next_sibling_id: int) -> Task: + """ + Updates the task with the given task_id + :param task_id: the id of the task + :param status_id: the id of the status of the task + :param previous_sibling_id: the id of the previous sibling of the task + :param next_sibling_id: the id of the next sibling of the task + :return: the updated task + """ + status = StatusesService.get_status(status_id) + task = TasksRepository.find_by_id(task_id) if status.name == "done": # TODO implement logic for done @@ -42,17 +46,17 @@ def move_task(self, task_id, status_id, previous_sibling_id, next_sibling_id) -> if not previous_sibling_id and not next_sibling_id: task.sort_order = 10 elif previous_sibling_id and next_sibling_id: - previous_task = self.__tasks_repository.find_by_id(int(previous_sibling_id)) - next_task = self.__tasks_repository.find_by_id(int(next_sibling_id)) + previous_task = TasksRepository().find_by_id(int(previous_sibling_id)) + next_task = TasksRepository().find_by_id(int(next_sibling_id)) new_sort_order = previous_task.sort_order + ((next_task.sort_order - previous_task.sort_order) / 2) task.sort_order = new_sort_order elif previous_sibling_id and not next_sibling_id: - previous_task = self.__tasks_repository.find_by_id(int(previous_sibling_id)) + previous_task = TasksRepository().find_by_id(int(previous_sibling_id)) task.sort_order = previous_task.sort_order + 10 elif not previous_sibling_id and next_sibling_id: - next_task = self.__tasks_repository.find_by_id(int(next_sibling_id)) + next_task = TasksRepository().find_by_id(int(next_sibling_id)) task.sort_order = next_task.sort_order / 2 - task = self.__tasks_repository.save(task) + task = TasksRepository().save(task) return task diff --git a/tests/api/routes/test_status.py b/tests/api/routes/test_status.py new file mode 100644 index 00000000..e4e6f388 --- /dev/null +++ b/tests/api/routes/test_status.py @@ -0,0 +1,12 @@ +from fastapi.testclient import TestClient +from tad.models.task import MoveTask + + +def test_get_root(client: TestClient) -> None: + move_task: MoveTask = MoveTask(taskId="1", statusId="2", previousSiblingId="3", nextSiblingId="4") + print(move_task.model_dump()) + response = client.post("/tasks/move", data=move_task.model_dump()) + assert response.status_code == 200 + assert response.headers["content-type"] == "text/html; charset=utf-8" + + assert b"<h1>Welcome to the Home Page</h1>" in response.content