From bf087c35b04910554fcb95dc9b5e8e76dfeb08e3 Mon Sep 17 00:00:00 2001 From: aldbr Date: Fri, 10 Jan 2025 16:51:16 +0100 Subject: [PATCH] fix(diracx-routers): remove sqlalchemy dependency --- diracx-core/src/diracx/core/exceptions.py | 10 ++++ .../src/diracx/db/sql/sandbox_metadata/db.py | 48 ++++++++++--------- diracx-db/tests/jobs/test_sandbox_metadata.py | 3 +- diracx-routers/pyproject.toml | 1 - .../src/diracx/routers/jobs/sandboxes.py | 35 ++++++-------- 5 files changed, 51 insertions(+), 46 deletions(-) diff --git a/diracx-core/src/diracx/core/exceptions.py b/diracx-core/src/diracx/core/exceptions.py index 79834b1c..cf9fe312 100644 --- a/diracx-core/src/diracx/core/exceptions.py +++ b/diracx-core/src/diracx/core/exceptions.py @@ -44,6 +44,16 @@ def __init__(self, job_id: int, detail: str | None = None): super().__init__(f"Job {job_id} not found" + (" ({detail})" if detail else "")) +class SandboxNotFoundError(Exception): + def __init__(self, pfn: str, se_name: str, detail: str | None = None): + self.pfn: str = pfn + self.se_name: str = se_name + super().__init__( + f"Sandbox with {pfn} and {se_name} not found" + + (" ({detail})" if detail else "") + ) + + class JobError(Exception): def __init__(self, job_id, detail: str | None = None): self.job_id: int = job_id diff --git a/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py b/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py index 28462778..8ff09c6d 100644 --- a/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py +++ b/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py @@ -2,8 +2,10 @@ from typing import Any -import sqlalchemy +from sqlalchemy import Executable, delete, insert, literal, select, update +from sqlalchemy.exc import IntegrityError, NoResultFound +from diracx.core.exceptions import SandboxNotFoundError from diracx.core.models import SandboxInfo, SandboxType, UserInfo from diracx.db.sql.utils import BaseSQLDB, UTCNow @@ -17,7 +19,7 @@ class SandboxMetadataDB(BaseSQLDB): async def upsert_owner(self, user: UserInfo) -> int: """Get the id of the owner from the database.""" # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49 - stmt = sqlalchemy.select(SBOwners.OwnerID).where( + stmt = select(SBOwners.OwnerID).where( SBOwners.Owner == user.preferred_username, SBOwners.OwnerGroup == user.dirac_group, SBOwners.VO == user.vo, @@ -26,7 +28,7 @@ async def upsert_owner(self, user: UserInfo) -> int: if owner_id := result.scalar_one_or_none(): return owner_id - stmt = sqlalchemy.insert(SBOwners).values( + stmt = insert(SBOwners).values( Owner=user.preferred_username, OwnerGroup=user.dirac_group, VO=user.vo, @@ -53,7 +55,7 @@ async def insert_sandbox( """Add a new sandbox in SandboxMetadataDB.""" # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49 owner_id = await self.upsert_owner(user) - stmt = sqlalchemy.insert(SandBoxes).values( + stmt = insert(SandBoxes).values( OwnerId=owner_id, SEName=se_name, SEPFN=pfn, @@ -63,27 +65,31 @@ async def insert_sandbox( ) try: result = await self.conn.execute(stmt) - except sqlalchemy.exc.IntegrityError: + except IntegrityError: await self.update_sandbox_last_access_time(se_name, pfn) else: assert result.rowcount == 1 async def update_sandbox_last_access_time(self, se_name: str, pfn: str) -> None: stmt = ( - sqlalchemy.update(SandBoxes) + update(SandBoxes) .where(SandBoxes.SEName == se_name, SandBoxes.SEPFN == pfn) .values(LastAccessTime=UTCNow()) ) result = await self.conn.execute(stmt) assert result.rowcount == 1 - async def sandbox_is_assigned(self, pfn: str, se_name: str) -> bool: + async def sandbox_is_assigned(self, pfn: str, se_name: str) -> bool | None: """Checks if a sandbox exists and has been assigned.""" - stmt: sqlalchemy.Executable = sqlalchemy.select(SandBoxes.Assigned).where( + stmt: Executable = select(SandBoxes.Assigned).where( SandBoxes.SEName == se_name, SandBoxes.SEPFN == pfn ) result = await self.conn.execute(stmt) - is_assigned = result.scalar_one() + try: + is_assigned = result.scalar_one() + except NoResultFound as e: + raise SandboxNotFoundError(pfn, se_name) from e + return is_assigned @staticmethod @@ -97,7 +103,7 @@ async def get_sandbox_assigned_to_job( """Get the sandbox assign to job.""" entity_id = self.jobid_to_entity_id(job_id) stmt = ( - sqlalchemy.select(SandBoxes.SEPFN) + select(SandBoxes.SEPFN) .where(SandBoxes.SBId == SBEntityMapping.SBId) .where( SBEntityMapping.EntityId == entity_id, @@ -118,24 +124,20 @@ async def assign_sandbox_to_jobs( for job_id in jobs_ids: # Define the entity id as 'Entity:entity_id' due to the DB definition: entity_id = self.jobid_to_entity_id(job_id) - select_sb_id = sqlalchemy.select( + select_sb_id = select( SandBoxes.SBId, - sqlalchemy.literal(entity_id).label("EntityId"), - sqlalchemy.literal(sb_type).label("Type"), + literal(entity_id).label("EntityId"), + literal(sb_type).label("Type"), ).where( SandBoxes.SEName == se_name, SandBoxes.SEPFN == pfn, ) - stmt = sqlalchemy.insert(SBEntityMapping).from_select( + stmt = insert(SBEntityMapping).from_select( ["SBId", "EntityId", "Type"], select_sb_id ) await self.conn.execute(stmt) - stmt = ( - sqlalchemy.update(SandBoxes) - .where(SandBoxes.SEPFN == pfn) - .values(Assigned=True) - ) + stmt = update(SandBoxes).where(SandBoxes.SEPFN == pfn).values(Assigned=True) result = await self.conn.execute(stmt) assert result.rowcount == 1 @@ -143,7 +145,7 @@ async def unassign_sandboxes_to_jobs(self, jobs_ids: list[int]) -> None: """Delete mapping between jobs and sandboxes.""" for job_id in jobs_ids: entity_id = self.jobid_to_entity_id(job_id) - sb_sel_stmt = sqlalchemy.select(SandBoxes.SBId) + sb_sel_stmt = select(SandBoxes.SBId) sb_sel_stmt = sb_sel_stmt.join( SBEntityMapping, SBEntityMapping.SBId == SandBoxes.SBId ) @@ -152,19 +154,19 @@ async def unassign_sandboxes_to_jobs(self, jobs_ids: list[int]) -> None: result = await self.conn.execute(sb_sel_stmt) sb_ids = [row.SBId for row in result] - del_stmt = sqlalchemy.delete(SBEntityMapping).where( + del_stmt = delete(SBEntityMapping).where( SBEntityMapping.EntityId == entity_id ) await self.conn.execute(del_stmt) - sb_entity_sel_stmt = sqlalchemy.select(SBEntityMapping.SBId).where( + sb_entity_sel_stmt = select(SBEntityMapping.SBId).where( SBEntityMapping.SBId.in_(sb_ids) ) result = await self.conn.execute(sb_entity_sel_stmt) remaining_sb_ids = [row.SBId for row in result] if not remaining_sb_ids: unassign_stmt = ( - sqlalchemy.update(SandBoxes) + update(SandBoxes) .where(SandBoxes.SBId.in_(sb_ids)) .values(Assigned=False) ) diff --git a/diracx-db/tests/jobs/test_sandbox_metadata.py b/diracx-db/tests/jobs/test_sandbox_metadata.py index bcb1c2cc..d7b6bcbd 100644 --- a/diracx-db/tests/jobs/test_sandbox_metadata.py +++ b/diracx-db/tests/jobs/test_sandbox_metadata.py @@ -7,6 +7,7 @@ import pytest import sqlalchemy +from diracx.core.exceptions import SandboxNotFoundError from diracx.core.models import SandboxInfo, UserInfo from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB from diracx.db.sql.sandbox_metadata.schema import SandBoxes, SBEntityMapping @@ -48,7 +49,7 @@ async def test_insert_sandbox(sandbox_metadata_db: SandboxMetadataDB): db_contents = await _dump_db(sandbox_metadata_db) assert pfn1 not in db_contents async with sandbox_metadata_db: - with pytest.raises(sqlalchemy.exc.NoResultFound): + with pytest.raises(SandboxNotFoundError): await sandbox_metadata_db.sandbox_is_assigned(pfn1, "SandboxSE") # Insert the sandbox diff --git a/diracx-routers/pyproject.toml b/diracx-routers/pyproject.toml index 7bae7dd8..722c49da 100644 --- a/diracx-routers/pyproject.toml +++ b/diracx-routers/pyproject.toml @@ -26,7 +26,6 @@ dependencies = [ "httpx", "pydantic >=2.10", "uvicorn", - "sqlalchemy", "opentelemetry-api", "opentelemetry-exporter-otlp", "opentelemetry-instrumentation-fastapi", diff --git a/diracx-routers/src/diracx/routers/jobs/sandboxes.py b/diracx-routers/src/diracx/routers/jobs/sandboxes.py index 8277d697..c331ab2f 100644 --- a/diracx-routers/src/diracx/routers/jobs/sandboxes.py +++ b/diracx-routers/src/diracx/routers/jobs/sandboxes.py @@ -12,7 +12,6 @@ from pydantic import BaseModel, PrivateAttr from pydantic_settings import SettingsConfigDict from pyparsing import Any -from sqlalchemy.exc import NoResultFound from diracx.core.models import ( SandboxInfo, @@ -121,26 +120,20 @@ async def initiate_sandbox_upload( detail=f"Sandbox too large. Max size is {MAX_SANDBOX_SIZE_BYTES} bytes", ) - try: - exists_and_assigned = await sandbox_metadata_db.sandbox_is_assigned( - pfn, settings.se_name - ) - except NoResultFound: - # The sandbox doesn't exist in the database - pass - else: - # As sandboxes are registered in the DB before uploading to the storage - # backend we can't rely on their existence in the database to determine if - # they have been uploaded. Instead we check if the sandbox has been - # assigned to a job. If it has then we know it has been uploaded and we - # can avoid communicating with the storage backend. - if exists_and_assigned or s3_object_exists( - settings.s3_client, settings.bucket_name, pfn_to_key(pfn) - ): - await sandbox_metadata_db.update_sandbox_last_access_time( - settings.se_name, pfn - ) - return SandboxUploadResponse(pfn=full_pfn) + exists_and_assigned = await sandbox_metadata_db.sandbox_is_assigned( + pfn, settings.se_name + ) + + # As sandboxes are registered in the DB before uploading to the storage + # backend we can't rely on their existence in the database to determine if + # they have been uploaded. Instead we check if the sandbox has been + # assigned to a job. If it has then we know it has been uploaded and we + # can avoid communicating with the storage backend. + if exists_and_assigned or s3_object_exists( + settings.s3_client, settings.bucket_name, pfn_to_key(pfn) + ): + await sandbox_metadata_db.update_sandbox_last_access_time(settings.se_name, pfn) + return SandboxUploadResponse(pfn=full_pfn) upload_info = await generate_presigned_upload( settings.s3_client,