Skip to content

Commit

Permalink
Add test for database init on demo mode
Browse files Browse the repository at this point in the history
  • Loading branch information
berrydenhartog committed Jun 11, 2024
1 parent 9b81d75 commit 11a7468
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 20 deletions.
4 changes: 0 additions & 4 deletions tad/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,6 @@ def SQLALCHEMY_DATABASE_URI(self) -> str:
)
)

@SQLALCHEMY_DATABASE_URI.setter
def SQLALCHEMY_DATABASE_URI(self, value: str) -> None:
self.SQLALCHEMY_DATABASE_URI = value

@model_validator(mode="after")
def _enforce_database_rules(self: SelfSettings) -> SelfSettings:
if self.ENVIRONMENT == "production" and self.APP_DATABASE_SCHEME == "sqlite":
Expand Down
6 changes: 3 additions & 3 deletions tad/services/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,18 @@ def close(self) -> None:

class WriterFactory:
@staticmethod
def get_writer(writer_type: str = "file", **kwargs: str) -> Writer:
def get_writer(writer_type: str = "file", **kwargs: Any) -> Writer:
match writer_type:
case "file":
if not all(k in kwargs for k in ("location", "filename")):
raise KeyError("The `location` or `filename` variables are not provided as input for get_writer()")
return FileSystemWriteService(location=str(kwargs["location"]), filename=str(kwargs["filename"]))
return FileSystemWriteService(location=Path(kwargs["location"]), filename=str(kwargs["filename"]))
case _:
raise ValueError(f"Unknown writer type: {writer_type}")


class FileSystemWriteService(Writer):
def __init__(self, location: str = "./tests/data", filename: str = "system_card.yaml") -> None:
def __init__(self, location: str | Path = "./tests/data", filename: str = "system_card.yaml") -> None:
self.location = location
if not filename.endswith(".yaml"):
raise ValueError(f"Filename {filename} must end with .yaml instead of .{filename.split('.')[-1]}")
Expand Down
2 changes: 1 addition & 1 deletion tad/services/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def move_task(
self.storage_writer.write(self.system_card.model_dump())

if not isinstance(status.id, int):
raise TypeError("status_id must be an integer")
raise TypeError("status_id must be an integer") # pragma: no cover

# assign the task to the current user
if status.id > 1:
Expand Down
39 changes: 35 additions & 4 deletions tests/core/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlmodel import Session, select
from tad.core.config import Settings
from tad.core.db import check_db, init_db
from tad.models import Status, Task, User

logger = logging.getLogger(__name__)

Expand All @@ -21,9 +22,31 @@ def test_check_database():

@pytest.mark.parametrize(
"patch_settings",
[
{"ENVIRONMENT": "demo"},
],
[{"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}],
indirect=True,
)
def test_init_database_none(patch_settings: Settings):
org_exec = Session.exec
Session.exec = MagicMock()
Session.exec.return_value.first.return_value = None

init_db()

expected = [
(select(User).where(User.name == "Robbert"),),
(select(Status).where(Status.name == "Todo"),),
(select(Task).where(Task.title == "First task"),),
]

for i, call_args in enumerate(Session.exec.call_args_list):
assert str(expected[i][0]) == str(call_args.args[0])

Session.exec = org_exec


@pytest.mark.parametrize(
"patch_settings",
[{"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}],
indirect=True,
)
def test_init_database(patch_settings: Settings):
Expand All @@ -32,5 +55,13 @@ def test_init_database(patch_settings: Settings):

init_db()

Session.exec.assert_called()
expected = [
(select(User).where(User.name == "Robbert"),),
(select(Status).where(Status.name == "Todo"),),
(select(Task).where(Task.title == "First task"),),
]

for i, call_args in enumerate(Session.exec.call_args_list):
assert str(expected[i][0]) == str(call_args.args[0])

Session.exec = org_exec
28 changes: 20 additions & 8 deletions tests/services/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@


@pytest.fixture()
def setup_and_teardown(tmp_path: Path) -> tuple[str, str]:
def setup_and_teardown(tmp_path: Path) -> tuple[str, Path]:
filename = "test.yaml"
return filename, str(tmp_path.absolute())
return filename, tmp_path.absolute()


def test_file_system_writer_empty_yaml(setup_and_teardown: tuple[str, str]) -> None:
def test_file_system_writer_empty_yaml(setup_and_teardown: tuple[str, Path]) -> None:
filename, location = setup_and_teardown

storage_writer = WriterFactory.get_writer(writer_type="file", location=location, filename=filename)
Expand All @@ -21,23 +21,23 @@ def test_file_system_writer_empty_yaml(setup_and_teardown: tuple[str, str]) -> N
assert Path.is_file(Path(location) / filename), True


def test_file_system_writer_no_location_variable(setup_and_teardown: tuple[str, str]) -> None:
def test_file_system_writer_no_location_variable(setup_and_teardown: tuple[str, Path]) -> None:
filename, _ = setup_and_teardown
with pytest.raises(
KeyError, match="The `location` or `filename` variables are not provided as input for get_writer()"
):
WriterFactory.get_writer(writer_type="file", filename=filename)


def test_file_system_writer_no_filename_variable(setup_and_teardown: tuple[str, str]) -> None:
def test_file_system_writer_no_filename_variable(setup_and_teardown: tuple[str, Path]) -> None:
_, location = setup_and_teardown
with pytest.raises(
KeyError, match="The `location` or `filename` variables are not provided as input for get_writer()"
):
WriterFactory.get_writer(writer_type="file", location=location)


def test_file_system_writer_yaml_with_content(setup_and_teardown: tuple[str, str]) -> None:
def test_file_system_writer_yaml_with_content(setup_and_teardown: tuple[str, Path]) -> None:
filename, location = setup_and_teardown
data = {"test": "test"}
storage_writer = WriterFactory.get_writer(writer_type="file", location=location, filename=filename)
Expand All @@ -47,7 +47,19 @@ def test_file_system_writer_yaml_with_content(setup_and_teardown: tuple[str, str
assert safe_load(f) == data, True


def test_file_system_writer_with_system_card(setup_and_teardown: tuple[str, str]) -> None:
def test_file_system_writer_yaml_with_content_in_dir(setup_and_teardown: tuple[str, Path]) -> None:
filename, location = setup_and_teardown
data = {"test": "test"}

new_location = Path(location) / "new_dir"
storage_writer = WriterFactory.get_writer(writer_type="file", location=new_location, filename=filename)
storage_writer.write(data)

with open(new_location / filename) as f:
assert safe_load(f) == data, True


def test_file_system_writer_with_system_card(setup_and_teardown: tuple[str, Path]) -> None:
filename, location = setup_and_teardown
data = SystemCard()
data.title = "test"
Expand All @@ -60,7 +72,7 @@ def test_file_system_writer_with_system_card(setup_and_teardown: tuple[str, str]
assert safe_load(f) == data_dict, True


def test_abstract_writer_non_yaml_filename(setup_and_teardown: tuple[str, str]) -> None:
def test_abstract_writer_non_yaml_filename(setup_and_teardown: tuple[str, Path]) -> None:
_, location = setup_and_teardown
filename = "test.csv"
with pytest.raises(
Expand Down

0 comments on commit 11a7468

Please sign in to comment.