Skip to content

Commit

Permalink
Add more demo objects to the demo suite
Browse files Browse the repository at this point in the history
  • Loading branch information
uittenbroekrobbert committed Jul 4, 2024
1 parent a857a7d commit 4c97b93
Show file tree
Hide file tree
Showing 6 changed files with 205 additions and 69 deletions.
7 changes: 7 additions & 0 deletions BUILD.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ When poetry is done installing all dependencies you can start using the tool.
poetry run python -m uvicorn tad.main:app --log-level warning
```

### Suggested development ENVIRONMENT settings
To use a demo environment during local development, you can use the following environment options. You can leave out the TRUNCATE_TABLES option if you wish to keep the state between runs.
```shell
export ENVIRONMENT=demo AUTO_CREATE_SCHEMA=true TRUNCATE_TABLES=true
```


## Database

We support most SQL database types. You can use the variable `APP_DATABASE_SCHEME` to change the database. The default scheme is sqlite.
Expand Down
1 change: 1 addition & 0 deletions tad/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Settings(BaseSettings):

DEBUG: bool = False
AUTO_CREATE_SCHEMA: bool = False
TRUNCATE_TABLES: bool = False

# todo(berry): create submodel for database settings
APP_DATABASE_SCHEME: DatabaseSchemaType = "sqlite"
Expand Down
87 changes: 69 additions & 18 deletions tad/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from sqlalchemy.engine import Engine
from sqlalchemy.pool import QueuePool, StaticPool
from sqlmodel import Session, SQLModel, create_engine, select
from sqlmodel import Session, SQLModel, create_engine, delete, select

from tad.core.config import get_settings
from tad.models import Status, Task, User
Expand Down Expand Up @@ -31,7 +31,20 @@ def check_db():
with Session(get_engine()) as session:
session.exec(select(1))

logger.info("Finisch Checking database connection")
logger.info("Finish Checking database connection")


def remove_old_demo_objects(session: Session):
user = session.exec(select(User).where(User.name == "Robbert")).first()
if user:
session.delete(user)
status = session.exec(select(Status).where(Status.name == "Todo")).first()
if status:
session.delete(status)
task = session.exec(select(Task).where(Task.title == "First task")).first()
if task:
session.delete(task)
session.commit()


def init_db():
Expand All @@ -43,21 +56,59 @@ def init_db():

with Session(get_engine()) as session:
if get_settings().ENVIRONMENT == "demo":
if get_settings().TRUNCATE_TABLES:
truncate_tables(session)
logger.info("Creating demo data")

user = session.exec(select(User).where(User.name == "Robbert")).first()
if not user:
user = User(name="Robbert", avatar=None)
session.add(user)

status = session.exec(select(Status).where(Status.name == "Todo")).first()
if not status:
status = Status(name="Todo", sort_order=1)
session.add(status)

task = session.exec(select(Task).where(Task.title == "First task")).first()
if not task:
task = Task(title="First task", description="This is the first task", sort_order=1, status_id=status.id)
session.add(task)
session.commit()
remove_old_demo_objects(session)
add_demo_users(session, ["default user"])
demo_statuses = add_demo_statuses(session, ["todo", "review", "in_progress", "done"])
add_demo_tasks(session, demo_statuses[0], 3)
logger.info("Finished initializing database")


def truncate_tables(session: Session) -> None:
logger.info("Truncating tables")
session.exec(delete(Task)) # type: ignore
session.exec(delete(User)) # type: ignore
session.exec(delete(Status)) # type: ignore
session.commit()


def add_demo_users(session: Session, user_names: list[str]) -> None:
for user_name in user_names:
user = session.exec(select(User).where(User.name == user_name)).first()
if not user:
session.add(User(name=user_name, avatar=None))
session.commit()


def add_demo_tasks(session: Session, status: Status | None, number_of_tasks: int) -> None:
if status is None:
return
for index in range(1, number_of_tasks + 1):
title = "Example task " + str(index)
task = session.exec(select(Task).where(Task.title == title)).first()
if not task:
session.add(
Task(
title=title,
description="Example description " + str(index),
sort_order=index,
status_id=status.id,
)
)
session.commit()


def add_demo_statuses(session: Session, statuses: list[str]) -> list[Status]:
return_statuses: list[Status] = []
for index, status_name in enumerate(statuses):
status = session.exec(select(Status).where(Status.name == status_name)).first()
if not status:
status = Status(name=status_name, sort_order=index + 1)
session.add(status)
return_statuses.append(status)
session.commit()
for return_status in return_statuses:
session.refresh(return_status)
return return_statuses
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
status = op.create_table(
op.create_table(
"status",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
Expand Down Expand Up @@ -56,17 +56,6 @@ def upgrade() -> None:
)
# ### end Alembic commands ###

# ### custom commands ###
op.bulk_insert(
status,
[
{"name": "Todo", "sort_order": 1},
{"name": "In Progress", "sort_order": 2},
{"name": "Review", "sort_order": 3},
{"name": "Done", "sort_order": 4},
],
)


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
Expand Down
160 changes: 122 additions & 38 deletions tests/core/test_db.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
import logging
from typing import Any
from unittest.mock import MagicMock

import pytest
from sqlmodel import Session, select
from tad.core.config import Settings
from tad.core.db import check_db, init_db
from tad.core.db import (
add_demo_statuses,
add_demo_tasks,
add_demo_users,
check_db,
init_db,
remove_old_demo_objects,
truncate_tables,
)
from tad.models import Status, Task, User

from tests.constants import (
default_status,
default_task,
default_user,
)
from tests.database_test_utils import DatabaseTestUtils

logger = logging.getLogger(__name__)


Expand All @@ -22,46 +38,114 @@ def test_check_database():

@pytest.mark.parametrize(
"patch_settings",
[{"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}],
indirect=True,
)
def test_init_database_none(patch_settings: Settings):
org_exec = Session.exec
Session.exec = MagicMock()
Session.exec.return_value.first.return_value = None

init_db()

expected = [
(select(User).where(User.name == "Robbert"),),
(select(Status).where(Status.name == "Todo"),),
(select(Task).where(Task.title == "First task"),),
]

for i, call_args in enumerate(Session.exec.call_args_list):
assert str(expected[i][0]) == str(call_args.args[0])

Session.exec = org_exec


@pytest.mark.parametrize(
"patch_settings",
[{"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}],
[
{"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True},
{"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": False},
{"ENVIRONMENT": "demo", "TRUNCATE_TABLES": True},
{"ENVIRONMENT": "production", "TRUNCATE_TABLES": True},
],
indirect=True,
)
def test_init_database(patch_settings: Settings):
org_exec = Session.exec
Session.exec = MagicMock()

init_db()

expected = [
(select(User).where(User.name == "Robbert"),),
(select(Status).where(Status.name == "Todo"),),
(select(Task).where(Task.title == "First task"),),
]

for i, call_args in enumerate(Session.exec.call_args_list):
assert str(expected[i][0]) == str(call_args.args[0])

Session.exec = org_exec
def test_remove_old_demo_objects(db: DatabaseTestUtils, monkeypatch: pytest.MonkeyPatch):
user = User(name="Robbert", avatar=None)
status = Status(name="Todo", sort_order=1)
task = Task(title="First task", description="This is the first task", sort_order=1, status_id=status.id)

db.given([user, status, task])
session_delete_counter = count_calls(monkeypatch, db.get_session(), db.get_session().delete)
remove_old_demo_objects(db.get_session())
assert session_delete_counter.get() == 3

session_delete_counter = count_calls(monkeypatch, db.get_session(), db.get_session().delete)
remove_old_demo_objects(db.get_session())
assert session_delete_counter.get() == 0


def test_truncate_tables(db: DatabaseTestUtils):
db.given([default_task(), default_user(), default_status()])
truncate_tables(db.get_session())
assert not db.exists(Task.title, default_task().title)
assert not db.exists(User.name, default_user().name)
assert not db.exists(Status.name, default_status().name)


def test_add_demo_user(db: DatabaseTestUtils, monkeypatch: pytest.MonkeyPatch):
user_names = [default_user().name]

session_add_counter = count_calls(monkeypatch, db.get_session(), db.get_session().add)
add_demo_users(db.get_session(), user_names)
assert db.exists(User.name, user_names[0])
assert session_add_counter.get() == 1
# test again, the user already exists so new new user should be created
session_add_counter = count_calls(monkeypatch, db.get_session(), db.get_session().add)
add_demo_users(db.get_session(), user_names)
assert db.exists(User.name, user_names[0])
assert session_add_counter.get() == 0


def test_add_demo_status(db: DatabaseTestUtils, monkeypatch: pytest.MonkeyPatch):
session_add_counter = count_calls(monkeypatch, db.get_session(), db.get_session().add)
add_demo_statuses(db.get_session(), [default_status().name])
assert db.exists(Status.name, default_status().name)
assert session_add_counter.get() == 1
# test again, the status already exists so no new status should be created
session_add_counter = count_calls(monkeypatch, db.get_session(), db.get_session().add)
add_demo_statuses(db.get_session(), [default_status().name])
assert db.exists(Status.name, default_status().name)
assert session_add_counter.get() == 0


def test_add_demo_tasks(db: DatabaseTestUtils, monkeypatch: pytest.MonkeyPatch):
db.given([default_status()])
session_add_counter = count_calls(monkeypatch, db.get_session(), db.get_session().add)
add_demo_tasks(db.get_session(), default_status(), 3)
assert session_add_counter.get() == 3
assert db.exists(Task.title, "Example task 1")
# test again, the tasks already exist so no new cards should be created
session_add_counter = count_calls(monkeypatch, db.get_session(), db.get_session().add)
add_demo_tasks(db.get_session(), default_status(), 3)
assert session_add_counter.get() == 0
assert db.exists(Task.title, "Example task 1")
# test without status
session_add_counter = count_calls(monkeypatch, db.get_session(), db.get_session().add)
add_demo_tasks(db.get_session(), None, 1)
assert session_add_counter.get() == 0


class MutableCounter:
"""
Used by "count_calls".
"""

n = 0

def inc(self):
self.n += 1

def get(self):
return self.n

def reset(self):
self.n = 0


def count_calls(monkeypatch: pytest.MonkeyPatch, module: object, fn: object) -> MutableCounter:
"""
Returns a mutable object containing the number of times the given function
"fn", from the module "module", was called.
Intended to be used inside pytest functions.
"""
cnt = MutableCounter()

def mock_fn(*args: Any, **kwargs: Any) -> Any:
nonlocal cnt
cnt.inc()
return fn(*args, **kwargs) # type: ignore

mock_fn.__name__ = fn.__name__ # type: ignore
monkeypatch.setattr(module, fn.__name__, mock_fn) # type: ignore
return cnt
6 changes: 5 additions & 1 deletion tests/database_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pydantic import BaseModel
from sqlmodel import Session, SQLModel
from sqlmodel import Session, SQLModel, exists
from tad.core.db import get_engine


Expand Down Expand Up @@ -29,3 +29,7 @@ def given(self, models: list[BaseModel]) -> None:

def get_session(self) -> Session:
return self.session

def exists(self, model_field: str, field_value: str | int) -> SQLModel | None:
# todo (robbert): this should be an exec
return self.get_session().query(exists().where(model_field == field_value)).scalar() # type: ignore

0 comments on commit 4c97b93

Please sign in to comment.