From 0fe34965c241130964e5b0e57941c382a5b9fc4d Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Mon, 2 Dec 2024 09:15:38 +0100 Subject: [PATCH] @pcrespov review: remove match and use a mapping --- .../modules/db/repositories/comp_runs.py | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py index 68798d89a29..50070611394 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py @@ -1,6 +1,6 @@ import datetime import logging -from typing import Any +from typing import Any, Final import arrow import sqlalchemy as sa @@ -20,7 +20,6 @@ ClusterNotFoundError, ComputationalRunNotFoundError, DirectorError, - ProjectNotFoundError, UserNotFoundError, ) from ....models.comp_runs import CompRunsAtDB, RunMetadataDict @@ -30,6 +29,20 @@ logger = logging.getLogger(__name__) +_POSTGRES_ERROR_TO_ERROR_MAP: Final[ + dict[tuple[str, ...], tuple[type[DirectorError], tuple[str, ...]]] +] = { + ("users", "user_id"): (UserNotFoundError, ("users", "user_id")), + ("projects", "project_uuid"): ( + UserNotFoundError, + ("projects", "project_id"), + ), + ("clusters", "cluster_id"): ( + ClusterNotFoundError, + ("clusters", "cluster_id"), + ), +} + class CompRunsRepository(BaseRepository): async def get( @@ -173,15 +186,13 @@ async def create( return CompRunsAtDB.model_validate(row) except ForeignKeyViolation as exc: message = exc.args[0] - match message: - case s if "users" in s and "user_id" in s: - raise UserNotFoundError(user_id=user_id) from exc - case s if "projects" in s and "project_uuid" in s: - raise ProjectNotFoundError(project_id=project_id) from exc - case s if "clusters" in s and "cluster_id" in s: - raise ClusterNotFoundError(cluster_id=cluster_id) from exc - case _: - raise DirectorError from exc + + for pg_keys, (exc_type, exc_keys) in _POSTGRES_ERROR_TO_ERROR_MAP.items(): + if all(k in message for k in pg_keys): + raise exc_type( + **{f"{k}": locals().get(k) for k in exc_keys} + ) from exc + raise DirectorError from exc async def update( self, user_id: UserID, project_id: ProjectID, iteration: PositiveInt, **values