Skip to content

Commit

Permalink
Adding more structure and database support
Browse files Browse the repository at this point in the history
  • Loading branch information
uittenbroekrobbert committed May 23, 2024
1 parent ff5d50e commit 71bff01
Show file tree
Hide file tree
Showing 17 changed files with 215 additions and 177 deletions.
21 changes: 0 additions & 21 deletions tad/api/deps.py

This file was deleted.

14 changes: 10 additions & 4 deletions tad/api/routes/pages.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from fastapi import APIRouter, Request
from typing import Annotated

from fastapi import APIRouter, Depends, Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates

Expand All @@ -10,10 +12,14 @@


@router.get("/", response_class=HTMLResponse)
async def default_layout(request: Request):
async def default_layout(
request: Request,
status_service: Annotated[StatusesService, Depends(StatusesService)],
tasks_service: Annotated[TasksService, Depends(TasksService)],
):
context = {
"page_title": "This is the page title",
"tasks_service": TasksService(),
"statuses_service": StatusesService(),
"tasks_service": tasks_service,
"statuses_service": status_service,
}
return templates.TemplateResponse(request=request, name="default_layout.jinja", context=context)
2 changes: 1 addition & 1 deletion tad/api/routes/root.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import APIRouter, Request
from fastapi.responses import HTMLResponse

from tad.api.deps import templates
from tad.repositories.deps import templates

router = APIRouter()

Expand Down
37 changes: 25 additions & 12 deletions tad/api/routes/tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any
from typing import Annotated, Any

from fastapi import APIRouter, Request
from fastapi import APIRouter, Depends, Request, status
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates

Expand All @@ -12,24 +12,37 @@


@router.post("/move", response_class=HTMLResponse)
async def move_task(request: Request, move_task: MoveTask) -> HTMLResponse:
async def move_task(
request: Request, move_task: MoveTask, tasks_service: Annotated[TasksService, Depends(TasksService)]
) -> HTMLResponse:
"""
Move a task through an API call.
:param tasks_service: the task service
:param request: the request object
:param move_task: the move task object
:return: a HTMLResponse object, in this case the html code of the card that was moved
"""
task = TasksService.move_task(
move_task.id,
move_task.status_id,
convert_to_int_if_is_int(move_task.previous_sibling_id),
convert_to_int_if_is_int(move_task.next_sibling_id),
)
return templates.TemplateResponse(request=request, name="task.jinja", context={"task": task})
try:
task = tasks_service.move_task(
convert_to_int_if_is_int(move_task.id),
convert_to_int_if_is_int(move_task.status_id),
convert_to_int_if_is_int(move_task.previous_sibling_id),
convert_to_int_if_is_int(move_task.next_sibling_id),
)
# todo(Robbert) add error handling for input error or task error handling
return templates.TemplateResponse(request=request, name="task.jinja", context={"task": task})
except Exception:
return templates.TemplateResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, request=request, name="error.jinja"
)


def convert_to_int_if_is_int(value: Any) -> int | Any:
# If the given value is of type integer, convert it to integer, otherwise return the given value
if isinstance(value, int):
"""
If the given value is of type int, return it as int, otherwise return the input value as is.
:param value: the value to convert
:return: the value as int or the original type
"""
if value is not None and isinstance(value, str) and value.isdigit():
return int(value)
return value
11 changes: 9 additions & 2 deletions tad/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@

from tad.core.config import settings

engine: Engine = create_engine(settings.SQLALCHEMY_DATABASE_URI)
_engine: None | Engine = None


def get_engine() -> Engine:
global _engine
if _engine is None:
_engine = create_engine(settings.SQLALCHEMY_DATABASE_URI, echo=True, connect_args={"check_same_thread": False})
return _engine


