From 392011622a916160f553bab4a660baed7a464ba8 Mon Sep 17 00:00:00 2001 From: Berry den Hartog <38954346+berrydenhartog@users.noreply.github.com> Date: Thu, 6 Jun 2024 12:32:39 +0000 Subject: [PATCH] Add database schema initialization --- .devcontainer/devcontainer.json | 4 +- .env.test | 24 ---- .gitignore | 1 + .pre-commit-config.yaml | 3 +- .vscode/launch.json | 17 ++- Dockerfile | 7 +- compose.yml | 15 +- database/init-user-db.sh | 3 +- docker-entrypoint.sh | 51 +++++++ poetry.lock | 130 +++++++++--------- .env => prod.env | 9 +- pyproject.toml | 5 +- script/build | 5 + script/format | 2 +- script/lint | 3 +- script/test | 8 +- tad/api/{routes => }/deps.py | 6 +- tad/api/routes/pages.py | 2 +- tad/api/routes/root.py | 9 +- tad/api/routes/tasks.py | 2 +- tad/core/config.py | 69 ++++++---- tad/core/db.py | 63 +++++++-- tad/core/log.py | 2 +- tad/core/types.py | 3 +- tad/main.py | 34 ++--- .../versions/006c480a1920_a_message.py | 36 ----- ...68e4_create_status_user_and_task_table.py} | 34 +++-- tad/repositories/statuses.py | 17 ++- tad/repositories/tasks.py | 14 ++ tad/services/storage.py | 6 +- tad/services/tasks.py | 5 +- tad/site/templates/default_layout.jinja | 1 + tests/api/routes/test_pages.py | 14 +- tests/api/routes/test_root.py | 12 +- tests/api/routes/test_static.py | 7 +- tests/api/routes/test_status.py | 26 ++-- tests/api/routes/test_tasks_move.py | 14 +- tests/conftest.py | 128 ++++++++++------- tests/constants.py | 39 ++++++ tests/core/test_config.py | 47 +++---- tests/core/test_db.py | 68 +++++++-- tests/core/test_log.py | 16 --- tests/database_test_utils.py | 120 +++------------- tests/e2e/test_move_task.py | 28 ++-- tests/repositories/test_statuses.py | 82 ++++++----- tests/repositories/test_tasks.py | 100 ++++++++------ tests/services/test_storage.py | 28 ++-- 47 files changed, 727 insertions(+), 592 deletions(-) delete mode 100644 .env.test create mode 100755 docker-entrypoint.sh rename .env => prod.env (80%) create mode 100755 script/build rename tad/api/{routes => }/deps.py (64%) delete mode 100644 tad/migrations/versions/006c480a1920_a_message.py rename tad/migrations/versions/{eb2eed884ae9_a_message.py => b62dbd9468e4_create_status_user_and_task_table.py} (70%) create mode 100644 tests/constants.py diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 94443caa..35552b11 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -45,7 +45,9 @@ "editor.defaultFormatter": "charliermarsh.ruff" }, "python.analysis.typeCheckingMode": "strict", - "markiscodecoverage.searchCriteria": "coverage.lcov" + "markiscodecoverage.searchCriteria": "coverage.lcov", + "remote.autoForwardPorts": false, + "remote.restoreForwardedPorts": false } } } diff --git a/.env.test b/.env.test deleted file mode 100644 index 4b36cdf8..00000000 --- a/.env.test +++ /dev/null @@ -1,24 +0,0 @@ -# Domain -DOMAIN=localhost - -# Environment: local, staging, production -ENVIRONMENT=local -PROJECT_NAME="TAD" - -# TAD backend -BACKEND_CORS_ORIGINS="http://localhost,https://localhost,http://127.0.0.1,https://127.0.0.1" -SECRET_KEY=changethis -APP_DATABASE_SCHEME="sqlite" -APP_DATABASE_USER=tad -APP_DATABASE_DB=tad -APP_DATABASE_PASSWORD=changethis - -# Postgres database -POSTGRES_SERVER=db -POSTGRES_PORT=5432 -POSTGRES_DB=postgres -POSTGRES_USER=postgres -POSTGRES_PASSWORD=changethis - -# Database viewer -PGADMIN_DEFAULT_PASSWORD=changethis diff --git a/.gitignore b/.gitignore index b2a56297..c94f0c1a 100644 --- a/.gitignore +++ b/.gitignore @@ -35,3 +35,4 @@ __pypackages__/ # tad tool tad.log* database.sqlite3 +output/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8f4cb9b7..83340105 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,10 +5,11 @@ repos: rev: v4.6.0 hooks: - id: end-of-file-fixer - exclude: ^tad/static/vendor/* + exclude: ^tad/static/vendor/.* - id: trailing-whitespace - id: check-yaml - id: check-json + - id: check-added-large-files - id: check-merge-conflict - id: check-toml diff --git a/.vscode/launch.json b/.vscode/launch.json index 00ed195b..59374257 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -7,12 +7,19 @@ "request": "launch", "module": "uvicorn", "justMyCode": false, - "args": [ "--log-level", "warning" ,"tad.main:app"], + "args": [ + "--log-level", + "warning", + "tad.main:app" + ], "cwd": "${workspaceFolder}/", "env": { - "PYTHONPATH": "${workspaceFolder}" - }, - "envFile": "${workspaceFolder}/.env.test" + "PYTHONPATH": "${workspaceFolder}", + "DEBUG": "True", + "AUTO_CREATE_SCHEMA": "True", + "ENVIRONMENT": "demo", + "LOGGING_LEVEL": "DEBUG" + } }, { "name": "Project: tests", @@ -20,7 +27,7 @@ "request": "launch", "module": "pytest", "cwd": "${workspaceFolder}", - "justMyCode": true, + "justMyCode": false, "args": [] } ] diff --git a/Dockerfile b/Dockerfile index 51faf57d..8fc3c6ff 100644 --- a/Dockerfile +++ b/Dockerfile @@ -61,10 +61,13 @@ USER tad COPY --chown=root:root --chmod=755 ./tad /app/tad COPY --chown=root:root --chmod=755 alembic.ini /app/alembic.ini -COPY --chown=root:root --chmod=755 .env /app/.env +COPY --chown=root:root --chmod=755 prod.env /app/.env COPY --chown=root:root --chmod=755 LICENSE /app/LICENSE +COPY --chown=tad:tad --chmod=755 docker-entrypoint.sh /app/docker-entrypoint.sh ENV PYTHONPATH=/app/ WORKDIR /app/ -CMD ["python", "-m", "uvicorn", "--host", "0.0.0.0", "tad.main:app", "--log-level", "warning" ] +ENV PATH="/app/:$PATH" + +CMD [ "docker-entrypoint.sh" ] diff --git a/compose.yml b/compose.yml index 4f4a93c4..d1c7a372 100644 --- a/compose.yml +++ b/compose.yml @@ -9,10 +9,10 @@ services: db: condition: service_healthy env_file: - - path: .env + - path: prod.env required: true environment: - - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?Variable not set} + - ENVIRONMENT=demo ports: - 8000:8000 healthcheck: @@ -25,12 +25,10 @@ services: - app-db-data:/var/lib/postgresql/data/pgdata - ./database/:/docker-entrypoint-initdb.d/:cached env_file: - - path: .env + - path: prod.env required: true environment: - PGDATA=/var/lib/postgresql/data/pgdata - - POSTGRES_PASSWORD=${POSTGRES_PASSWORD:?Variable not set} - - SECRET_KEY=${SECRET_KEY:?Variable not set} healthcheck: test: ["CMD", "pg_isready", "-q", "-d", "tad", "-U", "tad"] @@ -40,16 +38,15 @@ services: ports: - 8080:8080 environment: - - PGADMIN_DEFAULT_EMAIL=${PGADMIN_DEFAULT_EMAIL:-tad@minbzk.nl} - - PGADMIN_DEFAULT_PASSWORD=${PGADMIN_DEFAULT_PASSWORD:?Variable not set} - PGADMIN_LISTEN_PORT=${PGADMIN_LISTEN_PORT:-8080} + env_file: + - path: prod.env + required: true depends_on: db: condition: service_healthy healthcheck: test: ["CMD", "wget", "-O", "-", "http://localhost:8080/misc/ping"] -#TODO(berry): Traefik - volumes: app-db-data: diff --git a/database/init-user-db.sh b/database/init-user-db.sh index b13ccff2..57a7e55a 100755 --- a/database/init-user-db.sh +++ b/database/init-user-db.sh @@ -4,6 +4,5 @@ set -e # todo(berry): make user and database variables psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL CREATE USER tad WITH PASSWORD 'changethis'; - CREATE DATABASE tad; - GRANT ALL PRIVILEGES ON DATABASE tad TO tad; + CREATE DATABASE tad OWNER tad; EOSQL diff --git a/docker-entrypoint.sh b/docker-entrypoint.sh new file mode 100755 index 00000000..fa5698e3 --- /dev/null +++ b/docker-entrypoint.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash + +DATABASE_MIGRATE="" +HOST="0.0.0.0" +LOGLEVEL="warning" +PORT="8000" + +while getopts "dh:l:p:" opt; do + case $opt in + d) + DATABASE_MIGRATE="True" + ;; + h) + HOST=$OPTARG + ;; + l) + LOGLEVEL=$OPTARG + ;; + p) + PORT=$OPTARG + ;; + :) + echo "Option -${OPTARG} requires an argument." + exit 1 + ;; + + ?) + echo "Invalid option: $OPTARG" + + echo "Usage: docker-entrypoint.sh [-d] [-h host] [-l loglevel]" + exit 1 + ;; + esac +done + +echo "DATABASE_MIGRATE: $DATABASE_MIGRATE" +echo "HOST: $HOST" +echo "LOGLEVEL: $LOGLEVEL" +echo "PORT: $PORT" + + +if [ -z $DATABASE_MIGRATE ]; then + echo "Upgrading database" + if ! alembic upgrade head; then + echo "Failed to upgrade database" + exit 1 + fi +fi + +echo "Starting server" +python -m uvicorn --host "$HOST" tad.main:app --port "$PORT" --log-level "$LOGLEVEL" diff --git a/poetry.lock b/poetry.lock index 58726fe6..8a8378c0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -295,13 +295,13 @@ wmi = ["wmi (>=1.5.1)"] [[package]] name = "email-validator" -version = "2.1.1" +version = "2.1.2" description = "A robust email address syntax and deliverability validation library." optional = false python-versions = ">=3.8" files = [ - {file = "email_validator-2.1.1-py3-none-any.whl", hash = "sha256:97d882d174e2a65732fb43bfce81a3a834cbc1bde8bf419e30ef5ea976370a05"}, - {file = "email_validator-2.1.1.tar.gz", hash = "sha256:200a70680ba08904be6d1eef729205cc0d687634399a5924d842533efb824b84"}, + {file = "email_validator-2.1.2-py3-none-any.whl", hash = "sha256:d89f6324e13b1e39889eab7f9ca2f91dc9aebb6fa50a6d8bd4329ab50f251115"}, + {file = "email_validator-2.1.2.tar.gz", hash = "sha256:14c0f3d343c4beda37400421b39fa411bbe33a75df20825df73ad53e06a9f04c"}, ] [package.dependencies] @@ -368,18 +368,18 @@ standard = ["fastapi", "uvicorn[standard] (>=0.15.0)"] [[package]] name = "filelock" -version = "3.14.0" +version = "3.15.1" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.14.0-py3-none-any.whl", hash = "sha256:43339835842f110ca7ae60f1e1c160714c5a6afd15a2873419ab185334975c0f"}, - {file = "filelock-3.14.0.tar.gz", hash = "sha256:6ea72da3be9b8c82afd3edcf99f2fffbb5076335a5ae4d03248bb5b6c3eae78a"}, + {file = "filelock-3.15.1-py3-none-any.whl", hash = "sha256:71b3102950e91dfc1bb4209b64be4dc8854f40e5f534428d8684f953ac847fac"}, + {file = "filelock-3.15.1.tar.gz", hash = "sha256:58a2549afdf9e02e10720eaa4d4470f56386d7a6f72edd7d0596337af8ed7ad8"}, ] [package.extras] docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] typing = ["typing-extensions (>=4.8)"] [[package]] @@ -761,68 +761,68 @@ files = [ [[package]] name = "orjson" -version = "3.10.3" +version = "3.10.5" description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" optional = false python-versions = ">=3.8" files = [ - {file = "orjson-3.10.3-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:9fb6c3f9f5490a3eb4ddd46fc1b6eadb0d6fc16fb3f07320149c3286a1409dd8"}, - {file = "orjson-3.10.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:252124b198662eee80428f1af8c63f7ff077c88723fe206a25df8dc57a57b1fa"}, - {file = "orjson-3.10.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9f3e87733823089a338ef9bbf363ef4de45e5c599a9bf50a7a9b82e86d0228da"}, - {file = "orjson-3.10.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c8334c0d87103bb9fbbe59b78129f1f40d1d1e8355bbed2ca71853af15fa4ed3"}, - {file = "orjson-3.10.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1952c03439e4dce23482ac846e7961f9d4ec62086eb98ae76d97bd41d72644d7"}, - {file = "orjson-3.10.3-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:c0403ed9c706dcd2809f1600ed18f4aae50be263bd7112e54b50e2c2bc3ebd6d"}, - {file = "orjson-3.10.3-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:382e52aa4270a037d41f325e7d1dfa395b7de0c367800b6f337d8157367bf3a7"}, - {file = "orjson-3.10.3-cp310-none-win32.whl", hash = "sha256:be2aab54313752c04f2cbaab4515291ef5af8c2256ce22abc007f89f42f49109"}, - {file = "orjson-3.10.3-cp310-none-win_amd64.whl", hash = "sha256:416b195f78ae461601893f482287cee1e3059ec49b4f99479aedf22a20b1098b"}, - {file = "orjson-3.10.3-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:73100d9abbbe730331f2242c1fc0bcb46a3ea3b4ae3348847e5a141265479700"}, - {file = "orjson-3.10.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:544a12eee96e3ab828dbfcb4d5a0023aa971b27143a1d35dc214c176fdfb29b3"}, - {file = "orjson-3.10.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:520de5e2ef0b4ae546bea25129d6c7c74edb43fc6cf5213f511a927f2b28148b"}, - {file = "orjson-3.10.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ccaa0a401fc02e8828a5bedfd80f8cd389d24f65e5ca3954d72c6582495b4bcf"}, - {file = "orjson-3.10.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9a7bc9e8bc11bac40f905640acd41cbeaa87209e7e1f57ade386da658092dc16"}, - {file = "orjson-3.10.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3582b34b70543a1ed6944aca75e219e1192661a63da4d039d088a09c67543b08"}, - {file = "orjson-3.10.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1c23dfa91481de880890d17aa7b91d586a4746a4c2aa9a145bebdbaf233768d5"}, - {file = "orjson-3.10.3-cp311-none-win32.whl", hash = "sha256:1770e2a0eae728b050705206d84eda8b074b65ee835e7f85c919f5705b006c9b"}, - {file = "orjson-3.10.3-cp311-none-win_amd64.whl", hash = "sha256:93433b3c1f852660eb5abdc1f4dd0ced2be031ba30900433223b28ee0140cde5"}, - {file = "orjson-3.10.3-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:a39aa73e53bec8d410875683bfa3a8edf61e5a1c7bb4014f65f81d36467ea098"}, - {file = "orjson-3.10.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0943a96b3fa09bee1afdfccc2cb236c9c64715afa375b2af296c73d91c23eab2"}, - {file = "orjson-3.10.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e852baafceff8da3c9defae29414cc8513a1586ad93e45f27b89a639c68e8176"}, - {file = "orjson-3.10.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:18566beb5acd76f3769c1d1a7ec06cdb81edc4d55d2765fb677e3eaa10fa99e0"}, - {file = "orjson-3.10.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1bd2218d5a3aa43060efe649ec564ebedec8ce6ae0a43654b81376216d5ebd42"}, - {file = "orjson-3.10.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:cf20465e74c6e17a104ecf01bf8cd3b7b252565b4ccee4548f18b012ff2f8069"}, - {file = "orjson-3.10.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ba7f67aa7f983c4345eeda16054a4677289011a478ca947cd69c0a86ea45e534"}, - {file = "orjson-3.10.3-cp312-none-win32.whl", hash = "sha256:17e0713fc159abc261eea0f4feda611d32eabc35708b74bef6ad44f6c78d5ea0"}, - {file = "orjson-3.10.3-cp312-none-win_amd64.whl", hash = "sha256:4c895383b1ec42b017dd2c75ae8a5b862fc489006afde06f14afbdd0309b2af0"}, - {file = "orjson-3.10.3-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:be2719e5041e9fb76c8c2c06b9600fe8e8584e6980061ff88dcbc2691a16d20d"}, - {file = "orjson-3.10.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0175a5798bdc878956099f5c54b9837cb62cfbf5d0b86ba6d77e43861bcec2"}, - {file = "orjson-3.10.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:978be58a68ade24f1af7758626806e13cff7748a677faf95fbb298359aa1e20d"}, - {file = "orjson-3.10.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:16bda83b5c61586f6f788333d3cf3ed19015e3b9019188c56983b5a299210eb5"}, - {file = "orjson-3.10.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ad1f26bea425041e0a1adad34630c4825a9e3adec49079b1fb6ac8d36f8b754"}, - {file = "orjson-3.10.3-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:9e253498bee561fe85d6325ba55ff2ff08fb5e7184cd6a4d7754133bd19c9195"}, - {file = "orjson-3.10.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0a62f9968bab8a676a164263e485f30a0b748255ee2f4ae49a0224be95f4532b"}, - {file = "orjson-3.10.3-cp38-none-win32.whl", hash = "sha256:8d0b84403d287d4bfa9bf7d1dc298d5c1c5d9f444f3737929a66f2fe4fb8f134"}, - {file = "orjson-3.10.3-cp38-none-win_amd64.whl", hash = "sha256:8bc7a4df90da5d535e18157220d7915780d07198b54f4de0110eca6b6c11e290"}, - {file = "orjson-3.10.3-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:9059d15c30e675a58fdcd6f95465c1522b8426e092de9fff20edebfdc15e1cb0"}, - {file = "orjson-3.10.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8d40c7f7938c9c2b934b297412c067936d0b54e4b8ab916fd1a9eb8f54c02294"}, - {file = "orjson-3.10.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d4a654ec1de8fdaae1d80d55cee65893cb06494e124681ab335218be6a0691e7"}, - {file = "orjson-3.10.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:831c6ef73f9aa53c5f40ae8f949ff7681b38eaddb6904aab89dca4d85099cb78"}, - {file = "orjson-3.10.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99b880d7e34542db89f48d14ddecbd26f06838b12427d5a25d71baceb5ba119d"}, - {file = "orjson-3.10.3-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:2e5e176c994ce4bd434d7aafb9ecc893c15f347d3d2bbd8e7ce0b63071c52e25"}, - {file = "orjson-3.10.3-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:b69a58a37dab856491bf2d3bbf259775fdce262b727f96aafbda359cb1d114d8"}, - {file = "orjson-3.10.3-cp39-none-win32.whl", hash = "sha256:b8d4d1a6868cde356f1402c8faeb50d62cee765a1f7ffcfd6de732ab0581e063"}, - {file = "orjson-3.10.3-cp39-none-win_amd64.whl", hash = "sha256:5102f50c5fc46d94f2033fe00d392588564378260d64377aec702f21a7a22912"}, - {file = "orjson-3.10.3.tar.gz", hash = "sha256:2b166507acae7ba2f7c315dcf185a9111ad5e992ac81f2d507aac39193c2c818"}, + {file = "orjson-3.10.5-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:545d493c1f560d5ccfc134803ceb8955a14c3fcb47bbb4b2fee0232646d0b932"}, + {file = "orjson-3.10.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f4324929c2dd917598212bfd554757feca3e5e0fa60da08be11b4aa8b90013c1"}, + {file = "orjson-3.10.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c13ca5e2ddded0ce6a927ea5a9f27cae77eee4c75547b4297252cb20c4d30e6"}, + {file = "orjson-3.10.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b6c8e30adfa52c025f042a87f450a6b9ea29649d828e0fec4858ed5e6caecf63"}, + {file = "orjson-3.10.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:338fd4f071b242f26e9ca802f443edc588fa4ab60bfa81f38beaedf42eda226c"}, + {file = "orjson-3.10.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6970ed7a3126cfed873c5d21ece1cd5d6f83ca6c9afb71bbae21a0b034588d96"}, + {file = "orjson-3.10.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:235dadefb793ad12f7fa11e98a480db1f7c6469ff9e3da5e73c7809c700d746b"}, + {file = "orjson-3.10.5-cp310-none-win32.whl", hash = "sha256:be79e2393679eda6a590638abda16d167754393f5d0850dcbca2d0c3735cebe2"}, + {file = "orjson-3.10.5-cp310-none-win_amd64.whl", hash = "sha256:c4a65310ccb5c9910c47b078ba78e2787cb3878cdded1702ac3d0da71ddc5228"}, + {file = "orjson-3.10.5-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:cdf7365063e80899ae3a697def1277c17a7df7ccfc979990a403dfe77bb54d40"}, + {file = "orjson-3.10.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b68742c469745d0e6ca5724506858f75e2f1e5b59a4315861f9e2b1df77775a"}, + {file = "orjson-3.10.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7d10cc1b594951522e35a3463da19e899abe6ca95f3c84c69e9e901e0bd93d38"}, + {file = "orjson-3.10.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dcbe82b35d1ac43b0d84072408330fd3295c2896973112d495e7234f7e3da2e1"}, + {file = "orjson-3.10.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10c0eb7e0c75e1e486c7563fe231b40fdd658a035ae125c6ba651ca3b07936f5"}, + {file = "orjson-3.10.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:53ed1c879b10de56f35daf06dbc4a0d9a5db98f6ee853c2dbd3ee9d13e6f302f"}, + {file = "orjson-3.10.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:099e81a5975237fda3100f918839af95f42f981447ba8f47adb7b6a3cdb078fa"}, + {file = "orjson-3.10.5-cp311-none-win32.whl", hash = "sha256:1146bf85ea37ac421594107195db8bc77104f74bc83e8ee21a2e58596bfb2f04"}, + {file = "orjson-3.10.5-cp311-none-win_amd64.whl", hash = "sha256:36a10f43c5f3a55c2f680efe07aa93ef4a342d2960dd2b1b7ea2dd764fe4a37c"}, + {file = "orjson-3.10.5-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:68f85ecae7af14a585a563ac741b0547a3f291de81cd1e20903e79f25170458f"}, + {file = "orjson-3.10.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:28afa96f496474ce60d3340fe8d9a263aa93ea01201cd2bad844c45cd21f5268"}, + {file = "orjson-3.10.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9cd684927af3e11b6e754df80b9ffafd9fb6adcaa9d3e8fdd5891be5a5cad51e"}, + {file = "orjson-3.10.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d21b9983da032505f7050795e98b5d9eee0df903258951566ecc358f6696969"}, + {file = "orjson-3.10.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ad1de7fef79736dde8c3554e75361ec351158a906d747bd901a52a5c9c8d24b"}, + {file = "orjson-3.10.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:2d97531cdfe9bdd76d492e69800afd97e5930cb0da6a825646667b2c6c6c0211"}, + {file = "orjson-3.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d69858c32f09c3e1ce44b617b3ebba1aba030e777000ebdf72b0d8e365d0b2b3"}, + {file = "orjson-3.10.5-cp312-none-win32.whl", hash = "sha256:64c9cc089f127e5875901ac05e5c25aa13cfa5dbbbd9602bda51e5c611d6e3e2"}, + {file = "orjson-3.10.5-cp312-none-win_amd64.whl", hash = "sha256:b2efbd67feff8c1f7728937c0d7f6ca8c25ec81373dc8db4ef394c1d93d13dc5"}, + {file = "orjson-3.10.5-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:03b565c3b93f5d6e001db48b747d31ea3819b89abf041ee10ac6988886d18e01"}, + {file = "orjson-3.10.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:584c902ec19ab7928fd5add1783c909094cc53f31ac7acfada817b0847975f26"}, + {file = "orjson-3.10.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5a35455cc0b0b3a1eaf67224035f5388591ec72b9b6136d66b49a553ce9eb1e6"}, + {file = "orjson-3.10.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1670fe88b116c2745a3a30b0f099b699a02bb3482c2591514baf5433819e4f4d"}, + {file = "orjson-3.10.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:185c394ef45b18b9a7d8e8f333606e2e8194a50c6e3c664215aae8cf42c5385e"}, + {file = "orjson-3.10.5-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:ca0b3a94ac8d3886c9581b9f9de3ce858263865fdaa383fbc31c310b9eac07c9"}, + {file = "orjson-3.10.5-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:dfc91d4720d48e2a709e9c368d5125b4b5899dced34b5400c3837dadc7d6271b"}, + {file = "orjson-3.10.5-cp38-none-win32.whl", hash = "sha256:c05f16701ab2a4ca146d0bca950af254cb7c02f3c01fca8efbbad82d23b3d9d4"}, + {file = "orjson-3.10.5-cp38-none-win_amd64.whl", hash = "sha256:8a11d459338f96a9aa7f232ba95679fc0c7cedbd1b990d736467894210205c09"}, + {file = "orjson-3.10.5-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:85c89131d7b3218db1b24c4abecea92fd6c7f9fab87441cfc342d3acc725d807"}, + {file = "orjson-3.10.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb66215277a230c456f9038d5e2d84778141643207f85336ef8d2a9da26bd7ca"}, + {file = "orjson-3.10.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:51bbcdea96cdefa4a9b4461e690c75ad4e33796530d182bdd5c38980202c134a"}, + {file = "orjson-3.10.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dbead71dbe65f959b7bd8cf91e0e11d5338033eba34c114f69078d59827ee139"}, + {file = "orjson-3.10.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5df58d206e78c40da118a8c14fc189207fffdcb1f21b3b4c9c0c18e839b5a214"}, + {file = "orjson-3.10.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c4057c3b511bb8aef605616bd3f1f002a697c7e4da6adf095ca5b84c0fd43595"}, + {file = "orjson-3.10.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:b39e006b00c57125ab974362e740c14a0c6a66ff695bff44615dcf4a70ce2b86"}, + {file = "orjson-3.10.5-cp39-none-win32.whl", hash = "sha256:eded5138cc565a9d618e111c6d5c2547bbdd951114eb822f7f6309e04db0fb47"}, + {file = "orjson-3.10.5-cp39-none-win_amd64.whl", hash = "sha256:cc28e90a7cae7fcba2493953cff61da5a52950e78dc2dacfe931a317ee3d8de7"}, + {file = "orjson-3.10.5.tar.gz", hash = "sha256:7a5baef8a4284405d96c90c7c62b755e9ef1ada84c2406c24a9ebec86b89f46d"}, ] [[package]] name = "packaging" -version = "24.0" +version = "24.1" description = "Core utilities for Python packages" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "packaging-24.0-py3-none-any.whl", hash = "sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5"}, - {file = "packaging-24.0.tar.gz", hash = "sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9"}, + {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, + {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, ] [[package]] @@ -1588,13 +1588,13 @@ typing-extensions = ">=3.7.4.3" [[package]] name = "typing-extensions" -version = "4.12.1" +version = "4.12.2" description = "Backported and Experimental Type Hints for Python 3.8+" optional = false python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.12.1-py3-none-any.whl", hash = "sha256:6024b58b69089e5a89c347397254e35f1bf02a907728ec7fee9bf0fe837d203a"}, - {file = "typing_extensions-4.12.1.tar.gz", hash = "sha256:915f5e35ff76f56588223f15fdd5938f9a1cf9195c0de25130c627e4d597f6d1"}, + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] [[package]] @@ -1686,13 +1686,13 @@ files = [ [[package]] name = "urllib3" -version = "2.2.1" +version = "2.2.2" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.8" files = [ - {file = "urllib3-2.2.1-py3-none-any.whl", hash = "sha256:450b20ec296a467077128bff42b73080516e71b56ff59a60a02bef2232c4fa9d"}, - {file = "urllib3-2.2.1.tar.gz", hash = "sha256:d0570876c61ab9e520d776c38acbbb5b05a776d3f9ff98a5c8fd5162a444cf19"}, + {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"}, + {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"}, ] [package.extras] @@ -1962,4 +1962,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "ff0867c4b530a26f3e3ace242fa75bc143ac2bc1b60d9d3b00b7814ff6f9a34b" +content-hash = "91b78ddf8ee4ba11d31441fceca67b9749af720a2590be41456a89d9270d5e41" diff --git a/.env b/prod.env similarity index 80% rename from .env rename to prod.env index b21bb0ee..331ae2e2 100644 --- a/.env +++ b/prod.env @@ -1,9 +1,5 @@ -# Domain -DOMAIN=localhost - -# Environment: local, staging, production -ENVIRONMENT=local -PROJECT_NAME="TAD" +# Environment: local, production, demo +ENVIRONMENT=production # TAD backend BACKEND_CORS_ORIGINS="http://localhost,https://localhost,http://127.0.0.1,https://127.0.0.1" @@ -22,3 +18,4 @@ POSTGRES_PASSWORD=changethis # Database viewer PGADMIN_DEFAULT_PASSWORD=changethis +PGADMIN_DEFAULT_EMAIL=admin@admin.com diff --git a/pyproject.toml b/pyproject.toml index dbac15a5..7f2fff0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,6 @@ pyyaml = "^6.0.1" pytest = "^8.2.1" coverage = "^7.5.3" httpx = "^0.27.0" -urllib3 = "^2.2.1" playwright = "^1.44.0" pytest-playwright = "^0.5.0" @@ -107,10 +106,12 @@ title = "tad" testpaths = [ "tests" ] -addopts = "--strict-markers" +addopts = "--strict-markers -v -q" filterwarnings = [ "ignore::UserWarning" ] +log_cli = true +log_cli_level = "INFO" [tool.liccheck] level = "PARANOID" diff --git a/script/build b/script/build new file mode 100755 index 00000000..9252a9d7 --- /dev/null +++ b/script/build @@ -0,0 +1,5 @@ +#!/usr/bin/env bash + +set -x + +docker build . -t ghcr.io/minbzk/tad:latest "$@" diff --git a/script/format b/script/format index 00fd706f..a8cc2b54 100755 --- a/script/format +++ b/script/format @@ -2,4 +2,4 @@ set -x -ruff format $@ +ruff format "$@" diff --git a/script/lint b/script/lint index 971b7902..dd730f1b 100755 --- a/script/lint +++ b/script/lint @@ -1,6 +1,5 @@ #!/usr/bin/env bash -set -e set -x -ruff check --fix $@ +ruff check --fix "$@" diff --git a/script/test b/script/test index 1f8a3941..b17492e5 100755 --- a/script/test +++ b/script/test @@ -3,8 +3,8 @@ set -e set -x -coverage run -m pytest $@ -if [ $? -ne 0 ]; then + +if ! coverage run -m pytest "$@" ; then echo "Test failed" exit 1 fi @@ -12,8 +12,8 @@ fi coverage report coverage html coverage lcov -pyright $@ -if [ $? -ne 0 ]; then + +if ! pyright; then echo "Typecheck failed" exit 1 fi diff --git a/tad/api/routes/deps.py b/tad/api/deps.py similarity index 64% rename from tad/api/routes/deps.py rename to tad/api/deps.py index d5e96e33..9dd18215 100644 --- a/tad/api/routes/deps.py +++ b/tad/api/deps.py @@ -2,14 +2,14 @@ from fastapi.templating import Jinja2Templates from jinja2 import Environment -from tad.core.config import settings +from tad.core.config import VERSION def version_context_processor(request: Request): - return {"version": settings.VERSION} + return {"version": VERSION} env = Environment( autoescape=True, ) -templates = Jinja2Templates(directory=settings.TEMPLATE_DIR, context_processors=[version_context_processor], env=env) +templates = Jinja2Templates(directory="tad/site/templates/", context_processors=[version_context_processor], env=env) diff --git a/tad/api/routes/pages.py b/tad/api/routes/pages.py index f07972a6..648f6ac7 100644 --- a/tad/api/routes/pages.py +++ b/tad/api/routes/pages.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends, Request from fastapi.responses import HTMLResponse -from tad.api.routes.deps import templates +from tad.api.deps import templates from tad.services.statuses import StatusesService from tad.services.tasks import TasksService diff --git a/tad/api/routes/root.py b/tad/api/routes/root.py index 7293628f..1ba579ee 100644 --- a/tad/api/routes/root.py +++ b/tad/api/routes/root.py @@ -1,7 +1,5 @@ from fastapi import APIRouter -from fastapi.responses import FileResponse, RedirectResponse - -from tad.core.config import settings +from fastapi.responses import RedirectResponse router = APIRouter() @@ -9,8 +7,3 @@ @router.get("/") async def base() -> RedirectResponse: return RedirectResponse("/pages/") - - -@router.get("/favicon.ico", include_in_schema=False) -async def favicon(): - return FileResponse(settings.STATIC_DIR + "/favicon.ico") diff --git a/tad/api/routes/tasks.py b/tad/api/routes/tasks.py index 1bc86e15..120434d2 100644 --- a/tad/api/routes/tasks.py +++ b/tad/api/routes/tasks.py @@ -3,7 +3,7 @@ from fastapi import APIRouter, Depends, Request, status from fastapi.responses import HTMLResponse -from tad.api.routes.deps import templates +from tad.api.deps import templates from tad.schema.task import MovedTask from tad.services.tasks import TasksService diff --git a/tad/core/config.py b/tad/core/config.py index a1bec7ae..71b80201 100644 --- a/tad/core/config.py +++ b/tad/core/config.py @@ -1,5 +1,6 @@ import logging import secrets +from functools import lru_cache from typing import Any, TypeVar from pydantic import ( @@ -12,38 +13,26 @@ from tad.core.exceptions import SettingsError from tad.core.types import DatabaseSchemaType, EnvironmentType, LoggingLevelType +logger = logging.getLogger(__name__) + # Self type is not available in Python 3.10 so create our own with TypeVar SelfSettings = TypeVar("SelfSettings", bound="Settings") +PROJECT_NAME: str = "TAD" +PROJECT_DESCRIPTION: str = "Transparency of Algorithmic Decision making" +VERSION: str = "0.1.0" # replace in CI/CD pipeline + class Settings(BaseSettings): - # todo(berry): investigate yaml, toml or json file support for SettingsConfigDict - # todo(berry): investigate multiple .env files support for SettingsConfigDict - model_config = SettingsConfigDict( - env_file=(".env", ".env.test", ".env.prod"), env_ignore_empty=True, extra="ignore" - ) SECRET_KEY: str = secrets.token_urlsafe(32) - DOMAIN: str = "localhost" ENVIRONMENT: EnvironmentType = "local" - @computed_field # type: ignore[misc] - @property - def server_host(self) -> str: - if self.ENVIRONMENT == "local": - return f"http://{self.DOMAIN}" - return f"https://{self.DOMAIN}" - - VERSION: str = "0.1.0" - LOGGING_LEVEL: LoggingLevelType = "INFO" LOGGING_CONFIG: dict[str, Any] | None = None - PROJECT_NAME: str = "TAD" - PROJECT_DESCRIPTION: str = "Transparency of Algorithmic Decision making" - - STATIC_DIR: str = "tad/site/static/" - TEMPLATE_DIR: str = "tad/site/templates" + DEBUG: bool = False + AUTO_CREATE_SCHEMA: bool = False # todo(berry): create submodel for database settings APP_DATABASE_SCHEME: DatabaseSchemaType = "sqlite" @@ -55,22 +44,27 @@ def server_host(self) -> str: APP_DATABASE_PASSWORD: str | None = None APP_DATABASE_DB: str = "tad" - APP_DATABASE_FILE: str = "database.sqlite3" + APP_DATABASE_FILE: str = "/database.sqlite3" + + model_config = SettingsConfigDict(extra="ignore") @computed_field # type: ignore[misc] @property - def SQLALCHEMY_DATABASE_URI(self) -> str: - logging.info(f"test: {self.APP_DATABASE_SCHEME}") - - if self.APP_DATABASE_SCHEME == "sqlite": - return str(MultiHostUrl.build(scheme=self.APP_DATABASE_SCHEME, host="", path=self.APP_DATABASE_FILE)) + def SQLALCHEMY_ECHO(self) -> bool: + return self.DEBUG + @computed_field # type: ignore[misc] + @property + def SQLALCHEMY_DATABASE_URI(self) -> str: scheme: str = ( f"{self.APP_DATABASE_SCHEME}+{self.APP_DATABASE_DRIVER}" if isinstance(self.APP_DATABASE_DRIVER, str) else self.APP_DATABASE_SCHEME ) + if self.APP_DATABASE_SCHEME == "sqlite": + return f"{scheme}://{self.APP_DATABASE_FILE}" + return str( MultiHostUrl.build( scheme=scheme, @@ -84,9 +78,26 @@ def SQLALCHEMY_DATABASE_URI(self) -> str: @model_validator(mode="after") def _enforce_database_rules(self: SelfSettings) -> SelfSettings: - if self.ENVIRONMENT != "local" and self.APP_DATABASE_SCHEME == "sqlite": - raise SettingsError("SQLite is not supported in production") + if self.ENVIRONMENT == "production" and self.APP_DATABASE_SCHEME == "sqlite": + raise SettingsError("APP_DATABASE_SCHEME=SQLITE is not supported in production") + return self + + @model_validator(mode="after") + def _enforce_debug_rules(self: SelfSettings) -> SelfSettings: + if self.ENVIRONMENT == "production" and self.DEBUG: + raise SettingsError("DEBUG=True is not supported in production") + return self + + @model_validator(mode="after") + def _enforce_autocreate_rules(self: SelfSettings) -> SelfSettings: + if self.ENVIRONMENT == "production" and self.AUTO_CREATE_SCHEMA: + raise SettingsError("AUTO_CREATE_SCHEMA=True is not supported in production") return self -settings = Settings() # type: ignore +# TODO(berry): make it a function with lrucache + + +@lru_cache(maxsize=8) +def get_settings() -> Settings: + return Settings() diff --git a/tad/core/db.py b/tad/core/db.py index dda4f7c4..e07be8b1 100644 --- a/tad/core/db.py +++ b/tad/core/db.py @@ -1,18 +1,63 @@ -from sqlalchemy.engine.base import Engine -from sqlmodel import Session, create_engine, select +import logging +from functools import lru_cache -from tad.core.config import settings +from sqlalchemy.engine import Engine +from sqlalchemy.pool import QueuePool, StaticPool +from sqlmodel import Session, SQLModel, create_engine, select -_engine: None | Engine = None +from tad.core.config import get_settings +from tad.models import Status, Task, User +logger = logging.getLogger(__name__) + +@lru_cache(maxsize=8) def get_engine() -> Engine: - global _engine - if _engine is None: - _engine = create_engine(settings.SQLALCHEMY_DATABASE_URI) - return _engine + connect_args = ( + {"check_same_thread": False, "isolation_level": None} if get_settings().APP_DATABASE_SCHEME == "sqlite" else {} + ) + poolclass = StaticPool if get_settings().APP_DATABASE_SCHEME == "sqlite" else QueuePool + + return create_engine( + str(get_settings().SQLALCHEMY_DATABASE_URI), + connect_args=connect_args, + poolclass=poolclass, + echo=get_settings().SQLALCHEMY_ECHO, + ) -async def check_db(): +def check_db(): + logger.info("Checking database connection") with Session(get_engine()) as session: session.exec(select(1)) + + logger.info("Finisch Checking database connection") + + +def init_db(): + logger.info("Initializing database") + + if get_settings().AUTO_CREATE_SCHEMA: + logger.info("Creating database schema") + SQLModel.metadata.create_all(get_engine()) + + with Session(get_engine()) as session: + if get_settings().ENVIRONMENT == "demo": + logger.info("Creating demo data") + + user = session.exec(select(User).where(User.name == "Robbert")).first() + if not user: + user = User(name="Robbert", avatar=None) + session.add(user) + + status = session.exec(select(Status).where(Status.name == "Todo")).first() + if not status: + status = Status(name="Todo", sort_order=1) + session.add(status) + + task = session.exec(select(Task).where(Task.title == "First task")).first() + if not task: + task = Task(title="First task", description="This is the first task", sort_order=1, status_id=status.id) + session.add(task) + session.commit() + logger.info("Finished initializing database") diff --git a/tad/core/log.py b/tad/core/log.py index 4850c419..92eec496 100644 --- a/tad/core/log.py +++ b/tad/core/log.py @@ -30,7 +30,7 @@ }, }, "loggers": { - "tad": {"handlers": ["console", "file"], "level": "DEBUG", "propagate": False}, + "": {"handlers": ["console", "file"], "level": "DEBUG", "propagate": False}, }, } diff --git a/tad/core/types.py b/tad/core/types.py index 4848e7f3..a18f4139 100644 --- a/tad/core/types.py +++ b/tad/core/types.py @@ -1,5 +1,6 @@ from typing import Literal -EnvironmentType = Literal["local", "staging", "production"] +# TODO(berry): make enums and convert to types +EnvironmentType = Literal["local", "production", "demo"] LoggingLevelType = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] DatabaseSchemaType = Literal["sqlite", "postgresql", "mysql", "oracle"] diff --git a/tad/main.py b/tad/main.py index 68740234..f219d1aa 100644 --- a/tad/main.py +++ b/tad/main.py @@ -5,12 +5,11 @@ from fastapi.exceptions import RequestValidationError from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles -from fastapi.templating import Jinja2Templates from starlette.exceptions import HTTPException as StarletteHTTPException from tad.api.main import api_router -from tad.core.config import settings -from tad.core.db import check_db +from tad.core.config import PROJECT_DESCRIPTION, PROJECT_NAME, VERSION, get_settings +from tad.core.db import check_db, init_db from tad.core.exception_handlers import ( http_exception_handler as tad_http_exception_handler, ) @@ -22,39 +21,38 @@ from .middleware.route_logging import RequestLoggingMiddleware -configure_logging(settings.LOGGING_LEVEL, settings.LOGGING_CONFIG) - +configure_logging(get_settings().LOGGING_LEVEL, get_settings().LOGGING_CONFIG) logger = logging.getLogger(__name__) -mask = Mask(mask_keywords=["database_uri"]) # todo(berry): move lifespan to own file @asynccontextmanager async def lifespan(app: FastAPI): - logger.info(f"Starting {settings.PROJECT_NAME} version {settings.VERSION}") - logger.info(f"Settings: {mask.secrets(settings.model_dump())}") - # todo(berry): setup database connection - await check_db() + mask = Mask(mask_keywords=["database_uri"]) + check_db() + init_db() + logger.info(f"Starting {PROJECT_NAME} version {VERSION}") + logger.info(f"Settings: {mask.secrets(get_settings().model_dump())}") yield - logger.info(f"Stopping application {settings.PROJECT_NAME} version {settings.VERSION}") + logger.info(f"Stopping application {PROJECT_NAME} version {VERSION}") logging.shutdown() -templates = Jinja2Templates(directory="templates") - +# todo(berry): Create factor for FastAPI app app = FastAPI( lifespan=lifespan, - title=settings.PROJECT_NAME, - summary=settings.PROJECT_DESCRIPTION, - version=settings.VERSION, + title=PROJECT_NAME, + summary=PROJECT_DESCRIPTION, + version=VERSION, openapi_url=None, default_response_class=HTMLResponse, redirect_slashes=False, + debug=get_settings().DEBUG, ) app.add_middleware(RequestLoggingMiddleware) -app.mount("/static", StaticFiles(directory=settings.STATIC_DIR), name="static") +app.mount("/static", StaticFiles(directory="tad/site/static/"), name="static") @app.exception_handler(StarletteHTTPException) @@ -68,5 +66,3 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE app.include_router(api_router) - -# todo (robbert) add init code for example tasks and statuses diff --git a/tad/migrations/versions/006c480a1920_a_message.py b/tad/migrations/versions/006c480a1920_a_message.py deleted file mode 100644 index fa83e759..00000000 --- a/tad/migrations/versions/006c480a1920_a_message.py +++ /dev/null @@ -1,36 +0,0 @@ -"""a message - -Revision ID: 006c480a1920 -Revises: -Create Date: 2024-05-13 12:36:32.647256 - -""" - -from collections.abc import Sequence - -import sqlalchemy as sa -import sqlmodel.sql.sqltypes -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "006c480a1920" -down_revision: str | None = None -branch_labels: str | Sequence[str] | None = None -depends_on: str | Sequence[str] | None = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "hero", - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), - sa.PrimaryKeyConstraint("id"), - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("hero") - # ### end Alembic commands ### diff --git a/tad/migrations/versions/eb2eed884ae9_a_message.py b/tad/migrations/versions/b62dbd9468e4_create_status_user_and_task_table.py similarity index 70% rename from tad/migrations/versions/eb2eed884ae9_a_message.py rename to tad/migrations/versions/b62dbd9468e4_create_status_user_and_task_table.py index 4a861421..d5912d55 100644 --- a/tad/migrations/versions/eb2eed884ae9_a_message.py +++ b/tad/migrations/versions/b62dbd9468e4_create_status_user_and_task_table.py @@ -1,8 +1,8 @@ -"""Create the user, status and task tables, drop table hero +"""Create Status, User and Task table -Revision ID: eb2eed884ae9 +Revision ID: b62dbd9468e4 Revises: -Create Date: 2024-05-14 13:36:23.551663 +Create Date: 2024-06-06 09:18:14.989874 """ @@ -13,7 +13,7 @@ from alembic import op # revision identifiers, used by Alembic. -revision: str = "eb2eed884ae9" +revision: str = "b62dbd9468e4" down_revision: str | None = None branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None @@ -21,14 +21,14 @@ def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("hero") - op.create_table( + status = op.create_table( "status", sa.Column("id", sa.Integer(), nullable=False), sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("sort_order", sa.Float(), nullable=False), sa.PrimaryKeyConstraint("id"), ) + op.create_table( "user", sa.Column("id", sa.Integer(), nullable=False), @@ -42,8 +42,12 @@ def upgrade() -> None: sa.Column("title", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("description", sqlmodel.sql.sqltypes.AutoString(), nullable=False), sa.Column("sort_order", sa.Float(), nullable=False), - sa.Column("status_id", sa.Integer(), nullable=False), + sa.Column("status_id", sa.Integer(), nullable=True), sa.Column("user_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["status_id"], + ["status.id"], + ), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], @@ -52,9 +56,21 @@ def upgrade() -> None: ) # ### end Alembic commands ### + # ### custom commands ### + op.bulk_insert( + status, + [ + {"name": "Todo", "sort_order": 1}, + {"name": "In Progress", "sort_order": 2}, + {"name": "Review", "sort_order": 3}, + {"name": "Done", "sort_order": 4}, + ], + ) + def downgrade() -> None: - # we do not delete any tables on a downgrade - pass # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("task") + op.drop_table("user") + op.drop_table("status") # ### end Alembic commands ### diff --git a/tad/repositories/statuses.py b/tad/repositories/statuses.py index bdaeb9e3..56d86c6b 100644 --- a/tad/repositories/statuses.py +++ b/tad/repositories/statuses.py @@ -34,14 +34,29 @@ def save(self, status: Status) -> Status: :param status: the status to store :return: the updated status after storing """ - self.session.add(status) try: + self.session.add(status) self.session.commit() + self.session.refresh(status) except Exception as e: self.session.rollback() raise RepositoryError from e return status + def delete(self, status: Status) -> None: + """ + Deletes the given status in the repository. + :param status: the status to store + :return: the updated status after storing + """ + try: + self.session.delete(status) + self.session.commit() + except Exception as e: + self.session.rollback() + raise RepositoryError from e + return None + def find_by_id(self, status_id: int) -> Status: """ Returns the status with the given id or an exception if the id does not exist. diff --git a/tad/repositories/tasks.py b/tad/repositories/tasks.py index c39c2b5c..82000783 100644 --- a/tad/repositories/tasks.py +++ b/tad/repositories/tasks.py @@ -53,6 +53,20 @@ def save(self, task: Task) -> Task: raise RepositoryError from e return task + def delete(self, task: Task) -> None: + """ + Deletes the given task in the repository or throws a RepositoryException + :param task: the task to store + :return: the updated task after storing + """ + try: + self.session.delete(task) + self.session.commit() + except Exception as e: + self.session.rollback() + raise RepositoryError from e + return None + def find_by_id(self, task_id: int) -> Task: """ Returns the task with the given id. diff --git a/tad/services/storage.py b/tad/services/storage.py index f62de852..fdf842fb 100644 --- a/tad/services/storage.py +++ b/tad/services/storage.py @@ -17,18 +17,18 @@ def close(self) -> None: class WriterFactory: @staticmethod - def get_writer(writer_type: str = "file", **kwargs: str) -> Writer: + def get_writer(writer_type: str = "file", **kwargs: Any) -> Writer: match writer_type: case "file": if not all(k in kwargs for k in ("location", "filename")): raise KeyError("The `location` or `filename` variables are not provided as input for get_writer()") - return FileSystemWriteService(location=str(kwargs["location"]), filename=str(kwargs["filename"])) + return FileSystemWriteService(location=Path(kwargs["location"]), filename=str(kwargs["filename"])) case _: raise ValueError(f"Unknown writer type: {writer_type}") class FileSystemWriteService(Writer): - def __init__(self, location: str = "./tests/data", filename: str = "system_card.yaml") -> None: + def __init__(self, location: str | Path = "./tests/data", filename: str = "system_card.yaml") -> None: self.location = location if not filename.endswith(".yaml"): raise ValueError(f"Filename {filename} must end with .yaml instead of .{filename.split('.')[-1]}") diff --git a/tad/services/tasks.py b/tad/services/tasks.py index 363b821d..13c70ac4 100644 --- a/tad/services/tasks.py +++ b/tad/services/tasks.py @@ -52,8 +52,11 @@ def move_task( self.system_card.title = task.title self.storage_writer.write(self.system_card.model_dump()) + if not isinstance(status.id, int): + raise TypeError("status_id must be an integer") # pragma: no cover + # assign the task to the current user - if status.name == "in_progress": + if status.id > 1: task.user_id = 1 # update the status for the task (this may not be needed if the status has not changed) diff --git a/tad/site/templates/default_layout.jinja b/tad/site/templates/default_layout.jinja index 2dedbb79..cbea94c6 100644 --- a/tad/site/templates/default_layout.jinja +++ b/tad/site/templates/default_layout.jinja @@ -22,6 +22,7 @@ + diff --git a/tests/api/routes/test_pages.py b/tests/api/routes/test_pages.py index c177102a..8111d9d9 100644 --- a/tests/api/routes/test_pages.py +++ b/tests/api/routes/test_pages.py @@ -1,17 +1,17 @@ from fastapi.testclient import TestClient +from tests.constants import all_statusses, default_task from tests.database_test_utils import DatabaseTestUtils def test_get_main_page(client: TestClient, db: DatabaseTestUtils) -> None: - db.init( - [ - {"table": "status", "id": 1}, - {"table": "task", "id": 1, "status_id": 1}, - {"table": "task", "id": 2, "status_id": 1}, - ] - ) + # given + db.given([*all_statusses(), default_task()]) + + # when response = client.get("/pages/") + + # then assert response.status_code == 200 assert response.headers["content-type"] == "text/html; charset=utf-8" assert b"" in response.content diff --git a/tests/api/routes/test_root.py b/tests/api/routes/test_root.py index 6ff8b1f8..735bb625 100644 --- a/tests/api/routes/test_root.py +++ b/tests/api/routes/test_root.py @@ -4,15 +4,7 @@ def test_get_root(client: TestClient) -> None: response = client.get( "/", + follow_redirects=False, ) # todo (robbert) this is a quick test to see if we (most likely) get the expected page - assert response.status_code == 200 - assert response.headers["content-type"] == "text/html; charset=utf-8" - assert b"Transparency of Algorithmic Decision making (TAD)" in response.content - - -def test_get_favicon(client: TestClient) -> None: - response = client.get( - "/favicon.ico", - ) - assert response.status_code == 200 + assert response.status_code == 307 diff --git a/tests/api/routes/test_static.py b/tests/api/routes/test_static.py index b2d1a367..2146a5c1 100644 --- a/tests/api/routes/test_static.py +++ b/tests/api/routes/test_static.py @@ -1,14 +1,11 @@ -import pytest from fastapi.testclient import TestClient -@pytest.mark.skip(reason="Not working yet") def test_static_css(client: TestClient) -> None: - response = client.get("/static/styles.css") + response = client.get("/static/css/layout.css") assert response.status_code == 200 -@pytest.mark.skip(reason="Not working yet") def test_static_js(client: TestClient) -> None: - response = client.get("/static/main.js") + response = client.get("/static/js/tad.js") assert response.status_code == 200 diff --git a/tests/api/routes/test_status.py b/tests/api/routes/test_status.py index 9a45fd01..2992a300 100644 --- a/tests/api/routes/test_status.py +++ b/tests/api/routes/test_status.py @@ -1,19 +1,16 @@ from fastapi.testclient import TestClient from tad.schema.task import MovedTask +from tests.constants import all_statusses, default_task from tests.database_test_utils import DatabaseTestUtils def test_post_move_task(client: TestClient, db: DatabaseTestUtils) -> None: - db.init( - [ - {"table": "status", "id": 2}, - {"table": "task", "id": 1, "status_id": 2}, - {"table": "task", "id": 2, "status_id": 2}, - {"table": "task", "id": 3, "status_id": 2}, - ] - ) + db.given([*all_statusses()]) + db.given([default_task(), default_task(), default_task()]) + move_task: MovedTask = MovedTask(taskId=2, statusId=2, previousSiblingId=1, nextSiblingId=3) + response = client.patch("/tasks/", json=move_task.model_dump(by_alias=True)) assert response.status_code == 200 assert response.headers["content-type"] == "text/html; charset=utf-8" @@ -21,17 +18,12 @@ def test_post_move_task(client: TestClient, db: DatabaseTestUtils) -> None: def test_post_move_task_no_siblings(client: TestClient, db: DatabaseTestUtils) -> None: - db.init( - [ - {"table": "status", "id": 2}, - {"table": "status", "id": 1}, - {"table": "task", "id": 1, "status_id": 2}, - {"table": "task", "id": 2, "status_id": 2}, - {"table": "task", "id": 3, "status_id": 2}, - ] - ) + db.given([*all_statusses()]) + db.given([default_task(), default_task(), default_task()]) + move_task: MovedTask = MovedTask(taskId=2, statusId=1, previousSiblingId=-1, nextSiblingId=-1) response = client.patch("/tasks/", json=move_task.model_dump(by_alias=True)) + assert response.status_code == 200 assert response.headers["content-type"] == "text/html; charset=utf-8" assert b'class="progress_card_container"' in response.content diff --git a/tests/api/routes/test_tasks_move.py b/tests/api/routes/test_tasks_move.py index d60eac05..2b9a5d71 100644 --- a/tests/api/routes/test_tasks_move.py +++ b/tests/api/routes/test_tasks_move.py @@ -1,16 +1,13 @@ from fastapi.testclient import TestClient +from tests.constants import all_statusses, default_task from tests.database_test_utils import DatabaseTestUtils def test_post_task_move(client: TestClient, db: DatabaseTestUtils) -> None: - db.init( - [ - {"table": "status", "id": 1}, - {"table": "task", "id": 1, "status_id": 1}, - {"table": "task", "id": 2, "status_id": 1}, - ] - ) + db.given([*all_statusses()]) + db.given([default_task(), default_task(), default_task()]) + response = client.patch( "/tasks/", json={"taskId": "1", "statusId": "1", "previousSiblingId": "2", "nextSiblingId": "-1"} ) @@ -19,8 +16,7 @@ def test_post_task_move(client: TestClient, db: DatabaseTestUtils) -> None: assert b'id="card-1"' in response.content -def test_task_move_error(client: TestClient, db: DatabaseTestUtils) -> None: - db.init() +def test_task_move_error(client: TestClient) -> None: response = client.patch( "/tasks/", json={"taskId": "1", "statusId": "1", "previousSiblingId": "2", "nextSiblingId": "-1"} ) diff --git a/tests/conftest.py b/tests/conftest.py index 1dc12247..8673b097 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,79 +1,67 @@ import os -import urllib -from collections.abc import Generator +from collections.abc import Generator, Iterator from multiprocessing import Process -from time import sleep from typing import Any -from urllib.error import URLError +import httpx import pytest -import uvicorn from _pytest.fixtures import SubRequest from fastapi.testclient import TestClient from playwright.sync_api import Page, Playwright, sync_playwright -from sqlmodel import Session -from tad.core.config import settings +from sqlmodel import Session, SQLModel +from tad.core.config import Settings, get_settings from tad.core.db import get_engine from tad.main import app +from uvicorn.main import run as uvicorn_run from tests.database_test_utils import DatabaseTestUtils -class TestSettings: - HTTP_SERVER_SCHEME: str = "http://" - HTTP_SERVER_HOST: str = "127.0.0.1" - HTTP_SERVER_PORT: int = 8000 - RETRY: int = 10 - - -def run_server() -> None: - uvicorn.run(app, host=TestSettings.HTTP_SERVER_HOST, port=TestSettings.HTTP_SERVER_PORT) - - -def wait_for_server_ready(server: Generator[Any, Any, Any]) -> None: - for _ in range(TestSettings.RETRY): - try: - # we use urllib instead of playwright, because we only want a simple request - # not a full page with all assets - assert urllib.request.urlopen(server).getcode() == 200 # type: ignore # noqa - break - # todo (robbert) find out what exception to catch - except URLError: # server was not ready - sleep(1) +def run_uvicorn(uvicorn: Any) -> None: + uvicorn_run(app, host=uvicorn["host"], port=uvicorn["port"]) @pytest.fixture(scope="module") -def server() -> Generator[Any, Any, Any]: - # todo (robbert) use a better way to get the test database in the app configuration - os.environ["APP_DATABASE_FILE"] = "database.sqlite3.test" - process = Process(target=run_server) +def run_server(request: pytest.FixtureRequest) -> Generator[Any, None, None]: + uvicorn_settings = request.config.uvicorn # type: ignore + + process = Process(target=run_uvicorn, args=(uvicorn_settings,)) # type: ignore process.start() - server_address = ( - f"{TestSettings.HTTP_SERVER_SCHEME}" f"{TestSettings.HTTP_SERVER_HOST}:{TestSettings.HTTP_SERVER_PORT}" - ) - yield server_address + yield f"http://{uvicorn_settings['host']}:{uvicorn_settings['port']}" process.terminate() - del os.environ["APP_DATABASE_FILE"] -@pytest.fixture(scope="session") +@pytest.fixture() def get_session() -> Generator[Session, Any, Any]: - with Session(get_engine()) as session: + with Session(get_engine(), expire_on_commit=False) as session: yield session -def pytest_configure() -> None: - """ - Called after the Session object has been created and - before performing collection and entering the run test loop. - """ - # todo (robbert) creating an in memory database does not work right, tables seem to get lost? - settings.APP_DATABASE_FILE = "database.sqlite3.test" # set to none so we'll use an in memory database +def pytest_configure(config: pytest.Config) -> None: + os.environ.clear() # lets always start with a clean environment to make tests consistent + os.environ["ENVIRONMENT"] = "local" + os.environ["APP_DATABASE_SCHEME"] = "sqlite" + config.uvicorn = { # type: ignore + "host": "127.0.0.1", + "port": 8756, + } -@pytest.fixture(scope="module") + +def pytest_sessionstart(session: pytest.Session) -> None: + get_settings.cache_clear() + get_engine.cache_clear() + SQLModel.metadata.create_all(get_engine()) + + +def pytest_sessionfinish(session: pytest.Session) -> None: + SQLModel.metadata.drop_all(get_engine()) + + +@pytest.fixture(scope="session") def client() -> Generator[TestClient, None, None]: with TestClient(app, raise_server_exceptions=True) as c: + # app.dependency_overrides[get_app_session] = get_session # noqa: ERA001 c.timeout = 5 yield c @@ -84,16 +72,50 @@ def playwright(): yield p -@pytest.fixture(params=["chromium", "firefox", "webkit"]) -def browser(playwright: Playwright, request: SubRequest, server: Generator[Any, Any, Any]) -> Generator[Page, Any, Any]: +@pytest.fixture(params=["chromium"]) # lets start with 1 browser for now, we can add more later +def browser( + playwright: Playwright, request: SubRequest, run_server: Generator[str, Any, Any] +) -> Generator[Page, Any, Any]: browser = getattr(playwright, request.param).launch(headless=True) - context = browser.new_context(base_url=server) + context = browser.new_context(base_url=run_server) page = context.new_page() - wait_for_server_ready(server) + + transport = httpx.HTTPTransport(retries=5, local_address="127.0.0.1") + with httpx.Client(transport=transport, verify=False) as client: # noqa: S501 + client.get(f"{run_server}/", timeout=0.8) + yield page browser.close() @pytest.fixture() -def db(get_session: Generator[Session, Any, Any]): - return DatabaseTestUtils(get_session) +def db(get_session: Session) -> Generator[DatabaseTestUtils, None, None]: + database = DatabaseTestUtils(get_session) + yield database + del database + + +@pytest.fixture() +def patch_settings(request: pytest.FixtureRequest) -> Iterator[Settings]: + settings = get_settings() + original_settings = settings.model_copy() + + vars_to_patch = getattr(request, "param", {}) + + for k, v in settings.model_fields.items(): + setattr(settings, k, v.default) + + for key, val in vars_to_patch.items(): + if not hasattr(settings, key): + raise ValueError(f"Unknown setting: {key}") + + # Raise an error if the env var has an invalid type + expected_type = getattr(settings, key).__class__ + if not isinstance(val, expected_type): + raise ValueError(f"Invalid type for {key}: {val.__class__} instead " "of {expected_type}") # noqa: TRY004 + setattr(settings, key, val) + + yield settings + + # Restore the original settings + settings.__dict__.update(original_settings.__dict__) diff --git a/tests/constants.py b/tests/constants.py new file mode 100644 index 00000000..69ef56d1 --- /dev/null +++ b/tests/constants.py @@ -0,0 +1,39 @@ +from tad.models import Status, Task, User + + +def default_status(): + return Status(name="Todo", sort_order=1) + + +def todo_status() -> Status: + return default_status() + + +def in_progress_status() -> Status: + return Status(name="In progress", sort_order=2) + + +def in_review_status() -> Status: + return Status(name="In review", sort_order=3) + + +def done_status() -> Status: + return Status(name="Done", sort_order=4) + + +def all_statusses() -> list[Status]: + return [todo_status(), in_progress_status(), in_review_status(), done_status()] + + +def default_user(name: str = "default user", avatar: str | None = None) -> User: + return User(name=name, avatar=avatar) + + +def default_task( + title: str = "Default Task", + description: str = "My default task", + sort_order: float = 1.0, + status_id: int | None = None, + user_id: int | None = None, +) -> Task: + return Task(title=title, description=description, sort_order=sort_order, status_id=status_id, user_id=user_id) diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 3ca07a16..2cad3c0e 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -4,44 +4,33 @@ def test_default_settings(): - settings = Settings(_env_file="nonexisitingfile") # type: ignore + settings = Settings(_env_file=None) # type: ignore - assert settings.DOMAIN == "localhost" assert settings.ENVIRONMENT == "local" - assert settings.server_host == "http://localhost" - assert settings.VERSION == "0.1.0" assert settings.LOGGING_LEVEL == "INFO" - assert settings.PROJECT_NAME == "TAD" - assert settings.PROJECT_DESCRIPTION == "Transparency of Algorithmic Decision making" assert settings.APP_DATABASE_SCHEME == "sqlite" - # todo (robbert) we change the database for the test and use the default config - assert settings.SQLALCHEMY_DATABASE_URI == "sqlite:///database.sqlite3" + assert settings.APP_DATABASE_SERVER == "db" + assert settings.APP_DATABASE_PORT == 5432 + assert settings.APP_DATABASE_USER == "tad" + assert settings.APP_DATABASE_DB == "tad" def test_environment_settings(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setenv("DOMAIN", "google.com") monkeypatch.setenv("ENVIRONMENT", "production") - monkeypatch.setenv("PROJECT_NAME", "TAD2") monkeypatch.setenv("SECRET_KEY", "mysecret") monkeypatch.setenv("APP_DATABASE_SCHEME", "postgresql") monkeypatch.setenv("APP_DATABASE_USER", "tad2") monkeypatch.setenv("APP_DATABASE_DB", "tad2") monkeypatch.setenv("APP_DATABASE_PASSWORD", "mypassword") - settings = Settings(_env_file="nonexisitingfile") # type: ignore + settings = Settings(_env_file=None) # type: ignore assert settings.SECRET_KEY == "mysecret" # noqa: S105 - assert settings.DOMAIN == "google.com" assert settings.ENVIRONMENT == "production" - assert settings.server_host == "https://google.com" - assert settings.VERSION == "0.1.0" assert settings.LOGGING_LEVEL == "INFO" - assert settings.PROJECT_NAME == "TAD2" - assert settings.PROJECT_DESCRIPTION == "Transparency of Algorithmic Decision making" assert settings.APP_DATABASE_SCHEME == "postgresql" assert settings.APP_DATABASE_SERVER == "db" assert settings.APP_DATABASE_PORT == 5432 assert settings.APP_DATABASE_USER == "tad2" - assert settings.APP_DATABASE_PASSWORD == "mypassword" # noqa: S105 assert settings.APP_DATABASE_DB == "tad2" assert settings.SQLALCHEMY_DATABASE_URI == "postgresql://tad2:mypassword@db:5432/tad2" @@ -49,15 +38,27 @@ def test_environment_settings(monkeypatch: pytest.MonkeyPatch): def test_environment_settings_production_sqlite_error(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("ENVIRONMENT", "production") monkeypatch.setenv("APP_DATABASE_SCHEME", "sqlite") - monkeypatch.setenv("APP_DATABASE_PASSWORD", "32452345432") with pytest.raises(SettingsError) as e: - _settings = Settings(_env_file="nonexisitingfile") # type: ignore + _settings = Settings(_env_file=None) # type: ignore + + assert e.value.message == "APP_DATABASE_SCHEME=SQLITE is not supported in production" + + +def test_environment_settings_production_debug_error(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("ENVIRONMENT", "production") + monkeypatch.setenv("DEBUG", "True") + monkeypatch.setenv("APP_DATABASE_SCHEME", "postgresql") + with pytest.raises(SettingsError) as e: + _settings = Settings(_env_file=None) # type: ignore - assert e.value.message == "SQLite is not supported in production" + assert e.value.message == "DEBUG=True is not supported in production" -def test_environment_settings_production_nopassword_error(monkeypatch: pytest.MonkeyPatch): +def test_environment_settings_production_autocreate_error(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("ENVIRONMENT", "production") + monkeypatch.setenv("AUTO_CREATE_SCHEMA", "True") + monkeypatch.setenv("APP_DATABASE_SCHEME", "postgresql") + with pytest.raises(SettingsError) as e: + _settings = Settings(_env_file=None) # type: ignore - with pytest.raises(SettingsError): - _settings = Settings(_env_file="nonexisitingfile") # type: ignore + assert e.value.message == "AUTO_CREATE_SCHEMA=True is not supported in production" diff --git a/tests/core/test_db.py b/tests/core/test_db.py index 23470d99..dbc7a8a9 100644 --- a/tests/core/test_db.py +++ b/tests/core/test_db.py @@ -1,15 +1,67 @@ -from unittest.mock import Mock, patch +import logging +from unittest.mock import MagicMock import pytest from sqlmodel import Session, select -from tad.core.db import check_db +from tad.core.config import Settings +from tad.core.db import check_db, init_db +from tad.models import Status, Task, User +logger = logging.getLogger(__name__) -@pytest.mark.skip(reason="not working yet") -async def test_check_database(): - mock_session = Mock(spec=Session) - with patch("sqlmodel.Session", return_value=mock_session): - await check_db() +def test_check_database(): + org_exec = Session.exec + Session.exec = MagicMock() + check_db() - assert mock_session.exec.assert_called_once_with(select(1)) + assert Session.exec.call_args is not None + assert str(select(1)) == str(Session.exec.call_args.args[0]) + Session.exec = org_exec + + +@pytest.mark.parametrize( + "patch_settings", + [{"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}], + indirect=True, +) +def test_init_database_none(patch_settings: Settings): + org_exec = Session.exec + Session.exec = MagicMock() + Session.exec.return_value.first.return_value = None + + init_db() + + expected = [ + (select(User).where(User.name == "Robbert"),), + (select(Status).where(Status.name == "Todo"),), + (select(Task).where(Task.title == "First task"),), + ] + + for i, call_args in enumerate(Session.exec.call_args_list): + assert str(expected[i][0]) == str(call_args.args[0]) + + Session.exec = org_exec + + +@pytest.mark.parametrize( + "patch_settings", + [{"ENVIRONMENT": "demo", "AUTO_CREATE_SCHEMA": True}], + indirect=True, +) +def test_init_database(patch_settings: Settings): + org_exec = Session.exec + Session.exec = MagicMock() + + init_db() + + expected = [ + (select(User).where(User.name == "Robbert"),), + (select(Status).where(Status.name == "Todo"),), + (select(Task).where(Task.title == "First task"),), + ] + + for i, call_args in enumerate(Session.exec.call_args_list): + assert str(expected[i][0]) == str(call_args.args[0]) + + Session.exec = org_exec diff --git a/tests/core/test_log.py b/tests/core/test_log.py index 4aae5629..09bc31f3 100644 --- a/tests/core/test_log.py +++ b/tests/core/test_log.py @@ -24,22 +24,6 @@ def test_logging_tad_module(caplog: pytest.LogCaptureFixture): assert caplog.records[0].message == message -def test_logging_root(caplog: pytest.LogCaptureFixture): - configure_logging() - - logger = logging.getLogger("") - - message = "This is a test log message" - logger.debug(message) - logger.info(message) - logger.warning(message) # defaults to warning - logger.error(message) - logger.critical(message) - - assert len(caplog.records) == 3 - assert caplog.records[0].message == message - - def test_logging_submodule(caplog: pytest.LogCaptureFixture): config = {"loggers": {"tad": {"propagate": True}}} diff --git a/tests/database_test_utils.py b/tests/database_test_utils.py index 0ad228f8..539efb60 100644 --- a/tests/database_test_utils.py +++ b/tests/database_test_utils.py @@ -1,113 +1,31 @@ -from collections.abc import Generator -from typing import Any - -from sqlalchemy import text +from pydantic import BaseModel from sqlmodel import Session, SQLModel from tad.core.db import get_engine class DatabaseTestUtils: - """ - Class to use for testing database calls. On creation, this class destroys and recreates the database tables. - """ - - def __init__(self, session: Generator[Session, Any, Any]): - self.clear() - self.session: Generator[Session, Any, Any] = session - - def clear(self) -> None: - """ - Drops and recreates the database tables. - :return: None - """ + def __init__(self, session: Session) -> None: SQLModel.metadata.drop_all(get_engine()) SQLModel.metadata.create_all(get_engine()) + self.session: Session = session + self.models: list[BaseModel] = [] - def _enrich_with_default_values(self, specification: dict[str, str | int]) -> dict[str, str | int]: - """ - If a known table dictionary is given, like a task or status, default values will be added - and an enriched dictionary is returned. - :param specification: the dictionary to be enriched - :return: an enriched dictionary - """ - default_specification: dict[str, str | int] = {} - if specification["table"] == "task": - default_specification["title"] = "Test task " + str(specification["id"]) - default_specification["description"] = "Test task description " + str(specification["id"]) - default_specification["sort_order"] = specification["id"] - default_specification["status_id"] = 1 - elif specification["table"] == "status": - default_specification["name"] = "Status " + str(specification["id"]) - default_specification["sort_order"] = specification["id"] - return default_specification | specification - - def _fix_missing_relations(self, specification: dict[str, Any]) -> None: - """ - If a dictionary with a known table is given, like a task, the related item, - for example a status, will be created in the database if the id does not - exist yet. We do this to comply with database relationships and make it - easier to set up tests with minimal effort. - :param specification: a dictionary with a table specification - :return: None - """ - if specification["table"] == "task": - status_specification = {"id": specification["status_id"], "table": "status"} - if not self.item_exists(status_specification): - self.init([status_specification]) - - def get_items(self, specification: dict[str, str | int]) -> Any: - """ - Create a query based on the dictionary specification and return the result - :param specification: a dictionary with a table specification - :return: the results of the query - """ - values = ", ".join( - key + "=" + str(val) if str(val).isnumeric() else str('"' + val + '"') # type: ignore - for key, val in specification.items() # type: ignore - if key != "table" # type: ignore - ) - table = specification["table"] - statement = f"SELECT * FROM {table} WHERE {values}" # noqa S608 - return self.session.exec(text(statement)).all() # type: ignore - - def item_exists(self, specification: dict[str, Any]) -> bool: - """ - Check if an item exists in the database with the table and id given - in the dictionary - :param specification: a dictionary with a table specification - :return: True if the item exists in the database, False otherwise - """ - result = self.get_items(specification) - return len(result) != 0 - - def init(self, specifications: list[dict[str, str | int]] | None = None) -> None: - """ - Given an array of specifications, create the database entries. - - Example: [{'table': 'task', 'id': 1 'title': 'Test task 1', 'description': 'Test task description 1'}] + def __del__(self): + for model in self.models: + try: + self.session.delete(model) + self.session.commit() + except Exception: # noqa: S110 + pass - The example below will be enriched so all required fields for a task will have a value. + def given(self, models: list[BaseModel]) -> None: + self.models.extend(models) + self.session.add_all(models) - Example: [{'table': 'task', 'id': 1}] + self.session.commit() - Example: [{"table": "status", "id": 1},{"table": "task", "id": 1, "status_id": 1}] + for model in models: + self.session.refresh(model) # inefficient, but needed to create correlations between models - :param specifications: an array of dictionaries with table specifications - :return: None - """ - if specifications is None: - return - for specification in specifications: - specification = self._enrich_with_default_values(specification) - exists_specification = {"table": specification["table"], "id": specification["id"]} - if not self.item_exists(exists_specification): - self._fix_missing_relations(specification) - table = specification.pop("table") - keys = ", ".join(key for key in specification) - values = ", ".join( - str(val) if str(val).isnumeric() else str("'" + val + "'") # type: ignore - for val in specification.values() # type: ignore - ) - statement = f"INSERT INTO {table} ({keys}) VALUES ({values})" # noqa S608 - self.session.exec(text(statement)) # type: ignore - self.session.commit() # type: ignore + def get_session(self) -> Session: + return self.session diff --git a/tests/e2e/test_move_task.py b/tests/e2e/test_move_task.py index 764cbffb..e796b76b 100644 --- a/tests/e2e/test_move_task.py +++ b/tests/e2e/test_move_task.py @@ -1,5 +1,6 @@ from playwright.sync_api import Page, expect +from tests.constants import all_statusses, default_task from tests.database_test_utils import DatabaseTestUtils @@ -10,15 +11,9 @@ def test_move_task_to_column(browser: Page, db: DatabaseTestUtils) -> None: :param start_server: the start server fixture :return: None """ - db.init( - [ - {"table": "status", "id": 1}, - {"table": "status", "id": 2}, - {"table": "status", "id": 3}, - {"table": "status", "id": 4}, - {"table": "task", "id": 1, "status_id": 1}, - ] - ) + all_status = all_statusses() + db.given([*all_status]) + db.given([default_task(status_id=all_status[0].id)]) browser.goto("/pages/") @@ -42,13 +37,14 @@ def test_move_task_order_in_same_column(browser: Page, db: DatabaseTestUtils) -> it is in the right position in the column. :return: None """ - db.init( - [ - {"table": "task", "id": 1, "status_id": 1}, - {"table": "task", "id": 2, "status_id": 1}, - {"table": "task", "id": 3, "status_id": 1}, - ] - ) + all_status = [*all_statusses()] + db.given([*all_status]) + + task1 = default_task(status_id=all_status[0].id) + task2 = default_task(status_id=all_status[0].id) + task3 = default_task(status_id=all_status[0].id) + + db.given([task1, task2, task3]) browser.goto("/pages/") diff --git a/tests/repositories/test_statuses.py b/tests/repositories/test_statuses.py index ac28237d..a2f77d38 100644 --- a/tests/repositories/test_statuses.py +++ b/tests/repositories/test_statuses.py @@ -1,63 +1,83 @@ import pytest -from sqlmodel import Session from tad.core.exceptions import RepositoryError from tad.models import Status from tad.repositories.statuses import StatusesRepository +from tests.constants import in_progress_status, todo_status from tests.database_test_utils import DatabaseTestUtils -def test_find_all(get_session: Session, db: DatabaseTestUtils): - db.init( - [ - {"table": "status", "id": 1}, - {"table": "status", "id": 2}, - ] - ) - status_repository: StatusesRepository = StatusesRepository(get_session) +def test_find_all(db: DatabaseTestUtils): + db.given([todo_status(), in_progress_status()]) + status_repository: StatusesRepository = StatusesRepository(db.get_session()) results = status_repository.find_all() assert results[0].id == 1 assert results[1].id == 2 assert len(results) == 2 -def test_find_all_no_results(get_session: Session, db: DatabaseTestUtils): - db.init() - status_repository: StatusesRepository = StatusesRepository(get_session) +def test_find_all_no_results(db: DatabaseTestUtils): + status_repository: StatusesRepository = StatusesRepository(db.get_session()) results = status_repository.find_all() assert len(results) == 0 -def test_save(get_session: Session, db: DatabaseTestUtils): - db.init() - status_repository: StatusesRepository = StatusesRepository(get_session) +def test_save(db: DatabaseTestUtils): + status_repository: StatusesRepository = StatusesRepository(db.get_session()) status: Status = Status(id=1, name="test", sort_order=10) status_repository.save(status) - result = db.get_items({"table": "status", "id": 1}) - assert result[0].id == 1 - assert result[0].name == "test" - assert result[0].sort_order == 10 + result = status_repository.find_by_id(1) -def test_save_failed(get_session: Session, db: DatabaseTestUtils): - db.init() - status_repository: StatusesRepository = StatusesRepository(get_session) + status_repository.delete(status) # cleanup + + assert result.id == 1 + assert result.name == "test" + assert result.sort_order == 10 + + +def test_delete(db: DatabaseTestUtils): + status_repository: StatusesRepository = StatusesRepository(db.get_session()) status: Status = Status(id=1, name="test", sort_order=10) status_repository.save(status) - status: Status = Status(id=1, name="test has duplicate id", sort_order=10) + status_repository.delete(status) + + results = status_repository.find_all() + + assert len(results) == 0 + + +def test_save_failed(db: DatabaseTestUtils): + status_repository: StatusesRepository = StatusesRepository(db.get_session()) + status: Status = Status(id=1, name="test", sort_order=10) + status_duplicate: Status = Status(id=1, name="test has duplicate id", sort_order=10) + + status_repository.save(status) + + with pytest.raises(RepositoryError): + status_repository.save(status_duplicate) + + status_repository.delete(status) # cleanup + + +def test_delete_failed(db: DatabaseTestUtils): + status_repository: StatusesRepository = StatusesRepository(db.get_session()) + status: Status = Status(id=1, name="test", sort_order=10) + with pytest.raises(RepositoryError): - status_repository.save(status) + status_repository.delete(status) + +def test_find_by_id(db: DatabaseTestUtils): + status = todo_status() + db.given([status]) -def test_find_by_id(get_session: Session, db: DatabaseTestUtils): - db.init([{"table": "status", "id": 1, "name": "test for find by id"}]) - status_repository: StatusesRepository = StatusesRepository(get_session) + status_repository: StatusesRepository = StatusesRepository(db.get_session()) result: Status = status_repository.find_by_id(1) assert result.id == 1 - assert result.name == "test for find by id" + assert result.name == status.name -def test_find_by_id_failed(get_session: Session, db: DatabaseTestUtils): - db.init() - status_repository: StatusesRepository = StatusesRepository(get_session) +def test_find_by_id_failed(db: DatabaseTestUtils): + status_repository: StatusesRepository = StatusesRepository(db.get_session()) with pytest.raises(RepositoryError): status_repository.find_by_id(1) diff --git a/tests/repositories/test_tasks.py b/tests/repositories/test_tasks.py index 56228fc9..80d8d5a6 100644 --- a/tests/repositories/test_tasks.py +++ b/tests/repositories/test_tasks.py @@ -1,80 +1,96 @@ import pytest -from sqlmodel import Session from tad.core.exceptions import RepositoryError from tad.models import Task from tad.repositories.tasks import TasksRepository +from tests.constants import all_statusses, default_task from tests.database_test_utils import DatabaseTestUtils -def test_find_all(get_session: Session, db: DatabaseTestUtils): - db.init( - [ - {"table": "task", "id": 1, "status_id": 1}, - {"table": "task", "id": 2, "status_id": 1}, - ] - ) - tasks_repository: TasksRepository = TasksRepository(get_session) +def test_find_all(db: DatabaseTestUtils): + db.given([*all_statusses()]) + db.given([default_task(), default_task()]) + + tasks_repository: TasksRepository = TasksRepository(db.get_session()) results = tasks_repository.find_all() assert results[0].id == 1 assert results[1].id == 2 assert len(results) == 2 -def test_find_all_no_results(get_session: Session, db: DatabaseTestUtils): - db.init() - tasks_repository: TasksRepository = TasksRepository(get_session) +def test_find_all_no_results(db: DatabaseTestUtils): + tasks_repository: TasksRepository = TasksRepository(db.get_session()) results = tasks_repository.find_all() assert len(results) == 0 -def test_save(get_session: Session, db: DatabaseTestUtils): - db.init() - tasks_repository: TasksRepository = TasksRepository(get_session) +def test_save(db: DatabaseTestUtils): + tasks_repository: TasksRepository = TasksRepository(db.get_session()) task: Task = Task(id=1, title="Test title", description="Test description", sort_order=10) tasks_repository.save(task) - result = db.get_items({"table": "task", "id": 1}) - assert result[0].id == 1 - assert result[0].title == "Test title" - assert result[0].description == "Test description" - assert result[0].sort_order == 10 + result = tasks_repository.find_by_id(1) + + assert result.id == 1 + assert result.title == "Test title" + assert result.description == "Test description" + assert result.sort_order == 10 + + tasks_repository.delete(task) # cleanup + + +def test_delete(db: DatabaseTestUtils): + tasks_repository: TasksRepository = TasksRepository(db.get_session()) + task: Task = Task(id=1, title="Test title", description="Test description", sort_order=10) + + tasks_repository.save(task) + tasks_repository.delete(task) # cleanup + + results = tasks_repository.find_all() + + assert len(results) == 0 -@pytest.mark.filterwarnings("ignore:New instance") -def test_save_failed(get_session: Session, db: DatabaseTestUtils): - db.init() - tasks_repository: TasksRepository = TasksRepository(get_session) +def test_save_failed(db: DatabaseTestUtils): + tasks_repository: TasksRepository = TasksRepository(db.get_session()) task: Task = Task(id=1, title="Test title", description="Test description", sort_order=10) tasks_repository.save(task) task_duplicate: Task = Task(id=1, title="Test title duplicate", description="Test description", sort_order=10) with pytest.raises(RepositoryError): tasks_repository.save(task_duplicate) + tasks_repository.delete(task) # cleanup + + +def test_delete_failed(db: DatabaseTestUtils): + tasks_repository: TasksRepository = TasksRepository(db.get_session()) + task: Task = Task(id=1, title="Test title", description="Test description", sort_order=10) + with pytest.raises(RepositoryError): + tasks_repository.delete(task) -def test_find_by_id(get_session: Session, db: DatabaseTestUtils): - db.init([{"table": "task", "id": 1, "title": "test for find by id"}]) - tasks_repository: TasksRepository = TasksRepository(get_session) + +def test_find_by_id(db: DatabaseTestUtils): + db.given([*all_statusses()]) + task = default_task() + db.given([task]) + + tasks_repository: TasksRepository = TasksRepository(db.get_session()) result: Task = tasks_repository.find_by_id(1) assert result.id == 1 - assert result.title == "test for find by id" + assert result.title == "Default Task" -def test_find_by_id_failed(get_session: Session, db: DatabaseTestUtils): - db.init() - tasks_repository: TasksRepository = TasksRepository(get_session) +def test_find_by_id_failed(db: DatabaseTestUtils): + tasks_repository: TasksRepository = TasksRepository(db.get_session()) with pytest.raises(RepositoryError): tasks_repository.find_by_id(1) -def test_find_by_status_id(get_session: Session, db: DatabaseTestUtils): - db.init( - [ - {"table": "status", "id": 1}, - {"table": "task", "id": 1, "status_id": 1}, - {"table": "task", "id": 2, "status_id": 1}, - ] - ) - tasks_repository: TasksRepository = TasksRepository(get_session) +def test_find_by_status_id(db: DatabaseTestUtils): + all_status = [*all_statusses()] + db.given([*all_status]) + task = default_task(status_id=all_status[0].id) + db.given([task, default_task()]) + + tasks_repository: TasksRepository = TasksRepository(db.get_session()) results = tasks_repository.find_by_status_id(1) - assert len(results) == 2 + assert len(results) == 1 assert results[0].id == 1 - assert results[1].id == 2 diff --git a/tests/services/test_storage.py b/tests/services/test_storage.py index d29aab8f..314343fe 100644 --- a/tests/services/test_storage.py +++ b/tests/services/test_storage.py @@ -7,12 +7,12 @@ @pytest.fixture() -def setup_and_teardown(tmp_path: Path) -> tuple[str, str]: +def setup_and_teardown(tmp_path: Path) -> tuple[str, Path]: filename = "test.yaml" - return filename, str(tmp_path.absolute()) + return filename, tmp_path.absolute() -def test_file_system_writer_empty_yaml(setup_and_teardown: tuple[str, str]) -> None: +def test_file_system_writer_empty_yaml(setup_and_teardown: tuple[str, Path]) -> None: filename, location = setup_and_teardown storage_writer = WriterFactory.get_writer(writer_type="file", location=location, filename=filename) @@ -21,7 +21,7 @@ def test_file_system_writer_empty_yaml(setup_and_teardown: tuple[str, str]) -> N assert Path.is_file(Path(location) / filename), True -def test_file_system_writer_no_location_variable(setup_and_teardown: tuple[str, str]) -> None: +def test_file_system_writer_no_location_variable(setup_and_teardown: tuple[str, Path]) -> None: filename, _ = setup_and_teardown with pytest.raises( KeyError, match="The `location` or `filename` variables are not provided as input for get_writer()" @@ -29,7 +29,7 @@ def test_file_system_writer_no_location_variable(setup_and_teardown: tuple[str, WriterFactory.get_writer(writer_type="file", filename=filename) -def test_file_system_writer_no_filename_variable(setup_and_teardown: tuple[str, str]) -> None: +def test_file_system_writer_no_filename_variable(setup_and_teardown: tuple[str, Path]) -> None: _, location = setup_and_teardown with pytest.raises( KeyError, match="The `location` or `filename` variables are not provided as input for get_writer()" @@ -37,7 +37,7 @@ def test_file_system_writer_no_filename_variable(setup_and_teardown: tuple[str, WriterFactory.get_writer(writer_type="file", location=location) -def test_file_system_writer_yaml_with_content(setup_and_teardown: tuple[str, str]) -> None: +def test_file_system_writer_yaml_with_content(setup_and_teardown: tuple[str, Path]) -> None: filename, location = setup_and_teardown data = {"test": "test"} storage_writer = WriterFactory.get_writer(writer_type="file", location=location, filename=filename) @@ -47,7 +47,19 @@ def test_file_system_writer_yaml_with_content(setup_and_teardown: tuple[str, str assert safe_load(f) == data, True -def test_file_system_writer_with_system_card(setup_and_teardown: tuple[str, str]) -> None: +def test_file_system_writer_yaml_with_content_in_dir(setup_and_teardown: tuple[str, Path]) -> None: + filename, location = setup_and_teardown + data = {"test": "test"} + + new_location = Path(location) / "new_dir" + storage_writer = WriterFactory.get_writer(writer_type="file", location=new_location, filename=filename) + storage_writer.write(data) + + with open(new_location / filename) as f: + assert safe_load(f) == data, True + + +def test_file_system_writer_with_system_card(setup_and_teardown: tuple[str, Path]) -> None: filename, location = setup_and_teardown data = SystemCard() data.title = "test" @@ -60,7 +72,7 @@ def test_file_system_writer_with_system_card(setup_and_teardown: tuple[str, str] assert safe_load(f) == data_dict, True -def test_abstract_writer_non_yaml_filename(setup_and_teardown: tuple[str, str]) -> None: +def test_abstract_writer_non_yaml_filename(setup_and_teardown: tuple[str, Path]) -> None: _, location = setup_and_teardown filename = "test.csv" with pytest.raises(