From a2f9058df1ac025fc7018cd1a6bbee9d3b9e18f7 Mon Sep 17 00:00:00 2001 From: Sylvain <35365065+sanderegg@users.noreply.github.com> Date: Mon, 2 Dec 2024 11:36:01 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8Computational=20backend:=20DV-2=20c?= =?UTF-8?q?omputational=20scheduler=20becomes=20replicable=20(=F0=9F=97=83?= =?UTF-8?q?=EF=B8=8F=F0=9F=9A=A8)=20(#6736)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .coveragerc | 9 +- .../src/models_library/projects_nodes_io.py | 2 +- .../7ad64e963e0f_add_timezone_comp_tasks.py | 68 + ...6d8aa2_added_distributed_comp_scheduler.py | 33 + .../e05bdc5b3c7b_add_timezone_comp_runs.py | 87 + .../models/comp_runs.py | 37 +- .../models/comp_tasks.py | 13 +- .../src/pytest_simcore/rabbit_service.py | 23 +- .../src/servicelib/exception_utils.py | 38 +- .../src/servicelib/rabbitmq/_models.py | 6 +- .../tests/test_exception_utils.py | 59 +- .../api/dependencies/scheduler.py | 8 +- .../api/routes/computations.py | 38 +- .../core/errors.py | 4 + .../core/settings.py | 12 +- .../models/comp_pipelines.py | 4 +- .../models/comp_runs.py | 16 +- .../modules/comp_scheduler/__init__.py | 26 +- .../modules/comp_scheduler/_constants.py | 7 + .../modules/comp_scheduler/_manager.py | 173 ++ .../modules/comp_scheduler/_models.py | 19 + .../modules/comp_scheduler/_publisher.py | 32 + ...{_base_scheduler.py => _scheduler_base.py} | 347 ++-- ...{_dask_scheduler.py => _scheduler_dask.py} | 9 +- .../comp_scheduler/_scheduler_factory.py | 19 +- .../comp_scheduler/_utils.py} | 44 +- .../modules/comp_scheduler/_worker.py | 94 ++ .../modules/db/__init__.py | 12 +- .../modules/db/repositories/comp_runs.py | 142 +- .../db/repositories/comp_tasks/_utils.py | 30 +- services/director-v2/tests/conftest.py | 17 +- services/director-v2/tests/unit/_helpers.py | 77 +- services/director-v2/tests/unit/conftest.py | 5 + .../tests/unit/test_utils_comp_scheduler.py | 2 +- .../unit/with_dbs/comp_scheduler/conftest.py | 69 + .../test_db_repositories_comp_runs.py | 525 ++++++ .../with_dbs/comp_scheduler/test_manager.py | 371 +++++ .../test_scheduler_dask.py} | 1452 ++++++++++------- .../with_dbs/comp_scheduler/test_worker.py | 135 ++ .../tests/unit/with_dbs/conftest.py | 172 +- .../unit/with_dbs/test_api_route_clusters.py | 40 +- .../test_api_route_clusters_details.py | 6 +- .../with_dbs/test_api_route_computations.py | 41 +- .../test_api_route_computations_tasks.py | 14 +- .../unit/with_dbs/test_utils_rabbitmq.py | 12 +- 45 files changed, 3224 insertions(+), 1125 deletions(-) create mode 100644 packages/postgres-database/src/simcore_postgres_database/migration/versions/7ad64e963e0f_add_timezone_comp_tasks.py create mode 100644 packages/postgres-database/src/simcore_postgres_database/migration/versions/b7f23f6d8aa2_added_distributed_comp_scheduler.py create mode 100644 packages/postgres-database/src/simcore_postgres_database/migration/versions/e05bdc5b3c7b_add_timezone_comp_runs.py create mode 100644 services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_constants.py create mode 100644 services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_manager.py create mode 100644 services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_models.py create mode 100644 services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_publisher.py rename services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/{_base_scheduler.py => _scheduler_base.py} (75%) rename services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/{_dask_scheduler.py => _scheduler_dask.py} (98%) rename services/director-v2/src/simcore_service_director_v2/{utils/comp_scheduler.py => modules/comp_scheduler/_utils.py} (61%) create mode 100644 services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_worker.py create mode 100644 services/director-v2/tests/unit/with_dbs/comp_scheduler/conftest.py create mode 100644 services/director-v2/tests/unit/with_dbs/comp_scheduler/test_db_repositories_comp_runs.py create mode 100644 services/director-v2/tests/unit/with_dbs/comp_scheduler/test_manager.py rename services/director-v2/tests/unit/with_dbs/{test_modules_comp_scheduler_dask_scheduler.py => comp_scheduler/test_scheduler_dask.py} (55%) create mode 100644 services/director-v2/tests/unit/with_dbs/comp_scheduler/test_worker.py diff --git a/.coveragerc b/.coveragerc index fb3d7c12624..ebf1465b0fb 100644 --- a/.coveragerc +++ b/.coveragerc @@ -7,25 +7,20 @@ parallel = True [report] # Regexes for lines to exclude from consideration -exclude_lines = - # Have to re-enable the standard pragma - pragma: no cover - +exclude_also = # Don't complain about missing debug-only code: def __repr__ if self\.debug - # Don't complain if tests don't hit defensive assertion code: raise AssertionError raise NotImplementedError - # Don't complain if non-runnable code isn't run: if 0: if __name__ == .__main__.: if __name__ == __main__.: + class .*\bProtocol\): # Don't complain about abstract methods, they aren't run: @(abc\.)?abstract(((class|static)?method)|property) - # Don't complain about type checking if TYPE_CHECKING: diff --git a/packages/models-library/src/models_library/projects_nodes_io.py b/packages/models-library/src/models_library/projects_nodes_io.py index 3a79b6acf00..4d4637ac362 100644 --- a/packages/models-library/src/models_library/projects_nodes_io.py +++ b/packages/models-library/src/models_library/projects_nodes_io.py @@ -34,7 +34,7 @@ UUIDStr: TypeAlias = Annotated[str, StringConstraints(pattern=UUID_RE)] -NodeIDStr = UUIDStr +NodeIDStr: TypeAlias = UUIDStr LocationID = int LocationName = str diff --git a/packages/postgres-database/src/simcore_postgres_database/migration/versions/7ad64e963e0f_add_timezone_comp_tasks.py b/packages/postgres-database/src/simcore_postgres_database/migration/versions/7ad64e963e0f_add_timezone_comp_tasks.py new file mode 100644 index 00000000000..fe56f4c548f --- /dev/null +++ b/packages/postgres-database/src/simcore_postgres_database/migration/versions/7ad64e963e0f_add_timezone_comp_tasks.py @@ -0,0 +1,68 @@ +"""add_timezone_comp_tasks + +Revision ID: 7ad64e963e0f +Revises: b7f23f6d8aa2 +Create Date: 2024-11-27 22:28:51.898433+00:00 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "7ad64e963e0f" +down_revision = "b7f23f6d8aa2" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "comp_tasks", + "submit", + existing_type=postgresql.TIMESTAMP(), + type_=sa.DateTime(timezone=True), + existing_nullable=True, + ) + op.alter_column( + "comp_tasks", + "start", + existing_type=postgresql.TIMESTAMP(), + type_=sa.DateTime(timezone=True), + existing_nullable=True, + ) + op.alter_column( + "comp_tasks", + "end", + existing_type=postgresql.TIMESTAMP(), + type_=sa.DateTime(timezone=True), + existing_nullable=True, + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "comp_tasks", + "end", + existing_type=sa.DateTime(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_nullable=True, + ) + op.alter_column( + "comp_tasks", + "start", + existing_type=sa.DateTime(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_nullable=True, + ) + op.alter_column( + "comp_tasks", + "submit", + existing_type=sa.DateTime(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_nullable=True, + ) + # ### end Alembic commands ### diff --git a/packages/postgres-database/src/simcore_postgres_database/migration/versions/b7f23f6d8aa2_added_distributed_comp_scheduler.py b/packages/postgres-database/src/simcore_postgres_database/migration/versions/b7f23f6d8aa2_added_distributed_comp_scheduler.py new file mode 100644 index 00000000000..b1e5bc9f30c --- /dev/null +++ b/packages/postgres-database/src/simcore_postgres_database/migration/versions/b7f23f6d8aa2_added_distributed_comp_scheduler.py @@ -0,0 +1,33 @@ +"""added_distributed_comp_scheduler + +Revision ID: b7f23f6d8aa2 +Revises: c9db8bf5091e +Create Date: 2024-11-26 17:06:27.053774+00:00 + +""" +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "b7f23f6d8aa2" +down_revision = "c9db8bf5091e" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "comp_runs", sa.Column("scheduled", sa.DateTime(timezone=True), nullable=True) + ) + op.add_column( + "comp_runs", sa.Column("processed", sa.DateTime(timezone=True), nullable=True) + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("comp_runs", "processed") + op.drop_column("comp_runs", "scheduled") + # ### end Alembic commands ### diff --git a/packages/postgres-database/src/simcore_postgres_database/migration/versions/e05bdc5b3c7b_add_timezone_comp_runs.py b/packages/postgres-database/src/simcore_postgres_database/migration/versions/e05bdc5b3c7b_add_timezone_comp_runs.py new file mode 100644 index 00000000000..3d3d6c6896a --- /dev/null +++ b/packages/postgres-database/src/simcore_postgres_database/migration/versions/e05bdc5b3c7b_add_timezone_comp_runs.py @@ -0,0 +1,87 @@ +"""add_timezone_comp_runs + +Revision ID: e05bdc5b3c7b +Revises: 7ad64e963e0f +Create Date: 2024-11-27 22:51:21.112336+00:00 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "e05bdc5b3c7b" +down_revision = "7ad64e963e0f" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "comp_runs", + "created", + existing_type=postgresql.TIMESTAMP(), + type_=sa.DateTime(timezone=True), + existing_nullable=False, + existing_server_default="now()", + ) + op.alter_column( + "comp_runs", + "modified", + existing_type=postgresql.TIMESTAMP(), + type_=sa.DateTime(timezone=True), + existing_nullable=False, + existing_server_default="now()", + ) + op.alter_column( + "comp_runs", + "started", + existing_type=postgresql.TIMESTAMP(), + type_=sa.DateTime(timezone=True), + existing_nullable=True, + ) + op.alter_column( + "comp_runs", + "ended", + existing_type=postgresql.TIMESTAMP(), + type_=sa.DateTime(timezone=True), + existing_nullable=True, + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "comp_runs", + "ended", + existing_type=sa.DateTime(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_nullable=True, + ) + op.alter_column( + "comp_runs", + "started", + existing_type=sa.DateTime(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_nullable=True, + ) + op.alter_column( + "comp_runs", + "modified", + existing_type=sa.DateTime(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_nullable=False, + existing_server_default="now()", + ) + op.alter_column( + "comp_runs", + "created", + existing_type=sa.DateTime(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_nullable=False, + existing_server_default="now()", + ) + # ### end Alembic commands ### diff --git a/packages/postgres-database/src/simcore_postgres_database/models/comp_runs.py b/packages/postgres-database/src/simcore_postgres_database/models/comp_runs.py index 3975cb91eee..d92227c07e2 100644 --- a/packages/postgres-database/src/simcore_postgres_database/models/comp_runs.py +++ b/packages/postgres-database/src/simcore_postgres_database/models/comp_runs.py @@ -1,11 +1,11 @@ """ Computational Runs Table """ + import sqlalchemy as sa from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.sql import func -from ._common import RefActions +from ._common import RefActions, column_created_datetime, column_modified_datetime from .base import metadata from .comp_pipeline import StateType @@ -72,31 +72,18 @@ doc="The result of the run entry", ), # dag node id and class - sa.Column( - "created", - sa.DateTime(), - nullable=False, - server_default=func.now(), - doc="When the run entry was created", - ), - sa.Column( - "modified", - sa.DateTime(), - nullable=False, - server_default=func.now(), - onupdate=func.now(), # this will auto-update on modification - doc="When the run entry was last modified", - ), + column_created_datetime(timezone=True), + column_modified_datetime(timezone=True), # utc timestamps for submission/start/end sa.Column( "started", - sa.DateTime, + sa.DateTime(timezone=True), nullable=True, doc="When the run was started", ), sa.Column( "ended", - sa.DateTime, + sa.DateTime(timezone=True), nullable=True, doc="When the run was finished", ), @@ -106,6 +93,18 @@ nullable=True, doc="If filled, when cancellation was requested", ), + sa.Column( + "scheduled", + sa.DateTime(timezone=True), + nullable=True, + doc="last time the pipeline was scheduled to be processed", + ), + sa.Column( + "processed", + sa.DateTime(timezone=True), + nullable=True, + doc="last time the pipeline was actually processed", + ), sa.Column("metadata", JSONB, nullable=True, doc="the run optional metadata"), sa.Column( "use_on_demand_clusters", diff --git a/packages/postgres-database/src/simcore_postgres_database/models/comp_tasks.py b/packages/postgres-database/src/simcore_postgres_database/models/comp_tasks.py index 60bfc3f95c3..af5dc451cc3 100644 --- a/packages/postgres-database/src/simcore_postgres_database/models/comp_tasks.py +++ b/packages/postgres-database/src/simcore_postgres_database/models/comp_tasks.py @@ -1,6 +1,7 @@ """ Computational Tasks Table """ + import enum import sqlalchemy as sa @@ -77,9 +78,15 @@ class NodeClass(enum.Enum): doc="current progress of the task if available", ), # utc timestamps for submission/start/end - sa.Column("submit", sa.DateTime, doc="UTC timestamp for task submission"), - sa.Column("start", sa.DateTime, doc="UTC timestamp when task started"), - sa.Column("end", sa.DateTime, doc="UTC timestamp for task completion"), + sa.Column( + "submit", sa.DateTime(timezone=True), doc="UTC timestamp for task submission" + ), + sa.Column( + "start", sa.DateTime(timezone=True), doc="UTC timestamp when task started" + ), + sa.Column( + "end", sa.DateTime(timezone=True), doc="UTC timestamp for task completion" + ), sa.Column( "last_heartbeat", sa.DateTime(timezone=True), diff --git a/packages/pytest-simcore/src/pytest_simcore/rabbit_service.py b/packages/pytest-simcore/src/pytest_simcore/rabbit_service.py index d8dc38feb09..91873a69d08 100644 --- a/packages/pytest-simcore/src/pytest_simcore/rabbit_service.py +++ b/packages/pytest-simcore/src/pytest_simcore/rabbit_service.py @@ -6,11 +6,12 @@ import asyncio import logging from collections.abc import AsyncIterator, Awaitable, Callable +from contextlib import suppress import aio_pika import pytest import tenacity -from servicelib.rabbitmq import RabbitMQClient, RabbitMQRPCClient +from servicelib.rabbitmq import QueueName, RabbitMQClient, RabbitMQRPCClient from settings_library.rabbit import RabbitSettings from tenacity.before_sleep import before_sleep_log from tenacity.stop import stop_after_attempt @@ -131,3 +132,23 @@ async def _creator(client_name: str, *, heartbeat: int = 60) -> RabbitMQRPCClien yield _creator # cleanup, properly close the clients await asyncio.gather(*(client.close() for client in created_clients)) + + +@pytest.fixture +async def ensure_parametrized_queue_is_empty( + create_rabbitmq_client: Callable[[str], RabbitMQClient], queue_name: QueueName +) -> AsyncIterator[None]: + rabbitmq_client = create_rabbitmq_client("pytest-purger") + + async def _queue_messages_purger() -> None: + with suppress(aio_pika.exceptions.ChannelClosed): + assert rabbitmq_client._channel_pool # noqa: SLF001 + async with rabbitmq_client._channel_pool.acquire() as channel: # noqa: SLF001 + assert isinstance(channel, aio_pika.RobustChannel) + queue = await channel.get_queue(queue_name) + await queue.purge() + + await _queue_messages_purger() + yield + # cleanup + await _queue_messages_purger() diff --git a/packages/service-library/src/servicelib/exception_utils.py b/packages/service-library/src/servicelib/exception_utils.py index 4f44d673838..2de33fd98e6 100644 --- a/packages/service-library/src/servicelib/exception_utils.py +++ b/packages/service-library/src/servicelib/exception_utils.py @@ -1,6 +1,9 @@ +import inspect import logging +from collections.abc import Callable from datetime import datetime -from typing import Final +from functools import wraps +from typing import Any, Final, ParamSpec, TypeVar from pydantic import BaseModel, Field, NonNegativeFloat, PrivateAttr @@ -65,3 +68,36 @@ def else_reset(self) -> None: """error no longer occurs reset tracking""" self._first_exception_skip = None self._failure_counter = 0 + + +P = ParamSpec("P") +R = TypeVar("R") + +F = TypeVar("F", bound=Callable[..., Any]) + + +def silence_exceptions(exceptions: tuple[type[BaseException], ...]) -> Callable[[F], F]: + def _decorator(func_or_coro: F) -> F: + + if inspect.iscoroutinefunction(func_or_coro): + + @wraps(func_or_coro) + async def _async_wrapper(*args, **kwargs) -> Any: + try: + assert inspect.iscoroutinefunction(func_or_coro) # nosec + return await func_or_coro(*args, **kwargs) + except exceptions: + return None + + return _async_wrapper # type: ignore[return-value] # decorators typing is hard + + @wraps(func_or_coro) + def _sync_wrapper(*args, **kwargs) -> Any: + try: + return func_or_coro(*args, **kwargs) + except exceptions: + return None + + return _sync_wrapper # type: ignore[return-value] # decorators typing is hard + + return _decorator diff --git a/packages/service-library/src/servicelib/rabbitmq/_models.py b/packages/service-library/src/servicelib/rabbitmq/_models.py index d713edfdc1d..cd674e526ff 100644 --- a/packages/service-library/src/servicelib/rabbitmq/_models.py +++ b/packages/service-library/src/servicelib/rabbitmq/_models.py @@ -18,11 +18,9 @@ class RabbitMessage(Protocol): - def body(self) -> bytes: - ... + def body(self) -> bytes: ... - def routing_key(self) -> str | None: - ... + def routing_key(self) -> str | None: ... class RPCNamespacedMethodName(ConstrainedStr): diff --git a/packages/service-library/tests/test_exception_utils.py b/packages/service-library/tests/test_exception_utils.py index 299855e8241..a884d3dafb1 100644 --- a/packages/service-library/tests/test_exception_utils.py +++ b/packages/service-library/tests/test_exception_utils.py @@ -4,7 +4,7 @@ import pytest from pydantic import PositiveFloat, PositiveInt -from servicelib.exception_utils import DelayedExceptionHandler +from servicelib.exception_utils import DelayedExceptionHandler, silence_exceptions TOLERANCE: Final[PositiveFloat] = 0.1 SLEEP_FOR: Final[PositiveFloat] = TOLERANCE * 0.1 @@ -49,3 +49,60 @@ def test_workflow_passes() -> None: def test_workflow_raises() -> None: with pytest.raises(TargetException): workflow(stop_raising_after=ITERATIONS + 1) + + +# Define some custom exceptions for testing +class CustomError(Exception): + pass + + +class AnotherCustomError(Exception): + pass + + +@silence_exceptions((CustomError,)) +def sync_function(*, raise_error: bool, raise_another_error: bool) -> str: + if raise_error: + raise CustomError + if raise_another_error: + raise AnotherCustomError + return "Success" + + +@silence_exceptions((CustomError,)) +async def async_function(*, raise_error: bool, raise_another_error: bool) -> str: + if raise_error: + raise CustomError + if raise_another_error: + raise AnotherCustomError + return "Success" + + +def test_sync_function_no_exception(): + result = sync_function(raise_error=False, raise_another_error=False) + assert result == "Success" + + +def test_sync_function_with_exception_is_silenced(): + result = sync_function(raise_error=True, raise_another_error=False) + assert result is None + + +async def test_async_function_no_exception(): + result = await async_function(raise_error=False, raise_another_error=False) + assert result == "Success" + + +async def test_async_function_with_exception_is_silenced(): + result = await async_function(raise_error=True, raise_another_error=False) + assert result is None + + +def test_sync_function_with_different_exception(): + with pytest.raises(AnotherCustomError): + sync_function(raise_error=False, raise_another_error=True) + + +async def test_async_function_with_different_exception(): + with pytest.raises(AnotherCustomError): + await async_function(raise_error=False, raise_another_error=True) diff --git a/services/director-v2/src/simcore_service_director_v2/api/dependencies/scheduler.py b/services/director-v2/src/simcore_service_director_v2/api/dependencies/scheduler.py index aa01af1f34b..e480d204d3b 100644 --- a/services/director-v2/src/simcore_service_director_v2/api/dependencies/scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/api/dependencies/scheduler.py @@ -1,17 +1,11 @@ from typing import Annotated -from fastapi import Depends, FastAPI, Request +from fastapi import Depends, FastAPI from ...core.settings import ComputationalBackendSettings -from ...modules.comp_scheduler import BaseCompScheduler from . import get_app -def get_scheduler(request: Request) -> BaseCompScheduler: - scheduler: BaseCompScheduler = request.app.state.scheduler - return scheduler - - def get_scheduler_settings( app: Annotated[FastAPI, Depends(get_app)] ) -> ComputationalBackendSettings: diff --git a/services/director-v2/src/simcore_service_director_v2/api/routes/computations.py b/services/director-v2/src/simcore_service_director_v2/api/routes/computations.py index 251e35fa638..f25fdf32ece 100644 --- a/services/director-v2/src/simcore_service_director_v2/api/routes/computations.py +++ b/services/director-v2/src/simcore_service_director_v2/api/routes/computations.py @@ -21,7 +21,7 @@ from typing import Annotated, Any, Final import networkx as nx -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, FastAPI, HTTPException from models_library.api_schemas_directorv2.comp_tasks import ( ComputationCreate, ComputationDelete, @@ -63,7 +63,7 @@ from ...models.comp_runs import CompRunsAtDB, ProjectMetadataDict, RunMetadataDict from ...models.comp_tasks import CompTaskAtDB from ...modules.catalog import CatalogClient -from ...modules.comp_scheduler import BaseCompScheduler +from ...modules.comp_scheduler import run_new_pipeline, stop_pipeline from ...modules.db.repositories.clusters import ClustersRepository from ...modules.db.repositories.comp_pipelines import CompPipelinesRepository from ...modules.db.repositories.comp_runs import CompRunsRepository @@ -89,7 +89,6 @@ from ..dependencies.director_v0 import get_director_v0_client from ..dependencies.rabbitmq import rabbitmq_rpc_client from ..dependencies.rut_client import get_rut_client -from ..dependencies.scheduler import get_scheduler from .computations_tasks import analyze_pipeline _PIPELINE_ABORT_TIMEOUT_S: Final[int] = 10 @@ -212,12 +211,12 @@ async def _get_project_node_names( async def _try_start_pipeline( + app: FastAPI, *, project_repo: ProjectsRepository, computation: ComputationCreate, complete_dag: nx.DiGraph, minimal_dag: nx.DiGraph, - scheduler: BaseCompScheduler, project: ProjectAtDB, users_repo: UsersRepository, projects_metadata_repo: ProjectsMetadataRepository, @@ -242,11 +241,12 @@ async def _try_start_pipeline( wallet_id = computation.wallet_info.wallet_id wallet_name = computation.wallet_info.wallet_name - await scheduler.run_new_pipeline( - computation.user_id, - computation.project_id, - computation.cluster_id or DEFAULT_CLUSTER_ID, - RunMetadataDict( + await run_new_pipeline( + app, + user_id=computation.user_id, + project_id=computation.project_id, + cluster_id=computation.cluster_id or DEFAULT_CLUSTER_ID, + run_metadata=RunMetadataDict( node_id_names_map={ NodeID(node_idstr): node_data.label for node_idstr, node_data in project.workbench.items() @@ -313,7 +313,6 @@ async def create_computation( # noqa: PLR0913 # pylint: disable=too-many-positi ProjectsMetadataRepository, Depends(get_repository(ProjectsMetadataRepository)) ], director_client: Annotated[DirectorV0Client, Depends(get_director_v0_client)], - scheduler: Annotated[BaseCompScheduler, Depends(get_scheduler)], catalog_client: Annotated[CatalogClient, Depends(get_catalog_client)], rut_client: Annotated[ResourceUsageTrackerClient, Depends(get_rut_client)], rpc_client: Annotated[RabbitMQRPCClient, Depends(rabbitmq_rpc_client)], @@ -370,11 +369,11 @@ async def create_computation( # noqa: PLR0913 # pylint: disable=too-many-positi if computation.start_pipeline: await _try_start_pipeline( + request.app, project_repo=project_repo, computation=computation, complete_dag=complete_dag, minimal_dag=minimal_computational_dag, - scheduler=scheduler, project=project, users_repo=users_repo, projects_metadata_repo=projects_metadata_repo, @@ -549,7 +548,6 @@ async def stop_computation( comp_runs_repo: Annotated[ CompRunsRepository, Depends(get_repository(CompRunsRepository)) ], - scheduler: Annotated[BaseCompScheduler, Depends(get_scheduler)], ) -> ComputationGet: _logger.debug( "User %s stopping computation for project %s", @@ -575,7 +573,9 @@ async def stop_computation( pipeline_state = utils.get_pipeline_state_from_task_states(filtered_tasks) if utils.is_pipeline_running(pipeline_state): - await scheduler.stop_pipeline(computation_stop.user_id, project_id) + await stop_pipeline( + request.app, user_id=computation_stop.user_id, project_id=project_id + ) # get run details if any last_run: CompRunsAtDB | None = None @@ -615,6 +615,7 @@ async def stop_computation( async def delete_computation( computation_stop: ComputationDelete, project_id: ProjectID, + request: Request, project_repo: Annotated[ ProjectsRepository, Depends(get_repository(ProjectsRepository)) ], @@ -624,7 +625,6 @@ async def delete_computation( comp_tasks_repo: Annotated[ CompTasksRepository, Depends(get_repository(CompTasksRepository)) ], - scheduler: Annotated[BaseCompScheduler, Depends(get_scheduler)], ) -> None: try: # get the project @@ -642,7 +642,9 @@ async def delete_computation( ) # abort the pipeline first try: - await scheduler.stop_pipeline(computation_stop.user_id, project_id) + await stop_pipeline( + request.app, user_id=computation_stop.user_id, project_id=project_id + ) except ComputationalSchedulerError as e: _logger.warning( "Project %s could not be stopped properly.\n reason: %s", @@ -663,9 +665,9 @@ def return_last_value(retry_state: Any) -> Any: before_sleep=before_sleep_log(_logger, logging.INFO), ) async def check_pipeline_stopped() -> bool: - comp_tasks: list[ - CompTaskAtDB - ] = await comp_tasks_repo.list_computational_tasks(project_id) + comp_tasks: list[CompTaskAtDB] = ( + await comp_tasks_repo.list_computational_tasks(project_id) + ) pipeline_state = utils.get_pipeline_state_from_task_states( comp_tasks, ) diff --git a/services/director-v2/src/simcore_service_director_v2/core/errors.py b/services/director-v2/src/simcore_service_director_v2/core/errors.py index 18a5b674ed2..492e75bdeab 100644 --- a/services/director-v2/src/simcore_service_director_v2/core/errors.py +++ b/services/director-v2/src/simcore_service_director_v2/core/errors.py @@ -35,6 +35,10 @@ class ConfigurationError(DirectorError): msg_template: str = "Application misconfiguration: {msg}" +class UserNotFoundError(DirectorError): + msg_template: str = "user {user_id} not found" + + class ProjectNotFoundError(DirectorError): msg_template: str = "project {project_id} not found" diff --git a/services/director-v2/src/simcore_service_director_v2/core/settings.py b/services/director-v2/src/simcore_service_director_v2/core/settings.py index 0ccdce64de1..fe0af49fc5c 100644 --- a/services/director-v2/src/simcore_service_director_v2/core/settings.py +++ b/services/director-v2/src/simcore_service_director_v2/core/settings.py @@ -4,9 +4,10 @@ import datetime from functools import cached_property -from typing import Annotated +from typing import Annotated, cast from common_library.pydantic_validators import validate_numeric_string_as_timedelta +from fastapi import FastAPI from models_library.basic_types import LogLevel, PortInt, VersionTag from models_library.clusters import ( DEFAULT_CLUSTER_ID, @@ -21,6 +22,7 @@ AnyUrl, Field, NonNegativeInt, + PositiveInt, field_validator, ) from servicelib.logging_utils_filtering import LoggerName, MessageSubstring @@ -72,6 +74,10 @@ class ComputationalBackendSettings(BaseCustomSettings): COMPUTATIONAL_BACKEND_ENABLED: bool = Field( default=True, ) + COMPUTATIONAL_BACKEND_SCHEDULING_CONCURRENCY: PositiveInt = Field( + default=50, + description="defines how many pipelines the application can schedule concurrently", + ) COMPUTATIONAL_BACKEND_DASK_CLIENT_ENABLED: bool = Field( default=True, ) @@ -263,3 +269,7 @@ def _validate_loglevel(cls, value: str) -> str: _validate_service_tracking_heartbeat = validate_numeric_string_as_timedelta( "SERVICE_TRACKING_HEARTBEAT" ) + + +def get_application_settings(app: FastAPI) -> AppSettings: + return cast(AppSettings, app.state.settings) diff --git a/services/director-v2/src/simcore_service_director_v2/models/comp_pipelines.py b/services/director-v2/src/simcore_service_director_v2/models/comp_pipelines.py index 5de823d826b..63017ee62e7 100644 --- a/services/director-v2/src/simcore_service_director_v2/models/comp_pipelines.py +++ b/services/director-v2/src/simcore_service_director_v2/models/comp_pipelines.py @@ -17,7 +17,7 @@ class CompPipelineAtDB(BaseModel): @field_validator("state", mode="before") @classmethod - def convert_state_from_state_type_enum_if_needed(cls, v): + def _convert_state_from_state_type_enum_if_needed(cls, v): if isinstance(v, str): # try to convert to a StateType, if it fails the validations will continue # and pydantic will try to convert it to a RunninState later on @@ -29,7 +29,7 @@ def convert_state_from_state_type_enum_if_needed(cls, v): @field_validator("dag_adjacency_list", mode="before") @classmethod - def auto_convert_dag(cls, v): + def _auto_convert_dag(cls, v): # this enforcement is here because the serialization using json is not happy with non str Dict keys, also comparison gets funny if the lists are having sometimes UUIDs or str. # NOTE: this might not be necessary anymore once we have something fully defined return {str(key): [str(n) for n in value] for key, value in v.items()} diff --git a/services/director-v2/src/simcore_service_director_v2/models/comp_runs.py b/services/director-v2/src/simcore_service_director_v2/models/comp_runs.py index 62270380293..f3fedc6a9f9 100644 --- a/services/director-v2/src/simcore_service_director_v2/models/comp_runs.py +++ b/services/director-v2/src/simcore_service_director_v2/models/comp_runs.py @@ -1,5 +1,6 @@ import datetime from contextlib import suppress +from typing import TypeAlias from models_library.clusters import DEFAULT_CLUSTER_ID, ClusterID from models_library.projects import ProjectID @@ -37,12 +38,15 @@ class RunMetadataDict(TypedDict, total=False): project_metadata: ProjectMetadataDict +Iteration: TypeAlias = PositiveInt + + class CompRunsAtDB(BaseModel): run_id: PositiveInt project_uuid: ProjectID user_id: UserID cluster_id: ClusterID | None - iteration: PositiveInt + iteration: Iteration result: RunningState created: datetime.datetime modified: datetime.datetime @@ -51,6 +55,8 @@ class CompRunsAtDB(BaseModel): cancelled: datetime.datetime | None metadata: RunMetadataDict = RunMetadataDict() use_on_demand_clusters: bool + scheduled: datetime.datetime | None + processed: datetime.datetime | None @field_validator("result", mode="before") @classmethod @@ -103,6 +109,8 @@ def convert_null_to_empty_metadata(cls, v): "modified": "2021-03-01T13:07:34.191610", "cancelled": None, "use_on_demand_clusters": False, + "scheduled": None, + "processed": None, }, { "run_id": 432, @@ -117,6 +125,8 @@ def convert_null_to_empty_metadata(cls, v): "modified": "2021-03-01T13:07:34.191610", "cancelled": None, "use_on_demand_clusters": False, + "scheduled": None, + "processed": None, }, { "run_id": 43243, @@ -138,6 +148,8 @@ def convert_null_to_empty_metadata(cls, v): "some-other-metadata-which-is-an-array": [1, 3, 4], }, "use_on_demand_clusters": False, + "scheduled": None, + "processed": None, }, { "run_id": 43243, @@ -153,6 +165,8 @@ def convert_null_to_empty_metadata(cls, v): "cancelled": None, "metadata": None, "use_on_demand_clusters": False, + "scheduled": None, + "processed": None, }, ] }, diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/__init__.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/__init__.py index 2b29acf16c9..cf3370f4da8 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/__init__.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/__init__.py @@ -1,12 +1,13 @@ import logging from collections.abc import Callable, Coroutine -from typing import Any, cast +from typing import Any from fastapi import FastAPI from servicelib.logging_utils import log_context -from . import _scheduler_factory -from ._base_scheduler import BaseCompScheduler +from ._constants import MODULE_NAME_SCHEDULER +from ._manager import run_new_pipeline, setup_manager, shutdown_manager, stop_pipeline +from ._worker import setup_worker, shutdown_worker _logger = logging.getLogger(__name__) @@ -14,24 +15,25 @@ def on_app_startup(app: FastAPI) -> Callable[[], Coroutine[Any, Any, None]]: async def start_scheduler() -> None: with log_context( - _logger, level=logging.INFO, msg="starting computational scheduler" + _logger, level=logging.INFO, msg=f"starting {MODULE_NAME_SCHEDULER}" ): - app.state.scheduler = await _scheduler_factory.create_from_db(app) + await setup_worker(app) + await setup_manager(app) return start_scheduler def on_app_shutdown(app: FastAPI) -> Callable[[], Coroutine[Any, Any, None]]: async def stop_scheduler() -> None: - await get_scheduler(app).shutdown() + with log_context( + _logger, level=logging.INFO, msg=f"stopping {MODULE_NAME_SCHEDULER}" + ): + await shutdown_manager(app) + await shutdown_worker(app) return stop_scheduler -def get_scheduler(app: FastAPI) -> BaseCompScheduler: - return cast(BaseCompScheduler, app.state.scheduler) - - def setup(app: FastAPI): app.add_event_handler("startup", on_app_startup(app)) app.add_event_handler("shutdown", on_app_shutdown(app)) @@ -39,6 +41,6 @@ def setup(app: FastAPI): __all__: tuple[str, ...] = ( "setup", - "BaseCompScheduler", - "get_scheduler", + "run_new_pipeline", + "stop_pipeline", ) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_constants.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_constants.py new file mode 100644 index 00000000000..45efe93f0b0 --- /dev/null +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_constants.py @@ -0,0 +1,7 @@ +import datetime +from typing import Final + +MODULE_NAME_SCHEDULER: Final[str] = "computational-distributed-scheduler" +MODULE_NAME_WORKER: Final[str] = "computational-distributed-worker" +SCHEDULER_INTERVAL: Final[datetime.timedelta] = datetime.timedelta(seconds=5) +MAX_CONCURRENT_PIPELINE_SCHEDULING: Final[int] = 10 diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_manager.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_manager.py new file mode 100644 index 00000000000..281c9fc4630 --- /dev/null +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_manager.py @@ -0,0 +1,173 @@ +import logging +from typing import Final + +import networkx as nx +from aiopg.sa import Engine +from fastapi import FastAPI +from models_library.clusters import ClusterID +from models_library.projects import ProjectID +from models_library.users import UserID +from servicelib.background_task import start_periodic_task, stop_periodic_task +from servicelib.exception_utils import silence_exceptions +from servicelib.logging_utils import log_context +from servicelib.redis import CouldNotAcquireLockError +from servicelib.redis_utils import exclusive +from servicelib.utils import limited_gather + +from ...models.comp_runs import RunMetadataDict +from ...utils.rabbitmq import publish_project_log +from ..db import get_db_engine +from ..db.repositories.comp_pipelines import CompPipelinesRepository +from ..db.repositories.comp_runs import CompRunsRepository +from ..rabbitmq import get_rabbitmq_client +from ._constants import ( + MAX_CONCURRENT_PIPELINE_SCHEDULING, + MODULE_NAME_SCHEDULER, + SCHEDULER_INTERVAL, +) +from ._publisher import request_pipeline_scheduling +from ._utils import SCHEDULED_STATES, get_redis_client_from_app, get_redis_lock_key + +_logger = logging.getLogger(__name__) + + +async def run_new_pipeline( + app: FastAPI, + *, + user_id: UserID, + project_id: ProjectID, + cluster_id: ClusterID, + run_metadata: RunMetadataDict, + use_on_demand_clusters: bool, +) -> None: + """Sets a new pipeline to be scheduled on the computational resources. + Passing cluster_id=0 will use the default cluster. Passing an existing ID will instruct + the scheduler to run the tasks on the defined cluster""" + # ensure the pipeline exists and is populated with something + db_engine = get_db_engine(app) + dag = await _get_pipeline_dag(project_id, db_engine) + if not dag: + _logger.warning( + "project %s has no computational dag defined. not scheduled for a run.", + f"{project_id=}", + ) + return + + new_run = await CompRunsRepository.instance(db_engine).create( + user_id=user_id, + project_id=project_id, + cluster_id=cluster_id, + metadata=run_metadata, + use_on_demand_clusters=use_on_demand_clusters, + ) + + rabbitmq_client = get_rabbitmq_client(app) + await request_pipeline_scheduling( + rabbitmq_client, + db_engine, + user_id=new_run.user_id, + project_id=new_run.project_uuid, + iteration=new_run.iteration, + ) + await publish_project_log( + rabbitmq_client, + user_id, + project_id, + log=f"Project pipeline scheduled using {'on-demand clusters' if use_on_demand_clusters else 'pre-defined clusters'}, starting soon...", + log_level=logging.INFO, + ) + + +async def stop_pipeline( + app: FastAPI, + *, + user_id: UserID, + project_id: ProjectID, + iteration: int | None = None, +) -> None: + db_engine = get_db_engine(app) + comp_run = await CompRunsRepository.instance(db_engine).get( + user_id, project_id, iteration + ) + + # mark the scheduled pipeline for stopping + updated_comp_run = await CompRunsRepository.instance( + db_engine + ).mark_for_cancellation( + user_id=user_id, project_id=project_id, iteration=comp_run.iteration + ) + if updated_comp_run: + # ensure the scheduler starts right away + rabbitmq_client = get_rabbitmq_client(app) + await request_pipeline_scheduling( + rabbitmq_client, + db_engine, + user_id=updated_comp_run.user_id, + project_id=updated_comp_run.project_uuid, + iteration=updated_comp_run.iteration, + ) + + +async def _get_pipeline_dag(project_id: ProjectID, db_engine: Engine) -> nx.DiGraph: + comp_pipeline_repo = CompPipelinesRepository.instance(db_engine) + pipeline_at_db = await comp_pipeline_repo.get_pipeline(project_id) + return pipeline_at_db.get_graph() + + +_LOST_TASKS_FACTOR: Final[int] = 10 + + +@exclusive( + get_redis_client_from_app, + lock_key=get_redis_lock_key(MODULE_NAME_SCHEDULER, unique_lock_key_builder=None), +) +async def schedule_all_pipelines(app: FastAPI) -> None: + with log_context(_logger, logging.DEBUG, msg="scheduling pipelines"): + db_engine = get_db_engine(app) + runs_to_schedule = await CompRunsRepository.instance(db_engine).list( + filter_by_state=SCHEDULED_STATES, + never_scheduled=True, + processed_since=SCHEDULER_INTERVAL, + ) + possibly_lost_scheduled_pipelines = await CompRunsRepository.instance( + db_engine + ).list( + filter_by_state=SCHEDULED_STATES, + scheduled_since=SCHEDULER_INTERVAL * _LOST_TASKS_FACTOR, + ) + if possibly_lost_scheduled_pipelines: + _logger.error( + "found %d lost pipelines, they will be re-scheduled now", + len(possibly_lost_scheduled_pipelines), + ) + + rabbitmq_client = get_rabbitmq_client(app) + with log_context(_logger, logging.DEBUG, msg="distributing pipelines"): + await limited_gather( + *( + request_pipeline_scheduling( + rabbitmq_client, + db_engine, + user_id=run.user_id, + project_id=run.project_uuid, + iteration=run.iteration, + ) + for run in runs_to_schedule + possibly_lost_scheduled_pipelines + ), + limit=MAX_CONCURRENT_PIPELINE_SCHEDULING, + ) + if runs_to_schedule: + _logger.debug("distributed %d pipelines", len(runs_to_schedule)) + + +async def setup_manager(app: FastAPI) -> None: + app.state.scheduler_manager = start_periodic_task( + silence_exceptions((CouldNotAcquireLockError,))(schedule_all_pipelines), + interval=SCHEDULER_INTERVAL, + task_name=MODULE_NAME_SCHEDULER, + app=app, + ) + + +async def shutdown_manager(app: FastAPI) -> None: + await stop_periodic_task(app.state.scheduler_manager) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_models.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_models.py new file mode 100644 index 00000000000..28dca04dc53 --- /dev/null +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_models.py @@ -0,0 +1,19 @@ +from typing import Literal + +from models_library.projects import ProjectID +from models_library.rabbitmq_messages import RabbitMessageBase +from models_library.users import UserID + +from ...models.comp_runs import Iteration + + +class SchedulePipelineRabbitMessage(RabbitMessageBase): + channel_name: Literal[ + "simcore.services.director-v2.scheduling" + ] = "simcore.services.director-v2.scheduling" + user_id: UserID + project_id: ProjectID + iteration: Iteration + + def routing_key(self) -> str | None: # pylint: disable=no-self-use # abstract + return None diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_publisher.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_publisher.py new file mode 100644 index 00000000000..1c7ea23ac43 --- /dev/null +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_publisher.py @@ -0,0 +1,32 @@ +from aiopg.sa import Engine +from models_library.projects import ProjectID +from models_library.users import UserID +from servicelib.rabbitmq import RabbitMQClient + +from ...models.comp_runs import Iteration +from ..db.repositories.comp_runs import CompRunsRepository +from ._models import SchedulePipelineRabbitMessage + + +async def request_pipeline_scheduling( + rabbitmq_client: RabbitMQClient, + db_engine: Engine, + *, + user_id: UserID, + project_id: ProjectID, + iteration: Iteration +) -> None: + # NOTE: we should use the transaction and the asyncpg engine here to ensure 100% consistency + # https://github.com/ITISFoundation/osparc-simcore/issues/6818 + # async with transaction_context(get_asyncpg_engine(app)) as connection: + await rabbitmq_client.publish( + SchedulePipelineRabbitMessage.get_channel_name(), + SchedulePipelineRabbitMessage( + user_id=user_id, + project_id=project_id, + iteration=iteration, + ), + ) + await CompRunsRepository.instance(db_engine).mark_for_scheduling( + user_id=user_id, project_id=project_id, iteration=iteration + ) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py similarity index 75% rename from services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py rename to services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py index 2d663aec9a1..a16821d0fba 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py @@ -12,32 +12,26 @@ """ import asyncio -import contextlib import datetime -import functools import logging from abc import ABC, abstractmethod from collections.abc import Callable -from dataclasses import dataclass, field -from typing import Final, TypeAlias +from dataclasses import dataclass +from typing import Final import arrow import networkx as nx from aiopg.sa.engine import Engine -from models_library.clusters import ClusterID from models_library.projects import ProjectID from models_library.projects_nodes_io import NodeID, NodeIDStr from models_library.projects_state import RunningState -from models_library.services import ServiceKey, ServiceType, ServiceVersion +from models_library.services import ServiceType from models_library.users import UserID from networkx.classes.reportviews import InDegreeView -from pydantic import PositiveInt -from servicelib.background_task import start_periodic_task, stop_periodic_task from servicelib.common_headers import UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE -from servicelib.logging_utils import log_context +from servicelib.logging_utils import log_catch, log_context from servicelib.rabbitmq import RabbitMQClient, RabbitMQRPCClient -from servicelib.redis import CouldNotAcquireLockError, RedisClientSDK -from servicelib.redis_utils import exclusive +from servicelib.redis import RedisClientSDK from ...constants import UNDEFINED_STR_METADATA from ...core.errors import ( @@ -45,7 +39,6 @@ ComputationalBackendNotConnectedError, ComputationalBackendOnDemandNotReadyError, ComputationalSchedulerChangedError, - ComputationalSchedulerError, DaskClientAcquisisitonError, InvalidPipelineError, PipelineNotFoundError, @@ -53,19 +46,8 @@ ) from ...core.settings import ComputationalBackendSettings from ...models.comp_pipelines import CompPipelineAtDB -from ...models.comp_runs import CompRunsAtDB, RunMetadataDict +from ...models.comp_runs import CompRunsAtDB, Iteration, RunMetadataDict from ...models.comp_tasks import CompTaskAtDB -from ...utils.comp_scheduler import ( - COMPLETED_STATES, - PROCESSING_STATES, - RUNNING_STATES, - SCHEDULED_STATES, - TASK_TO_START_STATES, - WAITING_FOR_START_STATES, - Iteration, - create_service_resources_from_task, - get_resource_tracking_run_id, -) from ...utils.computations import get_pipeline_state_from_task_states from ...utils.rabbitmq import ( publish_project_log, @@ -76,6 +58,16 @@ from ..db.repositories.comp_pipelines import CompPipelinesRepository from ..db.repositories.comp_runs import CompRunsRepository from ..db.repositories.comp_tasks import CompTasksRepository +from ._publisher import request_pipeline_scheduling +from ._utils import ( + COMPLETED_STATES, + PROCESSING_STATES, + RUNNING_STATES, + TASK_TO_START_STATES, + WAITING_FOR_START_STATES, + create_service_resources_from_task, + get_resource_tracking_run_id, +) _logger = logging.getLogger(__name__) @@ -83,13 +75,36 @@ _Previous = CompTaskAtDB _Current = CompTaskAtDB _MAX_WAITING_FOR_CLUSTER_TIMEOUT_IN_MIN: Final[int] = 10 -_SCHEDULER_INTERVAL: Final[datetime.timedelta] = datetime.timedelta(seconds=5) -_TASK_NAME_TEMPLATE: Final[ - str -] = "computational-scheduler-{user_id}:{project_id}:{iteration}" -PipelineSchedulingTask: TypeAlias = asyncio.Task -PipelineSchedulingWakeUpEvent: TypeAlias = asyncio.Event + +def _auto_schedule_callback( + loop: asyncio.AbstractEventLoop, + db_engine: Engine, + rabbit_mq_client: RabbitMQClient, + *, + user_id: UserID, + project_id: ProjectID, + iteration: Iteration, +) -> Callable[[], None]: + """this function is called via Dask-backend from a separate thread. + Therefore the need to use run_coroutine_threadsafe to request a new + pipeline scheduling""" + + def _cb() -> None: + async def _async_cb() -> None: + await request_pipeline_scheduling( + rabbit_mq_client, + db_engine, + user_id=user_id, + project_id=project_id, + iteration=iteration, + ) + + future = asyncio.run_coroutine_threadsafe(_async_cb(), loop) + with log_catch(_logger, reraise=False): + future.result(timeout=10) + + return _cb @dataclass(frozen=True, slots=True) @@ -141,15 +156,6 @@ async def _triage_changed_tasks( ) -@dataclass(kw_only=True) -class ScheduledPipelineParams: - scheduler_task: asyncio.Task - scheduler_waker: asyncio.Event - - def wake_up(self) -> None: - self.scheduler_waker.set() - - @dataclass class BaseCompScheduler(ABC): db_engine: Engine @@ -159,169 +165,6 @@ class BaseCompScheduler(ABC): service_runtime_heartbeat_interval: datetime.timedelta redis_client: RedisClientSDK - # NOTE: this is a trick to be able to inheritate from the class - _scheduled_pipelines: dict[ - tuple[UserID, ProjectID, Iteration], ScheduledPipelineParams - ] = field(default_factory=dict, init=False) - - def __post_init__(self) -> None: - self._scheduled_pipelines = {} - - async def restore_scheduling_from_db(self) -> None: - # get currently scheduled runs - comp_runs = await CompRunsRepository.instance(self.db_engine).list( - filter_by_state=SCHEDULED_STATES - ) - - for run in comp_runs: - task, wake_up_event = self._start_scheduling( - run.user_id, run.project_uuid, run.iteration - ) - self._scheduled_pipelines |= { - ( - run.user_id, - run.project_uuid, - run.iteration, - ): ScheduledPipelineParams( - scheduler_task=task, scheduler_waker=wake_up_event - ) - } - - async def run_new_pipeline( - self, - user_id: UserID, - project_id: ProjectID, - cluster_id: ClusterID, - run_metadata: RunMetadataDict, - *, - use_on_demand_clusters: bool, - ) -> None: - """Sets a new pipeline to be scheduled on the computational resources. - Passing cluster_id=0 will use the default cluster. Passing an existing ID will instruct - the scheduler to run the tasks on the defined cluster""" - # ensure the pipeline exists and is populated with something - dag = await self._get_pipeline_dag(project_id) - if not dag: - _logger.warning( - "project %s has no computational dag defined. not scheduled for a run.", - f"{project_id=}", - ) - return - - runs_repo = CompRunsRepository.instance(self.db_engine) - new_run = await runs_repo.create( - user_id=user_id, - project_id=project_id, - cluster_id=cluster_id, - metadata=run_metadata, - use_on_demand_clusters=use_on_demand_clusters, - ) - task, wake_up_event = self._start_scheduling( - user_id, project_id, new_run.iteration - ) - self._scheduled_pipelines[ - (user_id, project_id, new_run.iteration) - ] = ScheduledPipelineParams(scheduler_task=task, scheduler_waker=wake_up_event) - await publish_project_log( - self.rabbitmq_client, - user_id, - project_id, - log=f"Project pipeline scheduled using {'on-demand clusters' if use_on_demand_clusters else 'pre-defined clusters'}, starting soon...", - log_level=logging.INFO, - ) - - async def stop_pipeline( - self, user_id: UserID, project_id: ProjectID, iteration: int | None = None - ) -> None: - if iteration is None: - # if no iteration given find the latest one in the list - possible_iterations = { - it - for u_id, p_id, it in self._scheduled_pipelines - if u_id == user_id and p_id == project_id - } - if not possible_iterations: - msg = f"There are no pipeline scheduled for {user_id}:{project_id}" - raise ComputationalSchedulerError(msg=msg) - current_max_iteration = max(possible_iterations) - selected_iteration = current_max_iteration - else: - selected_iteration = iteration - - # mark the scheduled pipeline for stopping - updated_comp_run = await CompRunsRepository.instance( - self.db_engine - ).mark_for_cancellation( - user_id=user_id, project_id=project_id, iteration=selected_iteration - ) - if updated_comp_run: - assert updated_comp_run.cancelled is not None # nosec - # ensure the scheduler starts right away - self._scheduled_pipelines[ - (user_id, project_id, selected_iteration) - ].wake_up() - - async def shutdown(self) -> None: - # cancel all current scheduling processes - await asyncio.gather( - *( - stop_periodic_task(p.scheduler_task, timeout=3) - for p in self._scheduled_pipelines.values() - if p.scheduler_task - ), - return_exceptions=True, - ) - - def _get_last_iteration(self, user_id: UserID, project_id: ProjectID) -> Iteration: - # if no iteration given find the latest one in the list - possible_iterations = { - it - for u_id, p_id, it in self._scheduled_pipelines - if u_id == user_id and p_id == project_id - } - if not possible_iterations: - msg = f"There are no pipeline scheduled for {user_id}:{project_id}" - raise ComputationalSchedulerError(msg=msg) - return max(possible_iterations) - - def _start_scheduling( - self, - user_id: UserID, - project_id: ProjectID, - iteration: Iteration, - ) -> tuple[PipelineSchedulingTask, PipelineSchedulingWakeUpEvent]: - async def _exclusive_safe_schedule_pipeline( - *, - user_id: UserID, - project_id: ProjectID, - iteration: Iteration, - wake_up_callback: Callable[[], None], - ) -> None: - with contextlib.suppress(CouldNotAcquireLockError): - await self._schedule_pipeline( - user_id=user_id, - project_id=project_id, - iteration=iteration, - wake_up_callback=wake_up_callback, - ) - - pipeline_wake_up_event = asyncio.Event() - pipeline_task = start_periodic_task( - functools.partial( - _exclusive_safe_schedule_pipeline, - user_id=user_id, - project_id=project_id, - iteration=iteration, - wake_up_callback=pipeline_wake_up_event.set, - ), - interval=_SCHEDULER_INTERVAL, - task_name=_TASK_NAME_TEMPLATE.format( - user_id=user_id, project_id=project_id, iteration=iteration - ), - early_wake_up_event=pipeline_wake_up_event, - ) - return pipeline_task, pipeline_wake_up_event - async def _get_pipeline_dag(self, project_id: ProjectID) -> nx.DiGraph: comp_pipeline_repo = CompPipelinesRepository.instance(self.db_engine) pipeline_at_db: CompPipelineAtDB = await comp_pipeline_repo.get_pipeline( @@ -336,7 +179,7 @@ async def _get_pipeline_tasks( ) -> dict[NodeIDStr, CompTaskAtDB]: comp_tasks_repo = CompTasksRepository.instance(self.db_engine) pipeline_comp_tasks: dict[NodeIDStr, CompTaskAtDB] = { - NodeIDStr(f"{t.node_id}"): t + f"{t.node_id}": t for t in await comp_tasks_repo.list_computational_tasks(project_id) if (f"{t.node_id}" in list(pipeline_dag.nodes())) } @@ -352,7 +195,7 @@ async def _update_run_result_from_tasks( self, user_id: UserID, project_id: ProjectID, - iteration: PositiveInt, + iteration: Iteration, pipeline_tasks: dict[NodeIDStr, CompTaskAtDB], ) -> RunningState: pipeline_state_from_tasks: RunningState = get_pipeline_state_from_task_states( @@ -372,7 +215,7 @@ async def _set_run_result( self, user_id: UserID, project_id: ProjectID, - iteration: PositiveInt, + iteration: Iteration, run_result: RunningState, ) -> None: comp_runs_repo = CompRunsRepository.instance(self.db_engine) @@ -384,6 +227,18 @@ async def _set_run_result( final_state=(run_result in COMPLETED_STATES), ) + async def _set_schedule_done( + self, + user_id: UserID, + project_id: ProjectID, + iteration: Iteration, + ) -> None: + await CompRunsRepository.instance(self.db_engine).mark_as_processed( + user_id=user_id, + project_id=project_id, + iteration=iteration, + ) + async def _set_states_following_failed_to_aborted( self, project_id: ProjectID, dag: nx.DiGraph ) -> dict[NodeIDStr, CompTaskAtDB]: @@ -394,9 +249,9 @@ async def _set_states_following_failed_to_aborted( for task in tasks.values(): if task.state == RunningState.FAILED: node_ids_to_set_as_aborted.update(nx.bfs_tree(dag, f"{task.node_id}")) - node_ids_to_set_as_aborted.remove(NodeIDStr(f"{task.node_id}")) + node_ids_to_set_as_aborted.remove(f"{task.node_id}") for node_id in node_ids_to_set_as_aborted: - tasks[NodeIDStr(f"{node_id}")].state = RunningState.ABORTED + tasks[f"{node_id}"].state = RunningState.ABORTED if node_ids_to_set_as_aborted: # update the current states back in DB comp_tasks_repo = CompTasksRepository.instance(self.db_engine) @@ -544,8 +399,8 @@ async def _process_started_tasks( root_parent_node_id=run_metadata.get("project_metadata", {}).get( "root_parent_node_id" ), - service_key=ServiceKey(t.image.name), - service_version=ServiceVersion(t.image.tag), + service_key=t.image.name, + service_version=t.image.tag, service_type=ServiceType.COMPUTATIONAL, service_resources=create_service_resources_from_task(t), service_additional_metadata={}, @@ -654,19 +509,19 @@ async def _start_tasks( comp_run: CompRunsAtDB, wake_up_callback: Callable[[], None], ) -> None: - ... + """start tasks in the 3rd party backend""" @abstractmethod async def _get_tasks_status( self, user_id: UserID, tasks: list[CompTaskAtDB], comp_run: CompRunsAtDB ) -> list[RunningState]: - ... + """returns tasks status from the 3rd party backend""" @abstractmethod async def _stop_tasks( self, user_id: UserID, tasks: list[CompTaskAtDB], comp_run: CompRunsAtDB ) -> None: - ... + """stop tasks in the 3rd party backend""" @abstractmethod async def _process_completed_tasks( @@ -676,32 +531,20 @@ async def _process_completed_tasks( iteration: Iteration, comp_run: CompRunsAtDB, ) -> None: - ... - - @staticmethod - def _build_exclusive_lock_key(*args, **kwargs) -> str: - assert args # nosec - return f"{kwargs['user_id']}:{kwargs['project_id']}:{kwargs['iteration']}" - - @staticmethod - def _redis_client_getter(*args, **kwargs) -> RedisClientSDK: - assert kwargs # nosec - zelf = args[0] - assert isinstance(zelf, BaseCompScheduler) # nosec - return zelf.redis_client - - @exclusive( - redis=_redis_client_getter, - lock_key=_build_exclusive_lock_key, - ) - async def _schedule_pipeline( + """process tasks from the 3rd party backend""" + + async def apply( self, *, user_id: UserID, project_id: ProjectID, - iteration: PositiveInt, - wake_up_callback: Callable[[], None], + iteration: Iteration, ) -> None: + """schedules a pipeline for a given user, project and iteration. + + Arguments: + wake_up_callback -- a callback function that is called in a separate thread everytime a pipeline node is completed + """ with log_context( _logger, level=logging.INFO, @@ -734,7 +577,14 @@ async def _schedule_pipeline( comp_tasks=comp_tasks, dag=dag, comp_run=comp_run, - wake_up_callback=wake_up_callback, + wake_up_callback=_auto_schedule_callback( + asyncio.get_running_loop(), + self.db_engine, + self.rabbitmq_client, + user_id=user_id, + project_id=project_id, + iteration=iteration, + ), ) # 4. timeout if waiting for cluster has been there for more than X minutes comp_tasks = await self._timeout_if_waiting_for_cluster_too_long( @@ -753,17 +603,11 @@ async def _schedule_pipeline( # 7. Are we done scheduling that pipeline? if not dag.nodes() or pipeline_result in COMPLETED_STATES: # there is nothing left, the run is completed, we're done here - self._scheduled_pipelines.pop( - (user_id, project_id, iteration), None - ) _logger.info( "pipeline %s scheduling completed with result %s", f"{project_id=}", f"{pipeline_result=}", ) - current_task = asyncio.current_task() - assert current_task is not None # nosec - current_task.cancel() except PipelineNotFoundError: _logger.warning( "pipeline %s does not exist in comp_pipeline table, it will be removed from scheduler", @@ -772,7 +616,6 @@ async def _schedule_pipeline( await self._set_run_result( user_id, project_id, iteration, RunningState.ABORTED ) - self._scheduled_pipelines.pop((user_id, project_id, iteration), None) except InvalidPipelineError as exc: _logger.warning( "pipeline %s appears to be misconfigured, it will be removed from scheduler. Please check pipeline:\n%s", @@ -782,7 +625,6 @@ async def _schedule_pipeline( await self._set_run_result( user_id, project_id, iteration, RunningState.ABORTED ) - self._scheduled_pipelines.pop((user_id, project_id, iteration), None) except (DaskClientAcquisisitonError, ClustersKeeperNotAvailableError): _logger.exception( "Unexpected error while connecting with computational backend, aborting pipeline" @@ -799,9 +641,10 @@ async def _schedule_pipeline( await self._set_run_result( user_id, project_id, iteration, RunningState.FAILED ) - self._scheduled_pipelines.pop((user_id, project_id, iteration), None) except ComputationalBackendNotConnectedError: _logger.exception("Computational backend is not connected!") + finally: + await self._set_schedule_done(user_id, project_id, iteration) async def _schedule_tasks_to_stop( self, @@ -846,9 +689,9 @@ async def _schedule_tasks_to_start( # noqa: C901 # get the tasks to start tasks_ready_to_start: dict[NodeID, CompTaskAtDB] = { - node_id: comp_tasks[NodeIDStr(f"{node_id}")] + node_id: comp_tasks[f"{node_id}"] for node_id in next_task_node_ids - if comp_tasks[NodeIDStr(f"{node_id}")].state in TASK_TO_START_STATES + if comp_tasks[f"{node_id}"].state in TASK_TO_START_STATES } if not tasks_ready_to_start: @@ -879,9 +722,7 @@ async def _schedule_tasks_to_start( # noqa: C901 RunningState.WAITING_FOR_CLUSTER, ) for task in tasks_ready_to_start: - comp_tasks[ - NodeIDStr(f"{task}") - ].state = RunningState.WAITING_FOR_CLUSTER + comp_tasks[f"{task}"].state = RunningState.WAITING_FOR_CLUSTER except ComputationalBackendOnDemandNotReadyError as exc: _logger.info( @@ -903,9 +744,7 @@ async def _schedule_tasks_to_start( # noqa: C901 RunningState.WAITING_FOR_CLUSTER, ) for task in tasks_ready_to_start: - comp_tasks[ - NodeIDStr(f"{task}") - ].state = RunningState.WAITING_FOR_CLUSTER + comp_tasks[f"{task}"].state = RunningState.WAITING_FOR_CLUSTER except ClustersKeeperNotAvailableError: _logger.exception("Unexpected error while starting tasks:") await publish_project_log( @@ -926,7 +765,7 @@ async def _schedule_tasks_to_start( # noqa: C901 optional_stopped=arrow.utcnow().datetime, ) for task in tasks_ready_to_start: - comp_tasks[NodeIDStr(f"{task}")].state = RunningState.FAILED + comp_tasks[f"{task}"].state = RunningState.FAILED raise except TaskSchedulingError as exc: _logger.exception( @@ -944,7 +783,7 @@ async def _schedule_tasks_to_start( # noqa: C901 optional_progress=1.0, optional_stopped=arrow.utcnow().datetime, ) - comp_tasks[NodeIDStr(f"{exc.node_id}")].state = RunningState.FAILED + comp_tasks[f"{exc.node_id}"].state = RunningState.FAILED except Exception: _logger.exception( "Unexpected error for %s with %s on %s happened when scheduling %s:", @@ -963,7 +802,7 @@ async def _schedule_tasks_to_start( # noqa: C901 optional_stopped=arrow.utcnow().datetime, ) for task in tasks_ready_to_start: - comp_tasks[NodeIDStr(f"{task}")].state = RunningState.FAILED + comp_tasks[f"{task}"].state = RunningState.FAILED raise return comp_tasks diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_dask_scheduler.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py similarity index 98% rename from services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_dask_scheduler.py rename to services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py index 2fdf7acd2e9..adc67853686 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_dask_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_dask.py @@ -1,10 +1,9 @@ import asyncio import contextlib import logging -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Callable from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import Callable import arrow from dask_task_models_library.container_tasks.errors import TaskCancelledError @@ -28,10 +27,9 @@ ComputationalBackendOnDemandNotReadyError, TaskSchedulingError, ) -from ...models.comp_runs import CompRunsAtDB, RunMetadataDict +from ...models.comp_runs import CompRunsAtDB, Iteration, RunMetadataDict from ...models.comp_tasks import CompTaskAtDB from ...models.dask_subsystem import DaskClientTaskState -from ...utils.comp_scheduler import Iteration, get_resource_tracking_run_id from ...utils.dask import ( clean_task_output_and_log_files_if_invalid, parse_dask_job_id, @@ -50,7 +48,8 @@ from ..db.repositories.clusters import ClustersRepository from ..db.repositories.comp_runs import CompRunsRepository from ..db.repositories.comp_tasks import CompTasksRepository -from ._base_scheduler import BaseCompScheduler +from ._scheduler_base import BaseCompScheduler +from ._utils import get_resource_tracking_run_id _logger = logging.getLogger(__name__) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py index 524dfc7e8ad..edda456f303 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py @@ -4,35 +4,28 @@ from servicelib.logging_utils import log_context from settings_library.redis import RedisDatabase -from ...core.errors import ConfigurationError from ...core.settings import AppSettings from ..dask_clients_pool import DaskClientsPool +from ..db import get_db_engine from ..rabbitmq import get_rabbitmq_client, get_rabbitmq_rpc_client from ..redis import get_redis_client_manager -from ._base_scheduler import BaseCompScheduler -from ._dask_scheduler import DaskScheduler +from ._scheduler_base import BaseCompScheduler +from ._scheduler_dask import DaskScheduler _logger = logging.getLogger(__name__) -async def create_from_db(app: FastAPI) -> BaseCompScheduler: - if not hasattr(app.state, "engine"): - msg = "Database connection is missing. Please check application configuration." - raise ConfigurationError(msg=msg) - db_engine = app.state.engine - +def create_scheduler(app: FastAPI) -> BaseCompScheduler: with log_context( _logger, logging.INFO, msg="Creating Dask-based computational scheduler" ): app_settings: AppSettings = app.state.settings - scheduler = DaskScheduler( + return DaskScheduler( settings=app_settings.DIRECTOR_V2_COMPUTATIONAL_BACKEND, dask_clients_pool=DaskClientsPool.instance(app), rabbitmq_client=get_rabbitmq_client(app), rabbitmq_rpc_client=get_rabbitmq_rpc_client(app), redis_client=get_redis_client_manager(app).client(RedisDatabase.LOCKS), - db_engine=db_engine, + db_engine=get_db_engine(app), service_runtime_heartbeat_interval=app_settings.SERVICE_TRACKING_HEARTBEAT, ) - await scheduler.restore_scheduling_from_db() - return scheduler diff --git a/services/director-v2/src/simcore_service_director_v2/utils/comp_scheduler.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_utils.py similarity index 61% rename from services/director-v2/src/simcore_service_director_v2/utils/comp_scheduler.py rename to services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_utils.py index 15f3481da10..0458b159811 100644 --- a/services/director-v2/src/simcore_service_director_v2/utils/comp_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_utils.py @@ -1,5 +1,6 @@ -from typing import TypeAlias +from typing import Callable +from fastapi import FastAPI from models_library.docker import DockerGenericTag from models_library.projects import ProjectID from models_library.projects_nodes_io import NodeID @@ -10,9 +11,12 @@ ServiceResourcesDictHelpers, ) from models_library.users import UserID -from pydantic import PositiveInt +from servicelib.redis import RedisClientSDK +from settings_library.redis import RedisDatabase +from simcore_service_director_v2.modules.redis import get_redis_client_manager -from ..models.comp_tasks import CompTaskAtDB +from ...models.comp_runs import Iteration +from ...models.comp_tasks import CompTaskAtDB SCHEDULED_STATES: set[RunningState] = { RunningState.PUBLISHED, @@ -51,9 +55,6 @@ } -Iteration: TypeAlias = PositiveInt - - def get_resource_tracking_run_id( user_id: UserID, project_id: ProjectID, node_id: NodeID, iteration: Iteration ) -> str: @@ -73,3 +74,34 @@ def create_service_resources_from_task(task: CompTaskAtDB) -> ServiceResourcesDi }, [task.image.boot_mode], ) + + +def _get_app_from_args(*args, **kwargs) -> FastAPI: + assert kwargs is not None # nosec + if args: + app = args[0] + else: + assert "app" in kwargs # nosec + app = kwargs["app"] + assert isinstance(app, FastAPI) # nosec + return app + + +def get_redis_client_from_app(*args, **kwargs) -> RedisClientSDK: + app = _get_app_from_args(*args, **kwargs) + return get_redis_client_manager(app).client(RedisDatabase.LOCKS) + + +def get_redis_lock_key( + suffix: str, *, unique_lock_key_builder: Callable[..., str] | None +) -> Callable[..., str]: + def _(*args, **kwargs) -> str: + app = _get_app_from_args(*args, **kwargs) + unique_lock_part = ( + unique_lock_key_builder(*args, **kwargs) if unique_lock_key_builder else "" + ) + if unique_lock_part: + unique_lock_part = f"-{unique_lock_part}" + return f"{app.title}-{suffix}{unique_lock_part}" + + return _ diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_worker.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_worker.py new file mode 100644 index 00000000000..397b68db0c9 --- /dev/null +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_worker.py @@ -0,0 +1,94 @@ +import asyncio +import contextlib +import functools +import logging +from typing import cast + +from fastapi import FastAPI +from models_library.projects import ProjectID +from models_library.users import UserID +from servicelib.logging_utils import log_context +from servicelib.redis import CouldNotAcquireLockError +from servicelib.redis_utils import exclusive + +from ...core.settings import get_application_settings +from ...models.comp_runs import Iteration +from ..rabbitmq import get_rabbitmq_client +from ._constants import MODULE_NAME_WORKER +from ._models import SchedulePipelineRabbitMessage +from ._scheduler_base import BaseCompScheduler +from ._scheduler_factory import create_scheduler +from ._utils import get_redis_client_from_app, get_redis_lock_key + +_logger = logging.getLogger(__name__) + + +def _get_scheduler_worker(app: FastAPI) -> BaseCompScheduler: + return cast(BaseCompScheduler, app.state.scheduler_worker) + + +def _unique_key_builder( + _app, user_id: UserID, project_id: ProjectID, iteration: Iteration +) -> str: + return f"{user_id}:{project_id}:{iteration}" + + +@exclusive( + redis=get_redis_client_from_app, + lock_key=get_redis_lock_key( + MODULE_NAME_WORKER, unique_lock_key_builder=_unique_key_builder + ), +) +async def _exclusively_schedule_pipeline( + app: FastAPI, *, user_id: UserID, project_id: ProjectID, iteration: Iteration +) -> None: + await _get_scheduler_worker(app).apply( + user_id=user_id, + project_id=project_id, + iteration=iteration, + ) + + +async def _handle_apply_distributed_schedule(app: FastAPI, data: bytes) -> bool: + + with log_context(_logger, logging.DEBUG, msg="handling scheduling"): + to_schedule_pipeline = SchedulePipelineRabbitMessage.model_validate_json(data) + with contextlib.suppress(CouldNotAcquireLockError): + await _exclusively_schedule_pipeline( + app, + user_id=to_schedule_pipeline.user_id, + project_id=to_schedule_pipeline.project_id, + iteration=to_schedule_pipeline.iteration, + ) + return True + + +async def setup_worker(app: FastAPI) -> None: + app_settings = get_application_settings(app) + rabbitmq_client = get_rabbitmq_client(app) + app.state.scheduler_worker_consumers = await asyncio.gather( + *( + rabbitmq_client.subscribe( + SchedulePipelineRabbitMessage.get_channel_name(), + functools.partial(_handle_apply_distributed_schedule, app), + exclusive_queue=False, + ) + for _ in range( + app_settings.DIRECTOR_V2_COMPUTATIONAL_BACKEND.COMPUTATIONAL_BACKEND_SCHEDULING_CONCURRENCY + ) + ) + ) + + app.state.scheduler_worker = create_scheduler(app) + + +async def shutdown_worker(app: FastAPI) -> None: + assert app.state.scheduler_worker # nosec + rabbitmq_client = get_rabbitmq_client(app) + await asyncio.gather( + *( + rabbitmq_client.unsubscribe_consumer(*consumer) + for consumer in app.state.scheduler_worker_consumers + ), + return_exceptions=False, + ) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/db/__init__.py b/services/director-v2/src/simcore_service_director_v2/modules/db/__init__.py index 7a5826d1aa3..a112ae63d46 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/db/__init__.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/db/__init__.py @@ -1,3 +1,6 @@ +from typing import cast + +from aiopg.sa import Engine from fastapi import FastAPI from settings_library.postgres import PostgresSettings @@ -22,4 +25,11 @@ async def on_shutdown() -> None: app.add_event_handler("shutdown", on_shutdown) -__all__: tuple[str, ...] = ("get_asyncpg_engine",) +def get_db_engine(app: FastAPI) -> Engine: + return cast(Engine, app.state.engine) + + +__all__: tuple[str, ...] = ( + "get_asyncpg_engine", + "get_db_engine", +) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py index 9ce28bcda8d..b746407a8aa 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_runs.py @@ -1,7 +1,6 @@ import datetime import logging -from collections import deque -from typing import Any +from typing import Any, Final import arrow import sqlalchemy as sa @@ -17,7 +16,13 @@ from sqlalchemy.sql.elements import literal_column from sqlalchemy.sql.expression import desc -from ....core.errors import ClusterNotFoundError, ComputationalRunNotFoundError +from ....core.errors import ( + ClusterNotFoundError, + ComputationalRunNotFoundError, + DirectorError, + ProjectNotFoundError, + UserNotFoundError, +) from ....models.comp_runs import CompRunsAtDB, RunMetadataDict from ....utils.db import RUNNING_STATE_TO_DB from ..tables import comp_runs @@ -25,6 +30,24 @@ logger = logging.getLogger(__name__) +_POSTGRES_FK_COLUMN_TO_ERROR_MAP: Final[ + dict[sa.Column, tuple[type[DirectorError], tuple[str, ...]]] +] = { + comp_runs.c.user_id: (UserNotFoundError, ("users", "user_id")), + comp_runs.c.project_uuid: ( + ProjectNotFoundError, + ("projects", "project_id"), + ), + comp_runs.c.cluster_id: ( + ClusterNotFoundError, + ("clusters", "cluster_id"), + ), +} +_DEFAULT_FK_CONSTRAINT_TO_ERROR: Final[tuple[type[DirectorError], tuple]] = ( + DirectorError, + (), +) + class CompRunsRepository(BaseRepository): async def get( @@ -55,24 +78,74 @@ async def get( return CompRunsAtDB.model_validate(row) async def list( - self, filter_by_state: set[RunningState] | None = None + self, + *, + filter_by_state: set[RunningState] | None = None, + never_scheduled: bool = False, + processed_since: datetime.timedelta | None = None, + scheduled_since: datetime.timedelta | None = None, ) -> list[CompRunsAtDB]: - if not filter_by_state: - filter_by_state = set() - runs_in_db: deque[CompRunsAtDB] = deque() + """lists the computational runs: + filter_by_state AND (never_scheduled OR processed_since OR scheduled_since) + + + Keyword Arguments: + filter_by_state -- will return only the runs with result in filter_by_state (default: {None}) + never_scheduled -- will return the runs which were never scheduled (default: {False}) + processed_since -- will return the runs which were processed since X, which are not re-scheduled since then (default: {None}) + scheduled_since -- will return the runs which were scheduled since X, which are not processed since then (default: {None}) + """ + + conditions = [] + if filter_by_state: + conditions.append( + or_( + *[ + comp_runs.c.result == RUNNING_STATE_TO_DB[s] + for s in filter_by_state + ] + ) + ) + + scheduling_or_conditions = [] + if never_scheduled: + scheduling_or_conditions.append(comp_runs.c.scheduled.is_(None)) + if scheduled_since is not None: + # a scheduled run is a run that has been scheduled but not processed yet + # e.g. the processing timepoint is either null or before the scheduling timepoint + scheduled_cutoff = arrow.utcnow().datetime - scheduled_since + scheduling_filter = ( + comp_runs.c.scheduled.is_not(None) + & ( + comp_runs.c.processed.is_(None) + | (comp_runs.c.scheduled > comp_runs.c.processed) + ) + & (comp_runs.c.scheduled <= scheduled_cutoff) + ) + scheduling_or_conditions.append(scheduling_filter) + + if processed_since is not None: + # a processed run is a run that has been scheduled and processed + # and the processing timepoint is after the scheduling timepoint + processed_cutoff = arrow.utcnow().datetime - processed_since + processed_filter = ( + comp_runs.c.processed.is_not(None) + & (comp_runs.c.processed > comp_runs.c.scheduled) + & (comp_runs.c.processed <= processed_cutoff) + ) + + scheduling_or_conditions.append(processed_filter) + + if scheduling_or_conditions: + conditions.append(sa.or_(*scheduling_or_conditions)) + async with self.db_engine.acquire() as conn: - async for row in conn.execute( - sa.select(comp_runs).where( - or_( - *[ - comp_runs.c.result == RUNNING_STATE_TO_DB[s] - for s in filter_by_state - ] - ) + return [ + CompRunsAtDB.model_validate(row) + async for row in conn.execute( + sa.select(comp_runs).where(sa.and_(*conditions)) ) - ): - runs_in_db.append(CompRunsAtDB.model_validate(row)) - return list(runs_in_db) + ] async def create( self, @@ -117,7 +190,17 @@ async def create( row = await result.first() return CompRunsAtDB.model_validate(row) except ForeignKeyViolation as exc: - raise ClusterNotFoundError(cluster_id=cluster_id) from exc + assert exc.diag.constraint_name # nosec # noqa: PT017 + for foreign_key in comp_runs.foreign_keys: + if exc.diag.constraint_name == foreign_key.name: + assert foreign_key.parent is not None # nosec + exc_type, exc_keys = _POSTGRES_FK_COLUMN_TO_ERROR_MAP[ + foreign_key.parent + ] + raise exc_type( + **{f"{k}": locals().get(k) for k in exc_keys} + ) from exc + raise DirectorError from exc async def update( self, user_id: UserID, project_id: ProjectID, iteration: PositiveInt, **values @@ -164,3 +247,24 @@ async def mark_for_cancellation( iteration, cancelled=arrow.utcnow().datetime, ) + + async def mark_for_scheduling( + self, *, user_id: UserID, project_id: ProjectID, iteration: PositiveInt + ) -> CompRunsAtDB | None: + return await self.update( + user_id, + project_id, + iteration, + scheduled=arrow.utcnow().datetime, + processed=None, + ) + + async def mark_as_processed( + self, *, user_id: UserID, project_id: ProjectID, iteration: PositiveInt + ) -> CompRunsAtDB | None: + return await self.update( + user_id, + project_id, + iteration, + processed=arrow.utcnow().datetime, + ) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks/_utils.py b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks/_utils.py index 637e0c7faf6..51082b698f1 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks/_utils.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/db/repositories/comp_tasks/_utils.py @@ -55,9 +55,9 @@ from .....models.comp_tasks import CompTaskAtDB, Image, NodeSchema from .....models.pricing import PricingInfo from .....modules.resource_usage_tracker_client import ResourceUsageTrackerClient -from .....utils.comp_scheduler import COMPLETED_STATES from .....utils.computations import to_node_class from ....catalog import CatalogClient +from ....comp_scheduler._utils import COMPLETED_STATES from ....director_v0 import DirectorV0Client from ...tables import NodeClass @@ -146,12 +146,12 @@ async def _get_node_infos( None, ) - result: tuple[ - ServiceMetaDataPublished, ServiceExtras, SimcoreServiceLabels - ] = await asyncio.gather( - _get_service_details(catalog_client, user_id, product_name, node), - director_client.get_service_extras(node.key, node.version), - director_client.get_service_labels(node), + result: tuple[ServiceMetaDataPublished, ServiceExtras, SimcoreServiceLabels] = ( + await asyncio.gather( + _get_service_details(catalog_client, user_id, product_name, node), + director_client.get_service_extras(node.key, node.version), + director_client.get_service_labels(node), + ) ) return result @@ -247,9 +247,9 @@ async def _get_pricing_and_hardware_infos( return pricing_info, hardware_info -_RAM_SAFE_MARGIN_RATIO: Final[ - float -] = 0.1 # NOTE: machines always have less available RAM than advertised +_RAM_SAFE_MARGIN_RATIO: Final[float] = ( + 0.1 # NOTE: machines always have less available RAM than advertised +) _CPUS_SAFE_MARGIN: Final[float] = 0.1 @@ -267,11 +267,11 @@ async def _update_project_node_resources_from_hardware_info( if not hardware_info.aws_ec2_instances: return try: - unordered_list_ec2_instance_types: list[ - EC2InstanceTypeGet - ] = await get_instance_type_details( - rabbitmq_rpc_client, - instance_type_names=set(hardware_info.aws_ec2_instances), + unordered_list_ec2_instance_types: list[EC2InstanceTypeGet] = ( + await get_instance_type_details( + rabbitmq_rpc_client, + instance_type_names=set(hardware_info.aws_ec2_instances), + ) ) assert unordered_list_ec2_instance_types # nosec diff --git a/services/director-v2/tests/conftest.py b/services/director-v2/tests/conftest.py index fcc0db6dbf1..72b94ec3262 100644 --- a/services/director-v2/tests/conftest.py +++ b/services/director-v2/tests/conftest.py @@ -197,23 +197,24 @@ def mock_env( @pytest.fixture() -async def client(mock_env: EnvVarsDict) -> AsyncIterator[TestClient]: +async def initialized_app(mock_env: EnvVarsDict) -> AsyncIterable[FastAPI]: settings = AppSettings.create_from_envs() app = init_app(settings) print("Application settings\n", settings.model_dump_json(indent=2)) - # NOTE: this way we ensure the events are run in the application - # since it starts the app on a test server - with TestClient(app, raise_server_exceptions=True) as test_client: - yield test_client + async with LifespanManager(app): + yield app @pytest.fixture() -async def initialized_app(mock_env: EnvVarsDict) -> AsyncIterable[FastAPI]: +async def client(mock_env: EnvVarsDict) -> AsyncIterator[TestClient]: + # NOTE: this way we ensure the events are run in the application + # since it starts the app on a test server settings = AppSettings.create_from_envs() app = init_app(settings) + # NOTE: we cannot use the initialized_app fixture here as the TestClient also creates it print("Application settings\n", settings.model_dump_json(indent=2)) - async with LifespanManager(app): - yield app + with TestClient(app, raise_server_exceptions=True) as test_client: + yield test_client @pytest.fixture() diff --git a/services/director-v2/tests/unit/_helpers.py b/services/director-v2/tests/unit/_helpers.py index 779d6cdd117..45632d0454a 100644 --- a/services/director-v2/tests/unit/_helpers.py +++ b/services/director-v2/tests/unit/_helpers.py @@ -1,38 +1,33 @@ from dataclasses import dataclass -from typing import Any +from typing import Any, Callable import aiopg import aiopg.sa -from models_library.projects import ProjectAtDB +import sqlalchemy as sa +from models_library.projects import ProjectAtDB, ProjectID from models_library.projects_nodes_io import NodeID -from simcore_postgres_database.models.comp_pipeline import StateType +from models_library.projects_state import RunningState +from pydantic import TypeAdapter +from simcore_postgres_database.models.comp_runs import comp_runs from simcore_postgres_database.models.comp_tasks import comp_tasks from simcore_service_director_v2.models.comp_pipelines import CompPipelineAtDB from simcore_service_director_v2.models.comp_runs import CompRunsAtDB from simcore_service_director_v2.models.comp_tasks import CompTaskAtDB +from sqlalchemy.ext.asyncio import AsyncEngine @dataclass class PublishedProject: + user: dict[str, Any] project: ProjectAtDB pipeline: CompPipelineAtDB tasks: list[CompTaskAtDB] -@dataclass +@dataclass(kw_only=True) class RunningProject(PublishedProject): runs: CompRunsAtDB - - -async def set_comp_task_state( - aiopg_engine: aiopg.sa.engine.Engine, node_id: str, state: StateType -) -> None: - async with aiopg_engine.acquire() as conn: - await conn.execute( - comp_tasks.update() - .where(comp_tasks.c.node_id == node_id) - .values(state=state) - ) + task_to_callback_mapping: dict[NodeID, Callable[[], None]] async def set_comp_task_outputs( @@ -61,3 +56,55 @@ async def set_comp_task_inputs( .where(comp_tasks.c.node_id == f"{node_id}") .values(inputs=inputs, schema={"outputs": {}, "inputs": inputs_schema}) ) + + +async def assert_comp_runs( + sqlalchemy_async_engine: AsyncEngine, + *, + expected_total: int, + expected_state: RunningState | None = None, + where_statement: Any | None = None, +) -> list[CompRunsAtDB]: + async with sqlalchemy_async_engine.connect() as conn: + query = sa.select(comp_runs) + if where_statement is not None: + query = query.where(where_statement) + list_of_comp_runs = [ + CompRunsAtDB.model_validate(row) for row in await conn.execute(query) + ] + assert len(list_of_comp_runs) == expected_total + if list_of_comp_runs and expected_state: + assert all( + r.result is expected_state for r in list_of_comp_runs + ), f"expected state '{expected_state}', got {[r.result for r in list_of_comp_runs]}" + return list_of_comp_runs + + +async def assert_comp_runs_empty(sqlalchemy_async_engine: AsyncEngine) -> None: + await assert_comp_runs(sqlalchemy_async_engine, expected_total=0) + + +async def assert_comp_tasks( + sqlalchemy_async_engine: AsyncEngine, + *, + project_uuid: ProjectID, + task_ids: list[NodeID], + expected_state: RunningState, + expected_progress: float | None, +) -> list[CompTaskAtDB]: + # check the database is correctly updated, the run is published + async with sqlalchemy_async_engine.connect() as conn: + result = await conn.execute( + comp_tasks.select().where( + (comp_tasks.c.project_id == f"{project_uuid}") + & (comp_tasks.c.node_id.in_([f"{n}" for n in task_ids])) + ) # there is only one entry + ) + tasks = TypeAdapter(list[CompTaskAtDB]).validate_python(result.fetchall()) + assert all( + t.state == expected_state for t in tasks + ), f"expected state: {expected_state}, found: {[t.state for t in tasks]}" + assert all( + t.progress == expected_progress for t in tasks + ), f"{expected_progress=}, found: {[t.progress for t in tasks]}" + return tasks diff --git a/services/director-v2/tests/unit/conftest.py b/services/director-v2/tests/unit/conftest.py index 1375795f0cb..cdf0751fab4 100644 --- a/services/director-v2/tests/unit/conftest.py +++ b/services/director-v2/tests/unit/conftest.py @@ -184,6 +184,11 @@ def fake_s3_settings(faker: Faker) -> S3Settings: ) +@pytest.fixture +def fake_s3_envs(fake_s3_settings: S3Settings) -> EnvVarsDict: + return fake_s3_settings.model_dump() + + @pytest.fixture def mocked_storage_service_api( fake_s3_settings: S3Settings, diff --git a/services/director-v2/tests/unit/test_utils_comp_scheduler.py b/services/director-v2/tests/unit/test_utils_comp_scheduler.py index dfb7c0326b1..05c899a5e40 100644 --- a/services/director-v2/tests/unit/test_utils_comp_scheduler.py +++ b/services/director-v2/tests/unit/test_utils_comp_scheduler.py @@ -10,7 +10,7 @@ from models_library.projects_state import RunningState from models_library.users import UserID from simcore_service_director_v2.models.comp_tasks import CompTaskAtDB -from simcore_service_director_v2.utils.comp_scheduler import ( +from simcore_service_director_v2.modules.comp_scheduler._utils import ( COMPLETED_STATES, SCHEDULED_STATES, TASK_TO_START_STATES, diff --git a/services/director-v2/tests/unit/with_dbs/comp_scheduler/conftest.py b/services/director-v2/tests/unit/with_dbs/comp_scheduler/conftest.py new file mode 100644 index 00000000000..8f1c2898222 --- /dev/null +++ b/services/director-v2/tests/unit/with_dbs/comp_scheduler/conftest.py @@ -0,0 +1,69 @@ +# pylint:disable=unused-variable +# pylint:disable=unused-argument +# pylint:disable=redefined-outer-name +# pylint:disable=no-value-for-parameter +# pylint:disable=protected-access +# pylint:disable=too-many-arguments +# pylint:disable=no-name-in-module +# pylint: disable=too-many-statements + + +from unittest import mock + +import pytest +import sqlalchemy as sa +from pytest_mock.plugin import MockerFixture +from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict +from pytest_simcore.helpers.typing_env import EnvVarsDict +from settings_library.rabbit import RabbitSettings +from settings_library.redis import RedisSettings + + +@pytest.fixture +def mock_env( + mock_env: EnvVarsDict, + monkeypatch: pytest.MonkeyPatch, + fake_s3_envs: EnvVarsDict, + postgres_db: sa.engine.Engine, + postgres_host_config: dict[str, str], + rabbit_service: RabbitSettings, + redis_service: RedisSettings, +) -> EnvVarsDict: + return mock_env | setenvs_from_dict( + monkeypatch, + {k: f"{v}" for k, v in fake_s3_envs.items()} + | { + "COMPUTATIONAL_BACKEND_ENABLED": True, + "COMPUTATIONAL_BACKEND_DASK_CLIENT_ENABLED": True, + }, + ) + + +@pytest.fixture +def with_disabled_auto_scheduling(mocker: MockerFixture) -> mock.Mock: + mocker.patch( + "simcore_service_director_v2.modules.comp_scheduler.shutdown_manager", + ) + return mocker.patch( + "simcore_service_director_v2.modules.comp_scheduler.setup_manager", + ) + + +@pytest.fixture +def with_disabled_scheduler_worker(mocker: MockerFixture) -> mock.Mock: + mocker.patch( + "simcore_service_director_v2.modules.comp_scheduler.shutdown_worker", + autospec=True, + ) + return mocker.patch( + "simcore_service_director_v2.modules.comp_scheduler.setup_worker", + autospec=True, + ) + + +@pytest.fixture +def with_disabled_scheduler_publisher(mocker: MockerFixture) -> mock.Mock: + return mocker.patch( + "simcore_service_director_v2.modules.comp_scheduler._manager.request_pipeline_scheduling", + autospec=True, + ) diff --git a/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_db_repositories_comp_runs.py b/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_db_repositories_comp_runs.py new file mode 100644 index 00000000000..ba903d1b069 --- /dev/null +++ b/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_db_repositories_comp_runs.py @@ -0,0 +1,525 @@ +# pylint: disable=no-value-for-parameter +# pylint: disable=protected-access +# pylint: disable=redefined-outer-name +# pylint: disable=too-many-arguments +# pylint: disable=unused-argument +# pylint: disable=unused-variable + +import asyncio +import datetime +import random +from collections.abc import Awaitable, Callable +from typing import cast + +import arrow +import pytest +from _helpers import PublishedProject +from faker import Faker +from models_library.clusters import DEFAULT_CLUSTER_ID, Cluster +from models_library.projects import ProjectID +from models_library.projects_state import RunningState +from models_library.users import UserID +from simcore_service_director_v2.core.errors import ( + ClusterNotFoundError, + ComputationalRunNotFoundError, + ProjectNotFoundError, + UserNotFoundError, +) +from simcore_service_director_v2.models.comp_runs import CompRunsAtDB, RunMetadataDict +from simcore_service_director_v2.modules.comp_scheduler._constants import ( + SCHEDULER_INTERVAL, +) +from simcore_service_director_v2.modules.db.repositories.comp_runs import ( + CompRunsRepository, +) + +pytest_simcore_core_services_selection = [ + "postgres", +] +pytest_simcore_ops_services_selection = [ + "adminer", +] + + +@pytest.fixture +def fake_user_id(faker: Faker) -> UserID: + return faker.pyint(min_value=1) + + +@pytest.fixture +def fake_project_id(faker: Faker) -> ProjectID: + return ProjectID(f"{faker.uuid4(cast_to=None)}") + + +async def test_get( + aiopg_engine, + fake_user_id: UserID, + fake_project_id: ProjectID, + publish_project: Callable[[], Awaitable[PublishedProject]], + create_comp_run: Callable[..., Awaitable[CompRunsAtDB]], +): + with pytest.raises(ComputationalRunNotFoundError): + await CompRunsRepository(aiopg_engine).get(fake_user_id, fake_project_id) + + published_project = await publish_project() + assert published_project.project.prj_owner + # there is still no comp run created + with pytest.raises(ComputationalRunNotFoundError): + await CompRunsRepository(aiopg_engine).get( + published_project.project.prj_owner, published_project.project.uuid + ) + + await create_comp_run(published_project.user, published_project.project) + await CompRunsRepository(aiopg_engine).get( + published_project.project.prj_owner, published_project.project.uuid + ) + + +async def test_list( + aiopg_engine, + publish_project: Callable[[], Awaitable[PublishedProject]], + run_metadata: RunMetadataDict, + faker: Faker, +): + assert await CompRunsRepository(aiopg_engine).list() == [] + + published_project = await publish_project() + assert await CompRunsRepository(aiopg_engine).list() == [] + + created = await CompRunsRepository(aiopg_engine).create( + user_id=published_project.user["id"], + project_id=published_project.project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + iteration=None, + metadata=run_metadata, + use_on_demand_clusters=faker.pybool(), + ) + assert await CompRunsRepository(aiopg_engine).list() == [created] + + created = [created] + await asyncio.gather( + *( + CompRunsRepository(aiopg_engine).create( + user_id=published_project.user["id"], + project_id=published_project.project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + iteration=created.iteration + n + 1, + metadata=run_metadata, + use_on_demand_clusters=faker.pybool(), + ) + for n in range(50) + ) + ) + assert sorted( + await CompRunsRepository(aiopg_engine).list(), key=lambda x: x.iteration + ) == sorted(created, key=lambda x: x.iteration) + + # test with filter of state + any_state_but_published = { + s for s in RunningState if s is not RunningState.PUBLISHED + } + assert ( + await CompRunsRepository(aiopg_engine).list( + filter_by_state=any_state_but_published + ) + == [] + ) + assert sorted( + await CompRunsRepository(aiopg_engine).list( + filter_by_state={RunningState.PUBLISHED} + ), + key=lambda x: x.iteration, + ) == sorted(created, key=lambda x: x.iteration) + + # test with never scheduled filter, let's create a bunch of scheduled entries, + assert sorted( + await CompRunsRepository(aiopg_engine).list(never_scheduled=True), + key=lambda x: x.iteration, + ) == sorted(created, key=lambda x: x.iteration) + comp_runs_marked_for_scheduling = random.sample(created, k=25) + await asyncio.gather( + *( + CompRunsRepository(aiopg_engine).mark_for_scheduling( + user_id=comp_run.user_id, + project_id=comp_run.project_uuid, + iteration=comp_run.iteration, + ) + for comp_run in comp_runs_marked_for_scheduling + ) + ) + # filter them away + created = [r for r in created if r not in comp_runs_marked_for_scheduling] + assert sorted( + await CompRunsRepository(aiopg_engine).list(never_scheduled=True), + key=lambda x: x.iteration, + ) == sorted(created, key=lambda x: x.iteration) + + # now mark a few of them as processed + comp_runs_marked_as_processed = random.sample(comp_runs_marked_for_scheduling, k=11) + await asyncio.gather( + *( + CompRunsRepository(aiopg_engine).mark_as_processed( + user_id=comp_run.user_id, + project_id=comp_run.project_uuid, + iteration=comp_run.iteration, + ) + for comp_run in comp_runs_marked_as_processed + ) + ) + # filter them away + comp_runs_marked_for_scheduling = [ + r + for r in comp_runs_marked_for_scheduling + if r not in comp_runs_marked_as_processed + ] + # since they were just marked as processed now, we will get nothing + assert ( + sorted( + await CompRunsRepository(aiopg_engine).list( + never_scheduled=False, processed_since=SCHEDULER_INTERVAL + ), + key=lambda x: x.iteration, + ) + == [] + ) + # now we artificially change the scheduled/processed time and set it 2x the scheduler interval + # these are correctly processed ones, so we should get them back + fake_scheduled_time = arrow.utcnow().datetime - 2 * SCHEDULER_INTERVAL + fake_processed_time = fake_scheduled_time + 0.5 * SCHEDULER_INTERVAL + comp_runs_marked_as_processed = ( + cast( # NOTE: the cast here is ok since gather will raise if there is an error + list[CompRunsAtDB], + await asyncio.gather( + *( + CompRunsRepository(aiopg_engine).update( + user_id=comp_run.user_id, + project_id=comp_run.project_uuid, + iteration=comp_run.iteration, + scheduled=fake_scheduled_time, + processed=fake_processed_time, + ) + for comp_run in comp_runs_marked_as_processed + ) + ), + ) + ) + # now we should get them + assert sorted( + await CompRunsRepository(aiopg_engine).list( + never_scheduled=False, processed_since=SCHEDULER_INTERVAL + ), + key=lambda x: x.iteration, + ) == sorted(comp_runs_marked_as_processed, key=lambda x: x.iteration) + + # now some of them were never processed (e.g. processed time is either null or before schedule time) + comp_runs_waiting_for_processing_or_never_processed = random.sample( + comp_runs_marked_as_processed, k=6 + ) + comp_runs_marked_as_processed = [ + r + for r in comp_runs_marked_as_processed + if r not in comp_runs_waiting_for_processing_or_never_processed + ] + # now we artificially change the processed time to be before the scheduled time + comp_runs_waiting_for_processing_or_never_processed = cast( + list[CompRunsAtDB], + await asyncio.gather( + *( + CompRunsRepository(aiopg_engine).update( + user_id=comp_run.user_id, + project_id=comp_run.project_uuid, + iteration=comp_run.iteration, + scheduled=fake_processed_time, # NOTE: we invert here the timings + processed=random.choice([fake_scheduled_time, None]), # noqa: S311 + ) + for comp_run in comp_runs_waiting_for_processing_or_never_processed + ) + ), + ) + # so the processed ones shall remain + assert sorted( + await CompRunsRepository(aiopg_engine).list( + never_scheduled=False, processed_since=SCHEDULER_INTERVAL + ), + key=lambda x: x.iteration, + ) == sorted(comp_runs_marked_as_processed, key=lambda x: x.iteration) + # the ones waiting for scheduling now + assert sorted( + await CompRunsRepository(aiopg_engine).list( + never_scheduled=False, scheduled_since=SCHEDULER_INTERVAL + ), + key=lambda x: x.iteration, + ) == sorted( + comp_runs_waiting_for_processing_or_never_processed, key=lambda x: x.iteration + ) + + +async def test_create( + aiopg_engine, + fake_user_id: UserID, + fake_project_id: ProjectID, + run_metadata: RunMetadataDict, + faker: Faker, + publish_project: Callable[[], Awaitable[PublishedProject]], + create_cluster: Callable[..., Awaitable[Cluster]], +): + with pytest.raises(ProjectNotFoundError): + await CompRunsRepository(aiopg_engine).create( + user_id=fake_user_id, + project_id=fake_project_id, + cluster_id=DEFAULT_CLUSTER_ID, + iteration=None, + metadata=run_metadata, + use_on_demand_clusters=faker.pybool(), + ) + published_project = await publish_project() + with pytest.raises(UserNotFoundError): + await CompRunsRepository(aiopg_engine).create( + user_id=fake_user_id, + project_id=published_project.project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + iteration=None, + metadata=run_metadata, + use_on_demand_clusters=faker.pybool(), + ) + + created = await CompRunsRepository(aiopg_engine).create( + user_id=published_project.user["id"], + project_id=published_project.project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + iteration=None, + metadata=run_metadata, + use_on_demand_clusters=faker.pybool(), + ) + got = await CompRunsRepository(aiopg_engine).get( + user_id=published_project.user["id"], + project_id=published_project.project.uuid, + ) + assert created == got + + # creating a second one auto increment the iteration + created = await CompRunsRepository(aiopg_engine).create( + user_id=published_project.user["id"], + project_id=published_project.project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + iteration=None, + metadata=run_metadata, + use_on_demand_clusters=faker.pybool(), + ) + assert created != got + assert created.iteration == got.iteration + 1 + + # getting without specifying the iteration returns the latest + got = await CompRunsRepository(aiopg_engine).get( + user_id=published_project.user["id"], + project_id=published_project.project.uuid, + ) + assert created == got + + with pytest.raises(ClusterNotFoundError): + await CompRunsRepository(aiopg_engine).create( + user_id=published_project.user["id"], + project_id=published_project.project.uuid, + cluster_id=faker.pyint(min_value=1), + iteration=None, + metadata=run_metadata, + use_on_demand_clusters=faker.pybool(), + ) + cluster = await create_cluster(published_project.user) + await CompRunsRepository(aiopg_engine).create( + user_id=published_project.user["id"], + project_id=published_project.project.uuid, + cluster_id=cluster.id, + iteration=None, + metadata=run_metadata, + use_on_demand_clusters=faker.pybool(), + ) + + +async def test_update( + aiopg_engine, + fake_user_id: UserID, + fake_project_id: ProjectID, + run_metadata: RunMetadataDict, + faker: Faker, + publish_project: Callable[[], Awaitable[PublishedProject]], +): + # this updates nothing but also does not complain + updated = await CompRunsRepository(aiopg_engine).update( + fake_user_id, fake_project_id, faker.pyint(min_value=1) + ) + assert updated is None + # now let's create a valid one + published_project = await publish_project() + created = await CompRunsRepository(aiopg_engine).create( + user_id=published_project.user["id"], + project_id=published_project.project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + iteration=None, + metadata=run_metadata, + use_on_demand_clusters=faker.pybool(), + ) + + got = await CompRunsRepository(aiopg_engine).get( + user_id=published_project.user["id"], + project_id=published_project.project.uuid, + ) + assert created == got + + updated = await CompRunsRepository(aiopg_engine).update( + created.user_id, + created.project_uuid, + created.iteration, + scheduled=datetime.datetime.now(datetime.UTC), + ) + assert updated is not None + assert created != updated + assert created.scheduled is None + assert updated.scheduled is not None + + +async def test_set_run_result( + aiopg_engine, + run_metadata: RunMetadataDict, + faker: Faker, + publish_project: Callable[[], Awaitable[PublishedProject]], +): + published_project = await publish_project() + created = await CompRunsRepository(aiopg_engine).create( + user_id=published_project.user["id"], + project_id=published_project.project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + iteration=None, + metadata=run_metadata, + use_on_demand_clusters=faker.pybool(), + ) + got = await CompRunsRepository(aiopg_engine).get( + user_id=published_project.user["id"], + project_id=published_project.project.uuid, + ) + assert created == got + assert created.result is not RunningState.PENDING + assert created.ended is None + + updated = await CompRunsRepository(aiopg_engine).set_run_result( + user_id=created.user_id, + project_id=created.project_uuid, + iteration=created.iteration, + result_state=RunningState.PENDING, + final_state=False, + ) + assert updated + assert updated != created + assert updated.result is RunningState.PENDING + assert updated.ended is None + + final_updated = await CompRunsRepository(aiopg_engine).set_run_result( + user_id=created.user_id, + project_id=created.project_uuid, + iteration=created.iteration, + result_state=RunningState.ABORTED, + final_state=True, + ) + assert final_updated + assert final_updated != updated + assert final_updated.result is RunningState.ABORTED + assert final_updated.ended is not None + + +async def test_mark_for_cancellation( + aiopg_engine, + run_metadata: RunMetadataDict, + faker: Faker, + publish_project: Callable[[], Awaitable[PublishedProject]], +): + published_project = await publish_project() + created = await CompRunsRepository(aiopg_engine).create( + user_id=published_project.user["id"], + project_id=published_project.project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + iteration=None, + metadata=run_metadata, + use_on_demand_clusters=faker.pybool(), + ) + got = await CompRunsRepository(aiopg_engine).get( + user_id=published_project.user["id"], + project_id=published_project.project.uuid, + ) + assert created == got + assert created.cancelled is None + + updated = await CompRunsRepository(aiopg_engine).mark_for_cancellation( + user_id=created.user_id, + project_id=created.project_uuid, + iteration=created.iteration, + ) + assert updated + assert updated != created + assert updated.cancelled is not None + + +async def test_mark_for_scheduling( + aiopg_engine, + run_metadata: RunMetadataDict, + faker: Faker, + publish_project: Callable[[], Awaitable[PublishedProject]], +): + published_project = await publish_project() + created = await CompRunsRepository(aiopg_engine).create( + user_id=published_project.user["id"], + project_id=published_project.project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + iteration=None, + metadata=run_metadata, + use_on_demand_clusters=faker.pybool(), + ) + got = await CompRunsRepository(aiopg_engine).get( + user_id=published_project.user["id"], + project_id=published_project.project.uuid, + ) + assert created == got + assert created.scheduled is None + assert created.processed is None + + updated = await CompRunsRepository(aiopg_engine).mark_for_scheduling( + user_id=created.user_id, + project_id=created.project_uuid, + iteration=created.iteration, + ) + assert updated + assert updated != created + assert updated.scheduled is not None + assert updated.processed is None + + +async def test_mark_scheduling_done( + aiopg_engine, + run_metadata: RunMetadataDict, + faker: Faker, + publish_project: Callable[[], Awaitable[PublishedProject]], +): + published_project = await publish_project() + created = await CompRunsRepository(aiopg_engine).create( + user_id=published_project.user["id"], + project_id=published_project.project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + iteration=None, + metadata=run_metadata, + use_on_demand_clusters=faker.pybool(), + ) + got = await CompRunsRepository(aiopg_engine).get( + user_id=published_project.user["id"], + project_id=published_project.project.uuid, + ) + assert created == got + assert created.scheduled is None + assert created.processed is None + + updated = await CompRunsRepository(aiopg_engine).mark_as_processed( + user_id=created.user_id, + project_id=created.project_uuid, + iteration=created.iteration, + ) + assert updated + assert updated != created + assert updated.scheduled is None + assert updated.processed is not None diff --git a/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_manager.py b/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_manager.py new file mode 100644 index 00000000000..ac5bbbcc942 --- /dev/null +++ b/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_manager.py @@ -0,0 +1,371 @@ +# pylint:disable=unused-variable +# pylint:disable=unused-argument +# pylint:disable=redefined-outer-name +# pylint:disable=no-value-for-parameter +# pylint:disable=protected-access +# pylint:disable=too-many-arguments +# pylint:disable=no-name-in-module +# pylint: disable=too-many-statements + + +import asyncio +import datetime +import logging +from collections.abc import AsyncIterator, Awaitable, Callable +from typing import Any +from unittest import mock + +import pytest +from _helpers import PublishedProject, assert_comp_runs, assert_comp_runs_empty +from fastapi import FastAPI +from models_library.clusters import DEFAULT_CLUSTER_ID +from models_library.projects import ProjectAtDB +from models_library.projects_state import RunningState +from pytest_mock.plugin import MockerFixture +from servicelib.rabbitmq._client import RabbitMQClient +from servicelib.redis import CouldNotAcquireLockError +from servicelib.utils import limited_gather +from simcore_service_director_v2.core.errors import PipelineNotFoundError +from simcore_service_director_v2.models.comp_pipelines import CompPipelineAtDB +from simcore_service_director_v2.models.comp_runs import RunMetadataDict +from simcore_service_director_v2.modules.comp_scheduler._manager import ( + _LOST_TASKS_FACTOR, + SCHEDULER_INTERVAL, + run_new_pipeline, + schedule_all_pipelines, + stop_pipeline, +) +from simcore_service_director_v2.modules.comp_scheduler._models import ( + SchedulePipelineRabbitMessage, +) +from simcore_service_director_v2.modules.db.repositories.comp_runs import ( + CompRunsRepository, +) +from sqlalchemy.ext.asyncio import AsyncEngine + +pytest_simcore_core_services_selection = ["postgres", "rabbit", "redis"] +pytest_simcore_ops_services_selection = ["adminer", "redis-commander"] + + +@pytest.fixture +async def scheduler_rabbit_client_parser( + create_rabbitmq_client: Callable[[str], RabbitMQClient], mocker: MockerFixture +) -> AsyncIterator[mock.AsyncMock]: + client = create_rabbitmq_client("scheduling_pytest_consumer") + mock = mocker.AsyncMock(return_value=True) + queue_name, _ = await client.subscribe( + SchedulePipelineRabbitMessage.get_channel_name(), mock, exclusive_queue=False + ) + yield mock + await client.unsubscribe(queue_name) + + +@pytest.fixture +def with_fast_scheduling(mocker: MockerFixture) -> None: + from simcore_service_director_v2.modules.comp_scheduler import _manager + + mocker.patch.object( + _manager, "SCHEDULER_INTERVAL", datetime.timedelta(seconds=0.01) + ) + + +@pytest.fixture +def mocked_schedule_all_pipelines(mocker: MockerFixture) -> mock.Mock: + return mocker.patch( + "simcore_service_director_v2.modules.comp_scheduler._manager.schedule_all_pipelines", + autospec=True, + ) + + +async def test_manager_starts_and_auto_schedules_pipelines( + with_fast_scheduling: None, + with_disabled_scheduler_worker: mock.Mock, + mocked_schedule_all_pipelines: mock.Mock, + initialized_app: FastAPI, + sqlalchemy_async_engine: AsyncEngine, +): + await assert_comp_runs_empty(sqlalchemy_async_engine) + mocked_schedule_all_pipelines.assert_called() + + +async def test_schedule_all_pipelines_empty_db( + with_disabled_auto_scheduling: mock.Mock, + with_disabled_scheduler_worker: mock.Mock, + initialized_app: FastAPI, + scheduler_rabbit_client_parser: mock.AsyncMock, + sqlalchemy_async_engine: AsyncEngine, +): + await assert_comp_runs_empty(sqlalchemy_async_engine) + + await schedule_all_pipelines(initialized_app) + + # check nothing was distributed + scheduler_rabbit_client_parser.assert_not_called() + + # check comp_runs is still empty + await assert_comp_runs_empty(sqlalchemy_async_engine) + + +async def test_schedule_all_pipelines_concurently_runs_exclusively_and_raises( + with_disabled_auto_scheduling: mock.Mock, + initialized_app: FastAPI, + mocker: MockerFixture, +): + CONCURRENCY = 5 + # NOTE: this ensure no flakyness as empty scheduling is very fast + # so we slow down the limited_gather function + original_function = limited_gather + + async def slow_limited_gather(*args, **kwargs): + result = await original_function(*args, **kwargs) + await asyncio.sleep(3) # to ensure flakyness does not occur + return result + + mock_function = mocker.patch( + "simcore_service_director_v2.modules.comp_scheduler._manager.limited_gather", + autospec=True, + side_effect=slow_limited_gather, + ) + + results = await asyncio.gather( + *(schedule_all_pipelines(initialized_app) for _ in range(CONCURRENCY)), + return_exceptions=True, + ) + + assert results.count(None) == 1, f"Only one task should have run: {results}" + for r in results: + if r: + assert isinstance(r, CouldNotAcquireLockError) + mock_function.assert_called_once() + + +async def test_schedule_all_pipelines( + with_disabled_auto_scheduling: mock.Mock, + with_disabled_scheduler_worker: mock.Mock, + initialized_app: FastAPI, + published_project: PublishedProject, + sqlalchemy_async_engine: AsyncEngine, + aiopg_engine, + run_metadata: RunMetadataDict, + scheduler_rabbit_client_parser: mock.AsyncMock, +): + await assert_comp_runs_empty(sqlalchemy_async_engine) + assert published_project.project.prj_owner + # now we schedule a pipeline + await run_new_pipeline( + initialized_app, + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + run_metadata=run_metadata, + use_on_demand_clusters=False, + ) + # this directly schedule a new pipeline + scheduler_rabbit_client_parser.assert_called_once_with( + SchedulePipelineRabbitMessage( + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + iteration=1, + ).body() + ) + scheduler_rabbit_client_parser.reset_mock() + comp_run = (await assert_comp_runs(sqlalchemy_async_engine, expected_total=1))[0] + assert comp_run.project_uuid == published_project.project.uuid + assert comp_run.user_id == published_project.project.prj_owner + assert comp_run.iteration == 1 + assert comp_run.cancelled is None + assert comp_run.cluster_id == DEFAULT_CLUSTER_ID + assert comp_run.metadata == run_metadata + assert comp_run.result is RunningState.PUBLISHED + assert comp_run.scheduled is not None + assert comp_run.processed is None + start_schedule_time = comp_run.scheduled + start_modified_time = comp_run.modified + + # this will now not schedule the pipeline since it was already scheduled + await schedule_all_pipelines(initialized_app) + scheduler_rabbit_client_parser.assert_not_called() + comp_runs = await assert_comp_runs(sqlalchemy_async_engine, expected_total=1) + comp_run = comp_runs[0] + assert comp_run.scheduled + assert comp_run.scheduled == start_schedule_time, "scheduled time changed!" + assert comp_run.cancelled is None + assert comp_run.modified == start_modified_time + + # to simulate that the worker did its job we will set times in the past + await CompRunsRepository(aiopg_engine).update( + user_id=comp_run.user_id, + project_id=comp_run.project_uuid, + iteration=comp_run.iteration, + scheduled=comp_run.scheduled - 1.5 * SCHEDULER_INTERVAL, + processed=comp_run.scheduled - 1.1 * SCHEDULER_INTERVAL, + ) + + # now we schedule a pipeline again, but we wait for the scheduler interval to pass + # this will trigger a new schedule + await schedule_all_pipelines(initialized_app) + scheduler_rabbit_client_parser.assert_called_once_with( + SchedulePipelineRabbitMessage( + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + iteration=1, + ).body() + ) + scheduler_rabbit_client_parser.reset_mock() + comp_runs = await assert_comp_runs(sqlalchemy_async_engine, expected_total=1) + comp_run = comp_runs[0] + assert comp_run.scheduled is not None + assert comp_run.scheduled > start_schedule_time + last_schedule_time = comp_run.scheduled + assert comp_run.cancelled is None + assert comp_run.modified > start_modified_time + + # now we stop the pipeline, which should instantly trigger a schedule + await stop_pipeline( + initialized_app, + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + ) + await schedule_all_pipelines(initialized_app) + scheduler_rabbit_client_parser.assert_called_once_with( + SchedulePipelineRabbitMessage( + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + iteration=1, + ).body() + ) + scheduler_rabbit_client_parser.reset_mock() + comp_runs = await assert_comp_runs(sqlalchemy_async_engine, expected_total=1) + comp_run = comp_runs[0] + assert comp_run.scheduled is not None + assert comp_run.scheduled > last_schedule_time + assert comp_run.cancelled is not None + + +async def test_schedule_all_pipelines_logs_error_if_it_find_old_pipelines( + with_disabled_auto_scheduling: mock.Mock, + with_disabled_scheduler_worker: mock.Mock, + initialized_app: FastAPI, + published_project: PublishedProject, + sqlalchemy_async_engine: AsyncEngine, + aiopg_engine, + run_metadata: RunMetadataDict, + scheduler_rabbit_client_parser: mock.AsyncMock, + caplog: pytest.LogCaptureFixture, +): + await assert_comp_runs_empty(sqlalchemy_async_engine) + assert published_project.project.prj_owner + # now we schedule a pipeline + await run_new_pipeline( + initialized_app, + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + run_metadata=run_metadata, + use_on_demand_clusters=False, + ) + # this directly schedule a new pipeline + scheduler_rabbit_client_parser.assert_called_once_with( + SchedulePipelineRabbitMessage( + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + iteration=1, + ).body() + ) + scheduler_rabbit_client_parser.reset_mock() + comp_run = (await assert_comp_runs(sqlalchemy_async_engine, expected_total=1))[0] + assert comp_run.project_uuid == published_project.project.uuid + assert comp_run.user_id == published_project.project.prj_owner + assert comp_run.iteration == 1 + assert comp_run.cancelled is None + assert comp_run.cluster_id == DEFAULT_CLUSTER_ID + assert comp_run.metadata == run_metadata + assert comp_run.result is RunningState.PUBLISHED + assert comp_run.scheduled is not None + start_schedule_time = comp_run.scheduled + start_modified_time = comp_run.modified + + # this will now not schedule the pipeline since it was already scheduled + await schedule_all_pipelines(initialized_app) + scheduler_rabbit_client_parser.assert_not_called() + comp_runs = await assert_comp_runs(sqlalchemy_async_engine, expected_total=1) + comp_run = comp_runs[0] + assert comp_run.scheduled == start_schedule_time, "scheduled time changed!" + assert comp_run.cancelled is None + assert comp_run.modified == start_modified_time + + # now we artificially set the last_schedule time well in the past + await CompRunsRepository(aiopg_engine).update( + comp_run.user_id, + comp_run.project_uuid, + comp_run.iteration, + scheduled=datetime.datetime.now(tz=datetime.UTC) + - SCHEDULER_INTERVAL * (_LOST_TASKS_FACTOR + 1), + ) + with caplog.at_level(logging.ERROR): + await schedule_all_pipelines(initialized_app) + assert ( + "found 1 lost pipelines, they will be re-scheduled now" in caplog.messages + ) + scheduler_rabbit_client_parser.assert_called_once_with( + SchedulePipelineRabbitMessage( + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + iteration=1, + ).body() + ) + scheduler_rabbit_client_parser.reset_mock() + comp_runs = await assert_comp_runs(sqlalchemy_async_engine, expected_total=1) + comp_run = comp_runs[0] + assert comp_run.scheduled is not None + assert comp_run.scheduled > start_schedule_time + assert comp_run.cancelled is None + assert comp_run.modified > start_modified_time + + +async def test_empty_pipeline_is_not_scheduled( + with_disabled_auto_scheduling: mock.Mock, + with_disabled_scheduler_worker: mock.Mock, + initialized_app: FastAPI, + registered_user: Callable[..., dict[str, Any]], + project: Callable[..., Awaitable[ProjectAtDB]], + create_pipeline: Callable[..., Awaitable[CompPipelineAtDB]], + run_metadata: RunMetadataDict, + sqlalchemy_async_engine: AsyncEngine, + scheduler_rabbit_client_parser: mock.AsyncMock, + caplog: pytest.LogCaptureFixture, +): + await assert_comp_runs_empty(sqlalchemy_async_engine) + user = registered_user() + empty_project = await project(user) + + # the project is not in the comp_pipeline, therefore scheduling it should fail + with pytest.raises(PipelineNotFoundError): + await run_new_pipeline( + initialized_app, + user_id=user["id"], + project_id=empty_project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + run_metadata=run_metadata, + use_on_demand_clusters=False, + ) + await assert_comp_runs_empty(sqlalchemy_async_engine) + scheduler_rabbit_client_parser.assert_not_called() + + # create the empty pipeline now + await create_pipeline(project_id=f"{empty_project.uuid}") + + # creating a run with an empty pipeline is useless, check the scheduler is not kicking in + with caplog.at_level(logging.WARNING): + await run_new_pipeline( + initialized_app, + user_id=user["id"], + project_id=empty_project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + run_metadata=run_metadata, + use_on_demand_clusters=False, + ) + assert len(caplog.records) == 1 + assert "no computational dag defined" in caplog.records[0].message + await assert_comp_runs_empty(sqlalchemy_async_engine) + scheduler_rabbit_client_parser.assert_not_called() diff --git a/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py b/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_scheduler_dask.py similarity index 55% rename from services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py rename to services/director-v2/tests/unit/with_dbs/comp_scheduler/test_scheduler_dask.py index 6f016f297c0..f6a041b934e 100644 --- a/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py +++ b/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_scheduler_dask.py @@ -16,12 +16,14 @@ from typing import Any, cast from unittest import mock -import aiopg -import aiopg.sa -import httpx import pytest -from _helpers import PublishedProject, RunningProject -from dask.distributed import SpecCluster +from _helpers import ( + PublishedProject, + RunningProject, + assert_comp_runs, + assert_comp_runs_empty, + assert_comp_tasks, +) from dask_task_models_library.container_tasks.errors import TaskCancelledError from dask_task_models_library.container_tasks.events import TaskProgressEvent from dask_task_models_library.container_tasks.io import TaskOutputData @@ -43,14 +45,9 @@ from models_library.users import UserID from pydantic import TypeAdapter from pytest_mock.plugin import MockerFixture -from pytest_simcore.helpers.typing_env import EnvVarsDict from servicelib.rabbitmq import RabbitMQClient -from servicelib.redis import CouldNotAcquireLockError -from settings_library.rabbit import RabbitSettings -from settings_library.redis import RedisSettings from simcore_postgres_database.models.comp_runs import comp_runs -from simcore_postgres_database.models.comp_tasks import NodeClass, comp_tasks -from simcore_service_director_v2.core.application import init_app +from simcore_postgres_database.models.comp_tasks import NodeClass from simcore_service_director_v2.core.errors import ( ClustersKeeperNotAvailableError, ComputationalBackendNotConnectedError, @@ -59,28 +56,32 @@ ComputationalBackendTaskResultsNotReadyError, ComputationalSchedulerChangedError, ComputationalSchedulerError, - ConfigurationError, - PipelineNotFoundError, ) -from simcore_service_director_v2.core.settings import AppSettings from simcore_service_director_v2.models.comp_pipelines import CompPipelineAtDB from simcore_service_director_v2.models.comp_runs import CompRunsAtDB, RunMetadataDict from simcore_service_director_v2.models.comp_tasks import CompTaskAtDB, Image from simcore_service_director_v2.models.dask_subsystem import DaskClientTaskState -from simcore_service_director_v2.modules.comp_scheduler import ( +from simcore_service_director_v2.modules.comp_scheduler._manager import ( + run_new_pipeline, + stop_pipeline, +) +from simcore_service_director_v2.modules.comp_scheduler._scheduler_base import ( BaseCompScheduler, - get_scheduler, ) -from simcore_service_director_v2.modules.comp_scheduler._dask_scheduler import ( +from simcore_service_director_v2.modules.comp_scheduler._scheduler_dask import ( DaskScheduler, ) +from simcore_service_director_v2.modules.comp_scheduler._utils import COMPLETED_STATES +from simcore_service_director_v2.modules.comp_scheduler._worker import ( + _get_scheduler_worker, +) from simcore_service_director_v2.modules.dask_client import ( DaskJobID, PublishedComputationTask, ) -from simcore_service_director_v2.utils.comp_scheduler import COMPLETED_STATES from simcore_service_director_v2.utils.dask_client_utils import TaskHandlers -from starlette.testclient import TestClient +from sqlalchemy import and_ +from sqlalchemy.ext.asyncio import AsyncEngine from tenacity.asyncio import AsyncRetrying from tenacity.retry import retry_if_exception_type from tenacity.stop import stop_after_delay @@ -113,119 +114,8 @@ def _assert_dask_client_correctly_initialized( ) -async def _assert_comp_run_db( - aiopg_engine: aiopg.sa.engine.Engine, - pub_project: PublishedProject, - expected_state: RunningState, -) -> None: - # check the database is correctly updated, the run is published - async with aiopg_engine.acquire() as conn: - result = await conn.execute( - comp_runs.select().where( - (comp_runs.c.user_id == pub_project.project.prj_owner) - & (comp_runs.c.project_uuid == f"{pub_project.project.uuid}") - ) # there is only one entry - ) - run_entry = CompRunsAtDB.model_validate(await result.first()) - assert ( - run_entry.result == expected_state - ), f"comp_runs: expected state '{expected_state}, found '{run_entry.result}'" - - -async def _assert_comp_tasks_db( - aiopg_engine: aiopg.sa.engine.Engine, - project_uuid: ProjectID, - task_ids: list[NodeID], - *, - expected_state: RunningState, - expected_progress: float | None, -) -> None: - # check the database is correctly updated, the run is published - async with aiopg_engine.acquire() as conn: - result = await conn.execute( - comp_tasks.select().where( - (comp_tasks.c.project_id == f"{project_uuid}") - & (comp_tasks.c.node_id.in_([f"{n}" for n in task_ids])) - ) # there is only one entry - ) - tasks = TypeAdapter(list[CompTaskAtDB]).validate_python(await result.fetchall()) - assert all( - t.state == expected_state for t in tasks - ), f"expected state: {expected_state}, found: {[t.state for t in tasks]}" - assert all( - t.progress == expected_progress for t in tasks - ), f"{expected_progress=}, found: {[t.progress for t in tasks]}" - - -async def schedule_all_pipelines(scheduler: BaseCompScheduler) -> None: - # NOTE: we take a copy of the pipelines, as this could change quickly if there are - # misconfigured pipelines that would be removed from the scheduler - # NOTE: we simulate multiple dv-2 replicas by running several times - # the same pipeline scheduling - local_pipelines = deepcopy(scheduler._scheduled_pipelines) # noqa: SLF001 - results = await asyncio.gather( - *( - scheduler._schedule_pipeline( # noqa: SLF001 - user_id=user_id, - project_id=project_id, - iteration=iteration, - wake_up_callback=params.scheduler_waker.set, - ) - for _ in range(3) - for ( - user_id, - project_id, - iteration, - ), params in local_pipelines.items() - ), - return_exceptions=True, - ) - # we should have exceptions 2/3 of the time - could_not_acquire_lock_count = sum( - isinstance(r, CouldNotAcquireLockError) for r in results - ) - total_results_count = len(results) - - # Check if 2/3 of the results are CouldNotAcquireLockError - # checks that scheduling is done exclusively - assert could_not_acquire_lock_count == (2 / 3) * total_results_count - - -@pytest.fixture -def minimal_dask_scheduler_config( - mock_env: EnvVarsDict, - postgres_host_config: dict[str, str], - monkeypatch: pytest.MonkeyPatch, - rabbit_service: RabbitSettings, - redis_service: RedisSettings, - faker: Faker, -) -> None: - """set a minimal configuration for testing the dask connection only""" - monkeypatch.setenv("DIRECTOR_V2_DYNAMIC_SIDECAR_ENABLED", "false") - monkeypatch.setenv("DIRECTOR_V0_ENABLED", "0") - monkeypatch.setenv("COMPUTATIONAL_BACKEND_DASK_CLIENT_ENABLED", "1") - monkeypatch.setenv("COMPUTATIONAL_BACKEND_ENABLED", "1") - monkeypatch.setenv("R_CLONE_PROVIDER", "MINIO") - monkeypatch.setenv("S3_ENDPOINT", faker.url()) - monkeypatch.setenv("S3_ACCESS_KEY", faker.pystr()) - monkeypatch.setenv("S3_REGION", faker.pystr()) - monkeypatch.setenv("S3_SECRET_KEY", faker.pystr()) - monkeypatch.setenv("S3_BUCKET_NAME", faker.pystr()) - - -@pytest.fixture -def scheduler( - minimal_dask_scheduler_config: None, - aiopg_engine: aiopg.sa.engine.Engine, - minimal_app: FastAPI, -) -> BaseCompScheduler: - scheduler = get_scheduler(minimal_app) - assert scheduler is not None - return scheduler - - @pytest.fixture -def mocked_dask_client(mocker: MockerFixture) -> mock.MagicMock: +def mocked_dask_client(mocker: MockerFixture) -> mock.Mock: mocked_dask_client = mocker.patch( "simcore_service_director_v2.modules.dask_clients_pool.DaskClient", autospec=True, @@ -237,246 +127,80 @@ def mocked_dask_client(mocker: MockerFixture) -> mock.MagicMock: @pytest.fixture def mocked_parse_output_data_fct(mocker: MockerFixture) -> mock.Mock: return mocker.patch( - "simcore_service_director_v2.modules.comp_scheduler._dask_scheduler.parse_output_data", + "simcore_service_director_v2.modules.comp_scheduler._scheduler_dask.parse_output_data", autospec=True, ) @pytest.fixture -def mocked_clean_task_output_fct(mocker: MockerFixture) -> mock.MagicMock: +def mocked_clean_task_output_fct(mocker: MockerFixture) -> mock.Mock: return mocker.patch( - "simcore_service_director_v2.modules.comp_scheduler._dask_scheduler.clean_task_output_and_log_files_if_invalid", + "simcore_service_director_v2.modules.comp_scheduler._scheduler_dask.clean_task_output_and_log_files_if_invalid", return_value=None, autospec=True, ) @pytest.fixture -def with_disabled_auto_scheduling(mocker: MockerFixture) -> mock.MagicMock: - """disables the scheduler task, note that it needs to be triggered manu>ally then""" - - def _fake_starter( - self: BaseCompScheduler, - *args, - **kwargs, - ): - scheduler_task = mocker.MagicMock() - scheduler_task_wake_up_event = mocker.MagicMock() - return scheduler_task, scheduler_task_wake_up_event - +def mocked_clean_task_output_and_log_files_if_invalid( + mocker: MockerFixture, +) -> mock.Mock: return mocker.patch( - "simcore_service_director_v2.modules.comp_scheduler._base_scheduler.BaseCompScheduler._start_scheduling", + "simcore_service_director_v2.modules.comp_scheduler._scheduler_dask.clean_task_output_and_log_files_if_invalid", autospec=True, - side_effect=_fake_starter, ) @pytest.fixture -async def minimal_app(async_client: httpx.AsyncClient) -> FastAPI: - # must use the minimal app from from the `async_client`` - # the`client` uses starlette's TestClient which spawns - # a new thread on which it creates a new loop - # causing issues downstream with coroutines not - # being created on the same loop - return async_client._transport.app # type: ignore # noqa: SLF001 - - -@pytest.fixture -def mocked_clean_task_output_and_log_files_if_invalid(mocker: MockerFixture) -> None: - mocker.patch( - "simcore_service_director_v2.modules.comp_scheduler._dask_scheduler.clean_task_output_and_log_files_if_invalid", - autospec=True, - ) - - -async def test_scheduler_gracefully_starts_and_stops( - minimal_dask_scheduler_config: None, - aiopg_engine: aiopg.sa.engine.Engine, - dask_spec_local_cluster: SpecCluster, - minimal_app: FastAPI, -): - # check it started correctly - assert get_scheduler(minimal_app) is not None - - -@pytest.mark.parametrize( - "missing_dependency", - [ - "COMPUTATIONAL_BACKEND_DASK_CLIENT_ENABLED", - ], -) -def test_scheduler_raises_exception_for_missing_dependencies( - minimal_dask_scheduler_config: None, - aiopg_engine: aiopg.sa.engine.Engine, - dask_spec_local_cluster: SpecCluster, - monkeypatch: pytest.MonkeyPatch, - missing_dependency: str, -): - # disable the dependency - monkeypatch.setenv(missing_dependency, "0") - # create the client - settings = AppSettings.create_from_envs() - app = init_app(settings) - - with pytest.raises(ConfigurationError), TestClient( - app, raise_server_exceptions=True - ) as _: - pass - - -async def test_empty_pipeline_is_not_scheduled( - with_disabled_auto_scheduling: None, - scheduler: BaseCompScheduler, - registered_user: Callable[..., dict[str, Any]], - project: Callable[..., Awaitable[ProjectAtDB]], - pipeline: Callable[..., CompPipelineAtDB], - aiopg_engine: aiopg.sa.engine.Engine, - run_metadata: RunMetadataDict, -): - user = registered_user() - empty_project = await project(user) - - # the project is not in the comp_pipeline, therefore scheduling it should fail - with pytest.raises(PipelineNotFoundError): - await scheduler.run_new_pipeline( - user_id=user["id"], - project_id=empty_project.uuid, - cluster_id=DEFAULT_CLUSTER_ID, - run_metadata=run_metadata, - use_on_demand_clusters=False, - ) - # create the empty pipeline now - pipeline(project_id=f"{empty_project.uuid}") - - # creating a run with an empty pipeline is useless, check the scheduler is not kicking in - await scheduler.run_new_pipeline( - user_id=user["id"], - project_id=empty_project.uuid, - cluster_id=DEFAULT_CLUSTER_ID, - run_metadata=run_metadata, - use_on_demand_clusters=False, - ) - assert len(scheduler._scheduled_pipelines) == 0 # noqa: SLF001 - # check the database is empty - async with aiopg_engine.acquire() as conn: - result = await conn.scalar( - comp_runs.select().where( - (comp_runs.c.user_id == user["id"]) - & (comp_runs.c.project_uuid == f"{empty_project.uuid}") - ) # there is only one entry - ) - assert result is None - - -async def test_misconfigured_pipeline_is_not_scheduled( - with_disabled_auto_scheduling: None, - scheduler: BaseCompScheduler, - registered_user: Callable[..., dict[str, Any]], - project: Callable[..., Awaitable[ProjectAtDB]], - pipeline: Callable[..., CompPipelineAtDB], - fake_workbench_without_outputs: dict[str, Any], - fake_workbench_adjacency: dict[str, Any], - aiopg_engine: aiopg.sa.engine.Engine, - run_metadata: RunMetadataDict, -): - """A pipeline which comp_tasks are missing should not be scheduled. - It shall be aborted and shown as such in the comp_runs db""" - user = registered_user() - sleepers_project = await project(user, workbench=fake_workbench_without_outputs) - pipeline( - project_id=f"{sleepers_project.uuid}", - dag_adjacency_list=fake_workbench_adjacency, - ) - # check the pipeline is correctly added to the scheduled pipelines - await scheduler.run_new_pipeline( - user_id=user["id"], - project_id=sleepers_project.uuid, - cluster_id=DEFAULT_CLUSTER_ID, - run_metadata=run_metadata, - use_on_demand_clusters=False, - ) - assert len(scheduler._scheduled_pipelines) == 1 # noqa: SLF001 - for ( - u_id, - p_id, - it, - ) in scheduler._scheduled_pipelines: # noqa: SLF001 - assert u_id == user["id"] - assert p_id == sleepers_project.uuid - assert it > 0 - # check the database was properly updated - async with aiopg_engine.acquire() as conn: - result = await conn.execute( - comp_runs.select().where( - (comp_runs.c.user_id == user["id"]) - & (comp_runs.c.project_uuid == f"{sleepers_project.uuid}") - ) # there is only one entry - ) - run_entry = CompRunsAtDB.model_validate(await result.first()) - assert run_entry.result == RunningState.PUBLISHED - # let the scheduler kick in - await schedule_all_pipelines(scheduler) - # check the scheduled pipelines is again empty since it's misconfigured - assert len(scheduler._scheduled_pipelines) == 0 # noqa: SLF001 - # check the database entry is correctly updated - async with aiopg_engine.acquire() as conn: - result = await conn.execute( - comp_runs.select().where( - (comp_runs.c.user_id == user["id"]) - & (comp_runs.c.project_uuid == f"{sleepers_project.uuid}") - ) # there is only one entry - ) - run_entry = CompRunsAtDB.model_validate(await result.first()) - assert run_entry.result == RunningState.ABORTED - assert run_entry.metadata == run_metadata +def scheduler_api(initialized_app: FastAPI) -> BaseCompScheduler: + return _get_scheduler_worker(initialized_app) async def _assert_start_pipeline( - aiopg_engine, + app: FastAPI, + *, + sqlalchemy_async_engine: AsyncEngine, published_project: PublishedProject, - scheduler: BaseCompScheduler, run_metadata: RunMetadataDict, -) -> list[CompTaskAtDB]: +) -> tuple[CompRunsAtDB, list[CompTaskAtDB]]: exp_published_tasks = deepcopy(published_project.tasks) assert published_project.project.prj_owner - await scheduler.run_new_pipeline( + await run_new_pipeline( + app, user_id=published_project.project.prj_owner, project_id=published_project.project.uuid, cluster_id=DEFAULT_CLUSTER_ID, run_metadata=run_metadata, use_on_demand_clusters=False, ) - assert ( - len(scheduler._scheduled_pipelines) == 1 # noqa: SLF001 - ), "the pipeline is not scheduled!" - for ( - u_id, - p_id, - it, - ) in scheduler._scheduled_pipelines: # noqa: SLF001 - assert u_id == published_project.project.prj_owner - assert p_id == published_project.project.uuid - assert it > 0 # check the database is correctly updated, the run is published - await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PUBLISHED) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [p.node_id for p in exp_published_tasks], + runs = await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.PUBLISHED, + where_statement=and_( + comp_runs.c.user_id == published_project.project.prj_owner, + comp_runs.c.project_uuid == f"{published_project.project.uuid}", + ), + ) + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[p.node_id for p in exp_published_tasks], expected_state=RunningState.PUBLISHED, expected_progress=None, ) - return exp_published_tasks + return runs[0], exp_published_tasks -async def _assert_schedule_pipeline_PENDING( # noqa: N802 - aiopg_engine, +async def _assert_publish_in_dask_backend( + sqlalchemy_async_engine: AsyncEngine, published_project: PublishedProject, published_tasks: list[CompTaskAtDB], mocked_dask_client: mock.MagicMock, scheduler: BaseCompScheduler, -) -> list[CompTaskAtDB]: +) -> tuple[list[CompTaskAtDB], dict[NodeID, Callable[[], None]]]: expected_pending_tasks = [ published_tasks[1], published_tasks[3], @@ -488,26 +212,42 @@ async def _return_tasks_pending(job_ids: list[str]) -> list[DaskClientTaskState] return [DaskClientTaskState.PENDING for job_id in job_ids] mocked_dask_client.get_tasks_status.side_effect = _return_tasks_pending - await schedule_all_pipelines(scheduler) + assert published_project.project.prj_owner + await scheduler.apply( + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + iteration=1, + ) _assert_dask_client_correctly_initialized(mocked_dask_client, scheduler) - await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PUBLISHED) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [p.node_id for p in expected_pending_tasks], + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.PUBLISHED, + where_statement=and_( + comp_runs.c.user_id == published_project.project.prj_owner, + comp_runs.c.project_uuid == f"{published_project.project.uuid}", + ), + ) + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[p.node_id for p in expected_pending_tasks], expected_state=RunningState.PENDING, expected_progress=None, ) # the other tasks are still waiting in published state - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [p.node_id for p in published_tasks], + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[p.node_id for p in published_tasks], expected_state=RunningState.PUBLISHED, expected_progress=None, # since we bypass the API entrypoint this is correct ) # tasks were send to the backend assert published_project.project.prj_owner is not None + assert isinstance(mocked_dask_client.send_computation_tasks, mock.Mock) + assert isinstance(mocked_dask_client.get_tasks_status, mock.Mock) + assert isinstance(mocked_dask_client.get_task_result, mock.Mock) mocked_dask_client.send_computation_tasks.assert_has_calls( calls=[ mock.call( @@ -523,23 +263,39 @@ async def _return_tasks_pending(job_ids: list[str]) -> list[DaskClientTaskState] ], any_order=True, ) + task_to_callback_mapping = { + task.node_id: mocked_dask_client.send_computation_tasks.call_args_list[ + i + ].kwargs["callback"] + for i, task in enumerate(expected_pending_tasks) + } mocked_dask_client.send_computation_tasks.reset_mock() mocked_dask_client.get_tasks_status.assert_not_called() mocked_dask_client.get_task_result.assert_not_called() # there is a second run of the scheduler to move comp_runs to pending, the rest does not change - await schedule_all_pipelines(scheduler) - await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PENDING) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [p.node_id for p in expected_pending_tasks], + await scheduler.apply( + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + iteration=1, + ) + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.PENDING, + where_statement=(comp_runs.c.user_id == published_project.project.prj_owner) + & (comp_runs.c.project_uuid == f"{published_project.project.uuid}"), + ) + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[p.node_id for p in expected_pending_tasks], expected_state=RunningState.PENDING, expected_progress=None, ) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [p.node_id for p in published_tasks], + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[p.node_id for p in published_tasks], expected_state=RunningState.PUBLISHED, expected_progress=None, ) @@ -549,7 +305,7 @@ async def _return_tasks_pending(job_ids: list[str]) -> list[DaskClientTaskState] ) mocked_dask_client.get_tasks_status.reset_mock() mocked_dask_client.get_task_result.assert_not_called() - return expected_pending_tasks + return expected_pending_tasks, task_to_callback_mapping @pytest.fixture @@ -606,9 +362,9 @@ async def _assert_message_received( return parsed_messages -def _mock_send_computation_tasks( +def _with_mock_send_computation_tasks( tasks: list[CompTaskAtDB], mocked_dask_client: mock.MagicMock -) -> None: +) -> mock.Mock: node_id_to_job_id_map = {task.node_id: task.job_id for task in tasks} async def _send_computation_tasks( @@ -625,6 +381,7 @@ async def _send_computation_tasks( ] # type: ignore mocked_dask_client.send_computation_tasks.side_effect = _send_computation_tasks + return mocked_dask_client.send_computation_tasks async def _trigger_progress_event( @@ -653,35 +410,45 @@ async def _trigger_progress_event( @pytest.mark.acceptance_test() async def test_proper_pipeline_is_scheduled( # noqa: PLR0915 - with_disabled_auto_scheduling: None, + with_disabled_auto_scheduling: mock.Mock, + with_disabled_scheduler_publisher: mock.Mock, + initialized_app: FastAPI, mocked_dask_client: mock.MagicMock, - scheduler: BaseCompScheduler, - aiopg_engine: aiopg.sa.engine.Engine, + scheduler_api: BaseCompScheduler, + sqlalchemy_async_engine: AsyncEngine, published_project: PublishedProject, mocked_parse_output_data_fct: mock.Mock, - mocked_clean_task_output_and_log_files_if_invalid: None, + mocked_clean_task_output_and_log_files_if_invalid: mock.Mock, instrumentation_rabbit_client_parser: mock.AsyncMock, resource_tracking_rabbit_client_parser: mock.AsyncMock, run_metadata: RunMetadataDict, ): - _mock_send_computation_tasks(published_project.tasks, mocked_dask_client) - - expected_published_tasks = await _assert_start_pipeline( - aiopg_engine, published_project, scheduler, run_metadata + with_disabled_auto_scheduling.assert_called_once() + _with_mock_send_computation_tasks(published_project.tasks, mocked_dask_client) + + # + # Initiate new pipeline run + # + run_in_db, expected_published_tasks = await _assert_start_pipeline( + initialized_app, + sqlalchemy_async_engine=sqlalchemy_async_engine, + published_project=published_project, + run_metadata=run_metadata, ) + with_disabled_scheduler_publisher.assert_called() # ------------------------------------------------------------------------------- - # 1. first run will move comp_tasks to PENDING so the worker can take them - expected_pending_tasks = await _assert_schedule_pipeline_PENDING( - aiopg_engine, + # 1. first run will move comp_tasks to PENDING so the dask-worker can take them + expected_pending_tasks, _ = await _assert_publish_in_dask_backend( + sqlalchemy_async_engine, published_project, expected_published_tasks, mocked_dask_client, - scheduler, + scheduler_api, ) # ------------------------------------------------------------------------------- - # 2.1. the worker might be taking the task, until we get a progress we do not know + # 2.1. the dask-worker might be taking the task, until we get a progress we do not know # whether it effectively started or it is still queued in the worker process exp_started_task = expected_pending_tasks[0] expected_pending_tasks.remove(exp_started_task) @@ -697,28 +464,32 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta ] mocked_dask_client.get_tasks_status.side_effect = _return_1st_task_running - - await schedule_all_pipelines(scheduler) - - await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PENDING) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [exp_started_task.node_id], + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, expected_state=RunningState.PENDING, - expected_progress=None, + where_statement=and_( + comp_runs.c.user_id == published_project.project.prj_owner, + comp_runs.c.project_uuid == f"{published_project.project.uuid}", + ), ) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [p.node_id for p in expected_pending_tasks], + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[exp_started_task.node_id] + + [p.node_id for p in expected_pending_tasks], expected_state=RunningState.PENDING, expected_progress=None, ) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [p.node_id for p in expected_published_tasks], + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[p.node_id for p in expected_published_tasks], expected_state=RunningState.PUBLISHED, expected_progress=None, # since we bypass the API entrypoint this is correct ) @@ -730,41 +501,52 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta mocked_dask_client.get_task_result.assert_not_called() # ------------------------------------------------------------------------------- - # 3. the "worker" starts processing a task - # here we trigger a progress from the worker + # 3. the dask-worker starts processing a task here we simulate a progress event assert exp_started_task.job_id assert exp_started_task.project_id assert exp_started_task.node_id assert published_project.project.prj_owner await _trigger_progress_event( - scheduler, + scheduler_api, job_id=exp_started_task.job_id, user_id=published_project.project.prj_owner, project_id=exp_started_task.project_id, node_id=exp_started_task.node_id, ) - await schedule_all_pipelines(scheduler) + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) # comp_run, the comp_task switch to STARTED - await _assert_comp_run_db(aiopg_engine, published_project, RunningState.STARTED) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [exp_started_task.node_id], + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.STARTED, + where_statement=and_( + comp_runs.c.user_id == published_project.project.prj_owner, + comp_runs.c.project_uuid == f"{published_project.project.uuid}", + ), + ) + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[exp_started_task.node_id], expected_state=RunningState.STARTED, expected_progress=0, ) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [p.node_id for p in expected_pending_tasks], + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[p.node_id for p in expected_pending_tasks], expected_state=RunningState.PENDING, expected_progress=None, ) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [p.node_id for p in expected_published_tasks], + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[p.node_id for p in expected_published_tasks], expected_state=RunningState.PUBLISHED, expected_progress=None, ) @@ -774,6 +556,7 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta ) mocked_dask_client.get_tasks_status.reset_mock() mocked_dask_client.get_task_result.assert_not_called() + # check the metrics are properly published messages = await _assert_message_received( instrumentation_rabbit_client_parser, 1, @@ -782,9 +565,7 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta assert messages[0].metrics == "service_started" assert messages[0].service_uuid == exp_started_task.node_id - def _parser(x) -> RabbitResourceTrackingMessages: - return TypeAdapter(RabbitResourceTrackingMessages).validate_json(x) - + # check the RUT messages are properly published messages = await _assert_message_received( resource_tracking_rabbit_client_parser, 1, @@ -793,7 +574,7 @@ def _parser(x) -> RabbitResourceTrackingMessages: assert messages[0].node_id == exp_started_task.node_id # ------------------------------------------------------------------------------- - # 4. the "worker" completed the task successfully + # 4. the dask-worker completed the task successfully async def _return_1st_task_success(job_ids: list[str]) -> list[DaskClientTaskState]: return [ ( @@ -810,15 +591,28 @@ async def _return_random_task_result(job_id) -> TaskOutputData: return TaskOutputData.model_validate({"out_1": None, "out_2": 45}) mocked_dask_client.get_task_result.side_effect = _return_random_task_result - await schedule_all_pipelines(scheduler) - await _assert_comp_run_db(aiopg_engine, published_project, RunningState.STARTED) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [exp_started_task.node_id], + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.STARTED, + where_statement=and_( + comp_runs.c.user_id == published_project.project.prj_owner, + comp_runs.c.project_uuid == f"{published_project.project.uuid}", + ), + ) + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[exp_started_task.node_id], expected_state=RunningState.SUCCESS, expected_progress=1, ) + # check metrics are published messages = await _assert_message_received( instrumentation_rabbit_client_parser, 1, @@ -826,6 +620,7 @@ async def _return_random_task_result(job_id) -> TaskOutputData: ) assert messages[0].metrics == "service_stopped" assert messages[0].service_uuid == exp_started_task.node_id + # check RUT messages are published messages = await _assert_message_received( resource_tracking_rabbit_client_parser, 1, @@ -835,17 +630,17 @@ async def _return_random_task_result(job_id) -> TaskOutputData: completed_tasks = [exp_started_task] next_pending_task = published_project.tasks[2] expected_pending_tasks.append(next_pending_task) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [p.node_id for p in expected_pending_tasks], + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[p.node_id for p in expected_pending_tasks], expected_state=RunningState.PENDING, expected_progress=None, ) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [ + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[ p.node_id for p in published_project.tasks if p not in expected_pending_tasks + completed_tasks @@ -884,7 +679,7 @@ async def _return_random_task_result(job_id) -> TaskOutputData: mocked_parse_output_data_fct.reset_mock() # ------------------------------------------------------------------------------- - # 6. the "worker" starts processing a task + # 6. the dask-worker starts processing a task exp_started_task = next_pending_task async def _return_2nd_task_running(job_ids: list[str]) -> list[DaskClientTaskState]: @@ -901,18 +696,30 @@ async def _return_2nd_task_running(job_ids: list[str]) -> list[DaskClientTaskSta # trigger the scheduler, run state should keep to STARTED, task should be as well assert exp_started_task.job_id await _trigger_progress_event( - scheduler, + scheduler_api, job_id=exp_started_task.job_id, user_id=published_project.project.prj_owner, project_id=exp_started_task.project_id, node_id=exp_started_task.node_id, ) - await schedule_all_pipelines(scheduler) - await _assert_comp_run_db(aiopg_engine, published_project, RunningState.STARTED) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [exp_started_task.node_id], + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.STARTED, + where_statement=and_( + comp_runs.c.user_id == published_project.project.prj_owner, + comp_runs.c.project_uuid == f"{published_project.project.uuid}", + ), + ) + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[exp_started_task.node_id], expected_state=RunningState.STARTED, expected_progress=0, ) @@ -951,12 +758,27 @@ async def _return_2nd_task_failed(job_ids: list[str]) -> list[DaskClientTaskStat mocked_dask_client.get_tasks_status.side_effect = _return_2nd_task_failed mocked_dask_client.get_task_result.side_effect = None - await schedule_all_pipelines(scheduler) - await _assert_comp_run_db(aiopg_engine, published_project, RunningState.STARTED) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [exp_started_task.node_id], + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) + mocked_clean_task_output_and_log_files_if_invalid.assert_called_once() + mocked_clean_task_output_and_log_files_if_invalid.reset_mock() + + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.STARTED, + where_statement=and_( + comp_runs.c.user_id == published_project.project.prj_owner, + comp_runs.c.project_uuid == f"{published_project.project.uuid}", + ), + ) + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[exp_started_task.node_id], expected_state=RunningState.FAILED, expected_progress=1, ) @@ -1000,13 +822,26 @@ async def _return_3rd_task_success(job_ids: list[str]) -> list[DaskClientTaskSta mocked_dask_client.get_task_result.side_effect = _return_random_task_result # trigger the scheduler, it should switch to FAILED, as we are done - await schedule_all_pipelines(scheduler) - await _assert_comp_run_db(aiopg_engine, published_project, RunningState.FAILED) + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) + mocked_clean_task_output_and_log_files_if_invalid.assert_not_called() + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.FAILED, + where_statement=and_( + comp_runs.c.user_id == published_project.project.prj_owner, + comp_runs.c.project_uuid == f"{published_project.project.uuid}", + ), + ) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [exp_started_task.node_id], + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[exp_started_task.node_id], expected_state=RunningState.SUCCESS, expected_progress=1, ) @@ -1020,7 +855,11 @@ async def _return_3rd_task_success(job_ids: list[str]) -> list[DaskClientTaskSta 2, InstrumentationRabbitMessage.model_validate_json, ) + # NOTE: the service was fast and went directly to success + def _parser(x) -> RabbitResourceTrackingMessages: + return TypeAdapter(RabbitResourceTrackingMessages).validate_json(x) + assert messages[0].metrics == "service_started" assert messages[0].service_uuid == exp_started_task.node_id assert messages[1].metrics == "service_stopped" @@ -1033,32 +872,309 @@ async def _return_3rd_task_success(job_ids: list[str]) -> list[DaskClientTaskSta assert isinstance(messages[0], RabbitResourceTrackingStartedMessage) assert isinstance(messages[1], RabbitResourceTrackingStoppedMessage) - # the scheduled pipeline shall be removed - assert scheduler._scheduled_pipelines == {} # noqa: SLF001 + +@pytest.fixture +async def with_started_project( + with_disabled_auto_scheduling: mock.Mock, + with_disabled_scheduler_publisher: mock.Mock, + initialized_app: FastAPI, + sqlalchemy_async_engine: AsyncEngine, + publish_project: Callable[[], Awaitable[PublishedProject]], + mocked_dask_client: mock.Mock, + run_metadata: RunMetadataDict, + scheduler_api: BaseCompScheduler, + instrumentation_rabbit_client_parser: mock.AsyncMock, + resource_tracking_rabbit_client_parser: mock.AsyncMock, +) -> RunningProject: + with_disabled_auto_scheduling.assert_called_once() + published_project = await publish_project() + # + # 1. Initiate new pipeline run + # + run_in_db, expected_published_tasks = await _assert_start_pipeline( + initialized_app, + sqlalchemy_async_engine=sqlalchemy_async_engine, + published_project=published_project, + run_metadata=run_metadata, + ) + with_disabled_scheduler_publisher.assert_called_once() + + # + # 2. This runs the scheduler until the project is started scheduled in the back-end + # + ( + expected_pending_tasks, + task_to_callback_mapping, + ) = await _assert_publish_in_dask_backend( + sqlalchemy_async_engine, + published_project, + expected_published_tasks, + mocked_dask_client, + scheduler_api, + ) + + # + # The dask-worker can take a job when it is PENDING, but the dask scheduler makes + # no difference between PENDING and STARTED + # + exp_started_task = expected_pending_tasks[0] + expected_pending_tasks.remove(exp_started_task) + + async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskState]: + return [ + ( + DaskClientTaskState.PENDING_OR_STARTED + if job_id == exp_started_task.job_id + else DaskClientTaskState.PENDING + ) + for job_id in job_ids + ] + + assert isinstance(mocked_dask_client.get_tasks_status, mock.Mock) + assert isinstance(mocked_dask_client.send_computation_tasks, mock.Mock) + assert isinstance(mocked_dask_client.get_task_result, mock.Mock) + mocked_dask_client.get_tasks_status.side_effect = _return_1st_task_running + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.PENDING, + where_statement=and_( + comp_runs.c.user_id == published_project.project.prj_owner, + comp_runs.c.project_uuid == f"{published_project.project.uuid}", + ), + ) + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[exp_started_task.node_id] + + [p.node_id for p in expected_pending_tasks], + expected_state=RunningState.PENDING, + expected_progress=None, + ) + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[p.node_id for p in expected_published_tasks], + expected_state=RunningState.PUBLISHED, + expected_progress=None, # since we bypass the API entrypoint this is correct + ) + mocked_dask_client.send_computation_tasks.assert_not_called() + mocked_dask_client.get_tasks_status.assert_called_once_with( + [p.job_id for p in (exp_started_task, *expected_pending_tasks)], + ) + mocked_dask_client.get_tasks_status.reset_mock() + mocked_dask_client.get_task_result.assert_not_called() + + # ------------------------------------------------------------------------------- + # 4. the dask-worker starts processing a task here we simulate a progress event + assert exp_started_task.job_id + assert exp_started_task.project_id + assert exp_started_task.node_id + assert published_project.project.prj_owner + await _trigger_progress_event( + scheduler_api, + job_id=exp_started_task.job_id, + user_id=published_project.project.prj_owner, + project_id=exp_started_task.project_id, + node_id=exp_started_task.node_id, + ) + + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) + # comp_run, the comp_task switch to STARTED + run_in_db = ( + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.STARTED, + where_statement=and_( + comp_runs.c.user_id == published_project.project.prj_owner, + comp_runs.c.project_uuid == f"{published_project.project.uuid}", + ), + ) + )[0] + tasks_in_db = await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[exp_started_task.node_id], + expected_state=RunningState.STARTED, + expected_progress=0, + ) + tasks_in_db += await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[p.node_id for p in expected_pending_tasks], + expected_state=RunningState.PENDING, + expected_progress=None, + ) + tasks_in_db += await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[p.node_id for p in expected_published_tasks], + expected_state=RunningState.PUBLISHED, + expected_progress=None, + ) + mocked_dask_client.send_computation_tasks.assert_not_called() + mocked_dask_client.get_tasks_status.assert_called_once_with( + [p.job_id for p in (exp_started_task, *expected_pending_tasks)], + ) + mocked_dask_client.get_tasks_status.reset_mock() + mocked_dask_client.get_task_result.assert_not_called() + # check the metrics are properly published + messages = await _assert_message_received( + instrumentation_rabbit_client_parser, + 1, + InstrumentationRabbitMessage.model_validate_json, + ) + assert messages[0].metrics == "service_started" + assert messages[0].service_uuid == exp_started_task.node_id + + # check the RUT messages are properly published + messages = await _assert_message_received( + resource_tracking_rabbit_client_parser, + 1, + RabbitResourceTrackingStartedMessage.model_validate_json, + ) + assert messages[0].node_id == exp_started_task.node_id + + return RunningProject( + published_project.user, + published_project.project, + published_project.pipeline, + tasks_in_db, + runs=run_in_db, + task_to_callback_mapping=task_to_callback_mapping, + ) + + +@pytest.fixture +def mocked_worker_publisher(mocker: MockerFixture) -> mock.Mock: + return mocker.patch( + "simcore_service_director_v2.modules.comp_scheduler._scheduler_base.request_pipeline_scheduling", + autospec=True, + ) + + +async def test_completed_task_triggers_new_scheduling_task( + mocked_worker_publisher: mock.Mock, + with_started_project: RunningProject, +): + """When a pipeline job completes, the Dask backend provides a callback + that runs in a separate thread. We use that callback to ask the + director-v2 computational scheduler manager to ask for a new schedule + After fiddling in distributed source code, here is a similar way to trigger that callback + """ + completed_node_id = with_started_project.tasks[0].node_id + callback = with_started_project.task_to_callback_mapping[completed_node_id] + await asyncio.to_thread(callback) + + mocked_worker_publisher.assert_called_once_with( + mock.ANY, + mock.ANY, + user_id=with_started_project.runs.user_id, + project_id=with_started_project.runs.project_uuid, + iteration=with_started_project.runs.iteration, + ) + + +async def test_broken_pipeline_configuration_is_not_scheduled_and_aborted( + with_disabled_auto_scheduling: mock.Mock, + with_disabled_scheduler_publisher: mock.Mock, + initialized_app: FastAPI, + scheduler_api: BaseCompScheduler, + registered_user: Callable[..., dict[str, Any]], + project: Callable[..., Awaitable[ProjectAtDB]], + create_pipeline: Callable[..., Awaitable[CompPipelineAtDB]], + fake_workbench_without_outputs: dict[str, Any], + fake_workbench_adjacency: dict[str, Any], + sqlalchemy_async_engine: AsyncEngine, + run_metadata: RunMetadataDict, +): + """A pipeline which comp_tasks are missing should not be scheduled. + It shall be aborted and shown as such in the comp_runs db""" + user = registered_user() + sleepers_project = await project(user, workbench=fake_workbench_without_outputs) + await create_pipeline( + project_id=f"{sleepers_project.uuid}", + dag_adjacency_list=fake_workbench_adjacency, + ) + await assert_comp_runs_empty(sqlalchemy_async_engine) + + # + # Initiate new pipeline scheduling + # + await run_new_pipeline( + initialized_app, + user_id=user["id"], + project_id=sleepers_project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + run_metadata=run_metadata, + use_on_demand_clusters=False, + ) + with_disabled_scheduler_publisher.assert_called_once() + # we shall have a a new comp_runs row with the new pipeline job + run_entry = ( + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.PUBLISHED, + where_statement=(comp_runs.c.user_id == user["id"]) + & (comp_runs.c.project_uuid == f"{sleepers_project.uuid}"), + ) + )[0] + + # + # Trigger scheduling manually. since the pipeline is broken, it shall be aborted + # + await scheduler_api.apply( + user_id=run_entry.user_id, + project_id=run_entry.project_uuid, + iteration=run_entry.iteration, + ) + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.ABORTED, + where_statement=(comp_runs.c.user_id == user["id"]) + & (comp_runs.c.project_uuid == f"{sleepers_project.uuid}"), + ) async def test_task_progress_triggers( - with_disabled_auto_scheduling: None, + with_disabled_auto_scheduling: mock.Mock, + with_disabled_scheduler_publisher: mock.Mock, + initialized_app: FastAPI, mocked_dask_client: mock.MagicMock, - scheduler: BaseCompScheduler, - aiopg_engine: aiopg.sa.engine.Engine, + scheduler_api: BaseCompScheduler, + sqlalchemy_async_engine: AsyncEngine, published_project: PublishedProject, - mocked_parse_output_data_fct: None, - mocked_clean_task_output_and_log_files_if_invalid: None, + mocked_parse_output_data_fct: mock.Mock, + mocked_clean_task_output_and_log_files_if_invalid: mock.Mock, run_metadata: RunMetadataDict, ): - _mock_send_computation_tasks(published_project.tasks, mocked_dask_client) - expected_published_tasks = await _assert_start_pipeline( - aiopg_engine, published_project, scheduler, run_metadata + _with_mock_send_computation_tasks(published_project.tasks, mocked_dask_client) + _run_in_db, expected_published_tasks = await _assert_start_pipeline( + initialized_app, + sqlalchemy_async_engine=sqlalchemy_async_engine, + published_project=published_project, + run_metadata=run_metadata, ) # ------------------------------------------------------------------------------- - # 1. first run will move comp_tasks to PENDING so the worker can take them - expected_pending_tasks = await _assert_schedule_pipeline_PENDING( - aiopg_engine, + # 1. first run will move comp_tasks to PENDING so the dask-worker can take them + expected_pending_tasks, _ = await _assert_publish_in_dask_backend( + sqlalchemy_async_engine, published_project, expected_published_tasks, mocked_dask_client, - scheduler, + scheduler_api, ) # send some progress @@ -1078,13 +1194,13 @@ async def test_task_progress_triggers( ), ) await cast( # noqa: SLF001 - DaskScheduler, scheduler + DaskScheduler, scheduler_api )._task_progress_change_handler(progress_event.model_dump_json()) # NOTE: not sure whether it should switch to STARTED.. it would make sense - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [started_task.node_id], + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[started_task.node_id], expected_state=RunningState.STARTED, expected_progress=min(max(0, progress), 1), ) @@ -1100,11 +1216,13 @@ async def test_task_progress_triggers( ), ], ) -async def test_handling_of_disconnected_dask_scheduler( - with_disabled_auto_scheduling: None, +async def test_handling_of_disconnected_scheduler_dask( + with_disabled_auto_scheduling: mock.Mock, + with_disabled_scheduler_publisher: mock.Mock, + initialized_app: FastAPI, mocked_dask_client: mock.MagicMock, - scheduler: BaseCompScheduler, - aiopg_engine: aiopg.sa.engine.Engine, + scheduler_api: BaseCompScheduler, + sqlalchemy_async_engine: AsyncEngine, mocker: MockerFixture, published_project: PublishedProject, backend_error: ComputationalSchedulerError, @@ -1112,14 +1230,15 @@ async def test_handling_of_disconnected_dask_scheduler( ): # this will create a non connected backend issue that will trigger re-connection mocked_dask_client_send_task = mocker.patch( - "simcore_service_director_v2.modules.comp_scheduler._dask_scheduler.DaskClient.send_computation_tasks", + "simcore_service_director_v2.modules.comp_scheduler._scheduler_dask.DaskClient.send_computation_tasks", side_effect=backend_error, ) assert mocked_dask_client_send_task # running the pipeline will now raise and the tasks are set back to PUBLISHED assert published_project.project.prj_owner - await scheduler.run_new_pipeline( + await run_new_pipeline( + initialized_app, user_id=published_project.project.prj_owner, project_id=published_project.project.uuid, cluster_id=DEFAULT_CLUSTER_ID, @@ -1129,28 +1248,42 @@ async def test_handling_of_disconnected_dask_scheduler( # since there is no cluster, there is no dask-scheduler, # the tasks shall all still be in PUBLISHED state now - await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PUBLISHED) + runs_in_db = await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.PUBLISHED, + where_statement=and_( + comp_runs.c.user_id == published_project.project.prj_owner, + comp_runs.c.project_uuid == f"{published_project.project.uuid}", + ), + ) + run_in_db = runs_in_db[0] - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [t.node_id for t in published_project.tasks], + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[t.node_id for t in published_project.tasks], expected_state=RunningState.PUBLISHED, expected_progress=None, ) # on the next iteration of the pipeline it will try to re-connect # now try to abort the tasks since we are wondering what is happening, this should auto-trigger the scheduler - await scheduler.stop_pipeline( + await stop_pipeline( + initialized_app, user_id=published_project.project.prj_owner, project_id=published_project.project.uuid, ) # we ensure the scheduler was run - await schedule_all_pipelines(scheduler) + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) # after this step the tasks are marked as ABORTED - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [ + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[ t.node_id for t in published_project.tasks if t.node_class == NodeClass.COMPUTATIONAL @@ -1159,9 +1292,21 @@ async def test_handling_of_disconnected_dask_scheduler( expected_progress=1, ) # then we have another scheduler run - await schedule_all_pipelines(scheduler) + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) # now the run should be ABORTED - await _assert_comp_run_db(aiopg_engine, published_project, RunningState.ABORTED) + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.ABORTED, + where_statement=and_( + comp_runs.c.user_id == published_project.project.prj_owner, + comp_runs.c.project_uuid == f"{published_project.project.uuid}", + ), + ) @dataclass(frozen=True, kw_only=True) @@ -1175,9 +1320,6 @@ class RebootState: expected_run_state: RunningState -@pytest.mark.skip( - reason="awaiting refactor in https://github.com/ITISFoundation/osparc-simcore/pull/6736" -) @pytest.mark.parametrize( "reboot_state", [ @@ -1245,14 +1387,15 @@ class RebootState: ), ], ) -async def test_handling_scheduling_after_reboot( - with_disabled_auto_scheduling: None, +async def test_handling_scheduled_tasks_after_director_reboots( + with_disabled_auto_scheduling: mock.Mock, + with_disabled_scheduler_publisher: mock.Mock, mocked_dask_client: mock.MagicMock, - aiopg_engine: aiopg.sa.engine.Engine, + sqlalchemy_async_engine: AsyncEngine, running_project: RunningProject, - scheduler: BaseCompScheduler, - mocked_parse_output_data_fct: mock.MagicMock, - mocked_clean_task_output_fct: mock.MagicMock, + scheduler_api: BaseCompScheduler, + mocked_parse_output_data_fct: mock.Mock, + mocked_clean_task_output_fct: mock.Mock, reboot_state: RebootState, ): """After the dask client is rebooted, or that the director-v2 reboots the dv-2 internal scheduler @@ -1270,8 +1413,12 @@ async def mocked_get_task_result(_job_id: str) -> TaskOutputData: return reboot_state.task_result mocked_dask_client.get_task_result.side_effect = mocked_get_task_result - - await schedule_all_pipelines(scheduler) + assert running_project.project.prj_owner + await scheduler_api.apply( + user_id=running_project.project.prj_owner, + project_id=running_project.project.uuid, + iteration=1, + ) # the status will be called once for all RUNNING tasks mocked_dask_client.get_tasks_status.assert_called_once() if reboot_state.expected_run_state in COMPLETED_STATES: @@ -1303,10 +1450,10 @@ async def mocked_get_task_result(_job_id: str) -> TaskOutputData: else: mocked_clean_task_output_fct.assert_not_called() - await _assert_comp_tasks_db( - aiopg_engine, - running_project.project.uuid, - [ + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=running_project.project.uuid, + task_ids=[ running_project.tasks[1].node_id, running_project.tasks[2].node_id, running_project.tasks[3].node_id, @@ -1314,40 +1461,58 @@ async def mocked_get_task_result(_job_id: str) -> TaskOutputData: expected_state=reboot_state.expected_task_state_group1, expected_progress=reboot_state.expected_task_progress_group1, ) - await _assert_comp_tasks_db( - aiopg_engine, - running_project.project.uuid, - [running_project.tasks[4].node_id], + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=running_project.project.uuid, + task_ids=[running_project.tasks[4].node_id], expected_state=reboot_state.expected_task_state_group2, expected_progress=reboot_state.expected_task_progress_group2, ) assert running_project.project.prj_owner - await _assert_comp_run_db( - aiopg_engine, running_project, reboot_state.expected_run_state + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=reboot_state.expected_run_state, + where_statement=and_( + comp_runs.c.user_id == running_project.project.prj_owner, + comp_runs.c.project_uuid == f"{running_project.project.uuid}", + ), ) async def test_handling_cancellation_of_jobs_after_reboot( - with_disabled_auto_scheduling: None, + with_disabled_auto_scheduling: mock.Mock, + with_disabled_scheduler_publisher: mock.Mock, mocked_dask_client: mock.MagicMock, - aiopg_engine: aiopg.sa.engine.Engine, + sqlalchemy_async_engine: AsyncEngine, running_project_mark_for_cancellation: RunningProject, - scheduler: BaseCompScheduler, - mocked_parse_output_data_fct: mock.MagicMock, - mocked_clean_task_output_fct: mock.MagicMock, + scheduler_api: BaseCompScheduler, + mocked_parse_output_data_fct: mock.Mock, + mocked_clean_task_output_fct: mock.Mock, ): """A running pipeline was cancelled by a user and the DV-2 was restarted BEFORE It could actually cancel the task. On reboot the DV-2 shall recover and actually cancel the pipeline properly""" # check initial status - await _assert_comp_run_db( - aiopg_engine, running_project_mark_for_cancellation, RunningState.STARTED - ) - await _assert_comp_tasks_db( - aiopg_engine, - running_project_mark_for_cancellation.project.uuid, - [t.node_id for t in running_project_mark_for_cancellation.tasks], + run_in_db = ( + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.STARTED, + where_statement=and_( + comp_runs.c.user_id + == running_project_mark_for_cancellation.project.prj_owner, + comp_runs.c.project_uuid + == f"{running_project_mark_for_cancellation.project.uuid}", + ), + ) + )[0] + + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=running_project_mark_for_cancellation.project.uuid, + task_ids=[t.node_id for t in running_project_mark_for_cancellation.tasks], expected_state=RunningState.STARTED, expected_progress=0, ) @@ -1358,7 +1523,11 @@ async def mocked_get_tasks_status(job_ids: list[str]) -> list[DaskClientTaskStat mocked_dask_client.get_tasks_status.side_effect = mocked_get_tasks_status # Running the scheduler, should actually cancel the run now - await schedule_all_pipelines(scheduler) + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) mocked_dask_client.abort_computation_task.assert_called() assert mocked_dask_client.abort_computation_task.call_count == len( [ @@ -1368,10 +1537,10 @@ async def mocked_get_tasks_status(job_ids: list[str]) -> list[DaskClientTaskStat ] ) # in the DB they are still running, they will be stopped in the next iteration - await _assert_comp_tasks_db( - aiopg_engine, - running_project_mark_for_cancellation.project.uuid, - [ + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=running_project_mark_for_cancellation.project.uuid, + task_ids=[ t.node_id for t in running_project_mark_for_cancellation.tasks if t.node_class == NodeClass.COMPUTATIONAL @@ -1379,8 +1548,16 @@ async def mocked_get_tasks_status(job_ids: list[str]) -> list[DaskClientTaskStat expected_state=RunningState.STARTED, expected_progress=0, ) - await _assert_comp_run_db( - aiopg_engine, running_project_mark_for_cancellation, RunningState.STARTED + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.STARTED, + where_statement=and_( + comp_runs.c.user_id + == running_project_mark_for_cancellation.project.prj_owner, + comp_runs.c.project_uuid + == f"{running_project_mark_for_cancellation.project.uuid}", + ), ) # the backend shall now report the tasks as aborted @@ -1395,12 +1572,16 @@ async def _return_random_task_result(job_id) -> TaskOutputData: raise TaskCancelledError mocked_dask_client.get_task_result.side_effect = _return_random_task_result - await schedule_all_pipelines(scheduler) + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) # now should be stopped - await _assert_comp_tasks_db( - aiopg_engine, - running_project_mark_for_cancellation.project.uuid, - [ + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=running_project_mark_for_cancellation.project.uuid, + task_ids=[ t.node_id for t in running_project_mark_for_cancellation.tasks if t.node_class == NodeClass.COMPUTATIONAL @@ -1408,8 +1589,16 @@ async def _return_random_task_result(job_id) -> TaskOutputData: expected_state=RunningState.ABORTED, expected_progress=1, ) - await _assert_comp_run_db( - aiopg_engine, running_project_mark_for_cancellation, RunningState.ABORTED + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.ABORTED, + where_statement=and_( + comp_runs.c.user_id + == running_project_mark_for_cancellation.project.prj_owner, + comp_runs.c.project_uuid + == f"{running_project_mark_for_cancellation.project.uuid}", + ), ) mocked_clean_task_output_fct.assert_called() @@ -1422,27 +1611,32 @@ def with_fast_service_heartbeat_s(monkeypatch: pytest.MonkeyPatch) -> int: async def test_running_pipeline_triggers_heartbeat( - with_disabled_auto_scheduling: None, + with_disabled_auto_scheduling: mock.Mock, + with_disabled_scheduler_publisher: mock.Mock, with_fast_service_heartbeat_s: int, + initialized_app: FastAPI, mocked_dask_client: mock.MagicMock, - scheduler: BaseCompScheduler, - aiopg_engine: aiopg.sa.engine.Engine, + scheduler_api: BaseCompScheduler, + sqlalchemy_async_engine: AsyncEngine, published_project: PublishedProject, resource_tracking_rabbit_client_parser: mock.AsyncMock, run_metadata: RunMetadataDict, ): - _mock_send_computation_tasks(published_project.tasks, mocked_dask_client) - expected_published_tasks = await _assert_start_pipeline( - aiopg_engine, published_project, scheduler, run_metadata + _with_mock_send_computation_tasks(published_project.tasks, mocked_dask_client) + run_in_db, expected_published_tasks = await _assert_start_pipeline( + initialized_app, + sqlalchemy_async_engine=sqlalchemy_async_engine, + published_project=published_project, + run_metadata=run_metadata, ) # ------------------------------------------------------------------------------- - # 1. first run will move comp_tasks to PENDING so the worker can take them - expected_pending_tasks = await _assert_schedule_pipeline_PENDING( - aiopg_engine, + # 1. first run will move comp_tasks to PENDING so the dask-worker can take them + expected_pending_tasks, _ = await _assert_publish_in_dask_backend( + sqlalchemy_async_engine, published_project, expected_published_tasks, mocked_dask_client, - scheduler, + scheduler_api, ) # ------------------------------------------------------------------------------- # 2. the "worker" starts processing a task @@ -1463,13 +1657,17 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta assert exp_started_task.job_id assert published_project.project.prj_owner await _trigger_progress_event( - scheduler, + scheduler_api, job_id=exp_started_task.job_id, user_id=published_project.project.prj_owner, project_id=exp_started_task.project_id, node_id=exp_started_task.node_id, ) - await schedule_all_pipelines(scheduler) + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) messages = await _assert_message_received( resource_tracking_rabbit_client_parser, @@ -1481,8 +1679,16 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta # ------------------------------------------------------------------------------- # 3. wait a bit and run again we should get another heartbeat, but only one! await asyncio.sleep(with_fast_service_heartbeat_s + 1) - await schedule_all_pipelines(scheduler) - await schedule_all_pipelines(scheduler) + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) messages = await _assert_message_received( resource_tracking_rabbit_client_parser, 1, @@ -1493,8 +1699,16 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta # ------------------------------------------------------------------------------- # 4. wait a bit and run again we should get another heartbeat, but only one! await asyncio.sleep(with_fast_service_heartbeat_s + 1) - await schedule_all_pipelines(scheduler) - await schedule_all_pipelines(scheduler) + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) messages = await _assert_message_received( resource_tracking_rabbit_client_parser, 1, @@ -1506,15 +1720,17 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta @pytest.fixture async def mocked_get_or_create_cluster(mocker: MockerFixture) -> mock.Mock: return mocker.patch( - "simcore_service_director_v2.modules.comp_scheduler._dask_scheduler.get_or_create_on_demand_cluster", + "simcore_service_director_v2.modules.comp_scheduler._scheduler_dask.get_or_create_on_demand_cluster", autospec=True, ) async def test_pipeline_with_on_demand_cluster_with_not_ready_backend_waits( - with_disabled_auto_scheduling: None, - scheduler: BaseCompScheduler, - aiopg_engine: aiopg.sa.engine.Engine, + with_disabled_auto_scheduling: mock.Mock, + with_disabled_scheduler_publisher: mock.Mock, + initialized_app: FastAPI, + scheduler_api: BaseCompScheduler, + sqlalchemy_async_engine: AsyncEngine, published_project: PublishedProject, run_metadata: RunMetadataDict, mocked_get_or_create_cluster: mock.Mock, @@ -1527,7 +1743,8 @@ async def test_pipeline_with_on_demand_cluster_with_not_ready_backend_waits( ) # running the pipeline will trigger a call to the clusters-keeper assert published_project.project.prj_owner - await scheduler.run_new_pipeline( + await run_new_pipeline( + initialized_app, user_id=published_project.project.prj_owner, project_id=published_project.project.uuid, cluster_id=DEFAULT_CLUSTER_ID, @@ -1536,11 +1753,21 @@ async def test_pipeline_with_on_demand_cluster_with_not_ready_backend_waits( ) # we ask to use an on-demand cluster, therefore the tasks are published first - await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PUBLISHED) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [t.node_id for t in published_project.tasks], + run_in_db = ( + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.PUBLISHED, + where_statement=and_( + comp_runs.c.user_id == published_project.project.prj_owner, + comp_runs.c.project_uuid == f"{published_project.project.uuid}", + ), + ) + )[0] + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[t.node_id for t in published_project.tasks], expected_state=RunningState.PUBLISHED, expected_progress=None, ) @@ -1550,32 +1777,52 @@ async def test_pipeline_with_on_demand_cluster_with_not_ready_backend_waits( published_project.tasks[1], published_project.tasks[3], ] - await schedule_all_pipelines(scheduler) + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) mocked_get_or_create_cluster.assert_called() assert mocked_get_or_create_cluster.call_count == 1 mocked_get_or_create_cluster.reset_mock() - await _assert_comp_run_db( - aiopg_engine, published_project, RunningState.WAITING_FOR_CLUSTER + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.WAITING_FOR_CLUSTER, + where_statement=and_( + comp_runs.c.user_id == published_project.project.prj_owner, + comp_runs.c.project_uuid == f"{published_project.project.uuid}", + ), ) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [t.node_id for t in expected_waiting_tasks], + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[t.node_id for t in expected_waiting_tasks], expected_state=RunningState.WAITING_FOR_CLUSTER, expected_progress=None, ) # again will trigger the same response - await schedule_all_pipelines(scheduler) + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) mocked_get_or_create_cluster.assert_called() assert mocked_get_or_create_cluster.call_count == 1 mocked_get_or_create_cluster.reset_mock() - await _assert_comp_run_db( - aiopg_engine, published_project, RunningState.WAITING_FOR_CLUSTER + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.WAITING_FOR_CLUSTER, + where_statement=and_( + comp_runs.c.user_id == published_project.project.prj_owner, + comp_runs.c.project_uuid == f"{published_project.project.uuid}", + ), ) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [t.node_id for t in expected_waiting_tasks], + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[t.node_id for t in expected_waiting_tasks], expected_state=RunningState.WAITING_FOR_CLUSTER, expected_progress=None, ) @@ -1586,18 +1833,23 @@ async def test_pipeline_with_on_demand_cluster_with_not_ready_backend_waits( [ClustersKeeperNotAvailableError], ) async def test_pipeline_with_on_demand_cluster_with_no_clusters_keeper_fails( - with_disabled_auto_scheduling: None, - scheduler: BaseCompScheduler, - aiopg_engine: aiopg.sa.engine.Engine, + with_disabled_auto_scheduling: mock.Mock, + with_disabled_scheduler_publisher: mock.Mock, + initialized_app: FastAPI, + scheduler_api: BaseCompScheduler, + sqlalchemy_async_engine: AsyncEngine, published_project: PublishedProject, run_metadata: RunMetadataDict, mocked_get_or_create_cluster: mock.Mock, get_or_create_exception: Exception, ): + # needs to change: https://github.com/ITISFoundation/osparc-simcore/issues/6817 + mocked_get_or_create_cluster.side_effect = get_or_create_exception # running the pipeline will trigger a call to the clusters-keeper assert published_project.project.prj_owner - await scheduler.run_new_pipeline( + await run_new_pipeline( + initialized_app, user_id=published_project.project.prj_owner, project_id=published_project.project.uuid, cluster_id=DEFAULT_CLUSTER_ID, @@ -1606,11 +1858,21 @@ async def test_pipeline_with_on_demand_cluster_with_no_clusters_keeper_fails( ) # we ask to use an on-demand cluster, therefore the tasks are published first - await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PUBLISHED) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [t.node_id for t in published_project.tasks], + run_in_db = ( + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.PUBLISHED, + where_statement=and_( + comp_runs.c.user_id == published_project.project.prj_owner, + comp_runs.c.project_uuid == f"{published_project.project.uuid}", + ), + ) + )[0] + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[t.node_id for t in published_project.tasks], expected_state=RunningState.PUBLISHED, expected_progress=None, ) @@ -1619,26 +1881,50 @@ async def test_pipeline_with_on_demand_cluster_with_no_clusters_keeper_fails( published_project.tasks[1], published_project.tasks[3], ] - await schedule_all_pipelines(scheduler) + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) mocked_get_or_create_cluster.assert_called() assert mocked_get_or_create_cluster.call_count == 1 mocked_get_or_create_cluster.reset_mock() - await _assert_comp_run_db(aiopg_engine, published_project, RunningState.FAILED) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [t.node_id for t in expected_failed_tasks], + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.FAILED, + where_statement=and_( + comp_runs.c.user_id == published_project.project.prj_owner, + comp_runs.c.project_uuid == f"{published_project.project.uuid}", + ), + ) + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[t.node_id for t in expected_failed_tasks], expected_state=RunningState.FAILED, expected_progress=1.0, ) # again will not re-trigger the call to clusters-keeper - await schedule_all_pipelines(scheduler) + await scheduler_api.apply( + user_id=run_in_db.user_id, + project_id=run_in_db.project_uuid, + iteration=run_in_db.iteration, + ) mocked_get_or_create_cluster.assert_not_called() - await _assert_comp_run_db(aiopg_engine, published_project, RunningState.FAILED) - await _assert_comp_tasks_db( - aiopg_engine, - published_project.project.uuid, - [t.node_id for t in expected_failed_tasks], + await assert_comp_runs( + sqlalchemy_async_engine, + expected_total=1, + expected_state=RunningState.FAILED, + where_statement=and_( + comp_runs.c.user_id == published_project.project.prj_owner, + comp_runs.c.project_uuid == f"{published_project.project.uuid}", + ), + ) + await assert_comp_tasks( + sqlalchemy_async_engine, + project_uuid=published_project.project.uuid, + task_ids=[t.node_id for t in expected_failed_tasks], expected_state=RunningState.FAILED, expected_progress=1.0, ) diff --git a/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_worker.py b/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_worker.py new file mode 100644 index 00000000000..9eb301e0910 --- /dev/null +++ b/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_worker.py @@ -0,0 +1,135 @@ +# pylint:disable=unused-variable +# pylint:disable=unused-argument +# pylint:disable=redefined-outer-name +# pylint:disable=no-value-for-parameter +# pylint:disable=protected-access +# pylint:disable=too-many-arguments +# pylint:disable=no-name-in-module +# pylint: disable=too-many-statements + +import asyncio +from collections.abc import Awaitable, Callable +from unittest import mock + +import pytest +from _helpers import PublishedProject +from fastapi import FastAPI +from models_library.clusters import DEFAULT_CLUSTER_ID +from pytest_mock import MockerFixture +from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict +from pytest_simcore.helpers.typing_env import EnvVarsDict +from simcore_service_director_v2.models.comp_runs import RunMetadataDict +from simcore_service_director_v2.modules.comp_scheduler._manager import run_new_pipeline +from simcore_service_director_v2.modules.comp_scheduler._models import ( + SchedulePipelineRabbitMessage, +) +from simcore_service_director_v2.modules.comp_scheduler._worker import ( + _get_scheduler_worker, +) + +pytest_simcore_core_services_selection = ["postgres", "rabbit", "redis"] +pytest_simcore_ops_services_selection = ["adminer"] + + +async def test_worker_starts_and_stops(initialized_app: FastAPI): + assert _get_scheduler_worker(initialized_app) is not None + + +@pytest.fixture +def mock_schedule_pipeline(mocker: MockerFixture) -> mock.Mock: + mock_scheduler_worker = mock.Mock() + mock_scheduler_worker.schedule_pipeline = mocker.AsyncMock(return_value=True) + return mock_scheduler_worker + + +@pytest.fixture +def mocked_get_scheduler_worker( + mocker: MockerFixture, + mock_schedule_pipeline: mock.Mock, +) -> mock.Mock: + # Mock `_get_scheduler_worker` to return our mock scheduler + return mocker.patch( + "simcore_service_director_v2.modules.comp_scheduler._worker._get_scheduler_worker", + return_value=mock_schedule_pipeline, + ) + + +async def test_worker_properly_autocalls_scheduler_api( + with_disabled_auto_scheduling: mock.Mock, + initialized_app: FastAPI, + mocked_get_scheduler_worker: mock.Mock, + published_project: PublishedProject, + run_metadata: RunMetadataDict, +): + assert published_project.project.prj_owner + await run_new_pipeline( + initialized_app, + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + run_metadata=run_metadata, + use_on_demand_clusters=False, + ) + mocked_get_scheduler_worker.assert_called_once_with(initialized_app) + mocked_get_scheduler_worker.return_value.apply.assert_called_once_with( + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + iteration=1, + ) + + +@pytest.fixture +async def mocked_scheduler_api(mocker: MockerFixture) -> mock.Mock: + return mocker.patch( + "simcore_service_director_v2.modules.comp_scheduler._scheduler_base.BaseCompScheduler.apply" + ) + + +@pytest.fixture +def with_scheduling_concurrency( + mock_env: EnvVarsDict, monkeypatch: pytest.MonkeyPatch, scheduling_concurrency: int +) -> EnvVarsDict: + return mock_env | setenvs_from_dict( + monkeypatch, + {"COMPUTATIONAL_BACKEND_SCHEDULING_CONCURRENCY": f"{scheduling_concurrency}"}, + ) + + +@pytest.mark.parametrize("scheduling_concurrency", [1, 50, 100]) +@pytest.mark.parametrize( + "queue_name", [SchedulePipelineRabbitMessage.get_channel_name()] +) +async def test_worker_scheduling_parallelism( + scheduling_concurrency: int, + with_scheduling_concurrency: EnvVarsDict, + with_disabled_auto_scheduling: mock.Mock, + mocked_scheduler_api: mock.Mock, + initialized_app: FastAPI, + publish_project: Callable[[], Awaitable[PublishedProject]], + run_metadata: RunMetadataDict, + ensure_parametrized_queue_is_empty: None, +): + with_disabled_auto_scheduling.assert_called_once() + + async def _side_effect(*args, **kwargs): + await asyncio.sleep(10) + + mocked_scheduler_api.side_effect = _side_effect + + async def _project_pipeline_creation_workflow() -> None: + published_project = await publish_project() + assert published_project.project.prj_owner + await run_new_pipeline( + initialized_app, + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + run_metadata=run_metadata, + use_on_demand_clusters=False, + ) + + await asyncio.gather( + *(_project_pipeline_creation_workflow() for _ in range(scheduling_concurrency)) + ) + mocked_scheduler_api.assert_called() + assert mocked_scheduler_api.call_count == scheduling_concurrency diff --git a/services/director-v2/tests/unit/with_dbs/conftest.py b/services/director-v2/tests/unit/with_dbs/conftest.py index fdb3b7d5a64..ee8259f9f5b 100644 --- a/services/director-v2/tests/unit/with_dbs/conftest.py +++ b/services/director-v2/tests/unit/with_dbs/conftest.py @@ -6,7 +6,7 @@ import datetime -from collections.abc import Awaitable, Callable, Iterator +from collections.abc import AsyncIterator, Awaitable, Callable from typing import Any, cast from uuid import uuid4 @@ -36,23 +36,24 @@ from simcore_service_director_v2.utils.dask import generate_dask_job_id from simcore_service_director_v2.utils.db import to_clusters_db from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.ext.asyncio import AsyncEngine @pytest.fixture -def pipeline( - postgres_db: sa.engine.Engine, -) -> Iterator[Callable[..., CompPipelineAtDB]]: +async def create_pipeline( + sqlalchemy_async_engine: AsyncEngine, +) -> AsyncIterator[Callable[..., Awaitable[CompPipelineAtDB]]]: created_pipeline_ids: list[str] = [] - def creator(**pipeline_kwargs) -> CompPipelineAtDB: + async def _(**pipeline_kwargs) -> CompPipelineAtDB: pipeline_config = { "project_id": f"{uuid4()}", "dag_adjacency_list": {}, "state": StateType.NOT_STARTED, } pipeline_config.update(**pipeline_kwargs) - with postgres_db.begin() as conn: - result = conn.execute( + async with sqlalchemy_async_engine.begin() as conn: + result = await conn.execute( comp_pipeline.insert() .values(**pipeline_config) .returning(sa.literal_column("*")) @@ -63,11 +64,11 @@ def creator(**pipeline_kwargs) -> CompPipelineAtDB: created_pipeline_ids.append(f"{new_pipeline.project_id}") return new_pipeline - yield creator + yield _ # cleanup - with postgres_db.connect() as conn: - conn.execute( + async with sqlalchemy_async_engine.begin() as conn: + await conn.execute( comp_pipeline.delete().where( comp_pipeline.c.project_id.in_(created_pipeline_ids) ) @@ -75,12 +76,12 @@ def creator(**pipeline_kwargs) -> CompPipelineAtDB: @pytest.fixture -def tasks( - postgres_db: sa.engine.Engine, -) -> Iterator[Callable[..., list[CompTaskAtDB]]]: +async def create_tasks( + sqlalchemy_async_engine: AsyncEngine, +) -> AsyncIterator[Callable[..., Awaitable[list[CompTaskAtDB]]]]: created_task_ids: list[int] = [] - def creator( + async def _( user: dict[str, Any], project: ProjectAtDB, **overrides_kwargs ) -> list[CompTaskAtDB]: created_tasks: list[CompTaskAtDB] = [] @@ -122,7 +123,7 @@ def creator( ), "node_class": to_node_class(node_data.key), "internal_id": internal_id + 1, - "submit": datetime.datetime.now(tz=datetime.UTC), + "submit": datetime.datetime.now(datetime.UTC), "job_id": generate_dask_job_id( service_key=node_data.key, service_version=node_data.version, @@ -132,8 +133,8 @@ def creator( ), } task_config.update(**overrides_kwargs) - with postgres_db.connect() as conn: - result = conn.execute( + async with sqlalchemy_async_engine.begin() as conn: + result = await conn.execute( comp_tasks.insert() .values(**task_config) .returning(sa.literal_column("*")) @@ -143,11 +144,11 @@ def creator( created_task_ids.extend([t.task_id for t in created_tasks if t.task_id]) return created_tasks - yield creator + yield _ # cleanup - with postgres_db.connect() as conn: - conn.execute( + async with sqlalchemy_async_engine.begin() as conn: + await conn.execute( comp_tasks.delete().where(comp_tasks.c.task_id.in_(created_task_ids)) ) @@ -186,64 +187,74 @@ def run_metadata( @pytest.fixture -def runs( - postgres_db: sa.engine.Engine, run_metadata: RunMetadataDict -) -> Iterator[Callable[..., CompRunsAtDB]]: +async def create_comp_run( + sqlalchemy_async_engine: AsyncEngine, run_metadata: RunMetadataDict +) -> AsyncIterator[Callable[..., Awaitable[CompRunsAtDB]]]: created_run_ids: list[int] = [] - def creator( + async def _( user: dict[str, Any], project: ProjectAtDB, **run_kwargs ) -> CompRunsAtDB: run_config = { "project_uuid": f"{project.uuid}", - "user_id": f"{user['id']}", + "user_id": user["id"], "iteration": 1, "result": StateType.NOT_STARTED, - "metadata": run_metadata, + "metadata": jsonable_encoder(run_metadata), "use_on_demand_clusters": False, } run_config.update(**run_kwargs) - with postgres_db.connect() as conn: - result = conn.execute( + async with sqlalchemy_async_engine.begin() as conn: + result = await conn.execute( comp_runs.insert() - .values(**jsonable_encoder(run_config)) + .values(**run_config) .returning(sa.literal_column("*")) ) new_run = CompRunsAtDB.model_validate(result.first()) created_run_ids.append(new_run.run_id) return new_run - yield creator + yield _ # cleanup - with postgres_db.connect() as conn: - conn.execute(comp_runs.delete().where(comp_runs.c.run_id.in_(created_run_ids))) + async with sqlalchemy_async_engine.begin() as conn: + await conn.execute( + comp_runs.delete().where(comp_runs.c.run_id.in_(created_run_ids)) + ) @pytest.fixture -def cluster( - postgres_db: sa.engine.Engine, -) -> Iterator[Callable[..., Cluster]]: +async def create_cluster( + sqlalchemy_async_engine: AsyncEngine, +) -> AsyncIterator[Callable[..., Awaitable[Cluster]]]: created_cluster_ids: list[str] = [] - def creator(user: dict[str, Any], **cluster_kwargs) -> Cluster: + async def _(user: dict[str, Any], **cluster_kwargs) -> Cluster: + assert "json_schema_extra" in Cluster.model_config + assert isinstance(Cluster.model_config["json_schema_extra"], dict) + assert isinstance(Cluster.model_config["json_schema_extra"]["examples"], list) + assert isinstance( + Cluster.model_config["json_schema_extra"]["examples"][1], dict + ) cluster_config = Cluster.model_config["json_schema_extra"]["examples"][1] cluster_config["owner"] = user["primary_gid"] cluster_config.update(**cluster_kwargs) new_cluster = Cluster.model_validate(cluster_config) assert new_cluster - with postgres_db.connect() as conn: + async with sqlalchemy_async_engine.begin() as conn: # insert basic cluster - created_cluster = conn.execute( - sa.insert(clusters) - .values(to_clusters_db(new_cluster, only_update=False)) - .returning(sa.literal_column("*")) + created_cluster = ( + await conn.execute( + sa.insert(clusters) + .values(to_clusters_db(new_cluster, only_update=False)) + .returning(sa.literal_column("*")) + ) ).one() created_cluster_ids.append(created_cluster.id) if "access_rights" in cluster_kwargs: for gid, rights in cluster_kwargs["access_rights"].items(): - conn.execute( + await conn.execute( pg_insert(cluster_to_groups) .values( cluster_id=created_cluster.id, @@ -256,7 +267,7 @@ def creator(user: dict[str, Any], **cluster_kwargs) -> Cluster: ) ) access_rights_in_db = {} - for row in conn.execute( + for row in await conn.execute( sa.select( cluster_to_groups.c.gid, cluster_to_groups.c.read, @@ -284,44 +295,57 @@ def creator(user: dict[str, Any], **cluster_kwargs) -> Cluster: thumbnail=None, ) - yield creator + yield _ # cleanup - with postgres_db.connect() as conn: - conn.execute( - # pylint: disable=no-value-for-parameter + async with sqlalchemy_async_engine.begin() as conn: + await conn.execute( clusters.delete().where(clusters.c.id.in_(created_cluster_ids)) ) @pytest.fixture -async def published_project( +async def publish_project( registered_user: Callable[..., dict[str, Any]], project: Callable[..., Awaitable[ProjectAtDB]], - pipeline: Callable[..., CompPipelineAtDB], - tasks: Callable[..., list[CompTaskAtDB]], + create_pipeline: Callable[..., Awaitable[CompPipelineAtDB]], + create_tasks: Callable[..., Awaitable[list[CompTaskAtDB]]], fake_workbench_without_outputs: dict[str, Any], fake_workbench_adjacency: dict[str, Any], -) -> PublishedProject: +) -> Callable[[], Awaitable[PublishedProject]]: user = registered_user() - created_project = await project(user, workbench=fake_workbench_without_outputs) - return PublishedProject( - project=created_project, - pipeline=pipeline( - project_id=f"{created_project.uuid}", - dag_adjacency_list=fake_workbench_adjacency, - ), - tasks=tasks(user=user, project=created_project, state=StateType.PUBLISHED), - ) + + async def _() -> PublishedProject: + created_project = await project(user, workbench=fake_workbench_without_outputs) + return PublishedProject( + user=user, + project=created_project, + pipeline=await create_pipeline( + project_id=f"{created_project.uuid}", + dag_adjacency_list=fake_workbench_adjacency, + ), + tasks=await create_tasks( + user=user, project=created_project, state=StateType.PUBLISHED + ), + ) + + return _ + + +@pytest.fixture +async def published_project( + publish_project: Callable[[], Awaitable[PublishedProject]] +) -> PublishedProject: + return await publish_project() @pytest.fixture async def running_project( registered_user: Callable[..., dict[str, Any]], project: Callable[..., Awaitable[ProjectAtDB]], - pipeline: Callable[..., CompPipelineAtDB], - tasks: Callable[..., list[CompTaskAtDB]], - runs: Callable[..., CompRunsAtDB], + create_pipeline: Callable[..., Awaitable[CompPipelineAtDB]], + create_tasks: Callable[..., Awaitable[list[CompTaskAtDB]]], + create_comp_run: Callable[..., Awaitable[CompRunsAtDB]], fake_workbench_without_outputs: dict[str, Any], fake_workbench_adjacency: dict[str, Any], ) -> RunningProject: @@ -329,24 +353,26 @@ async def running_project( created_project = await project(user, workbench=fake_workbench_without_outputs) now_time = arrow.utcnow().datetime return RunningProject( + user=user, project=created_project, - pipeline=pipeline( + pipeline=await create_pipeline( project_id=f"{created_project.uuid}", dag_adjacency_list=fake_workbench_adjacency, ), - tasks=tasks( + tasks=await create_tasks( user=user, project=created_project, state=StateType.RUNNING, progress=0.0, start=now_time, ), - runs=runs( + runs=await create_comp_run( user=user, project=created_project, started=now_time, result=StateType.RUNNING, ), + task_to_callback_mapping={}, ) @@ -354,9 +380,9 @@ async def running_project( async def running_project_mark_for_cancellation( registered_user: Callable[..., dict[str, Any]], project: Callable[..., Awaitable[ProjectAtDB]], - pipeline: Callable[..., CompPipelineAtDB], - tasks: Callable[..., list[CompTaskAtDB]], - runs: Callable[..., CompRunsAtDB], + create_pipeline: Callable[..., Awaitable[CompPipelineAtDB]], + create_tasks: Callable[..., Awaitable[list[CompTaskAtDB]]], + create_comp_run: Callable[..., Awaitable[CompRunsAtDB]], fake_workbench_without_outputs: dict[str, Any], fake_workbench_adjacency: dict[str, Any], ) -> RunningProject: @@ -364,25 +390,27 @@ async def running_project_mark_for_cancellation( created_project = await project(user, workbench=fake_workbench_without_outputs) now_time = arrow.utcnow().datetime return RunningProject( + user=user, project=created_project, - pipeline=pipeline( + pipeline=await create_pipeline( project_id=f"{created_project.uuid}", dag_adjacency_list=fake_workbench_adjacency, ), - tasks=tasks( + tasks=await create_tasks( user=user, project=created_project, state=StateType.RUNNING, progress=0.0, start=now_time, ), - runs=runs( + runs=await create_comp_run( user=user, project=created_project, result=StateType.RUNNING, started=now_time, cancelled=now_time + datetime.timedelta(seconds=5), ), + task_to_callback_mapping={}, ) diff --git a/services/director-v2/tests/unit/with_dbs/test_api_route_clusters.py b/services/director-v2/tests/unit/with_dbs/test_api_route_clusters.py index 19ab0ea2df3..9f55e71f935 100644 --- a/services/director-v2/tests/unit/with_dbs/test_api_route_clusters.py +++ b/services/director-v2/tests/unit/with_dbs/test_api_route_clusters.py @@ -4,7 +4,7 @@ import random from collections.abc import Callable, Iterator -from typing import Any +from typing import Any, Awaitable import httpx import pytest @@ -85,7 +85,7 @@ def clusters_cleaner(postgres_db: sa.engine.Engine) -> Iterator: async def test_list_clusters( clusters_config: None, registered_user: Callable[..., dict], - cluster: Callable[..., Cluster], + create_cluster: Callable[..., Awaitable[Cluster]], async_client: httpx.AsyncClient, ): user_1 = registered_user() @@ -106,7 +106,7 @@ async def test_list_clusters( # let's create some clusters NUM_CLUSTERS = 111 for n in range(NUM_CLUSTERS): - cluster(user_1, name=f"pytest cluster{n:04}") + await create_cluster(user_1, name=f"pytest cluster{n:04}") response = await async_client.get(list_clusters_url) assert response.status_code == status.HTTP_200_OK @@ -141,7 +141,7 @@ async def test_list_clusters( (CLUSTER_MANAGER_RIGHTS, "manager rights"), (CLUSTER_ADMIN_RIGHTS, "admin rights"), ]: - cluster( + await create_cluster( user_1, # cluster is owned by user_1 name=f"cluster with {name}", access_rights={ @@ -172,7 +172,7 @@ async def test_list_clusters( async def test_get_cluster( clusters_config: None, registered_user: Callable[..., dict], - cluster: Callable[..., Cluster], + create_cluster: Callable[..., Awaitable[Cluster]], async_client: httpx.AsyncClient, ): user_1 = registered_user() @@ -183,7 +183,7 @@ async def test_get_cluster( assert response.status_code == status.HTTP_404_NOT_FOUND # let's create some clusters a_bunch_of_clusters = [ - cluster(user_1, name=f"pytest cluster{n:04}") for n in range(111) + await create_cluster(user_1, name=f"pytest cluster{n:04}") for n in range(111) ] the_cluster = random.choice(a_bunch_of_clusters) @@ -213,7 +213,7 @@ async def test_get_cluster( (CLUSTER_MANAGER_RIGHTS, True), (CLUSTER_ADMIN_RIGHTS, True), ]: - a_cluster = cluster( + a_cluster = await create_cluster( user_2, # cluster is owned by user_2 access_rights={ user_2["primary_gid"]: CLUSTER_ADMIN_RIGHTS, @@ -243,7 +243,7 @@ async def test_get_cluster( async def test_get_another_cluster( clusters_config: None, registered_user: Callable[..., dict], - cluster: Callable[..., Cluster], + create_cluster: Callable[..., Awaitable[Cluster]], async_client: httpx.AsyncClient, cluster_sharing_rights: ClusterAccessRights, can_use: bool, @@ -252,7 +252,7 @@ async def test_get_another_cluster( user_2 = registered_user() # let's create some clusters a_bunch_of_clusters = [ - cluster( + await create_cluster( user_1, name=f"pytest cluster{n:04}", access_rights={ @@ -349,7 +349,7 @@ async def test_create_cluster( async def test_update_own_cluster( clusters_config: None, registered_user: Callable[..., dict], - cluster: Callable[..., Cluster], + create_cluster: Callable[..., Awaitable[Cluster]], cluster_simple_authentication: Callable, async_client: httpx.AsyncClient, faker: Faker, @@ -366,7 +366,7 @@ async def test_update_own_cluster( assert response.status_code == status.HTTP_404_NOT_FOUND # let's create some clusters a_bunch_of_clusters = [ - cluster(user_1, name=f"pytest cluster{n:04}") for n in range(111) + await create_cluster(user_1, name=f"pytest cluster{n:04}") for n in range(111) ] the_cluster = random.choice(a_bunch_of_clusters) # get the original one @@ -471,7 +471,7 @@ async def test_update_own_cluster( async def test_update_default_cluster_fails( clusters_config: None, registered_user: Callable[..., dict], - cluster: Callable[..., Cluster], + create_cluster: Callable[..., Awaitable[Cluster]], cluster_simple_authentication: Callable, async_client: httpx.AsyncClient, faker: Faker, @@ -506,7 +506,7 @@ async def test_update_default_cluster_fails( async def test_update_another_cluster( clusters_config: None, registered_user: Callable[..., dict], - cluster: Callable[..., Cluster], + create_cluster: Callable[..., Awaitable[Cluster]], cluster_simple_authentication: Callable, async_client: httpx.AsyncClient, faker: Faker, @@ -522,7 +522,7 @@ async def test_update_another_cluster( user_2 = registered_user() # let's create some clusters a_bunch_of_clusters = [ - cluster( + await create_cluster( user_1, name=f"pytest cluster{n:04}", access_rights={ @@ -603,13 +603,13 @@ async def test_update_another_cluster( async def test_delete_cluster( clusters_config: None, registered_user: Callable[..., dict], - cluster: Callable[..., Cluster], + create_cluster: Callable[..., Awaitable[Cluster]], async_client: httpx.AsyncClient, ): user_1 = registered_user() # let's create some clusters a_bunch_of_clusters = [ - cluster( + await create_cluster( user_1, name=f"pytest cluster{n:04}", access_rights={ @@ -647,7 +647,7 @@ async def test_delete_cluster( async def test_delete_another_cluster( clusters_config: None, registered_user: Callable[..., dict], - cluster: Callable[..., Cluster], + create_cluster: Callable[..., Awaitable[Cluster]], cluster_simple_authentication: Callable, async_client: httpx.AsyncClient, faker: Faker, @@ -658,7 +658,7 @@ async def test_delete_another_cluster( user_2 = registered_user() # let's create some clusters a_bunch_of_clusters = [ - cluster( + await create_cluster( user_1, name=f"pytest cluster{n:04}", access_rights={ @@ -754,7 +754,7 @@ async def test_ping_cluster( async def test_ping_specific_cluster( clusters_config: None, registered_user: Callable[..., dict], - cluster: Callable[..., Cluster], + create_cluster: Callable[..., Awaitable[Cluster]], async_client: httpx.AsyncClient, local_dask_gateway_server: DaskGatewayServer, ): @@ -767,7 +767,7 @@ async def test_ping_specific_cluster( # let's create some clusters and ping one a_bunch_of_clusters = [ - cluster( + await create_cluster( user_1, name=f"pytest cluster{n:04}", endpoint=local_dask_gateway_server.address, diff --git a/services/director-v2/tests/unit/with_dbs/test_api_route_clusters_details.py b/services/director-v2/tests/unit/with_dbs/test_api_route_clusters_details.py index 5dd1abaa594..357f3b7647a 100644 --- a/services/director-v2/tests/unit/with_dbs/test_api_route_clusters_details.py +++ b/services/director-v2/tests/unit/with_dbs/test_api_route_clusters_details.py @@ -4,7 +4,7 @@ import json from collections.abc import Callable -from typing import Any +from typing import Any, Awaitable import httpx import pytest @@ -142,14 +142,14 @@ async def test_get_cluster_details( registered_user: Callable[..., dict[str, Any]], async_client: httpx.AsyncClient, local_dask_gateway_server: DaskGatewayServer, - cluster: Callable[..., Cluster], + create_cluster: Callable[..., Awaitable[Cluster]], dask_gateway_cluster: GatewayCluster, dask_gateway_cluster_client: DaskClient, gateway_username: str, ): user_1 = registered_user() # define the cluster in the DB - some_cluster = cluster( + some_cluster = await create_cluster( user_1, endpoint=local_dask_gateway_server.address, authentication=SimpleAuthentication( diff --git a/services/director-v2/tests/unit/with_dbs/test_api_route_computations.py b/services/director-v2/tests/unit/with_dbs/test_api_route_computations.py index add9c4d77d3..2e75b18c009 100644 --- a/services/director-v2/tests/unit/with_dbs/test_api_route_computations.py +++ b/services/director-v2/tests/unit/with_dbs/test_api_route_computations.py @@ -87,7 +87,6 @@ def minimal_configuration( rabbit_service: RabbitSettings, redis_service: RedisSettings, monkeypatch: pytest.MonkeyPatch, - mocked_rabbit_mq_client: None, faker: Faker, ): monkeypatch.setenv("DIRECTOR_V2_DYNAMIC_SIDECAR_ENABLED", "false") @@ -789,12 +788,12 @@ async def test_start_computation_with_deprecated_services_raises_406( @pytest.fixture -def unusable_cluster( +async def unusable_cluster( registered_user: Callable[..., dict[str, Any]], - cluster: Callable[..., Cluster], + create_cluster: Callable[..., Awaitable[Cluster]], ) -> ClusterID: user = registered_user() - created_cluster = cluster(user) + created_cluster = await create_cluster(user) return created_cluster.id @@ -865,7 +864,7 @@ async def test_get_computation_from_empty_project( fake_workbench_adjacency: dict[str, Any], registered_user: Callable[..., dict[str, Any]], project: Callable[..., Awaitable[ProjectAtDB]], - pipeline: Callable[..., CompPipelineAtDB], + create_pipeline: Callable[..., Awaitable[CompPipelineAtDB]], faker: Faker, async_client: httpx.AsyncClient, ): @@ -884,8 +883,8 @@ async def test_get_computation_from_empty_project( response = await async_client.get(get_computation_url) assert response.status_code == status.HTTP_404_NOT_FOUND, response.text # create an empty pipeline - pipeline( - project_id=proj.uuid, + await create_pipeline( + project_id=f"{proj.uuid}", ) response = await async_client.get(get_computation_url) assert response.status_code == status.HTTP_200_OK, response.text @@ -917,8 +916,8 @@ async def test_get_computation_from_not_started_computation_task( fake_workbench_adjacency: dict[str, Any], registered_user: Callable[..., dict[str, Any]], project: Callable[..., Awaitable[ProjectAtDB]], - pipeline: Callable[..., CompPipelineAtDB], - tasks: Callable[..., list[CompTaskAtDB]], + create_pipeline: Callable[..., Awaitable[CompPipelineAtDB]], + create_tasks: Callable[..., Awaitable[list[CompTaskAtDB]]], async_client: httpx.AsyncClient, ): user = registered_user() @@ -926,8 +925,8 @@ async def test_get_computation_from_not_started_computation_task( get_computation_url = httpx.URL( f"/v2/computations/{proj.uuid}?user_id={user['id']}" ) - pipeline( - project_id=proj.uuid, + await create_pipeline( + project_id=f"{proj.uuid}", dag_adjacency_list=fake_workbench_adjacency, ) # create no task this should trigger an exception @@ -935,7 +934,7 @@ async def test_get_computation_from_not_started_computation_task( assert response.status_code == status.HTTP_409_CONFLICT, response.text # now create the expected tasks and the state is good again - comp_tasks = tasks(user=user, project=proj) + comp_tasks = await create_tasks(user=user, project=proj) response = await async_client.get(get_computation_url) assert response.status_code == status.HTTP_200_OK, response.text returned_computation = ComputationGet.model_validate(response.json()) @@ -989,19 +988,23 @@ async def test_get_computation_from_published_computation_task( fake_workbench_adjacency: dict[str, Any], registered_user: Callable[..., dict[str, Any]], project: Callable[..., Awaitable[ProjectAtDB]], - pipeline: Callable[..., CompPipelineAtDB], - tasks: Callable[..., list[CompTaskAtDB]], - runs: Callable[..., CompRunsAtDB], + create_pipeline: Callable[..., Awaitable[CompPipelineAtDB]], + create_tasks: Callable[..., Awaitable[list[CompTaskAtDB]]], + create_comp_run: Callable[..., Awaitable[CompRunsAtDB]], async_client: httpx.AsyncClient, ): user = registered_user() proj = await project(user, workbench=fake_workbench_without_outputs) - pipeline( - project_id=proj.uuid, + await create_pipeline( + project_id=f"{proj.uuid}", dag_adjacency_list=fake_workbench_adjacency, ) - comp_tasks = tasks(user=user, project=proj, state=StateType.PUBLISHED, progress=0) - comp_runs = runs(user=user, project=proj, result=StateType.PUBLISHED) + comp_tasks = await create_tasks( + user=user, project=proj, state=StateType.PUBLISHED, progress=0 + ) + comp_runs = await create_comp_run( + user=user, project=proj, result=StateType.PUBLISHED + ) assert comp_runs get_computation_url = httpx.URL( f"/v2/computations/{proj.uuid}?user_id={user['id']}" diff --git a/services/director-v2/tests/unit/with_dbs/test_api_route_computations_tasks.py b/services/director-v2/tests/unit/with_dbs/test_api_route_computations_tasks.py index 10bd1ba3a2f..845983b99cb 100644 --- a/services/director-v2/tests/unit/with_dbs/test_api_route_computations_tasks.py +++ b/services/director-v2/tests/unit/with_dbs/test_api_route_computations_tasks.py @@ -9,7 +9,6 @@ from uuid import uuid4 import httpx -from pydantic import TypeAdapter import pytest from faker import Faker from fastapi import FastAPI, status @@ -22,6 +21,7 @@ from models_library.projects import ProjectAtDB, ProjectID from models_library.projects_nodes_io import NodeID from models_library.users import UserID +from pydantic import TypeAdapter from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict from pytest_simcore.helpers.typing_env import EnvVarsDict from simcore_service_director_v2.core.settings import AppSettings @@ -116,21 +116,21 @@ async def project_id( fake_workbench_adjacency: dict[str, Any], user: dict[str, Any], project: Callable[..., Awaitable[ProjectAtDB]], - pipeline: Callable[..., CompPipelineAtDB], - tasks: Callable[..., list[CompTaskAtDB]], -): + create_pipeline: Callable[..., Awaitable[CompPipelineAtDB]], + create_tasks: Callable[..., Awaitable[list[CompTaskAtDB]]], +) -> ProjectID: """project uuid of a saved project (w/ tasks up-to-date)""" # insert project -> db proj = await project(user, workbench=fake_workbench_without_outputs) # insert pipeline -> comp_pipeline - pipeline( - project_id=proj.uuid, + await create_pipeline( + project_id=f"{proj.uuid}", dag_adjacency_list=fake_workbench_adjacency, ) # insert tasks -> comp_tasks - comp_tasks = tasks(user=user, project=proj) + comp_tasks = await create_tasks(user=user, project=proj) return proj.uuid diff --git a/services/director-v2/tests/unit/with_dbs/test_utils_rabbitmq.py b/services/director-v2/tests/unit/with_dbs/test_utils_rabbitmq.py index a041f70ecc7..8778d17245e 100644 --- a/services/director-v2/tests/unit/with_dbs/test_utils_rabbitmq.py +++ b/services/director-v2/tests/unit/with_dbs/test_utils_rabbitmq.py @@ -93,18 +93,18 @@ async def project( @pytest.fixture -def tasks( +async def tasks( user: dict[str, Any], project: ProjectAtDB, fake_workbench_adjacency: dict[str, Any], - pipeline: Callable[..., CompPipelineAtDB], - tasks: Callable[..., list[CompTaskAtDB]], + create_pipeline: Callable[..., Awaitable[CompPipelineAtDB]], + create_tasks: Callable[..., Awaitable[list[CompTaskAtDB]]], ) -> list[CompTaskAtDB]: - pipeline( - project_id=project.uuid, + await create_pipeline( + project_id=f"{project.uuid}", dag_adjacency_list=fake_workbench_adjacency, ) - comp_tasks = tasks(user, project) + comp_tasks = await create_tasks(user, project) assert len(comp_tasks) > 0 return comp_tasks