diff --git a/tad/api/deps.py b/tad/api/deps.py deleted file mode 100644 index a4bca175..00000000 --- a/tad/api/deps.py +++ /dev/null @@ -1,21 +0,0 @@ -from collections.abc import Generator -from typing import Annotated - -from fastapi import Depends -from fastapi.templating import Jinja2Templates -from sqlmodel import Session - -from tad.core.config import settings -from tad.core.db import engine - -templates = Jinja2Templates(directory=settings.TEMPLATE_DIR) - - -def get_db() -> Generator[Session, None, None]: - with Session(engine) as session: - yield session - - session.get() - - -SessionDep = Annotated[Session, Depends(get_db)] diff --git a/tad/api/routes/pages.py b/tad/api/routes/pages.py index 9983d673..2f17fc96 100644 --- a/tad/api/routes/pages.py +++ b/tad/api/routes/pages.py @@ -1,4 +1,6 @@ -from fastapi import APIRouter, Request +from typing import Annotated + +from fastapi import APIRouter, Depends, Request from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates @@ -10,10 +12,14 @@ @router.get("/", response_class=HTMLResponse) -async def default_layout(request: Request): +async def default_layout( + request: Request, + status_service: Annotated[StatusesService, Depends(StatusesService)], + tasks_service: Annotated[TasksService, Depends(TasksService)], +): context = { "page_title": "This is the page title", - "tasks_service": TasksService(), - "statuses_service": StatusesService(), + "tasks_service": tasks_service, + "statuses_service": status_service, } return templates.TemplateResponse(request=request, name="default_layout.jinja", context=context) diff --git a/tad/api/routes/root.py b/tad/api/routes/root.py index a33a202e..abe4fb08 100644 --- a/tad/api/routes/root.py +++ b/tad/api/routes/root.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Request from fastapi.responses import HTMLResponse -from tad.api.deps import templates +from tad.repositories.deps import templates router = APIRouter() diff --git a/tad/api/routes/tasks.py b/tad/api/routes/tasks.py index 9b17ae7e..e085922c 100644 --- a/tad/api/routes/tasks.py +++ b/tad/api/routes/tasks.py @@ -1,6 +1,6 @@ -from typing import Any +from typing import Annotated, Any -from fastapi import APIRouter, Request +from fastapi import APIRouter, Depends, Request, status from fastapi.responses import HTMLResponse from fastapi.templating import Jinja2Templates @@ -12,24 +12,37 @@ @router.post("/move", response_class=HTMLResponse) -async def move_task(request: Request, move_task: MoveTask) -> HTMLResponse: +async def move_task( + request: Request, move_task: MoveTask, tasks_service: Annotated[TasksService, Depends(TasksService)] +) -> HTMLResponse: """ Move a task through an API call. + :param tasks_service: the task service :param request: the request object :param move_task: the move task object :return: a HTMLResponse object, in this case the html code of the card that was moved """ - task = TasksService.move_task( - move_task.id, - move_task.status_id, - convert_to_int_if_is_int(move_task.previous_sibling_id), - convert_to_int_if_is_int(move_task.next_sibling_id), - ) - return templates.TemplateResponse(request=request, name="task.jinja", context={"task": task}) + try: + task = tasks_service.move_task( + convert_to_int_if_is_int(move_task.id), + convert_to_int_if_is_int(move_task.status_id), + convert_to_int_if_is_int(move_task.previous_sibling_id), + convert_to_int_if_is_int(move_task.next_sibling_id), + ) + # todo(Robbert) add error handling for input error or task error handling + return templates.TemplateResponse(request=request, name="task.jinja", context={"task": task}) + except Exception: + return templates.TemplateResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, request=request, name="error.jinja" + ) def convert_to_int_if_is_int(value: Any) -> int | Any: - # If the given value is of type integer, convert it to integer, otherwise return the given value - if isinstance(value, int): + """ + If the given value is of type int, return it as int, otherwise return the input value as is. + :param value: the value to convert + :return: the value as int or the original type + """ + if value is not None and isinstance(value, str) and value.isdigit(): return int(value) return value diff --git a/tad/core/db.py b/tad/core/db.py index 798139c0..0fa852c8 100644 --- a/tad/core/db.py +++ b/tad/core/db.py @@ -3,9 +3,16 @@ from tad.core.config import settings -engine: Engine = create_engine(settings.SQLALCHEMY_DATABASE_URI) +_engine: None | Engine = None + + +def get_engine() -> Engine: + global _engine + if _engine is None: + _engine = create_engine(settings.SQLALCHEMY_DATABASE_URI, echo=True, connect_args={"check_same_thread": False}) + return _engine async def check_db(): - with Session(engine) as session: + with Session(get_engine()) as session: session.exec(select(1)) diff --git a/tad/main.py b/tad/main.py index 9ac34687..68740234 100644 --- a/tad/main.py +++ b/tad/main.py @@ -18,11 +18,9 @@ validation_exception_handler as tad_validation_exception_handler, ) from tad.core.log import configure_logging -from tad.repositories.tasks import TasksRepository from tad.utils.mask import Mask from .middleware.route_logging import RequestLoggingMiddleware -from .repositories.statuses import StatusesRepository configure_logging(settings.LOGGING_LEVEL, settings.LOGGING_CONFIG) @@ -71,5 +69,4 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE app.include_router(api_router) -TasksRepository().create_example_tasks() -StatusesRepository().create_example_statuses() +# todo (robbert) add init code for example tasks and statuses diff --git a/tad/models/task.py b/tad/models/task.py index 0d484412..a3fe54aa 100644 --- a/tad/models/task.py +++ b/tad/models/task.py @@ -1,22 +1,30 @@ -from pydantic import BaseModel -from pydantic.fields import Field as PydanticField -from sqlmodel import Field, SQLModel +from pydantic import BaseModel, ValidationInfo, field_validator +from pydantic import Field as PydanticField +from sqlmodel import Field as SQLField +from sqlmodel import SQLModel class Task(SQLModel, table=True): - id: int = Field(default=None, primary_key=True) + id: int = SQLField(default=None, primary_key=True) title: str description: str sort_order: float - status_id: int | None = Field(default=None, foreign_key="status.id") - user_id: int | None = Field(default=None, foreign_key="user.id") + status_id: int | None = SQLField(default=None, foreign_key="status.id") + user_id: int | None = SQLField(default=None, foreign_key="user.id") # todo(robbert) Tasks probably are grouped (and sub-grouped), so we probably need a reference to a group_id class MoveTask(BaseModel): # todo(robbert) values from htmx json are all strings, using type int does not work for # sibling variables (they are optional) - id: int = PydanticField(None, alias="taskId", strict=False) - status_id: int = PydanticField(None, alias="statusId", strict=False) + id: str = PydanticField(None, alias="taskId", strict=False) + status_id: str = PydanticField(None, alias="statusId", strict=False) previous_sibling_id: str | None = PydanticField(None, alias="previousSiblingId", strict=False) next_sibling_id: str | None = PydanticField(None, alias="nextSiblingId", strict=False) + + @field_validator("id", "status_id", "previous_sibling_id", "next_sibling_id") + @classmethod + def check_is_int(cls, value: str, info: ValidationInfo) -> str: + if value is not None and isinstance(value, str) and value.isdigit(): + assert value.isdigit(), f"{info.field_name} must be an integer" # noqa: S101 + return value diff --git a/tad/repositories/deps.py b/tad/repositories/deps.py new file mode 100644 index 00000000..988da5f8 --- /dev/null +++ b/tad/repositories/deps.py @@ -0,0 +1,14 @@ +from collections.abc import Generator + +from fastapi.templating import Jinja2Templates +from sqlmodel import Session + +from tad.core.config import settings +from tad.core.db import get_engine + +templates = Jinja2Templates(directory=settings.TEMPLATE_DIR) + + +def get_session() -> Generator[Session, None, None]: + with Session(get_engine()) as session: + yield session diff --git a/tad/repositories/statuses.py b/tad/repositories/statuses.py index 616b41b6..ed7b77c9 100644 --- a/tad/repositories/statuses.py +++ b/tad/repositories/statuses.py @@ -1,49 +1,43 @@ import logging from collections.abc import Sequence +from typing import Annotated +from fastapi import Depends from sqlmodel import Session, select -from tad.core.db import engine from tad.models import Status +from tad.repositories.deps import get_session logger = logging.getLogger(__name__) class StatusesRepository: - # TODO find out how to reuse Session - - @staticmethod - def create_example_statuses(): - statuses = StatusesRepository.find_all() - if len(statuses) == 0: - with Session(engine) as session: - session.add(Status(id=1, name="todo", sort_order=1)) - session.add(Status(id=2, name="in_progress", sort_order=2)) - session.add(Status(id=3, name="review", sort_order=3)) - session.add(Status(id=4, name="done", sort_order=4)) - session.commit() - - @staticmethod - def find_all() -> Sequence[Status]: - with Session(engine) as session: - statement = select(Status) - return session.exec(statement).all() - - @staticmethod - def save(status) -> Status: - with Session(engine) as session: - session.add(status) - session.commit() - session.refresh(status) - return status - - @staticmethod - def find_by_id(status_id) -> Status: + def __init__(self, session: Annotated[Session, Depends(get_session)]): + self.session = session + + def create_example_statuses(self): + if len(self.find_all()) == 0: + self.session.add(Status(id=1, name="todo", sort_order=1)) + self.session.add(Status(id=2, name="in_progress", sort_order=2)) + self.session.add(Status(id=3, name="review", sort_order=3)) + self.session.add(Status(id=4, name="done", sort_order=4)) + self.session.commit() + + def find_all(self) -> Sequence[Status]: + statement = select(Status) + return self.session.exec(statement).all() + + def save(self, status) -> Status: + self.session.add(status) + self.session.commit() + self.session.refresh(status) + return status + + def find_by_id(self, status_id) -> Status: """ Returns the status with the given id or an exception if the id does not exist. :param status_id: the id of the status :return: the status with the given id or an exception """ - with Session(engine) as session: - statement = select(Status).where(Status.id == status_id) - return session.exec(statement).one() + statement = select(Status).where(Status.id == status_id) + return self.session.exec(statement).one() diff --git a/tad/repositories/tasks.py b/tad/repositories/tasks.py index 36907f53..87865147 100644 --- a/tad/repositories/tasks.py +++ b/tad/repositories/tasks.py @@ -1,56 +1,48 @@ from collections.abc import Sequence +from typing import Annotated +from fastapi import Depends from sqlmodel import Session, select -from tad.core.db import engine from tad.models import Task - -# todo(robbert) sessionmanagement should be done better, using a pool or maybe fastAPI dependencies +from tad.repositories.deps import get_session class TasksRepository: - @staticmethod - def create_example_tasks(): - tasks = TasksRepository.find_all() - if len(tasks) == 0: - with Session(engine) as session: - session.add( - Task( - status_id=1, - title="IAMA", - description="Impact Assessment Mensenrechten en Algoritmes", - sort_order=10, - ) - ) - session.add(Task(status_id=1, title="SHAP", description="SHAP", sort_order=20)) - session.add( - Task(status_id=1, title="This is title 3", description="This is description 3", sort_order=30) + def __init__(self, session: Annotated[Session, Depends(get_session)]): + self.session = session + + def create_example_tasks(self): + # todo (robbert) find_all should be a count query + if len(self.find_all()) == 0: + self.session.add( + Task( + status_id=1, + title="IAMA", + description="Impact Assessment Mensenrechten en Algoritmes", + sort_order=10, ) - session.commit() - - @staticmethod - def find_all() -> Sequence[Task]: + ) + self.session.add(Task(status_id=1, title="SHAP", description="SHAP", sort_order=20)) + self.session.add( + Task(status_id=1, title="This is title 3", description="This is description 3", sort_order=30) + ) + self.session.commit() + + def find_all(self) -> Sequence[Task]: """Returns all the tasks from the repository.""" - with Session(engine) as session: - statement = select(Task) - return session.exec(statement).all() - - @staticmethod - def find_by_status_id(status_id) -> Sequence[Task]: - with Session(engine) as session: - statement = select(Task).where(Task.status_id == status_id).order_by(Task.sort_order) - return session.exec(statement).all() - - @staticmethod - def save(task) -> Task: - with Session(engine) as session: - session.add(task) - session.commit() - session.refresh(task) - return task - - @staticmethod - def find_by_id(task_id) -> Task: - with Session(engine) as session: - statement = select(Task).where(Task.id == task_id) - return session.exec(statement).one() + return self.session.exec(select(Task)).all() + + def find_by_status_id(self, status_id) -> Sequence[Task]: + statement = select(Task).where(Task.status_id == status_id).order_by(Task.sort_order) + return self.session.exec(statement).all() + + def save(self, task) -> Task: + self.session.add(task) + self.session.commit() + self.session.refresh(task) + return task + + def find_by_id(self, task_id) -> Task: + statement = select(Task).where(Task.id == task_id) + return self.session.exec(statement).one() diff --git a/tad/services/statuses.py b/tad/services/statuses.py index 833cec7b..a1fef0b8 100644 --- a/tad/services/statuses.py +++ b/tad/services/statuses.py @@ -1,4 +1,7 @@ import logging +from typing import Annotated + +from fastapi import Depends from tad.repositories.statuses import StatusesRepository @@ -6,10 +9,11 @@ class StatusesService: - @staticmethod - def get_status(status_id): - return StatusesRepository.find_by_id(status_id) + def __init__(self, repository: Annotated[StatusesRepository, Depends(StatusesRepository)]): + self.repository = repository + + def get_status(self, status_id): + return self.repository.find_by_id(status_id) - @staticmethod - def get_statuses() -> []: - return StatusesRepository.find_all() + def get_statuses(self) -> []: + return self.repository.find_all() diff --git a/tad/services/tasks.py b/tad/services/tasks.py index 86172ec5..80cc8067 100644 --- a/tad/services/tasks.py +++ b/tad/services/tasks.py @@ -1,4 +1,7 @@ import logging +from typing import Annotated + +from fastapi import Depends from tad.models.task import Task from tad.models.user import User @@ -9,17 +12,22 @@ class TasksService: - @staticmethod - def get_tasks(status_id): - return TasksRepository.find_by_status_id(status_id) + def __init__( + self, + statuses_service: Annotated[StatusesService, Depends(StatusesService)], + repository: Annotated[TasksRepository, Depends(TasksRepository)], + ): + self.repository = repository + self.statuses_service = statuses_service + + def get_tasks(self, status_id): + return self.repository.find_by_status_id(status_id) - @staticmethod - def assign_task(task: Task, user: User) -> Task: + def assign_task(self, task: Task, user: User) -> Task: task.user_id = user.id - return TasksRepository.save(task) + return self.repository.save(task) - @staticmethod - def move_task(task_id: int, status_id: int, previous_sibling_id: int, next_sibling_id: int) -> Task: + def move_task(self, task_id: int, status_id: int, previous_sibling_id: int, next_sibling_id: int) -> Task: """ Updates the task with the given task_id :param task_id: the id of the task @@ -28,8 +36,8 @@ def move_task(task_id: int, status_id: int, previous_sibling_id: int, next_sibli :param next_sibling_id: the id of the next sibling of the task :return: the updated task """ - status = StatusesService.get_status(status_id) - task = TasksRepository.find_by_id(task_id) + status = self.statuses_service.get_status(status_id) + task = self.repository.find_by_id(task_id) if status.name == "done": # TODO implement logic for done @@ -46,17 +54,17 @@ def move_task(task_id: int, status_id: int, previous_sibling_id: int, next_sibli if not previous_sibling_id and not next_sibling_id: task.sort_order = 10 elif previous_sibling_id and next_sibling_id: - previous_task = TasksRepository().find_by_id(int(previous_sibling_id)) - next_task = TasksRepository().find_by_id(int(next_sibling_id)) + previous_task = self.repository.find_by_id(int(previous_sibling_id)) + next_task = self.repository.find_by_id(int(next_sibling_id)) new_sort_order = previous_task.sort_order + ((next_task.sort_order - previous_task.sort_order) / 2) task.sort_order = new_sort_order elif previous_sibling_id and not next_sibling_id: - previous_task = TasksRepository().find_by_id(int(previous_sibling_id)) + previous_task = self.repository.find_by_id(int(previous_sibling_id)) task.sort_order = previous_task.sort_order + 10 elif not previous_sibling_id and next_sibling_id: - next_task = TasksRepository().find_by_id(int(next_sibling_id)) + next_task = self.repository.find_by_id(int(next_sibling_id)) task.sort_order = next_task.sort_order / 2 - task = TasksRepository().save(task) + task = self.repository.save(task) return task diff --git a/tad/site/templates/default_layout.jinja b/tad/site/templates/default_layout.jinja index 479744cb..3f94c963 100644 --- a/tad/site/templates/default_layout.jinja +++ b/tad/site/templates/default_layout.jinja @@ -23,6 +23,7 @@ + @@ -51,14 +52,17 @@ -