diff --git a/tad/core/db.py b/tad/core/db.py index 0e2600ae..f55fae0f 100644 --- a/tad/core/db.py +++ b/tad/core/db.py @@ -3,7 +3,7 @@ from sqlalchemy.engine import Engine from sqlalchemy.pool import QueuePool, StaticPool -from sqlmodel import Session, SQLModel, create_engine, select +from sqlmodel import Session, SQLModel, create_engine, select, update from tad.core.config import get_settings from tad.models import Status, Task, User @@ -35,15 +35,16 @@ def check_db(): def remove_old_demo_objects(session: Session): + task = session.exec(select(Task).where(Task.title == "First task")).first() + if task: + session.delete(task) + session.exec(update(Task).values(status_id=None, user_id=None)) # type: ignore user = session.exec(select(User).where(User.name == "Robbert")).first() if user: session.delete(user) - status = session.exec(select(Status).where(Status.name == "Todo")).first() + status = session.exec(select(Status).where(Status.name == "todo")).first() if status: session.delete(status) - task = session.exec(select(Task).where(Task.title == "First task")).first() - if task: - session.delete(task) session.commit() @@ -87,6 +88,7 @@ def add_demo_tasks(session: Session, status: Status | None, number_of_tasks: int status_id=status.id, ) ) + session.exec(update(Task).values(status_id=status.id)) # type: ignore session.commit() diff --git a/tests/core/test_db.py b/tests/core/test_db.py index 65eb7f47..467df0f4 100644 --- a/tests/core/test_db.py +++ b/tests/core/test_db.py @@ -67,7 +67,7 @@ def test_remove_old_demo_objects(db: DatabaseTestUtils): db_session.delete = MagicMock() user = User(name="Robbert", avatar=None) - status = Status(name="Todo", sort_order=1) + status = Status(name="todo", sort_order=1) task = Task(title="First task", description="This is the first task", sort_order=1, status_id=status.id) db.given([user, status, task])