diff --git a/tad/core/config.py b/tad/core/config.py index 284b56264..71b802014 100644 --- a/tad/core/config.py +++ b/tad/core/config.py @@ -76,10 +76,6 @@ def SQLALCHEMY_DATABASE_URI(self) -> str: ) ) - @SQLALCHEMY_DATABASE_URI.setter - def SQLALCHEMY_DATABASE_URI(self, value: str) -> None: - self.SQLALCHEMY_DATABASE_URI = value - @model_validator(mode="after") def _enforce_database_rules(self: SelfSettings) -> SelfSettings: if self.ENVIRONMENT == "production" and self.APP_DATABASE_SCHEME == "sqlite": diff --git a/tad/services/storage.py b/tad/services/storage.py index f62de852f..fdf842fb7 100644 --- a/tad/services/storage.py +++ b/tad/services/storage.py @@ -17,18 +17,18 @@ def close(self) -> None: class WriterFactory: @staticmethod - def get_writer(writer_type: str = "file", **kwargs: str) -> Writer: + def get_writer(writer_type: str = "file", **kwargs: Any) -> Writer: match writer_type: case "file": if not all(k in kwargs for k in ("location", "filename")): raise KeyError("The `location` or `filename` variables are not provided as input for get_writer()") - return FileSystemWriteService(location=str(kwargs["location"]), filename=str(kwargs["filename"])) + return FileSystemWriteService(location=Path(kwargs["location"]), filename=str(kwargs["filename"])) case _: raise ValueError(f"Unknown writer type: {writer_type}") class FileSystemWriteService(Writer): - def __init__(self, location: str = "./tests/data", filename: str = "system_card.yaml") -> None: + def __init__(self, location: str | Path = "./tests/data", filename: str = "system_card.yaml") -> None: self.location = location if not filename.endswith(".yaml"): raise ValueError(f"Filename {filename} must end with .yaml instead of .{filename.split('.')[-1]}") diff --git a/tad/services/tasks.py b/tad/services/tasks.py index 820b26071..13c70ac42 100644 --- a/tad/services/tasks.py +++ b/tad/services/tasks.py @@ -53,7 +53,7 @@ def move_task( self.storage_writer.write(self.system_card.model_dump()) if not isinstance(status.id, int): - raise TypeError("status_id must be an integer") + raise TypeError("status_id must be an integer") # pragma: no cover # assign the task to the current user if status.id > 1: diff --git a/tests/core/test_db.py b/tests/core/test_db.py index c01c99313..dbc7a8a99 100644 --- a/tests/core/test_db.py +++ b/tests/core/test_db.py @@ -5,6 +5,7 @@ from sqlmodel import Session, select from tad.core.config import Settings from tad.core.db import check_db, init_db +from tad.models import Status, Task, User logger = logging.getLogger(__name__) @@ -21,9 +22,31 @@ def test_check_database(): @pytest.mark.parametrize( "patch_settings", - [ - {"ENVIRONMENT": "demo"}, - ], + [{"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}], + indirect=True, +) +def test_init_database_none(patch_settings: Settings): + org_exec = Session.exec + Session.exec = MagicMock() + Session.exec.return_value.first.return_value = None + + init_db() + + expected = [ + (select(User).where(User.name == "Robbert"),), + (select(Status).where(Status.name == "Todo"),), + (select(Task).where(Task.title == "First task"),), + ] + + for i, call_args in enumerate(Session.exec.call_args_list): + assert str(expected[i][0]) == str(call_args.args[0]) + + Session.exec = org_exec + + +@pytest.mark.parametrize( + "patch_settings", + [{"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}], indirect=True, ) def test_init_database(patch_settings: Settings): @@ -32,5 +55,13 @@ def test_init_database(patch_settings: Settings): init_db() - Session.exec.assert_called() + expected = [ + (select(User).where(User.name == "Robbert"),), + (select(Status).where(Status.name == "Todo"),), + (select(Task).where(Task.title == "First task"),), + ] + + for i, call_args in enumerate(Session.exec.call_args_list): + assert str(expected[i][0]) == str(call_args.args[0]) + Session.exec = org_exec diff --git a/tests/services/test_storage.py b/tests/services/test_storage.py index d29aab8fd..314343feb 100644 --- a/tests/services/test_storage.py +++ b/tests/services/test_storage.py @@ -7,12 +7,12 @@ @pytest.fixture() -def setup_and_teardown(tmp_path: Path) -> tuple[str, str]: +def setup_and_teardown(tmp_path: Path) -> tuple[str, Path]: filename = "test.yaml" - return filename, str(tmp_path.absolute()) + return filename, tmp_path.absolute() -def test_file_system_writer_empty_yaml(setup_and_teardown: tuple[str, str]) -> None: +def test_file_system_writer_empty_yaml(setup_and_teardown: tuple[str, Path]) -> None: filename, location = setup_and_teardown storage_writer = WriterFactory.get_writer(writer_type="file", location=location, filename=filename) @@ -21,7 +21,7 @@ def test_file_system_writer_empty_yaml(setup_and_teardown: tuple[str, str]) -> N assert Path.is_file(Path(location) / filename), True -def test_file_system_writer_no_location_variable(setup_and_teardown: tuple[str, str]) -> None: +def test_file_system_writer_no_location_variable(setup_and_teardown: tuple[str, Path]) -> None: filename, _ = setup_and_teardown with pytest.raises( KeyError, match="The `location` or `filename` variables are not provided as input for get_writer()" @@ -29,7 +29,7 @@ def test_file_system_writer_no_location_variable(setup_and_teardown: tuple[str, WriterFactory.get_writer(writer_type="file", filename=filename) -def test_file_system_writer_no_filename_variable(setup_and_teardown: tuple[str, str]) -> None: +def test_file_system_writer_no_filename_variable(setup_and_teardown: tuple[str, Path]) -> None: _, location = setup_and_teardown with pytest.raises( KeyError, match="The `location` or `filename` variables are not provided as input for get_writer()" @@ -37,7 +37,7 @@ def test_file_system_writer_no_filename_variable(setup_and_teardown: tuple[str, WriterFactory.get_writer(writer_type="file", location=location) -def test_file_system_writer_yaml_with_content(setup_and_teardown: tuple[str, str]) -> None: +def test_file_system_writer_yaml_with_content(setup_and_teardown: tuple[str, Path]) -> None: filename, location = setup_and_teardown data = {"test": "test"} storage_writer = WriterFactory.get_writer(writer_type="file", location=location, filename=filename) @@ -47,7 +47,19 @@ def test_file_system_writer_yaml_with_content(setup_and_teardown: tuple[str, str assert safe_load(f) == data, True -def test_file_system_writer_with_system_card(setup_and_teardown: tuple[str, str]) -> None: +def test_file_system_writer_yaml_with_content_in_dir(setup_and_teardown: tuple[str, Path]) -> None: + filename, location = setup_and_teardown + data = {"test": "test"} + + new_location = Path(location) / "new_dir" + storage_writer = WriterFactory.get_writer(writer_type="file", location=new_location, filename=filename) + storage_writer.write(data) + + with open(new_location / filename) as f: + assert safe_load(f) == data, True + + +def test_file_system_writer_with_system_card(setup_and_teardown: tuple[str, Path]) -> None: filename, location = setup_and_teardown data = SystemCard() data.title = "test" @@ -60,7 +72,7 @@ def test_file_system_writer_with_system_card(setup_and_teardown: tuple[str, str] assert safe_load(f) == data_dict, True -def test_abstract_writer_non_yaml_filename(setup_and_teardown: tuple[str, str]) -> None: +def test_abstract_writer_non_yaml_filename(setup_and_teardown: tuple[str, Path]) -> None: _, location = setup_and_teardown filename = "test.csv" with pytest.raises(