diff --git a/BUILD.md b/BUILD.md index 5d5420684..5a0152368 100644 --- a/BUILD.md +++ b/BUILD.md @@ -23,6 +23,13 @@ When poetry is done installing all dependencies you can start using the tool. poetry run python -m uvicorn tad.main:app --log-level warning ``` +### Suggested development ENVIRONMENT settings +To use a demo environment during local development, you can use the following environment options. You can leave out the TRUNCATE_TABLES option if you wish to keep the state between runs. +```shell +export ENVIRONMENT=demo AUTO_CREATE_SCHEMA=true TRUNCATE_TABLES=true +``` + + ## Database We support most SQL database types. You can use the variable `APP_DATABASE_SCHEME` to change the database. The default scheme is sqlite. diff --git a/tad/core/config.py b/tad/core/config.py index 71b802014..4b1113b8d 100644 --- a/tad/core/config.py +++ b/tad/core/config.py @@ -33,6 +33,7 @@ class Settings(BaseSettings): DEBUG: bool = False AUTO_CREATE_SCHEMA: bool = False + TRUNCATE_TABLES: bool = False # todo(berry): create submodel for database settings APP_DATABASE_SCHEME: DatabaseSchemaType = "sqlite" diff --git a/tad/core/db.py b/tad/core/db.py index e07be8b13..2b489de6a 100644 --- a/tad/core/db.py +++ b/tad/core/db.py @@ -3,7 +3,7 @@ from sqlalchemy.engine import Engine from sqlalchemy.pool import QueuePool, StaticPool -from sqlmodel import Session, SQLModel, create_engine, select +from sqlmodel import Session, SQLModel, create_engine, delete, select from tad.core.config import get_settings from tad.models import Status, Task, User @@ -31,7 +31,20 @@ def check_db(): with Session(get_engine()) as session: session.exec(select(1)) - logger.info("Finisch Checking database connection") + logger.info("Finish Checking database connection") + + +def remove_old_demo_objects(session: Session): + user = session.exec(select(User).where(User.name == "Robbert")).first() + if user: + session.delete(user) + status = session.exec(select(Status).where(Status.name == "Todo")).first() + if status: + session.delete(status) + task = session.exec(select(Task).where(Task.title == "First task")).first() + if task: + session.delete(task) + session.commit() def init_db(): @@ -43,21 +56,59 @@ def init_db(): with Session(get_engine()) as session: if get_settings().ENVIRONMENT == "demo": + if get_settings().TRUNCATE_TABLES: + truncate_tables(session) logger.info("Creating demo data") - - user = session.exec(select(User).where(User.name == "Robbert")).first() - if not user: - user = User(name="Robbert", avatar=None) - session.add(user) - - status = session.exec(select(Status).where(Status.name == "Todo")).first() - if not status: - status = Status(name="Todo", sort_order=1) - session.add(status) - - task = session.exec(select(Task).where(Task.title == "First task")).first() - if not task: - task = Task(title="First task", description="This is the first task", sort_order=1, status_id=status.id) - session.add(task) - session.commit() + remove_old_demo_objects(session) + add_demo_users(session, ["default user"]) + demo_statuses = add_demo_statuses(session, ["todo", "review", "in_progress", "done"]) + add_demo_tasks(session, demo_statuses[0], 3) logger.info("Finished initializing database") + + +def truncate_tables(session: Session) -> None: + logger.info("Truncating tables") + session.exec(delete(Task)) # type: ignore + session.exec(delete(User)) # type: ignore + session.exec(delete(Status)) # type: ignore + session.commit() + + +def add_demo_users(session: Session, user_names: list[str]) -> None: + for user_name in user_names: + user = session.exec(select(User).where(User.name == user_name)).first() + if not user: + session.add(User(name=user_name, avatar=None)) + session.commit() + + +def add_demo_tasks(session: Session, status: Status | None, number_of_tasks: int) -> None: + if status is None: + return + for index in range(1, number_of_tasks + 1): + title = "Example task " + str(index) + task = session.exec(select(Task).where(Task.title == title)).first() + if not task: + session.add( + Task( + title=title, + description="Example description " + str(index), + sort_order=index, + status_id=status.id, + ) + ) + session.commit() + + +def add_demo_statuses(session: Session, statuses: list[str]) -> list[Status]: + return_statuses: list[Status] = [] + for index, status_name in enumerate(statuses): + status = session.exec(select(Status).where(Status.name == status_name)).first() + if not status: + status = Status(name=status_name, sort_order=index + 1) + session.add(status) + return_statuses.append(status) + session.commit() + for return_status in return_statuses: + session.refresh(return_status) + return return_statuses diff --git a/tad/migrations/versions/b62dbd9468e4_create_status_user_and_task_table.py b/tad/migrations/versions/b62dbd9468e4_create_status_user_and_task_table.py index d5912d550..ab57151a7 100644 --- a/tad/migrations/versions/b62dbd9468e4_create_status_user_and_task_table.py +++ b/tad/migrations/versions/b62dbd9468e4_create_status_user_and_task_table.py @@ -21,7 +21,7 @@ def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - status = op.create_table( + op.create_table( "status", sa.Column("id", sa.Integer(), nullable=False), sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), @@ -56,17 +56,6 @@ def upgrade() -> None: ) # ### end Alembic commands ### - # ### custom commands ### - op.bulk_insert( - status, - [ - {"name": "Todo", "sort_order": 1}, - {"name": "In Progress", "sort_order": 2}, - {"name": "Review", "sort_order": 3}, - {"name": "Done", "sort_order": 4}, - ], - ) - def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### diff --git a/tests/core/test_db.py b/tests/core/test_db.py index dbc7a8a99..5e9fa5030 100644 --- a/tests/core/test_db.py +++ b/tests/core/test_db.py @@ -1,12 +1,28 @@ import logging +from typing import Any from unittest.mock import MagicMock import pytest from sqlmodel import Session, select from tad.core.config import Settings -from tad.core.db import check_db, init_db +from tad.core.db import ( + add_demo_statuses, + add_demo_tasks, + add_demo_users, + check_db, + init_db, + remove_old_demo_objects, + truncate_tables, +) from tad.models import Status, Task, User +from tests.constants import ( + default_status, + default_task, + default_user, +) +from tests.database_test_utils import DatabaseTestUtils + logger = logging.getLogger(__name__) @@ -22,46 +38,114 @@ def test_check_database(): @pytest.mark.parametrize( "patch_settings", - [{"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}], - indirect=True, -) -def test_init_database_none(patch_settings: Settings): - org_exec = Session.exec - Session.exec = MagicMock() - Session.exec.return_value.first.return_value = None - - init_db() - - expected = [ - (select(User).where(User.name == "Robbert"),), - (select(Status).where(Status.name == "Todo"),), - (select(Task).where(Task.title == "First task"),), - ] - - for i, call_args in enumerate(Session.exec.call_args_list): - assert str(expected[i][0]) == str(call_args.args[0]) - - Session.exec = org_exec - - -@pytest.mark.parametrize( - "patch_settings", - [{"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}], + [ + {"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}, + {"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": False}, + {"ENVIRONMENT": "demo", "TRUNCATE_TABLES": True}, + {"ENVIRONMENT": "production", "TRUNCATE_TABLES": True}, + ], indirect=True, ) def test_init_database(patch_settings: Settings): - org_exec = Session.exec - Session.exec = MagicMock() - init_db() - expected = [ - (select(User).where(User.name == "Robbert"),), - (select(Status).where(Status.name == "Todo"),), - (select(Task).where(Task.title == "First task"),), - ] - for i, call_args in enumerate(Session.exec.call_args_list): - assert str(expected[i][0]) == str(call_args.args[0]) - - Session.exec = org_exec +def test_remove_old_demo_objects(db: DatabaseTestUtils, monkeypatch: pytest.MonkeyPatch): + user = User(name="Robbert", avatar=None) + status = Status(name="Todo", sort_order=1) + task = Task(title="First task", description="This is the first task", sort_order=1, status_id=status.id) + + db.given([user, status, task]) + session_delete_counter = count_calls(monkeypatch, db.get_session(), db.get_session().delete) + remove_old_demo_objects(db.get_session()) + assert session_delete_counter.get() == 3 + + session_delete_counter = count_calls(monkeypatch, db.get_session(), db.get_session().delete) + remove_old_demo_objects(db.get_session()) + assert session_delete_counter.get() == 0 + + +def test_truncate_tables(db: DatabaseTestUtils): + db.given([default_task(), default_user(), default_status()]) + truncate_tables(db.get_session()) + assert not db.exists(Task.title, default_task().title) + assert not db.exists(User.name, default_user().name) + assert not db.exists(Status.name, default_status().name) + + +def test_add_demo_user(db: DatabaseTestUtils, monkeypatch: pytest.MonkeyPatch): + user_names = [default_user().name] + + session_add_counter = count_calls(monkeypatch, db.get_session(), db.get_session().add) + add_demo_users(db.get_session(), user_names) + assert db.exists(User.name, user_names[0]) + assert session_add_counter.get() == 1 + # test again, the user already exists so new new user should be created + session_add_counter = count_calls(monkeypatch, db.get_session(), db.get_session().add) + add_demo_users(db.get_session(), user_names) + assert db.exists(User.name, user_names[0]) + assert session_add_counter.get() == 0 + + +def test_add_demo_status(db: DatabaseTestUtils, monkeypatch: pytest.MonkeyPatch): + session_add_counter = count_calls(monkeypatch, db.get_session(), db.get_session().add) + add_demo_statuses(db.get_session(), [default_status().name]) + assert db.exists(Status.name, default_status().name) + assert session_add_counter.get() == 1 + # test again, the status already exists so no new status should be created + session_add_counter = count_calls(monkeypatch, db.get_session(), db.get_session().add) + add_demo_statuses(db.get_session(), [default_status().name]) + assert db.exists(Status.name, default_status().name) + assert session_add_counter.get() == 0 + + +def test_add_demo_tasks(db: DatabaseTestUtils, monkeypatch: pytest.MonkeyPatch): + db.given([default_status()]) + session_add_counter = count_calls(monkeypatch, db.get_session(), db.get_session().add) + add_demo_tasks(db.get_session(), default_status(), 3) + assert session_add_counter.get() == 3 + assert db.exists(Task.title, "Example task 1") + # test again, the tasks already exist so no new cards should be created + session_add_counter = count_calls(monkeypatch, db.get_session(), db.get_session().add) + add_demo_tasks(db.get_session(), default_status(), 3) + assert session_add_counter.get() == 0 + assert db.exists(Task.title, "Example task 1") + # test without status + session_add_counter = count_calls(monkeypatch, db.get_session(), db.get_session().add) + add_demo_tasks(db.get_session(), None, 1) + assert session_add_counter.get() == 0 + + +class MutableCounter: + """ + Used by "count_calls". + """ + + n = 0 + + def inc(self): + self.n += 1 + + def get(self): + return self.n + + def reset(self): + self.n = 0 + + +def count_calls(monkeypatch: pytest.MonkeyPatch, module: object, fn: object) -> MutableCounter: + """ + Returns a mutable object containing the number of times the given function + "fn", from the module "module", was called. + Intended to be used inside pytest functions. + """ + cnt = MutableCounter() + + def mock_fn(*args: Any, **kwargs: Any) -> Any: + nonlocal cnt + cnt.inc() + return fn(*args, **kwargs) # type: ignore + + mock_fn.__name__ = fn.__name__ # type: ignore + monkeypatch.setattr(module, fn.__name__, mock_fn) # type: ignore + return cnt diff --git a/tests/database_test_utils.py b/tests/database_test_utils.py index 539efb609..c5aa1dcf2 100644 --- a/tests/database_test_utils.py +++ b/tests/database_test_utils.py @@ -1,5 +1,5 @@ from pydantic import BaseModel -from sqlmodel import Session, SQLModel +from sqlmodel import Session, SQLModel, exists from tad.core.db import get_engine @@ -29,3 +29,7 @@ def given(self, models: list[BaseModel]) -> None: def get_session(self) -> Session: return self.session + + def exists(self, model_field: str, field_value: str | int) -> SQLModel | None: + # todo (robbert): this should be an exec + return self.get_session().query(exists().where(model_field == field_value)).scalar() # type: ignore