Skip to content

Commit

Permalink
fix(core): adapt tests and exceptions to sftkit changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mikonse committed Dec 27, 2024
1 parent dd2dee0 commit b2b510c
Show file tree
Hide file tree
Showing 11 changed files with 40 additions and 16 deletions.
2 changes: 1 addition & 1 deletion abrechnung/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ async def create_user(config: Config, name: str, email: str, skip_email_check: b
print("Passwords do not match!")
return

database = get_database(config)
database = get_database(config.database)
db_pool = await database.create_pool()
user_service = UserService(db_pool, config)
user_service.enable_registration = True
Expand Down
2 changes: 1 addition & 1 deletion abrechnung/application/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ async def delete_group(self, *, conn: Connection, user: User, group_id: int):
group_id,
)
if n_members != 1:
raise PermissionError(f"Can only delete a group when you are the last member")
raise InvalidArgument(f"Can only delete a group when you are the last member")

await conn.execute("delete from grp where id = $1", group_id)

Expand Down
9 changes: 8 additions & 1 deletion abrechnung/application/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from abrechnung.config import Config
from abrechnung.domain.users import Session, User
from abrechnung.util import is_valid_uuid

ALGORITHM = "HS256"

Expand Down Expand Up @@ -246,6 +247,8 @@ async def register_user(

@with_db_transaction
async def confirm_registration(self, *, conn: Connection, token: str) -> int:
if not is_valid_uuid(token):
raise InvalidArgument(f"Invalid confirmation token")
row = await conn.fetchrow(
"select user_id, valid_until from pending_registration where token = $1",
token,
Expand Down Expand Up @@ -342,6 +345,8 @@ async def request_email_change(self, *, conn: Connection, user: User, password:

@with_db_transaction
async def confirm_email_change(self, *, conn: Connection, token: str) -> int:
if not is_valid_uuid(token):
raise InvalidArgument(f"Invalid confirmation token")
row = await conn.fetchrow(
"select user_id, new_email, valid_until from pending_email_change where token = $1",
token,
Expand All @@ -360,7 +365,7 @@ async def confirm_email_change(self, *, conn: Connection, token: str) -> int:
async def request_password_recovery(self, *, conn: Connection, email: str):
user_id = await conn.fetchval("select id from usr where email = $1", email)
if not user_id:
raise PermissionError
raise InvalidArgument("permission denied")

await conn.execute(
"insert into pending_password_recovery (user_id) values ($1)",
Expand All @@ -369,6 +374,8 @@ async def request_password_recovery(self, *, conn: Connection, email: str):

@with_db_transaction
async def confirm_password_recovery(self, *, conn: Connection, token: str, new_password: str) -> int:
if not is_valid_uuid(token):
raise InvalidArgument(f"Invalid confirmation token")
row = await conn.fetchrow(
"select user_id, valid_until from pending_password_recovery where token = $1",
token,
Expand Down
2 changes: 1 addition & 1 deletion abrechnung/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ async def cleanup(config: Config):

deletion_threshold = datetime.now() - config.demo.wipe_interval

database = get_database(config)
database = get_database(config.database)
db_pool = await database.create_pool()
async with db_pool.acquire() as conn:
async with conn.transaction():
Expand Down
2 changes: 1 addition & 1 deletion abrechnung/mailer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, config: Config):
self.config = config
self.events: Optional[asyncio.Queue] = None
self.psql: Connection | None = None
self.database = get_database(config)
self.database = get_database(config.database)
self.mailer = None
self.logger = logging.getLogger(__name__)

Expand Down
9 changes: 9 additions & 0 deletions abrechnung/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import re
import uuid
from datetime import datetime, timedelta, timezone

postgres_timestamp_format = re.compile(
Expand Down Expand Up @@ -63,3 +64,11 @@ def log_setup(setting, default=1):
def clamp(number, smallest, largest):
"""return number but limit it to the inclusive given value range"""
return max(smallest, min(number, largest))


def is_valid_uuid(val: str):
try:
uuid.UUID(val)
return True
except ValueError:
return False
11 changes: 6 additions & 5 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@

def get_test_db_config() -> DatabaseConfig:
return DatabaseConfig(
user=os.environ.get("TEST_DB_USER", "abrechnung-test"),
password=os.environ.get("TEST_DB_PASSWORD", "asdf1234"),
host=os.environ.get("TEST_DB_HOST", "localhost"),
dbname=os.environ.get("TEST_DB_DATABASE", "abrechnung-test"),
user=os.environ.get("TEST_DB_USER"),
password=os.environ.get("TEST_DB_PASSWORD"),
host=os.environ.get("TEST_DB_HOST"),
dbname=os.environ.get("TEST_DB_DATABASE", "abrechnung_test"),
port=int(os.environ.get("TEST_DB_PORT", 5432)),
sslrootcert=None,
)


Expand Down Expand Up @@ -57,7 +58,7 @@ async def get_test_db() -> Pool:
"""
get a connection pool to the test database
"""
database = get_database(TEST_CONFIG)
database = get_database(TEST_CONFIG.database)
pool = await database.create_pool()

await reset_schema(pool)
Expand Down
7 changes: 7 additions & 0 deletions tests/http_tests/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=attribute-defined-outside-init
from httpx import ASGITransport, AsyncClient
from sftkit.http._context import ContextMiddleware

from abrechnung.http.api import Api
from tests.common import TEST_CONFIG, BaseTestCase
Expand All @@ -12,6 +13,12 @@ async def asyncSetUp(self) -> None:
self.http_service = Api(config=self.test_config)
await self.http_service._setup()

# workaround for bad testability in sftkit
self.http_service.server.api.add_middleware(
ContextMiddleware,
context=self.http_service.context,
)

self.transport = ASGITransport(app=self.http_service.server.api)
self.client = AsyncClient(transport=self.transport, base_url="https://abrechnung.sft.lol")
self.transaction_service = self.http_service.transaction_service
Expand Down
2 changes: 1 addition & 1 deletion tests/http_tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ async def test_reset_password(self):
f"/api/v1/auth/recover_password",
json={"email": "[email protected]"},
)
self.assertEqual(403, resp.status_code)
self.assertEqual(400, resp.status_code)

resp = await self.client.post(
f"/api/v1/auth/recover_password",
Expand Down
8 changes: 4 additions & 4 deletions tests/http_tests/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def test_create_group(self):
group = await self._fetch_group(group_id)
self.assertEqual("name", group["name"])

await self._fetch_group(13333, 404)
await self._fetch_group(13333, 400)

resp = await self._post(
f"/api/v1/groups/{group_id}",
Expand Down Expand Up @@ -128,12 +128,12 @@ async def test_delete_group(self):
)

resp = await self._delete(f"/api/v1/groups/{group_id}")
self.assertEqual(403, resp.status_code)
self.assertEqual(400, resp.status_code)

resp = await self._post(f"/api/v1/groups/{group_id}/leave")
self.assertEqual(204, resp.status_code)

await self._fetch_group(group_id, expected_status=404)
await self._fetch_group(group_id, expected_status=400)

resp = await self.client.delete(
f"/api/v1/groups/{group_id}",
Expand Down Expand Up @@ -345,7 +345,7 @@ async def test_get_account(self):
self.assertEqual(422, resp.status_code)

resp = await self._get(f"/api/v1/groups/{group_id}/accounts/13232")
self.assertEqual(404, resp.status_code)
self.assertEqual(400, resp.status_code)

async def test_invites(self):
group_id = await self.group_service.create_group(
Expand Down
2 changes: 1 addition & 1 deletion tools/generate_dummy_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async def main(
):
config = read_config(Path(config_path))

database = get_database(config)
database = get_database(config.database)
db_pool = await database.create_pool()
user_service = UserService(db_pool, config)
group_service = GroupService(db_pool, config)
Expand Down

0 comments on commit b2b510c

Please sign in to comment.