diff --git a/services/api-server/src/simcore_service_api_server/api/dependencies/rabbitmq.py b/services/api-server/src/simcore_service_api_server/api/dependencies/rabbitmq.py index daf9272af4a..f7a9939f817 100644 --- a/services/api-server/src/simcore_service_api_server/api/dependencies/rabbitmq.py +++ b/services/api-server/src/simcore_service_api_server/api/dependencies/rabbitmq.py @@ -1,12 +1,19 @@ -from typing import Annotated, cast +import logging +from typing import Annotated, Final, cast from fastapi import Depends, FastAPI from pydantic import NonNegativeInt +from servicelib.aiohttp.application_setup import ApplicationSetupError from servicelib.fastapi.dependencies import get_app from servicelib.rabbitmq import RabbitMQClient +from tenacity import before_sleep_log, retry, stop_after_delay, wait_fixed from ...services.log_streaming import LogDistributor +_MAX_WAIT_FOR_LOG_DISTRIBUTOR_SECONDS: Final[int] = 10 + +_logger = logging.getLogger(__name__) + def get_rabbitmq_client(app: Annotated[FastAPI, Depends(get_app)]) -> RabbitMQClient: assert app.state.rabbitmq_client # nosec @@ -18,6 +25,20 @@ def get_log_distributor(app: Annotated[FastAPI, Depends(get_app)]) -> LogDistrib return cast(LogDistributor, app.state.log_distributor) +@retry( + wait=wait_fixed(2), + stop=stop_after_delay(_MAX_WAIT_FOR_LOG_DISTRIBUTOR_SECONDS), + before_sleep=before_sleep_log(_logger, logging.WARNING), + reraise=True, +) +async def wait_till_log_distributor_ready(app) -> None: + if not hasattr(app.state, "log_distributor"): + raise ApplicationSetupError( + f"Api server's log_distributor was not ready within {_MAX_WAIT_FOR_LOG_DISTRIBUTOR_SECONDS=} seconds" + ) + return + + def get_log_check_timeout(app: Annotated[FastAPI, Depends(get_app)]) -> NonNegativeInt: assert app.state.settings # nosec return cast(NonNegativeInt, app.state.settings.API_SERVER_LOG_CHECK_TIMEOUT_SECONDS) diff --git a/services/api-server/src/simcore_service_api_server/api/errors/custom_errors.py b/services/api-server/src/simcore_service_api_server/api/errors/custom_errors.py index 73d59598ca6..0fe5f648823 100644 --- a/services/api-server/src/simcore_service_api_server/api/errors/custom_errors.py +++ b/services/api-server/src/simcore_service_api_server/api/errors/custom_errors.py @@ -19,6 +19,10 @@ class MissingWallet(CustomBaseError): pass +class ApplicationSetupError(CustomBaseError): + pass + + async def custom_error_handler(_: Request, exc: CustomBaseError): if isinstance(exc, InsufficientCredits): return JSONResponse( diff --git a/services/api-server/src/simcore_service_api_server/core/_prometheus_instrumentation.py b/services/api-server/src/simcore_service_api_server/core/_prometheus_instrumentation.py new file mode 100644 index 00000000000..3922784510a --- /dev/null +++ b/services/api-server/src/simcore_service_api_server/core/_prometheus_instrumentation.py @@ -0,0 +1,76 @@ +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Final, cast + +from fastapi import FastAPI +from prometheus_client import CollectorRegistry, Gauge +from servicelib.background_task import start_periodic_task, stop_periodic_task +from servicelib.fastapi.prometheus_instrumentation import ( + setup_prometheus_instrumentation as setup_rest_instrumentation, +) +from simcore_service_api_server.api.dependencies.rabbitmq import ( + get_log_distributor, + wait_till_log_distributor_ready, +) +from simcore_service_api_server.models.schemas.jobs import JobID + +from .._meta import PROJECT_NAME + +METRICS_NAMESPACE: Final[str] = PROJECT_NAME.replace("-", "_") + + +@dataclass(slots=True, kw_only=True) +class ApiServerPrometheusInstrumentation: + registry: CollectorRegistry + _logstreaming_queues: Gauge = field(init=False) + + def __post_init__(self) -> None: + self._logstreaming_queues = Gauge( + "log_stream_queue_length", + "#Logs in log streaming queue", + ["job_id"], + namespace=METRICS_NAMESPACE, + ) + + def update_metrics(self, log_queue_sizes: dict[JobID, int]): + self._logstreaming_queues.clear() + for job_id, length in log_queue_sizes.items(): + self._logstreaming_queues.labels(job_id=job_id).set(length) + + +async def _collect_prometheus_metrics_task(app: FastAPI): + get_instrumentation(app).update_metrics( + log_queue_sizes=get_log_distributor(app).get_log_queue_sizes() + ) + + +def setup_prometheus_instrumentation(app: FastAPI): + instrumentator = setup_rest_instrumentation(app) + + async def on_startup() -> None: + app.state.instrumentation = ApiServerPrometheusInstrumentation( + registry=instrumentator.registry + ) + await wait_till_log_distributor_ready(app) + app.state.instrumentation_task = start_periodic_task( + task=_collect_prometheus_metrics_task, + interval=timedelta( + seconds=app.state.settings.API_SERVER_PROMETHEUS_INSTRUMENTATION_COLLECT_SECONDS + ), + task_name="prometheus_metrics_collection_task", + app=app, + ) + + async def on_shutdown() -> None: + assert app.state.instrumentation_task # nosec + await stop_periodic_task(app.state.instrumentation_task) + + app.add_event_handler("startup", on_startup) + app.add_event_handler("shutdown", on_shutdown) + + +def get_instrumentation(app: FastAPI) -> ApiServerPrometheusInstrumentation: + assert ( + app.state.instrumentation + ), "Instrumentation not setup. Please check the configuration" # nosec + return cast(ApiServerPrometheusInstrumentation, app.state.instrumentation) diff --git a/services/api-server/src/simcore_service_api_server/core/application.py b/services/api-server/src/simcore_service_api_server/core/application.py index 7f31f2ecf52..9775d298825 100644 --- a/services/api-server/src/simcore_service_api_server/core/application.py +++ b/services/api-server/src/simcore_service_api_server/core/application.py @@ -5,9 +5,6 @@ from fastapi_pagination import add_pagination from httpx import HTTPError as HttpxException from models_library.basic_types import BootModeEnum -from servicelib.fastapi.prometheus_instrumentation import ( - setup_prometheus_instrumentation, -) from servicelib.logging_utils import config_all_loggers from simcore_service_api_server.api.errors.log_handling_error import ( log_handling_error_handler, @@ -30,6 +27,7 @@ from ..api.routes.health import router as health_router from ..services import catalog, director_v2, storage, webserver from ..services.rabbitmq import setup_rabbitmq +from ._prometheus_instrumentation import setup_prometheus_instrumentation from .events import create_start_app_handler, create_stop_app_handler from .openapi import override_openapi_method, use_route_names_as_operation_ids from .settings import ApplicationSettings diff --git a/services/api-server/src/simcore_service_api_server/core/settings.py b/services/api-server/src/simcore_service_api_server/core/settings.py index 8868464ff99..5e64c80c415 100644 --- a/services/api-server/src/simcore_service_api_server/core/settings.py +++ b/services/api-server/src/simcore_service_api_server/core/settings.py @@ -4,7 +4,7 @@ from typing import Any from models_library.basic_types import BootModeEnum, LogLevel -from pydantic import Field, NonNegativeInt, SecretStr, parse_obj_as +from pydantic import Field, NonNegativeInt, PositiveInt, SecretStr, parse_obj_as from pydantic.class_validators import validator from settings_library.base import BaseCustomSettings from settings_library.basic_types import PortInt, VersionTag @@ -137,6 +137,7 @@ class ApplicationSettings(BasicSettings): ) API_SERVER_LOG_CHECK_TIMEOUT_SECONDS: NonNegativeInt = 3 * 60 API_SERVER_PROMETHEUS_INSTRUMENTATION_ENABLED: bool = True + API_SERVER_PROMETHEUS_INSTRUMENTATION_COLLECT_SECONDS: PositiveInt = 5 # DEV-TOOLS API_SERVER_DEV_HTTP_CALLS_LOGS_PATH: Path | None = Field( default=None, diff --git a/services/api-server/src/simcore_service_api_server/services/log_streaming.py b/services/api-server/src/simcore_service_api_server/services/log_streaming.py index 44e6e2f2f76..457a196f7db 100644 --- a/services/api-server/src/simcore_service_api_server/services/log_streaming.py +++ b/services/api-server/src/simcore_service_api_server/services/log_streaming.py @@ -1,7 +1,7 @@ import asyncio import logging from asyncio import Queue -from typing import AsyncIterable, Awaitable, Callable, Final +from typing import AsyncIterable, Final from models_library.rabbitmq_messages import LoggerRabbitMessage from models_library.users import UserID @@ -31,7 +31,7 @@ class LogStreamerRegistionConflict(LogDistributionBaseException): class LogDistributor: def __init__(self, rabbitmq_client: RabbitMQClient): self._rabbit_client = rabbitmq_client - self._log_streamers: dict[JobID, Callable[[JobLog], Awaitable[None]]] = {} + self._log_streamers: dict[JobID, Queue] = {} self._queue_name: str async def setup(self): @@ -72,22 +72,20 @@ async def _distribute_logs(self, data: bytes): log_level=got.log_level, messages=got.messages, ) - callback = self._log_streamers.get(item.job_id) - if callback is None: + queue = self._log_streamers.get(item.job_id) + if queue is None: raise LogStreamerNotRegistered( f"Could not forward log because a logstreamer associated with job_id={item.job_id} was not registered" ) - await callback(item) + await queue.put(item) return True - async def register( - self, job_id: JobID, callback: Callable[[JobLog], Awaitable[None]] - ): + async def register(self, job_id: JobID, queue: Queue): if job_id in self._log_streamers: raise LogStreamerRegistionConflict( f"A stream was already connected to {job_id=}. Only a single stream can be connected at the time" ) - self._log_streamers[job_id] = callback + self._log_streamers[job_id] = queue await self._rabbit_client.add_topics( LoggerRabbitMessage.get_channel_name(), topics=[f"{job_id}.*"] ) @@ -100,6 +98,9 @@ async def deregister(self, job_id: JobID): ) del self._log_streamers[job_id] + def get_log_queue_sizes(self) -> dict[JobID, int]: + return {k: v.qsize() for k, v in self._log_streamers.items()} + class LogStreamer: def __init__( @@ -120,7 +121,7 @@ def __init__( self._log_check_timeout: NonNegativeInt = log_check_timeout async def setup(self): - await self._log_distributor.register(self._job_id, self._queue.put) + await self._log_distributor.register(self._job_id, self._queue) self._is_registered = True async def teardown(self): diff --git a/services/api-server/tests/unit/_with_db/conftest.py b/services/api-server/tests/unit/_with_db/conftest.py index 3570c60dbb1..5aa6e190749 100644 --- a/services/api-server/tests/unit/_with_db/conftest.py +++ b/services/api-server/tests/unit/_with_db/conftest.py @@ -154,6 +154,9 @@ def app_environment( ) -> EnvVarsDict: """app environments WITH database settings""" mocker.patch("simcore_service_api_server.core.application.setup_rabbitmq") + mocker.patch( + "simcore_service_api_server.core.application.setup_prometheus_instrumentation" + ) envs = setenvs_from_dict(monkeypatch, default_app_env_vars) assert "API_SERVER_POSTGRES" not in envs diff --git a/services/api-server/tests/unit/conftest.py b/services/api-server/tests/unit/conftest.py index 3897f0a6f7f..4644a1a30d6 100644 --- a/services/api-server/tests/unit/conftest.py +++ b/services/api-server/tests/unit/conftest.py @@ -87,6 +87,9 @@ def mock_missing_plugins(app_environment: EnvVarsDict, mocker: MockerFixture): settings = ApplicationSettings.create_from_envs() if settings.API_SERVER_RABBITMQ is None: mocker.patch("simcore_service_api_server.core.application.setup_rabbitmq") + mocker.patch( + "simcore_service_api_server.core.application.setup_prometheus_instrumentation" + ) return app_environment diff --git a/services/api-server/tests/unit/test_services_rabbitmq.py b/services/api-server/tests/unit/test_services_rabbitmq.py index d5e2cf71a66..a58d99b54e2 100644 --- a/services/api-server/tests/unit/test_services_rabbitmq.py +++ b/services/api-server/tests/unit/test_services_rabbitmq.py @@ -3,6 +3,7 @@ # pylint: disable=too-many-arguments # pylint: disable=unused-argument # pylint: disable=unused-variable +# pylint: disable=R6301 import asyncio import logging @@ -16,6 +17,7 @@ import httpx import pytest import respx +from attr import dataclass from faker import Faker from fastapi import FastAPI, status from fastapi.encoders import jsonable_encoder @@ -113,12 +115,18 @@ async def test_subscribe_publish_receive_logs( log_distributor: LogDistributor, mocker: MockerFixture, ): - async def _consumer_message_handler(job_log: JobLog): - _consumer_message_handler.called = True - _consumer_message_handler.job_log = job_log - assert isinstance(job_log, JobLog) + @dataclass + class MockQueue: + called: bool = False + job_log: JobLog | None = None - await log_distributor.register(project_id, _consumer_message_handler) + async def put(self, job_log: JobLog): + self.called = True + self.job_log = job_log + assert isinstance(job_log, JobLog) + + mock_queue = MockQueue() + await log_distributor.register(project_id, mock_queue) # type: ignore # log producer rabbitmq_producer = create_rabbitmq_client("pytest_producer") @@ -128,16 +136,14 @@ async def _consumer_message_handler(job_log: JobLog): node_id=node_id, messages=[faker.text() for _ in range(10)], ) - _consumer_message_handler.called = False - _consumer_message_handler.job_log = None await rabbitmq_producer.publish(log_message.channel_name, log_message) # check it received await asyncio.sleep(1) await log_distributor.deregister(project_id) - assert _consumer_message_handler.called - job_log = _consumer_message_handler.job_log + assert mock_queue.called + job_log = mock_queue.job_log assert isinstance(job_log, JobLog) assert job_log.job_id == log_message.project_id @@ -147,12 +153,13 @@ async def rabbit_consuming_context( app: FastAPI, project_id: ProjectID, ) -> AsyncIterable[AsyncMock]: - consumer_message_handler = AsyncMock() + queue = asyncio.Queue() + queue.put = AsyncMock() log_distributor: LogDistributor = get_log_distributor(app) - await log_distributor.register(project_id, consumer_message_handler) + await log_distributor.register(project_id, queue) - yield consumer_message_handler + yield queue.put await log_distributor.deregister(project_id) @@ -233,10 +240,12 @@ async def test_log_distributor_register_deregister( ): collected_logs: list[str] = [] - async def callback(job_log: JobLog): - for msg in job_log.messages: - collected_logs.append(msg) + class MockQueue: + async def put(self, job_log: JobLog): + for msg in job_log.messages: + collected_logs.append(msg) + queue = MockQueue() published_logs: list[str] = [] async def _log_publisher(): @@ -246,12 +255,12 @@ async def _log_publisher(): await produce_logs("expected", project_id, node_id, [msg], logging.DEBUG) published_logs.append(msg) - await log_distributor.register(project_id, callback) + await log_distributor.register(project_id, queue) # type: ignore publisher_task = asyncio.create_task(_log_publisher()) await asyncio.sleep(0.1) await log_distributor.deregister(project_id) await asyncio.sleep(0.1) - await log_distributor.register(project_id, callback) + await log_distributor.register(project_id, queue) # type: ignore await asyncio.gather(publisher_task) await asyncio.sleep(0.5) await log_distributor.deregister(project_id) @@ -274,12 +283,14 @@ async def test_log_distributor_multiple_streams( collected_logs: dict[JobID, list[str]] = {id_: [] for id_ in job_ids} - async def callback(job_log: JobLog): - job_id = job_log.job_id - assert (msgs := collected_logs.get(job_id)) is not None - for msg in job_log.messages: - msgs.append(msg) + class MockQueue: + async def put(self, job_log: JobLog): + job_id = job_log.job_id + assert (msgs := collected_logs.get(job_id)) is not None + for msg in job_log.messages: + msgs.append(msg) + queue = MockQueue() published_logs: dict[JobID, list[str]] = {id_: [] for id_ in job_ids} async def _log_publisher(): @@ -291,7 +302,7 @@ async def _log_publisher(): published_logs[job_id].append(msg) for job_id in job_ids: - await log_distributor.register(job_id, callback) + await log_distributor.register(job_id, queue) # type: ignore publisher_task = asyncio.create_task(_log_publisher()) await asyncio.gather(publisher_task) await asyncio.sleep(0.5)