async def check_db():
with Session(engine) as session:
with Session(get_engine()) as session:
session.exec(select(1))
5 changes: 1 addition & 4 deletions tad/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@
validation_exception_handler as tad_validation_exception_handler,
)
from tad.core.log import configure_logging
from tad.repositories.tasks import TasksRepository
from tad.utils.mask import Mask

from .middleware.route_logging import RequestLoggingMiddleware
from .repositories.statuses import StatusesRepository

configure_logging(settings.LOGGING_LEVEL, settings.LOGGING_CONFIG)

Expand Down Expand Up @@ -71,5 +69,4 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE

app.include_router(api_router)

TasksRepository().create_example_tasks()
StatusesRepository().create_example_statuses()
# todo (robbert) add init code for example tasks and statuses
24 changes: 16 additions & 8 deletions tad/models/task.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
from pydantic import BaseModel
from pydantic.fields import Field as PydanticField
from sqlmodel import Field, SQLModel
from pydantic import BaseModel, ValidationInfo, field_validator
from pydantic import Field as PydanticField
from sqlmodel import Field as SQLField
from sqlmodel import SQLModel


class Task(SQLModel, table=True):
id: int = Field(default=None, primary_key=True)
id: int = SQLField(default=None, primary_key=True)
title: str
description: str
sort_order: float
status_id: int | None = Field(default=None, foreign_key="status.id")
user_id: int | None = Field(default=None, foreign_key="user.id")
status_id: int | None = SQLField(default=None, foreign_key="status.id")
user_id: int | None = SQLField(default=None, foreign_key="user.id")
# todo(robbert) Tasks probably are grouped (and sub-grouped), so we probably need a reference to a group_id


class MoveTask(BaseModel):
# todo(robbert) values from htmx json are all strings, using type int does not work for
# sibling variables (they are optional)
id: int = PydanticField(None, alias="taskId", strict=False)
status_id: int = PydanticField(None, alias="statusId", strict=False)
id: str = PydanticField(None, alias="taskId", strict=False)
status_id: str = PydanticField(None, alias="statusId", strict=False)
previous_sibling_id: str | None = PydanticField(None, alias="previousSiblingId", strict=False)
next_sibling_id: str | None = PydanticField(None, alias="nextSiblingId", strict=False)

@field_validator("id", "status_id", "previous_sibling_id", "next_sibling_id")
@classmethod
def check_is_int(cls, value: str, info: ValidationInfo) -> str:
if value is not None and isinstance(value, str) and value.isdigit():
assert value.isdigit(), f"{info.field_name} must be an integer" # noqa: S101
return value
14 changes: 14 additions & 0 deletions tad/repositories/deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from collections.abc import Generator

from fastapi.templating import Jinja2Templates
from sqlmodel import Session

from tad.core.config import settings
from tad.core.db import get_engine

templates = Jinja2Templates(directory=settings.TEMPLATE_DIR)


def get_session() -> Generator[Session, None, None]:
with Session(get_engine()) as session:
yield session
60 changes: 27 additions & 33 deletions tad/repositories/statuses.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,43 @@
import logging
from collections.abc import Sequence
from typing import Annotated

from fastapi import Depends
from sqlmodel import Session, select

from tad.core.db import engine
from tad.models import Status
from tad.repositories.deps import get_session

logger = logging.getLogger(__name__)


class StatusesRepository:
# TODO find out how to reuse Session

@staticmethod
def create_example_statuses():
statuses = StatusesRepository.find_all()
if len(statuses) == 0:
with Session(engine) as session:
session.add(Status(id=1, name="todo", sort_order=1))
session.add(Status(id=2, name="in_progress", sort_order=2))
session.add(Status(id=3, name="review", sort_order=3))
session.add(Status(id=4, name="done", sort_order=4))
session.commit()

@staticmethod
def find_all() -> Sequence[Status]:
with Session(engine) as session:
statement = select(Status)
return session.exec(statement).all()

@staticmethod
def save(status) -> Status:
with Session(engine) as session:
session.add(status)
session.commit()
session.refresh(status)
return status

