diff --git a/backend/scripts/create_user.py b/backend/scripts/create_user.py index dc92dedc..1263b255 100644 --- a/backend/scripts/create_user.py +++ b/backend/scripts/create_user.py @@ -9,7 +9,7 @@ parser.add_argument("--user", required=True) parser.add_argument("--pass", required=True) args = parser.parse_args() - with SessionContextManager() as session: + with SessionContextManager(path="mangement_comment:create_user") as session: try: user = create_user( session=session, username=args.user, password=getattr(args, "pass") diff --git a/backend/scripts/create_worker.py b/backend/scripts/create_worker.py index 8f66a708..536a6e5f 100644 --- a/backend/scripts/create_worker.py +++ b/backend/scripts/create_worker.py @@ -13,7 +13,7 @@ if args.token is None: args.token = utils.get_random_string() - with SessionContextManager() as session: + with SessionContextManager(path="mangement_comment:create_worker") as session: statement = select(Worker).where(Worker.token == args.token) results = session.exec(statement) existing_worker = results.one_or_none() diff --git a/backend/scripts/reset_task.py b/backend/scripts/reset_task.py index 1c2ddbe7..aa9b8f3f 100644 --- a/backend/scripts/reset_task.py +++ b/backend/scripts/reset_task.py @@ -12,7 +12,7 @@ "--uuid", required=True, type=uuid.UUID, help="Task UUID or Document UUID" ) args = parser.parse_args() - with SessionContextManager() as session: + with SessionContextManager(path="mangement_comment:reset_task") as session: task = session.execute( update(Task) .where( diff --git a/backend/scripts/set_password.py b/backend/scripts/set_password.py index f0c08d20..de63dd43 100644 --- a/backend/scripts/set_password.py +++ b/backend/scripts/set_password.py @@ -10,7 +10,7 @@ parser.add_argument("--pass", required=True) args = parser.parse_args() - with SessionContextManager() as session: + with SessionContextManager(path="mangement_comment:set_password") as session: try: user = change_user_password( session=session, username=args.user, new_password=getattr(args, "pass") diff --git a/backend/transcribee_backend/db/__init__.py b/backend/transcribee_backend/db/__init__.py index 0881edd0..2c067fae 100644 --- a/backend/transcribee_backend/db/__init__.py +++ b/backend/transcribee_backend/db/__init__.py @@ -1,8 +1,14 @@ import os from contextlib import contextmanager from pathlib import Path +from typing import Optional +from fastapi import Request +from prometheus_client import Histogram +from prometheus_fastapi_instrumentator import routing +from sqlalchemy import event from sqlmodel import Session, create_engine +from starlette.websockets import WebSocket DEFAULT_SOCKET_PATH = Path(__file__).parent.parent.parent / "db" / "sockets" @@ -13,10 +19,44 @@ engine = create_engine(DATABASE_URL) +query_histogram = Histogram( + "sql_queries", + "Number of sql queries executed per db session", + ["path"], + buckets=[1, 2, 4, 8, 16, 32, 128, 256, 512], +) + + +def get_session(request: Request): + handler = routing.get_route_name(request) + with Session(engine) as session, query_counter(session, path=handler): + yield session + -def get_session(): - with Session(engine) as session: +def get_session_ws(websocket: WebSocket): + # get_route_name is typed with a Request, but in reality a HttpConnection + # (which WebSocket is) is enough + handler = routing.get_route_name(websocket) # type: ignore + with Session(engine) as session, query_counter(session, path=handler): yield session -SessionContextManager = contextmanager(get_session) +@contextmanager +def SessionContextManager(path: str): + with Session(engine) as session, query_counter(session, path=path): + yield session + + +@contextmanager +def query_counter(session: Session, path: Optional[str]): + engine = session.connection().engine + count = 0 + + def callback(*args, **kwargs): + nonlocal count + count += 1 + + event.listen(engine, "before_cursor_execute", callback) + yield + event.remove(engine, "before_cursor_execute", callback) + query_histogram.labels(path=path).observe(count) diff --git a/backend/transcribee_backend/helpers/tasks.py b/backend/transcribee_backend/helpers/tasks.py index dd7cddac..6b78772b 100644 --- a/backend/transcribee_backend/helpers/tasks.py +++ b/backend/transcribee_backend/helpers/tasks.py @@ -57,7 +57,7 @@ def timeouted_tasks(session: Session) -> Iterable[Task]: def timeout_attempts(): now = now_tz_aware() - with SessionContextManager() as session: + with SessionContextManager(path="repeating_task:timeout_attempts") as session: for task in timeouted_tasks(session): finish_current_attempt( session=session, task=task, now=now, successful=False @@ -72,7 +72,7 @@ def expired_tokens(session: Session) -> Iterable[UserToken]: def remove_expired_tokens(): - with SessionContextManager() as session: + with SessionContextManager(path="repeating_task:remove_expired_tokens") as session: for user_token in expired_tokens(session): session.delete(user_token) diff --git a/backend/transcribee_backend/media_storage.py b/backend/transcribee_backend/media_storage.py index 37f918d9..8b83a856 100644 --- a/backend/transcribee_backend/media_storage.py +++ b/backend/transcribee_backend/media_storage.py @@ -42,7 +42,7 @@ def store_file(file: BinaryIO) -> str: return name -def force_bytes(v: bytes | str): +def force_bytes(v: bytes | str) -> bytes: if isinstance(v, str): return v.encode() return v diff --git a/backend/transcribee_backend/metrics.py b/backend/transcribee_backend/metrics.py index e6bff987..651a5918 100644 --- a/backend/transcribee_backend/metrics.py +++ b/backend/transcribee_backend/metrics.py @@ -100,7 +100,7 @@ def refresh(self, session: Session): def refresh_metrics(): - with SessionContextManager() as session: + with SessionContextManager(path="repeating_task:refresh_metrics") as session: for metric in METRICS: metric.refresh(session) diff --git a/backend/transcribee_backend/routers/document.py b/backend/transcribee_backend/routers/document.py index 649fc3af..a1519718 100644 --- a/backend/transcribee_backend/routers/document.py +++ b/backend/transcribee_backend/routers/document.py @@ -1,8 +1,8 @@ import datetime import enum +import pathlib import uuid from dataclasses import dataclass -from pathlib import Path from typing import Annotated, Callable, List, Optional import magic @@ -13,6 +13,7 @@ Form, Header, HTTPException, + Path, Query, UploadFile, WebSocket, @@ -35,7 +36,10 @@ validate_worker_authorization, ) from transcribee_backend.config import get_model_config, settings -from transcribee_backend.db import get_session +from transcribee_backend.db import ( + get_session, + get_session_ws, +) from transcribee_backend.helpers.sync import DocumentSyncConsumer from transcribee_backend.helpers.time import now_tz_aware from transcribee_backend.models.document import ( @@ -200,7 +204,7 @@ def func( def auth_fn_to_ws(f: Callable): def func( document_id: uuid.UUID, - session: Session = Depends(get_session), + session: Session = Depends(get_session_ws), authorization: Optional[str] = Query(default=None), share_token: Optional[str] = Query(default=None, alias="share_token"), ): @@ -430,7 +434,7 @@ def delete_document( auth: AuthInfo = Depends(get_doc_full_auth), session: Session = Depends(get_session), ) -> None: - paths_to_delete: List[Path] = [] + paths_to_delete: List[pathlib.Path] = [] media_files = select(DocumentMediaFile).where( DocumentMediaFile.document == auth.document ) @@ -463,7 +467,7 @@ def get_document_tasks( async def websocket_endpoint( websocket: WebSocket, auth: AuthInfo = Depends(ws_get_doc_min_readonly_or_worker_auth), - session: Session = Depends(get_session), + session: Session = Depends(get_session_ws), ): connection = DocumentSyncConsumer( document=auth.document,