From 2f16eccea3bf178ba0e443e9343d13acbb3f653e Mon Sep 17 00:00:00 2001 From: robbertuittenbroek Date: Fri, 28 Jun 2024 12:40:44 +0200 Subject: [PATCH] Add more demo objects to the demo suite --- BUILD.md | 7 + tad/core/db.py | 75 +++++++-- ...c_rename_status_fields_for_translation_.py | 40 +++++ tests/core/test_db.py | 158 ++++++++++++++---- tests/database_test_utils.py | 5 +- 5 files changed, 235 insertions(+), 50 deletions(-) create mode 100644 tad/migrations/versions/bf86033947fc_rename_status_fields_for_translation_.py diff --git a/BUILD.md b/BUILD.md index 5d542068..9d146224 100644 --- a/BUILD.md +++ b/BUILD.md @@ -23,6 +23,13 @@ When poetry is done installing all dependencies you can start using the tool. poetry run python -m uvicorn tad.main:app --log-level warning ``` +### Suggested development ENVIRONMENT settings +To use a demo environment during local development, you can use the following environment options. +```shell +export ENVIRONMENT=demo AUTO_CREATE_SCHEMA=true +``` + + ## Database We support most SQL database types. You can use the variable `APP_DATABASE_SCHEME` to change the database. The default scheme is sqlite. diff --git a/tad/core/db.py b/tad/core/db.py index e07be8b1..0e2600ae 100644 --- a/tad/core/db.py +++ b/tad/core/db.py @@ -31,7 +31,20 @@ def check_db(): with Session(get_engine()) as session: session.exec(select(1)) - logger.info("Finisch Checking database connection") + logger.info("Finish Checking database connection") + + +def remove_old_demo_objects(session: Session): + user = session.exec(select(User).where(User.name == "Robbert")).first() + if user: + session.delete(user) + status = session.exec(select(Status).where(Status.name == "Todo")).first() + if status: + session.delete(status) + task = session.exec(select(Task).where(Task.title == "First task")).first() + if task: + session.delete(task) + session.commit() def init_db(): @@ -44,20 +57,48 @@ def init_db(): with Session(get_engine()) as session: if get_settings().ENVIRONMENT == "demo": logger.info("Creating demo data") - - user = session.exec(select(User).where(User.name == "Robbert")).first() - if not user: - user = User(name="Robbert", avatar=None) - session.add(user) - - status = session.exec(select(Status).where(Status.name == "Todo")).first() - if not status: - status = Status(name="Todo", sort_order=1) - session.add(status) - - task = session.exec(select(Task).where(Task.title == "First task")).first() - if not task: - task = Task(title="First task", description="This is the first task", sort_order=1, status_id=status.id) - session.add(task) - session.commit() + remove_old_demo_objects(session) + add_demo_users(session, ["default user"]) + demo_statuses = add_demo_statuses(session, ["todo", "review", "in_progress", "done"]) + add_demo_tasks(session, demo_statuses[0], 3) logger.info("Finished initializing database") + + +def add_demo_users(session: Session, user_names: list[str]) -> None: + for user_name in user_names: + user = session.exec(select(User).where(User.name == user_name)).first() + if not user: + session.add(User(name=user_name, avatar=None)) + session.commit() + + +def add_demo_tasks(session: Session, status: Status | None, number_of_tasks: int) -> None: + if status is None: + return + for index in range(1, number_of_tasks + 1): + title = "Example task " + str(index) + task = session.exec(select(Task).where(Task.title == title)).first() + if not task: + session.add( + Task( + title=title, + description="Example description " + str(index), + sort_order=index, + status_id=status.id, + ) + ) + session.commit() + + +def add_demo_statuses(session: Session, statuses: list[str]) -> list[Status]: + return_statuses: list[Status] = [] + for index, status_name in enumerate(statuses): + status = session.exec(select(Status).where(Status.name == status_name)).first() + if not status: + status = Status(name=status_name, sort_order=index + 1) + session.add(status) + return_statuses.append(status) + session.commit() + for return_status in return_statuses: + session.refresh(return_status) + return return_statuses diff --git a/tad/migrations/versions/bf86033947fc_rename_status_fields_for_translation_.py b/tad/migrations/versions/bf86033947fc_rename_status_fields_for_translation_.py new file mode 100644 index 00000000..3bdee47f --- /dev/null +++ b/tad/migrations/versions/bf86033947fc_rename_status_fields_for_translation_.py @@ -0,0 +1,40 @@ +"""rename status fields for translation support + +Revision ID: bf86033947fc +Revises: b62dbd9468e4 +Create Date: 2024-07-05 08:54:30.114471 + +""" + +from collections.abc import Sequence + +from alembic import op +from sqlalchemy import String +from sqlalchemy.sql import column, table + +# revision identifiers, used by Alembic. +revision: str = "bf86033947fc" +down_revision: str | None = "b62dbd9468e4" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + +status_table = table("status", column("name", String)) +status_renames = {"Todo": "todo", "In Progress": "in_progress", "Review": "review", "Done": "done"} + + +def upgrade() -> None: + for k, v in status_renames.items(): + op.execute( + status_table.update() + .where(status_table.c.name == op.inline_literal(k)) + .values({"name": op.inline_literal(v)}) + ) + + +def downgrade() -> None: + for k, v in status_renames.items(): + op.execute( + status_table.update() + .where(status_table.c.name == op.inline_literal(v)) + .values({"name": op.inline_literal(k)}) + ) diff --git a/tests/core/test_db.py b/tests/core/test_db.py index dbc7a8a9..65eb7f47 100644 --- a/tests/core/test_db.py +++ b/tests/core/test_db.py @@ -2,11 +2,26 @@ from unittest.mock import MagicMock import pytest -from sqlmodel import Session, select +import tad.core.db as testdb +from sqlmodel import Session, SQLModel, select from tad.core.config import Settings -from tad.core.db import check_db, init_db +from tad.core.db import ( + add_demo_statuses, + add_demo_tasks, + add_demo_users, + check_db, + init_db, + remove_old_demo_objects, +) from tad.models import Status, Task, User +from tests.constants import ( + default_status, + default_task, + default_user, +) +from tests.database_test_utils import DatabaseTestUtils + logger = logging.getLogger(__name__) @@ -22,46 +37,125 @@ def test_check_database(): @pytest.mark.parametrize( "patch_settings", - [{"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}], + [ + ({"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}), + ({"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": False}), + ({"ENVIRONMENT": "production"}), + ], indirect=True, ) -def test_init_database_none(patch_settings: Settings): - org_exec = Session.exec - Session.exec = MagicMock() - Session.exec.return_value.first.return_value = None +def test_init_database(patch_settings: Settings): + remove_old_demo_objects_orig = testdb.remove_old_demo_objects + testdb.remove_old_demo_objects = MagicMock() + create_all_orig = SQLModel.metadata.create_all + SQLModel.metadata.create_all = MagicMock() init_db() - expected = [ - (select(User).where(User.name == "Robbert"),), - (select(Status).where(Status.name == "Todo"),), - (select(Task).where(Task.title == "First task"),), - ] + if patch_settings.ENVIRONMENT == "demo": + assert testdb.remove_old_demo_objects.called + if patch_settings.AUTO_CREATE_SCHEMA is True: + assert SQLModel.metadata.create_all.called - for i, call_args in enumerate(Session.exec.call_args_list): - assert str(expected[i][0]) == str(call_args.args[0]) + testdb.remove_old_demo_objects = remove_old_demo_objects_orig + SQLModel.metadata.create_all = create_all_orig - Session.exec = org_exec +def test_remove_old_demo_objects(db: DatabaseTestUtils): + org_delete = db.get_session().delete + db_session = db.get_session() + db_session.delete = MagicMock() -@pytest.mark.parametrize( - "patch_settings", - [{"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}], - indirect=True, -) -def test_init_database(patch_settings: Settings): - org_exec = Session.exec - Session.exec = MagicMock() + user = User(name="Robbert", avatar=None) + status = Status(name="Todo", sort_order=1) + task = Task(title="First task", description="This is the first task", sort_order=1, status_id=status.id) + db.given([user, status, task]) - init_db() + remove_old_demo_objects(db.get_session()) + assert db_session.delete.call_count == 3 - expected = [ - (select(User).where(User.name == "Robbert"),), - (select(Status).where(Status.name == "Todo"),), - (select(Task).where(Task.title == "First task"),), - ] + db.get_session().delete = org_delete - for i, call_args in enumerate(Session.exec.call_args_list): - assert str(expected[i][0]) == str(call_args.args[0]) - Session.exec = org_exec +def test_remove_old_demo_objects_nothing_to_delete(db: DatabaseTestUtils): + org_delete = db.get_session().delete + db_session = db.get_session() + db_session.delete = MagicMock() + + remove_old_demo_objects(db.get_session()) + assert db_session.delete.call_count == 0 + + db.get_session().delete = org_delete + + +def test_add_demo_user(db: DatabaseTestUtils): + user_names = [default_user().name] + add_demo_users(db.get_session(), user_names) + assert db.exists(User, User.name, user_names[0]) + + +def test_add_demo_user_nothing_to_add(db: DatabaseTestUtils): + db.given([default_user()]) + + orig_add = db.get_session().add + db_session = db.get_session() + db_session.add = MagicMock() + + add_demo_users(db.get_session(), [default_user().name]) + + assert db_session.add.call_count == 0 + + db.get_session().add = orig_add + + +def test_add_demo_status(db: DatabaseTestUtils): + add_demo_statuses(db.get_session(), [default_status().name]) + assert db.exists(Status, Status.name, default_status().name) + + +def test_add_demo_status_nothing_to_add(db: DatabaseTestUtils): + db.given([default_status()]) + + orig_add = db.get_session().add + db_session = db.get_session() + db_session.add = MagicMock() + + add_demo_statuses(db.get_session(), [default_status().name]) + + assert db_session.add.call_count == 0 + db.get_session().add = orig_add + + +def test_add_demo_tasks(db: DatabaseTestUtils): + add_demo_tasks(db.get_session(), default_status(), 3) + assert db.exists(Task, Task.title, "Example task 1") + assert db.exists(Task, Task.title, "Example task 2") + assert db.exists(Task, Task.title, "Example task 3") + + +def test_add_demo_tasks_nothing_to_add(db: DatabaseTestUtils): + db.given( + [ + default_task(title="Example task 1"), + default_task(title="Example task 2"), + default_task(title="Example task 3"), + ] + ) + + orig_add = db.get_session().add + db_session = db.get_session() + db_session.add = MagicMock() + + add_demo_tasks(db.get_session(), default_status(), 3) + assert db_session.add.call_count == 0 + db.get_session().add = orig_add + + +def test_add_demo_tasks_no_status(db: DatabaseTestUtils): + orig_add = db.get_session().add + db_session = db.get_session() + db_session.add = MagicMock() + + add_demo_tasks(db.get_session(), None, 3) + assert db_session.add.call_count == 0 + db.get_session().add = orig_add diff --git a/tests/database_test_utils.py b/tests/database_test_utils.py index 539efb60..9c04af7a 100644 --- a/tests/database_test_utils.py +++ b/tests/database_test_utils.py @@ -1,5 +1,5 @@ from pydantic import BaseModel -from sqlmodel import Session, SQLModel +from sqlmodel import Session, SQLModel, select from tad.core.db import get_engine @@ -29,3 +29,6 @@ def given(self, models: list[BaseModel]) -> None: def get_session(self) -> Session: return self.session + + def exists(self, model: type[SQLModel], model_field: str, field_value: str | int) -> SQLModel | None: + return self.get_session().exec(select(model).where(model_field == field_value)).one() is not None # type: ignore