Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more demo objects to the demo suite #64

Merged
merged 1 commit into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions BUILD.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ When poetry is done installing all dependencies you can start using the tool.
poetry run python -m uvicorn tad.main:app --log-level warning
```

### Suggested development ENVIRONMENT settings
To use a demo environment during local development, you can use the following environment options.
```shell
export ENVIRONMENT=demo AUTO_CREATE_SCHEMA=true
```


## Database

We support most SQL database types. You can use the variable `APP_DATABASE_SCHEME` to change the database. The default scheme is sqlite.
Expand Down
75 changes: 58 additions & 17 deletions tad/core/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,20 @@ def check_db():
with Session(get_engine()) as session:
session.exec(select(1))

logger.info("Finisch Checking database connection")
logger.info("Finish Checking database connection")


def remove_old_demo_objects(session: Session):
user = session.exec(select(User).where(User.name == "Robbert")).first()
if user:
session.delete(user)
status = session.exec(select(Status).where(Status.name == "Todo")).first()
if status:
session.delete(status)
task = session.exec(select(Task).where(Task.title == "First task")).first()
if task:
session.delete(task)
session.commit()


def init_db():
Expand All @@ -44,20 +57,48 @@ def init_db():
with Session(get_engine()) as session:
if get_settings().ENVIRONMENT == "demo":
logger.info("Creating demo data")

user = session.exec(select(User).where(User.name == "Robbert")).first()
if not user:
user = User(name="Robbert", avatar=None)
session.add(user)

status = session.exec(select(Status).where(Status.name == "Todo")).first()
if not status:
status = Status(name="Todo", sort_order=1)
session.add(status)

task = session.exec(select(Task).where(Task.title == "First task")).first()
if not task:
task = Task(title="First task", description="This is the first task", sort_order=1, status_id=status.id)
session.add(task)
session.commit()
remove_old_demo_objects(session)
uittenbroekrobbert marked this conversation as resolved.
Show resolved Hide resolved
add_demo_users(session, ["default user"])
demo_statuses = add_demo_statuses(session, ["todo", "review", "in_progress", "done"])
add_demo_tasks(session, demo_statuses[0], 3)
logger.info("Finished initializing database")


def add_demo_users(session: Session, user_names: list[str]) -> None:
for user_name in user_names:
user = session.exec(select(User).where(User.name == user_name)).first()
if not user:
session.add(User(name=user_name, avatar=None))
session.commit()


def add_demo_tasks(session: Session, status: Status | None, number_of_tasks: int) -> None:
if status is None:
return
for index in range(1, number_of_tasks + 1):
title = "Example task " + str(index)
task = session.exec(select(Task).where(Task.title == title)).first()
if not task:
session.add(
Task(
title=title,
description="Example description " + str(index),
sort_order=index,
status_id=status.id,
)
)
session.commit()


def add_demo_statuses(session: Session, statuses: list[str]) -> list[Status]:
return_statuses: list[Status] = []
for index, status_name in enumerate(statuses):
status = session.exec(select(Status).where(Status.name == status_name)).first()
if not status:
status = Status(name=status_name, sort_order=index + 1)
session.add(status)
return_statuses.append(status)
session.commit()
for return_status in return_statuses:
session.refresh(return_status)
return return_statuses
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""rename status fields for translation support

Revision ID: bf86033947fc
Revises: b62dbd9468e4
Create Date: 2024-07-05 08:54:30.114471

"""

from collections.abc import Sequence

from alembic import op
from sqlalchemy import String
from sqlalchemy.sql import column, table

# revision identifiers, used by Alembic.
revision: str = "bf86033947fc"
down_revision: str | None = "b62dbd9468e4"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None

status_table = table("status", column("name", String))
status_renames = {"Todo": "todo", "In Progress": "in_progress", "Review": "review", "Done": "done"}


def upgrade() -> None:
for k, v in status_renames.items():
op.execute(
status_table.update()
.where(status_table.c.name == op.inline_literal(k))
.values({"name": op.inline_literal(v)})
)


def downgrade() -> None:
for k, v in status_renames.items():
op.execute(
status_table.update()
.where(status_table.c.name == op.inline_literal(v))
.values({"name": op.inline_literal(k)})
)
158 changes: 126 additions & 32 deletions tests/core/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,26 @@
from unittest.mock import MagicMock

import pytest
from sqlmodel import Session, select
import tad.core.db as testdb
from sqlmodel import Session, SQLModel, select
from tad.core.config import Settings
from tad.core.db import check_db, init_db
from tad.core.db import (
add_demo_statuses,
add_demo_tasks,
add_demo_users,
check_db,
init_db,
remove_old_demo_objects,
)
from tad.models import Status, Task, User

from tests.constants import (
default_status,
default_task,
default_user,
)
from tests.database_test_utils import DatabaseTestUtils

logger = logging.getLogger(__name__)


Expand All @@ -22,46 +37,125 @@ def test_check_database():

@pytest.mark.parametrize(
"patch_settings",
[{"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}],
[
({"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}),
({"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": False}),
({"ENVIRONMENT": "production"}),
],
indirect=True,
)
def test_init_database_none(patch_settings: Settings):
org_exec = Session.exec
Session.exec = MagicMock()
Session.exec.return_value.first.return_value = None
def test_init_database(patch_settings: Settings):
remove_old_demo_objects_orig = testdb.remove_old_demo_objects
testdb.remove_old_demo_objects = MagicMock()
create_all_orig = SQLModel.metadata.create_all
SQLModel.metadata.create_all = MagicMock()

init_db()

expected = [
(select(User).where(User.name == "Robbert"),),
(select(Status).where(Status.name == "Todo"),),
(select(Task).where(Task.title == "First task"),),
]
if patch_settings.ENVIRONMENT == "demo":
uittenbroekrobbert marked this conversation as resolved.
Show resolved Hide resolved
assert testdb.remove_old_demo_objects.called
if patch_settings.AUTO_CREATE_SCHEMA is True:
assert SQLModel.metadata.create_all.called

for i, call_args in enumerate(Session.exec.call_args_list):
assert str(expected[i][0]) == str(call_args.args[0])
testdb.remove_old_demo_objects = remove_old_demo_objects_orig
SQLModel.metadata.create_all = create_all_orig

Session.exec = org_exec

def test_remove_old_demo_objects(db: DatabaseTestUtils):
org_delete = db.get_session().delete
db_session = db.get_session()
db_session.delete = MagicMock()

@pytest.mark.parametrize(
"patch_settings",
[{"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}],
indirect=True,
)
def test_init_database(patch_settings: Settings):
org_exec = Session.exec
Session.exec = MagicMock()
user = User(name="Robbert", avatar=None)
status = Status(name="Todo", sort_order=1)
task = Task(title="First task", description="This is the first task", sort_order=1, status_id=status.id)
db.given([user, status, task])

init_db()
remove_old_demo_objects(db.get_session())
assert db_session.delete.call_count == 3

expected = [
(select(User).where(User.name == "Robbert"),),
(select(Status).where(Status.name == "Todo"),),
(select(Task).where(Task.title == "First task"),),
]
db.get_session().delete = org_delete

for i, call_args in enumerate(Session.exec.call_args_list):
assert str(expected[i][0]) == str(call_args.args[0])

Session.exec = org_exec
def test_remove_old_demo_objects_nothing_to_delete(db: DatabaseTestUtils):
org_delete = db.get_session().delete
db_session = db.get_session()
db_session.delete = MagicMock()

remove_old_demo_objects(db.get_session())
assert db_session.delete.call_count == 0

db.get_session().delete = org_delete


def test_add_demo_user(db: DatabaseTestUtils):
user_names = [default_user().name]
add_demo_users(db.get_session(), user_names)
assert db.exists(User, User.name, user_names[0])


def test_add_demo_user_nothing_to_add(db: DatabaseTestUtils):
db.given([default_user()])

orig_add = db.get_session().add
db_session = db.get_session()
db_session.add = MagicMock()

add_demo_users(db.get_session(), [default_user().name])

assert db_session.add.call_count == 0

db.get_session().add = orig_add


def test_add_demo_status(db: DatabaseTestUtils):
add_demo_statuses(db.get_session(), [default_status().name])
assert db.exists(Status, Status.name, default_status().name)


def test_add_demo_status_nothing_to_add(db: DatabaseTestUtils):
db.given([default_status()])

orig_add = db.get_session().add
db_session = db.get_session()
db_session.add = MagicMock()

add_demo_statuses(db.get_session(), [default_status().name])

assert db_session.add.call_count == 0
db.get_session().add = orig_add


def test_add_demo_tasks(db: DatabaseTestUtils):
add_demo_tasks(db.get_session(), default_status(), 3)
assert db.exists(Task, Task.title, "Example task 1")
assert db.exists(Task, Task.title, "Example task 2")
assert db.exists(Task, Task.title, "Example task 3")


def test_add_demo_tasks_nothing_to_add(db: DatabaseTestUtils):
db.given(
[
default_task(title="Example task 1"),
default_task(title="Example task 2"),
default_task(title="Example task 3"),
]
)

orig_add = db.get_session().add
db_session = db.get_session()
db_session.add = MagicMock()

add_demo_tasks(db.get_session(), default_status(), 3)
assert db_session.add.call_count == 0
db.get_session().add = orig_add


def test_add_demo_tasks_no_status(db: DatabaseTestUtils):
orig_add = db.get_session().add
db_session = db.get_session()
db_session.add = MagicMock()

add_demo_tasks(db.get_session(), None, 3)
assert db_session.add.call_count == 0
db.get_session().add = orig_add
5 changes: 4 additions & 1 deletion tests/database_test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pydantic import BaseModel
from sqlmodel import Session, SQLModel
from sqlmodel import Session, SQLModel, select
from tad.core.db import get_engine


Expand Down Expand Up @@ -29,3 +29,6 @@ def given(self, models: list[BaseModel]) -> None:

def get_session(self) -> Session:
return self.session

def exists(self, model: type[SQLModel], model_field: str, field_value: str | int) -> SQLModel | None:
return self.get_session().exec(select(model).where(model_field == field_value)).one() is not None # type: ignore