@staticmethod
def find_by_id(status_id) -> Status:
def __init__(self, session: Annotated[Session, Depends(get_session)]):
self.session = session

def create_example_statuses(self):
if len(self.find_all()) == 0:
self.session.add(Status(id=1, name="todo", sort_order=1))
self.session.add(Status(id=2, name="in_progress", sort_order=2))
self.session.add(Status(id=3, name="review", sort_order=3))
self.session.add(Status(id=4, name="done", sort_order=4))
self.session.commit()

def find_all(self) -> Sequence[Status]:
statement = select(Status)
return self.session.exec(statement).all()

def save(self, status) -> Status:
self.session.add(status)
self.session.commit()
self.session.refresh(status)
return status

def find_by_id(self, status_id) -> Status:
"""
Returns the status with the given id or an exception if the id does not exist.
:param status_id: the id of the status
:return: the status with the given id or an exception
"""
with Session(engine) as session:
statement = select(Status).where(Status.id == status_id)
return session.exec(statement).one()
statement = select(Status).where(Status.id == status_id)
return self.session.exec(statement).one()
84 changes: 38 additions & 46 deletions tad/repositories/tasks.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,48 @@
from collections.abc import Sequence
from typing import Annotated

from fastapi import Depends
from sqlmodel import Session, select

from tad.core.db import engine
from tad.models import Task

# todo(robbert) sessionmanagement should be done better, using a pool or maybe fastAPI dependencies
from tad.repositories.deps import get_session


class TasksRepository:
@staticmethod
def create_example_tasks():
tasks = TasksRepository.find_all()
if len(tasks) == 0:
with Session(engine) as session:
session.add(
Task(
status_id=1,
title="IAMA",
description="Impact Assessment Mensenrechten en Algoritmes",
sort_order=10,
)
)
session.add(Task(status_id=1, title="SHAP", description="SHAP", sort_order=20))
session.add(
Task(status_id=1, title="This is title 3", description="This is description 3", sort_order=30)
def __init__(self, session: Annotated[Session, Depends(get_session)]):
self.session = session

def create_example_tasks(self):
# todo (robbert) find_all should be a count query
if len(self.find_all()) == 0:
self.session.add(
Task(
status_id=1,
title="IAMA",
description="Impact Assessment Mensenrechten en Algoritmes",
sort_order=10,
)
session.commit()

@staticmethod
def find_all() -> Sequence[Task]:
)
self.session.add(Task(status_id=1, title="SHAP", description="SHAP", sort_order=20))
self.session.add(
Task(status_id=1, title="This is title 3", description="This is description 3", sort_order=30)
)
self.session.commit()

def find_all(self) -> Sequence[Task]:
"""Returns all the tasks from the repository."""
with Session(engine) as session:
statement = select(Task)
return session.exec(statement).all()

@staticmethod
def find_by_status_id(status_id) -> Sequence[Task]:
with Session(engine) as session:
statement = select(Task).where(Task.status_id == status_id).order_by(Task.sort_order)
return session.exec(statement).all()

@staticmethod
def save(task) -> Task:
with Session(engine) as session:
session.add(task)
session.commit()
session.refresh(task)
return task

@staticmethod
def find_by_id(task_id) -> Task:
with Session(engine) as session:
statement = select(Task).where(Task.id == task_id)
return session.exec(statement).one()
return self.session.exec(select(Task)).all()

def find_by_status_id(self, status_id) -> Sequence[Task]:
statement = select(Task).where(Task.status_id == status_id).order_by(Task.sort_order)
return self.session.exec(statement).all()

def save(self, task) -> Task:
self.session.add(task)
self.session.commit()
self.session.refresh(task)
return task

def find_by_id(self, task_id) -> Task:
statement = select(Task).where(Task.id == task_id)
return self.session.exec(statement).one()
Loading

0 comments on commit 71bff01

Please sign in to comment.