From 6b06e28f3ce276a215745a369e45147c9dba9536 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Tue, 12 Nov 2024 18:07:49 +0100 Subject: [PATCH 1/9] ensure cancellation of task is not only in 1 process --- .../modules/comp_scheduler/_base_scheduler.py | 10 ++++------ .../modules/comp_scheduler/_scheduler_factory.py | 1 - .../test_modules_comp_scheduler_dask_scheduler.py | 2 -- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py index 097afd95288..1d3d97be460 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py @@ -140,7 +140,6 @@ async def _triage_changed_tasks( class ScheduledPipelineParams: cluster_id: ClusterID run_metadata: RunMetadataDict - mark_for_cancellation: datetime.datetime | None use_on_demand_clusters: bool scheduler_task: asyncio.Task | None = None @@ -198,7 +197,6 @@ async def run_new_pipeline( cluster_id=cluster_id, run_metadata=new_run.metadata, use_on_demand_clusters=use_on_demand_clusters, - mark_for_cancellation=None, ) await publish_project_log( self.rabbitmq_client, @@ -236,9 +234,6 @@ async def stop_pipeline( ) if updated_comp_run: assert updated_comp_run.cancelled is not None # nosec - self.scheduled_pipelines[ - (user_id, project_id, selected_iteration) - ].mark_for_cancellation = updated_comp_run.cancelled # ensure the scheduler starts right away self.scheduled_pipelines[ (user_id, project_id, selected_iteration) @@ -714,7 +709,10 @@ async def _schedule_pipeline( project_id, dag ) # 3. do we want to stop the pipeline now? - if pipeline_params.mark_for_cancellation: + comp_run = await CompRunsRepository.instance(self.db_engine).get( + user_id, project_id, iteration + ) + if comp_run.cancelled: await self._schedule_tasks_to_stop( user_id, project_id, comp_tasks, pipeline_params ) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py index 4f7812816cc..d0dc66e0345 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py @@ -53,7 +53,6 @@ async def create_from_db(app: FastAPI) -> BaseCompScheduler: r.cluster_id if r.cluster_id is not None else DEFAULT_CLUSTER_ID ), run_metadata=r.metadata, - mark_for_cancellation=r.cancelled, use_on_demand_clusters=r.use_on_demand_clusters, ) for r in runs diff --git a/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py b/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py index 1df1ae09d39..69f10cc9022 100644 --- a/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py +++ b/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py @@ -402,7 +402,6 @@ async def test_misconfigured_pipeline_is_not_scheduled( assert u_id == user["id"] assert p_id == sleepers_project.uuid assert it > 0 - assert params.mark_for_cancellation is None # check the database was properly updated async with aiopg_engine.acquire() as conn: result = await conn.execute( @@ -450,7 +449,6 @@ async def _assert_start_pipeline( assert u_id == published_project.project.prj_owner assert p_id == published_project.project.uuid assert it > 0 - assert params.mark_for_cancellation is None assert params.run_metadata == run_metadata # check the database is correctly updated, the run is published From 9c362f57b1535a17f2abe2f8e0cee483c5f47b66 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Thu, 14 Nov 2024 13:12:33 +0100 Subject: [PATCH 2/9] refactor --- packages/models-library/src/models_library/clusters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/models-library/src/models_library/clusters.py b/packages/models-library/src/models_library/clusters.py index 1856dc5c287..5fa14c2daa1 100644 --- a/packages/models-library/src/models_library/clusters.py +++ b/packages/models-library/src/models_library/clusters.py @@ -154,7 +154,7 @@ class Config: ClusterID: TypeAlias = NonNegativeInt -DEFAULT_CLUSTER_ID: Final[NonNegativeInt] = 0 +DEFAULT_CLUSTER_ID: Final[ClusterID] = 0 class Cluster(BaseCluster): From 57e37e322125f7922ff9b35b04a79da2f97beef1 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Thu, 14 Nov 2024 13:56:18 +0100 Subject: [PATCH 3/9] removed usage of in process variable --- .../modules/comp_scheduler/_base_scheduler.py | 79 ++++++------- .../modules/comp_scheduler/_dask_scheduler.py | 104 ++++++++++++++---- .../comp_scheduler/_scheduler_factory.py | 9 +- 3 files changed, 116 insertions(+), 76 deletions(-) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py index 1d3d97be460..7de3bcff3ef 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py @@ -18,7 +18,7 @@ import logging from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import Final +from typing import Callable, Final import arrow import networkx as nx @@ -52,7 +52,7 @@ ) from ...core.settings import ComputationalBackendSettings from ...models.comp_pipelines import CompPipelineAtDB -from ...models.comp_runs import RunMetadataDict +from ...models.comp_runs import CompRunsAtDB, RunMetadataDict from ...models.comp_tasks import CompTaskAtDB from ...utils.comp_scheduler import ( COMPLETED_STATES, @@ -138,10 +138,6 @@ async def _triage_changed_tasks( @dataclass(kw_only=True) class ScheduledPipelineParams: - cluster_id: ClusterID - run_metadata: RunMetadataDict - use_on_demand_clusters: bool - scheduler_task: asyncio.Task | None = None scheduler_waker: asyncio.Event = field(default_factory=asyncio.Event) @@ -193,11 +189,7 @@ async def run_new_pipeline( ) self.scheduled_pipelines[ (user_id, project_id, new_run.iteration) - ] = pipeline_params = ScheduledPipelineParams( - cluster_id=cluster_id, - run_metadata=new_run.metadata, - use_on_demand_clusters=use_on_demand_clusters, - ) + ] = pipeline_params = ScheduledPipelineParams() await publish_project_log( self.rabbitmq_client, user_id, @@ -282,14 +274,14 @@ async def _exclusive_safe_schedule_pipeline( user_id: UserID, project_id: ProjectID, iteration: Iteration, - pipeline_params: ScheduledPipelineParams, + wake_up_callback: Callable[[], None], ) -> None: with contextlib.suppress(CouldNotAcquireLockError): await self._schedule_pipeline( user_id=user_id, project_id=project_id, iteration=iteration, - pipeline_params=pipeline_params, + wake_up_callback=wake_up_callback, ) pipeline_params.scheduler_task = start_periodic_task( @@ -298,7 +290,7 @@ async def _exclusive_safe_schedule_pipeline( user_id=user_id, project_id=project_id, iteration=iteration, - pipeline_params=pipeline_params, + wake_up_callback=pipeline_params.wake_up, ), interval=_SCHEDULER_INTERVAL, task_name=_TASK_NAME_TEMPLATE.format( @@ -446,10 +438,10 @@ async def _get_changed_tasks_from_backend( self, user_id: UserID, processing_tasks: list[CompTaskAtDB], - pipeline_params: ScheduledPipelineParams, + comp_run: CompRunsAtDB, ) -> list[tuple[_Previous, _Current]]: tasks_backend_status = await self._get_tasks_status( - user_id, processing_tasks, pipeline_params + user_id, processing_tasks, comp_run ) return [ @@ -587,7 +579,7 @@ async def _update_states_from_comp_backend( project_id: ProjectID, iteration: Iteration, pipeline_dag: nx.DiGraph, - pipeline_params: ScheduledPipelineParams, + comp_run: CompRunsAtDB, ) -> None: tasks = await self._get_pipeline_tasks(project_id, pipeline_dag) tasks_inprocess = [t for t in tasks.values() if t.state in PROCESSING_STATES] @@ -596,7 +588,7 @@ async def _update_states_from_comp_backend( # get the tasks which state actually changed since last check tasks_with_changed_states = await self._get_changed_tasks_from_backend( - user_id, tasks_inprocess, pipeline_params + user_id, tasks_inprocess, comp_run ) # NOTE: typical states a task goes through # NOT_STARTED (initial state) -> PUBLISHED (user press run/API call) -> PENDING -> WAITING_FOR_CLUSTER (cluster creation) -> @@ -615,7 +607,7 @@ async def _update_states_from_comp_backend( sorted_tasks.started, user_id=user_id, iteration=iteration, - run_metadata=pipeline_params.run_metadata, + run_metadata=comp_run.metadata, ) if sorted_tasks.completed or sorted_tasks.potentially_lost: @@ -623,7 +615,7 @@ async def _update_states_from_comp_backend( user_id, sorted_tasks.completed + sorted_tasks.potentially_lost, iteration, - pipeline_params=pipeline_params, + comp_run=comp_run, ) if sorted_tasks.waiting: @@ -636,25 +628,20 @@ async def _start_tasks( user_id: UserID, project_id: ProjectID, scheduled_tasks: dict[NodeID, CompTaskAtDB], - pipeline_params: ScheduledPipelineParams, + comp_run: CompRunsAtDB, + wake_up_callback: Callable[[], None], ) -> None: ... @abstractmethod async def _get_tasks_status( - self, - user_id: UserID, - tasks: list[CompTaskAtDB], - pipeline_params: ScheduledPipelineParams, + self, user_id: UserID, tasks: list[CompTaskAtDB], comp_run: CompRunsAtDB ) -> list[RunningState]: ... @abstractmethod async def _stop_tasks( - self, - user_id: UserID, - tasks: list[CompTaskAtDB], - pipeline_params: ScheduledPipelineParams, + self, user_id: UserID, tasks: list[CompTaskAtDB], comp_run: CompRunsAtDB ) -> None: ... @@ -664,7 +651,7 @@ async def _process_completed_tasks( user_id: UserID, tasks: list[CompTaskAtDB], iteration: Iteration, - pipeline_params: ScheduledPipelineParams, + comp_run: CompRunsAtDB, ) -> None: ... @@ -690,7 +677,7 @@ async def _schedule_pipeline( user_id: UserID, project_id: ProjectID, iteration: PositiveInt, - pipeline_params: ScheduledPipelineParams, + wake_up_callback: Callable[[], None], ) -> None: with log_context( _logger, @@ -699,22 +686,22 @@ async def _schedule_pipeline( ): dag: nx.DiGraph = nx.DiGraph() try: + comp_run = await CompRunsRepository.instance(self.db_engine).get( + user_id, project_id, iteration + ) dag = await self._get_pipeline_dag(project_id) # 1. Update our list of tasks with data from backend (state, results) await self._update_states_from_comp_backend( - user_id, project_id, iteration, dag, pipeline_params=pipeline_params + user_id, project_id, iteration, dag, comp_run ) # 2. Any task following a FAILED task shall be ABORTED comp_tasks = await self._set_states_following_failed_to_aborted( project_id, dag ) # 3. do we want to stop the pipeline now? - comp_run = await CompRunsRepository.instance(self.db_engine).get( - user_id, project_id, iteration - ) if comp_run.cancelled: await self._schedule_tasks_to_stop( - user_id, project_id, comp_tasks, pipeline_params + user_id, project_id, comp_tasks, comp_run ) else: # let's get the tasks to schedule then @@ -723,7 +710,8 @@ async def _schedule_pipeline( project_id=project_id, comp_tasks=comp_tasks, dag=dag, - pipeline_params=pipeline_params, + comp_run=comp_run, + wake_up_callback=wake_up_callback, ) # 4. timeout if waiting for cluster has been there for more than X minutes comp_tasks = await self._timeout_if_waiting_for_cluster_too_long( @@ -748,8 +736,9 @@ async def _schedule_pipeline( f"{project_id=}", f"{pipeline_result=}", ) - assert pipeline_params.scheduler_task is not None # nosec - pipeline_params.scheduler_task.cancel() + current_task = asyncio.current_task() + assert current_task is not None # nosec + current_task.cancel() except PipelineNotFoundError: _logger.warning( "pipeline %s does not exist in comp_pipeline table, it will be removed from scheduler", @@ -794,7 +783,7 @@ async def _schedule_tasks_to_stop( user_id: UserID, project_id: ProjectID, comp_tasks: dict[NodeIDStr, CompTaskAtDB], - pipeline_params: ScheduledPipelineParams, + comp_run: CompRunsAtDB, ) -> None: # get any running task and stop them comp_tasks_repo = CompTasksRepository.instance(self.db_engine) @@ -803,7 +792,7 @@ async def _schedule_tasks_to_stop( ) # stop any remaining running task, these are already submitted tasks_to_stop = [t for t in comp_tasks.values() if t.state in PROCESSING_STATES] - await self._stop_tasks(user_id, tasks_to_stop, pipeline_params) + await self._stop_tasks(user_id, tasks_to_stop, comp_run) async def _schedule_tasks_to_start( # noqa: C901 self, @@ -811,7 +800,8 @@ async def _schedule_tasks_to_start( # noqa: C901 project_id: ProjectID, comp_tasks: dict[NodeIDStr, CompTaskAtDB], dag: nx.DiGraph, - pipeline_params: ScheduledPipelineParams, + comp_run: CompRunsAtDB, + wake_up_callback: Callable[[], None], ) -> dict[NodeIDStr, CompTaskAtDB]: # filter out the successfully completed tasks dag.remove_nodes_from( @@ -843,7 +833,8 @@ async def _schedule_tasks_to_start( # noqa: C901 user_id=user_id, project_id=project_id, scheduled_tasks=tasks_ready_to_start, - pipeline_params=pipeline_params, + comp_run=comp_run, + wake_up_callback=wake_up_callback, ) except ( ComputationalBackendNotConnectedError, @@ -932,7 +923,7 @@ async def _schedule_tasks_to_start( # noqa: C901 "Unexpected error for %s with %s on %s happened when scheduling %s:", f"{user_id=}", f"{project_id=}", - f"{pipeline_params.cluster_id=}", + f"{comp_run.cluster_id=}", f"{tasks_ready_to_start.keys()=}", ) await CompTasksRepository.instance( diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_dask_scheduler.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_dask_scheduler.py index 512df1b1712..2f0c8c4eab4 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_dask_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_dask_scheduler.py @@ -4,6 +4,7 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass +from typing import Callable import arrow from dask_task_models_library.container_tasks.errors import TaskCancelledError @@ -12,7 +13,7 @@ TaskProgressEvent, ) from dask_task_models_library.container_tasks.io import TaskOutputData -from models_library.clusters import DEFAULT_CLUSTER_ID, BaseCluster +from models_library.clusters import DEFAULT_CLUSTER_ID, BaseCluster, ClusterID from models_library.errors import ErrorDict from models_library.projects import ProjectID from models_library.projects_nodes_io import NodeID @@ -27,7 +28,7 @@ ComputationalBackendOnDemandNotReadyError, TaskSchedulingError, ) -from ...models.comp_runs import RunMetadataDict +from ...models.comp_runs import CompRunsAtDB, RunMetadataDict from ...models.comp_tasks import CompTaskAtDB from ...models.dask_subsystem import DaskClientTaskState from ...utils.comp_scheduler import Iteration, get_resource_tracking_run_id @@ -49,7 +50,7 @@ from ..db.repositories.clusters import ClustersRepository from ..db.repositories.comp_runs import CompRunsRepository from ..db.repositories.comp_tasks import CompTasksRepository -from ._base_scheduler import BaseCompScheduler, ScheduledPipelineParams +from ._base_scheduler import BaseCompScheduler _logger = logging.getLogger(__name__) @@ -69,19 +70,22 @@ @asynccontextmanager async def _cluster_dask_client( user_id: UserID, - pipeline_params: ScheduledPipelineParams, scheduler: "DaskScheduler", + *, + use_on_demand_clusters: bool, + cluster_id: ClusterID, + run_metadata: RunMetadataDict, ) -> AsyncIterator[DaskClient]: cluster: BaseCluster = scheduler.settings.default_cluster - if pipeline_params.use_on_demand_clusters: + if use_on_demand_clusters: cluster = await get_or_create_on_demand_cluster( scheduler.rabbitmq_rpc_client, user_id=user_id, - wallet_id=pipeline_params.run_metadata.get("wallet_id"), + wallet_id=run_metadata.get("wallet_id"), ) - if pipeline_params.cluster_id != DEFAULT_CLUSTER_ID: + if cluster_id != DEFAULT_CLUSTER_ID: clusters_repo = ClustersRepository.instance(scheduler.db_engine) - cluster = await clusters_repo.get_cluster(user_id, pipeline_params.cluster_id) + cluster = await clusters_repo.get_cluster(user_id, cluster_id) async with scheduler.dask_clients_pool.acquire(cluster) as client: yield client @@ -104,10 +108,21 @@ async def _start_tasks( user_id: UserID, project_id: ProjectID, scheduled_tasks: dict[NodeID, CompTaskAtDB], - pipeline_params: ScheduledPipelineParams, + comp_run: CompRunsAtDB, + wake_up_callback: Callable[[], None], ) -> None: # now transfer the pipeline to the dask scheduler - async with _cluster_dask_client(user_id, pipeline_params, self) as client: + async with _cluster_dask_client( + user_id, + self, + use_on_demand_clusters=comp_run.use_on_demand_clusters, + cluster_id=( + comp_run.cluster_id + if comp_run.cluster_id is not None + else DEFAULT_CLUSTER_ID + ), + run_metadata=comp_run.metadata, + ) as client: # Change the tasks state to PENDING comp_tasks_repo = CompTasksRepository.instance(self.db_engine) await comp_tasks_repo.update_project_tasks_state( @@ -121,11 +136,15 @@ async def _start_tasks( client.send_computation_tasks( user_id=user_id, project_id=project_id, - cluster_id=pipeline_params.cluster_id, + cluster_id=( + comp_run.cluster_id + if comp_run.cluster_id is not None + else DEFAULT_CLUSTER_ID + ), tasks={node_id: task.image}, hardware_info=task.hardware_info, - callback=pipeline_params.wake_up, - metadata=pipeline_params.run_metadata, + callback=wake_up_callback, + metadata=comp_run.metadata, ) for node_id, task in scheduled_tasks.items() ), @@ -146,10 +165,20 @@ async def _get_tasks_status( self, user_id: UserID, tasks: list[CompTaskAtDB], - pipeline_params: ScheduledPipelineParams, + comp_run: CompRunsAtDB, ) -> list[RunningState]: try: - async with _cluster_dask_client(user_id, pipeline_params, self) as client: + async with _cluster_dask_client( + user_id, + self, + use_on_demand_clusters=comp_run.use_on_demand_clusters, + cluster_id=( + comp_run.cluster_id + if comp_run.cluster_id is not None + else DEFAULT_CLUSTER_ID + ), + run_metadata=comp_run.metadata, + ) as client: tasks_statuses = await client.get_tasks_status( [f"{t.job_id}" for t in tasks] ) @@ -177,14 +206,21 @@ async def _get_tasks_status( return [RunningState.WAITING_FOR_CLUSTER] * len(tasks) async def _stop_tasks( - self, - user_id: UserID, - tasks: list[CompTaskAtDB], - pipeline_params: ScheduledPipelineParams, + self, user_id: UserID, tasks: list[CompTaskAtDB], comp_run: CompRunsAtDB ) -> None: # NOTE: if this exception raises, it means the backend was anyway not up with contextlib.suppress(ComputationalBackendOnDemandNotReadyError): - async with _cluster_dask_client(user_id, pipeline_params, self) as client: + async with _cluster_dask_client( + user_id, + self, + use_on_demand_clusters=comp_run.use_on_demand_clusters, + cluster_id=( + comp_run.cluster_id + if comp_run.cluster_id is not None + else DEFAULT_CLUSTER_ID + ), + run_metadata=comp_run.metadata, + ) as client: await asyncio.gather( *[ client.abort_computation_task(t.job_id) @@ -209,10 +245,20 @@ async def _process_completed_tasks( user_id: UserID, tasks: list[CompTaskAtDB], iteration: Iteration, - pipeline_params: ScheduledPipelineParams, + comp_run: CompRunsAtDB, ) -> None: try: - async with _cluster_dask_client(user_id, pipeline_params, self) as client: + async with _cluster_dask_client( + user_id, + self, + use_on_demand_clusters=comp_run.use_on_demand_clusters, + cluster_id=( + comp_run.cluster_id + if comp_run.cluster_id is not None + else DEFAULT_CLUSTER_ID + ), + run_metadata=comp_run.metadata, + ) as client: tasks_results = await asyncio.gather( *[client.get_task_result(t.job_id or "undefined") for t in tasks], return_exceptions=True, @@ -220,13 +266,23 @@ async def _process_completed_tasks( await asyncio.gather( *[ self._process_task_result( - task, result, pipeline_params.run_metadata, iteration + task, result, comp_run.metadata, iteration ) for task, result in zip(tasks, tasks_results, strict=True) ] ) finally: - async with _cluster_dask_client(user_id, pipeline_params, self) as client: + async with _cluster_dask_client( + user_id, + self, + use_on_demand_clusters=comp_run.use_on_demand_clusters, + cluster_id=( + comp_run.cluster_id + if comp_run.cluster_id is not None + else DEFAULT_CLUSTER_ID + ), + run_metadata=comp_run.metadata, + ) as client: await asyncio.gather( *[client.release_task_result(t.job_id) for t in tasks if t.job_id] ) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py index d0dc66e0345..e0521bbb491 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py @@ -1,7 +1,6 @@ import logging from fastapi import FastAPI -from models_library.clusters import DEFAULT_CLUSTER_ID from servicelib.logging_utils import log_context from settings_library.redis import RedisDatabase @@ -48,13 +47,7 @@ async def create_from_db(app: FastAPI) -> BaseCompScheduler: redis_client=get_redis_client_manager(app).client(RedisDatabase.LOCKS), db_engine=db_engine, scheduled_pipelines={ - (r.user_id, r.project_uuid, r.iteration): ScheduledPipelineParams( - cluster_id=( - r.cluster_id if r.cluster_id is not None else DEFAULT_CLUSTER_ID - ), - run_metadata=r.metadata, - use_on_demand_clusters=r.use_on_demand_clusters, - ) + (r.user_id, r.project_uuid, r.iteration): ScheduledPipelineParams() for r in runs }, service_runtime_heartbeat_interval=app_settings.SERVICE_TRACKING_HEARTBEAT, From 2482929ebd7d8713b4dfb6b017d2ce897018e24a Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Thu, 14 Nov 2024 14:16:58 +0100 Subject: [PATCH 4/9] refactor --- .../modules/comp_scheduler/__init__.py | 5 +- .../modules/comp_scheduler/_base_scheduler.py | 73 +++++++++++++------ .../comp_scheduler/_scheduler_factory.py | 24 +----- 3 files changed, 57 insertions(+), 45 deletions(-) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/__init__.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/__init__.py index d06c37457b7..2b29acf16c9 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/__init__.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/__init__.py @@ -16,10 +16,7 @@ async def start_scheduler() -> None: with log_context( _logger, level=logging.INFO, msg="starting computational scheduler" ): - app.state.scheduler = scheduler = await _scheduler_factory.create_from_db( - app - ) - scheduler.recover_scheduling() + app.state.scheduler = await _scheduler_factory.create_from_db(app) return start_scheduler diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py index 7de3bcff3ef..d8ec72e22eb 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py @@ -58,6 +58,7 @@ COMPLETED_STATES, PROCESSING_STATES, RUNNING_STATES, + SCHEDULED_STATES, TASK_TO_START_STATES, WAITING_FOR_START_STATES, Iteration, @@ -147,17 +148,53 @@ def wake_up(self) -> None: @dataclass class BaseCompScheduler(ABC): - scheduled_pipelines: dict[ - tuple[UserID, ProjectID, Iteration], ScheduledPipelineParams - ] db_engine: Engine - wake_up_event: asyncio.Event = field(default_factory=asyncio.Event, init=False) rabbitmq_client: RabbitMQClient rabbitmq_rpc_client: RabbitMQRPCClient settings: ComputationalBackendSettings service_runtime_heartbeat_interval: datetime.timedelta redis_client: RedisClientSDK + # NOTE: this is a trick to be able to inheritate from the class + _scheduled_pipelines: dict[ + tuple[UserID, ProjectID, Iteration], ScheduledPipelineParams + ] = field(default_factory=dict, init=False) + + def __post_init__(self) -> None: + self._scheduled_pipelines = {} + + async def restore_scheduling_from_db(self) -> None: + # get currently scheduled runs + comp_runs = await CompRunsRepository.instance(self.db_engine).list( + filter_by_state=SCHEDULED_STATES + ) + + self._scheduled_pipelines |= { + ( + run.user_id, + run.project_uuid, + run.iteration, + ): ScheduledPipelineParams() + for run in comp_runs + } + + for ( + user_id, + project_id, + iteration, + ), params in self._scheduled_pipelines.items(): + self._start_scheduling(params, user_id, project_id, iteration) + + async def start_scheduling(self) -> None: + await self.restore_scheduling_from_db() + + for ( + user_id, + project_id, + iteration, + ), params in self._scheduled_pipelines.items(): + self._start_scheduling(params, user_id, project_id, iteration) + async def run_new_pipeline( self, user_id: UserID, @@ -187,7 +224,7 @@ async def run_new_pipeline( metadata=run_metadata, use_on_demand_clusters=use_on_demand_clusters, ) - self.scheduled_pipelines[ + self._scheduled_pipelines[ (user_id, project_id, new_run.iteration) ] = pipeline_params = ScheduledPipelineParams() await publish_project_log( @@ -207,7 +244,7 @@ async def stop_pipeline( # if no iteration given find the latest one in the list possible_iterations = { it - for u_id, p_id, it in self.scheduled_pipelines + for u_id, p_id, it in self._scheduled_pipelines if u_id == user_id and p_id == project_id } if not possible_iterations: @@ -227,24 +264,16 @@ async def stop_pipeline( if updated_comp_run: assert updated_comp_run.cancelled is not None # nosec # ensure the scheduler starts right away - self.scheduled_pipelines[ + self._scheduled_pipelines[ (user_id, project_id, selected_iteration) ].wake_up() - def recover_scheduling(self) -> None: - for ( - user_id, - project_id, - iteration, - ), params in self.scheduled_pipelines.items(): - self._start_scheduling(params, user_id, project_id, iteration) - async def shutdown(self) -> None: # cancel all current scheduling processes await asyncio.gather( *( stop_periodic_task(p.scheduler_task, timeout=3) - for p in self.scheduled_pipelines.values() + for p in self._scheduled_pipelines.values() if p.scheduler_task ), return_exceptions=True, @@ -254,7 +283,7 @@ def _get_last_iteration(self, user_id: UserID, project_id: ProjectID) -> Iterati # if no iteration given find the latest one in the list possible_iterations = { it - for u_id, p_id, it in self.scheduled_pipelines + for u_id, p_id, it in self._scheduled_pipelines if u_id == user_id and p_id == project_id } if not possible_iterations: @@ -730,7 +759,9 @@ async def _schedule_pipeline( # 7. Are we done scheduling that pipeline? if not dag.nodes() or pipeline_result in COMPLETED_STATES: # there is nothing left, the run is completed, we're done here - self.scheduled_pipelines.pop((user_id, project_id, iteration), None) + self._scheduled_pipelines.pop( + (user_id, project_id, iteration), None + ) _logger.info( "pipeline %s scheduling completed with result %s", f"{project_id=}", @@ -747,7 +778,7 @@ async def _schedule_pipeline( await self._set_run_result( user_id, project_id, iteration, RunningState.ABORTED ) - self.scheduled_pipelines.pop((user_id, project_id, iteration), None) + self._scheduled_pipelines.pop((user_id, project_id, iteration), None) except InvalidPipelineError as exc: _logger.warning( "pipeline %s appears to be misconfigured, it will be removed from scheduler. Please check pipeline:\n%s", @@ -757,7 +788,7 @@ async def _schedule_pipeline( await self._set_run_result( user_id, project_id, iteration, RunningState.ABORTED ) - self.scheduled_pipelines.pop((user_id, project_id, iteration), None) + self._scheduled_pipelines.pop((user_id, project_id, iteration), None) except (DaskClientAcquisisitonError, ClustersKeeperNotAvailableError): _logger.exception( "Unexpected error while connecting with computational backend, aborting pipeline" @@ -774,7 +805,7 @@ async def _schedule_pipeline( await self._set_run_result( user_id, project_id, iteration, RunningState.FAILED ) - self.scheduled_pipelines.pop((user_id, project_id, iteration), None) + self._scheduled_pipelines.pop((user_id, project_id, iteration), None) except ComputationalBackendNotConnectedError: _logger.exception("Computational backend is not connected!") diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py index e0521bbb491..39b432b9492 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py @@ -6,13 +6,10 @@ from ...core.errors import ConfigurationError from ...core.settings import AppSettings -from ...models.comp_runs import CompRunsAtDB -from ...utils.comp_scheduler import SCHEDULED_STATES from ..dask_clients_pool import DaskClientsPool -from ..db.repositories.comp_runs import CompRunsRepository from ..rabbitmq import get_rabbitmq_client, get_rabbitmq_rpc_client from ..redis import get_redis_client_manager -from ._base_scheduler import BaseCompScheduler, ScheduledPipelineParams +from ._base_scheduler import BaseCompScheduler from ._dask_scheduler import DaskScheduler _logger = logging.getLogger(__name__) @@ -23,32 +20,19 @@ async def create_from_db(app: FastAPI) -> BaseCompScheduler: msg = "Database connection is missing. Please check application configuration." raise ConfigurationError(msg) db_engine = app.state.engine - runs_repository = CompRunsRepository.instance(db_engine) - - # get currently scheduled runs - runs: list[CompRunsAtDB] = await runs_repository.list( - filter_by_state=SCHEDULED_STATES - ) - - _logger.debug( - "Following scheduled comp_runs found still to be scheduled: %s", - runs if runs else "NONE", - ) with log_context( _logger, logging.INFO, msg="Creating Dask-based computational scheduler" ): app_settings: AppSettings = app.state.settings - return DaskScheduler( + scheduler = DaskScheduler( settings=app_settings.DIRECTOR_V2_COMPUTATIONAL_BACKEND, dask_clients_pool=DaskClientsPool.instance(app), rabbitmq_client=get_rabbitmq_client(app), rabbitmq_rpc_client=get_rabbitmq_rpc_client(app), redis_client=get_redis_client_manager(app).client(RedisDatabase.LOCKS), db_engine=db_engine, - scheduled_pipelines={ - (r.user_id, r.project_uuid, r.iteration): ScheduledPipelineParams() - for r in runs - }, service_runtime_heartbeat_interval=app_settings.SERVICE_TRACKING_HEARTBEAT, ) + await scheduler.restore_scheduling_from_db() + return scheduler From 7634b8cf688a97a4a56c76cc1729c61bb6cc7409 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Thu, 14 Nov 2024 14:27:29 +0100 Subject: [PATCH 5/9] simplify --- .../modules/comp_scheduler/_base_scheduler.py | 66 +++++++++---------- 1 file changed, 30 insertions(+), 36 deletions(-) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py index d8ec72e22eb..91db4f4fcca 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py @@ -17,8 +17,9 @@ import functools import logging from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass, field -from typing import Callable, Final +from typing import Final, TypeAlias import arrow import networkx as nx @@ -87,6 +88,9 @@ str ] = "computational-scheduler-{user_id}:{project_id}:{iteration}" +PipelineSchedulingTask: TypeAlias = asyncio.Task +PipelineSchedulingWakeUpEvent: TypeAlias = asyncio.Event + @dataclass(frozen=True, slots=True) class SortedTasks: @@ -139,8 +143,8 @@ async def _triage_changed_tasks( @dataclass(kw_only=True) class ScheduledPipelineParams: - scheduler_task: asyncio.Task | None = None - scheduler_waker: asyncio.Event = field(default_factory=asyncio.Event) + scheduler_task: asyncio.Task + scheduler_waker: asyncio.Event def wake_up(self) -> None: self.scheduler_waker.set() @@ -169,31 +173,19 @@ async def restore_scheduling_from_db(self) -> None: filter_by_state=SCHEDULED_STATES ) - self._scheduled_pipelines |= { - ( - run.user_id, - run.project_uuid, - run.iteration, - ): ScheduledPipelineParams() - for run in comp_runs - } - - for ( - user_id, - project_id, - iteration, - ), params in self._scheduled_pipelines.items(): - self._start_scheduling(params, user_id, project_id, iteration) - - async def start_scheduling(self) -> None: - await self.restore_scheduling_from_db() - - for ( - user_id, - project_id, - iteration, - ), params in self._scheduled_pipelines.items(): - self._start_scheduling(params, user_id, project_id, iteration) + for run in comp_runs: + task, wake_up_event = self._start_scheduling( + run.user_id, run.project_uuid, run.iteration + ) + self._scheduled_pipelines |= { + ( + run.user_id, + run.project_uuid, + run.iteration, + ): ScheduledPipelineParams( + scheduler_task=task, scheduler_waker=wake_up_event + ) + } async def run_new_pipeline( self, @@ -224,9 +216,12 @@ async def run_new_pipeline( metadata=run_metadata, use_on_demand_clusters=use_on_demand_clusters, ) + task, wake_up_event = self._start_scheduling( + user_id, project_id, new_run.iteration + ) self._scheduled_pipelines[ (user_id, project_id, new_run.iteration) - ] = pipeline_params = ScheduledPipelineParams() + ] = ScheduledPipelineParams(scheduler_task=task, scheduler_waker=wake_up_event) await publish_project_log( self.rabbitmq_client, user_id, @@ -235,8 +230,6 @@ async def run_new_pipeline( log_level=logging.INFO, ) - self._start_scheduling(pipeline_params, user_id, project_id, new_run.iteration) - async def stop_pipeline( self, user_id: UserID, project_id: ProjectID, iteration: int | None = None ) -> None: @@ -293,11 +286,10 @@ def _get_last_iteration(self, user_id: UserID, project_id: ProjectID) -> Iterati def _start_scheduling( self, - pipeline_params: ScheduledPipelineParams, user_id: UserID, project_id: ProjectID, iteration: Iteration, - ) -> None: + ) -> tuple[PipelineSchedulingTask, PipelineSchedulingWakeUpEvent]: async def _exclusive_safe_schedule_pipeline( *, user_id: UserID, @@ -313,20 +305,22 @@ async def _exclusive_safe_schedule_pipeline( wake_up_callback=wake_up_callback, ) - pipeline_params.scheduler_task = start_periodic_task( + pipeline_wake_up_event = asyncio.Event() + pipeline_task = start_periodic_task( functools.partial( _exclusive_safe_schedule_pipeline, user_id=user_id, project_id=project_id, iteration=iteration, - wake_up_callback=pipeline_params.wake_up, + wake_up_callback=pipeline_wake_up_event.set, ), interval=_SCHEDULER_INTERVAL, task_name=_TASK_NAME_TEMPLATE.format( user_id=user_id, project_id=project_id, iteration=iteration ), - early_wake_up_event=pipeline_params.scheduler_waker, + early_wake_up_event=pipeline_wake_up_event, ) + return pipeline_task, pipeline_wake_up_event async def _get_pipeline_dag(self, project_id: ProjectID) -> nx.DiGraph: comp_pipeline_repo = CompPipelinesRepository.instance(self.db_engine) From 6540fb8c3c3dd645f1e848f0fef93bff7c5cdbaa Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:03:11 +0100 Subject: [PATCH 6/9] fixed test --- ...t_modules_comp_scheduler_dask_scheduler.py | 41 +++++++++++-------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py b/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py index 69f10cc9022..a54bb3fbf68 100644 --- a/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py +++ b/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py @@ -71,9 +71,6 @@ BaseCompScheduler, get_scheduler, ) -from simcore_service_director_v2.modules.comp_scheduler._base_scheduler import ( - ScheduledPipelineParams, -) from simcore_service_director_v2.modules.comp_scheduler._dask_scheduler import ( DaskScheduler, ) @@ -165,14 +162,14 @@ async def schedule_all_pipelines(scheduler: BaseCompScheduler) -> None: # misconfigured pipelines that would be removed from the scheduler # NOTE: we simulate multiple dv-2 replicas by running several times # the same pipeline scheduling - local_pipelines = deepcopy(scheduler.scheduled_pipelines) + local_pipelines = deepcopy(scheduler._scheduled_pipelines) # noqa: SLF001 results = await asyncio.gather( *( scheduler._schedule_pipeline( # noqa: SLF001 user_id=user_id, project_id=project_id, iteration=iteration, - pipeline_params=params, + wake_up_callback=params.scheduler_waker.set, ) for _ in range(3) for ( @@ -256,15 +253,16 @@ def mocked_clean_task_output_fct(mocker: MockerFixture) -> mock.MagicMock: @pytest.fixture def with_disabled_auto_scheduling(mocker: MockerFixture) -> mock.MagicMock: - """disables the scheduler task, note that it needs to be triggered manually then""" + """disables the scheduler task, note that it needs to be triggered manu>ally then""" def _fake_starter( self: BaseCompScheduler, - pipeline_params: ScheduledPipelineParams, *args, **kwargs, - ) -> None: - pipeline_params.scheduler_task = mocker.MagicMock() + ): + scheduler_task = mocker.MagicMock() + scheduler_task_wake_up_event = mocker.MagicMock() + return scheduler_task, scheduler_task_wake_up_event return mocker.patch( "simcore_service_director_v2.modules.comp_scheduler._base_scheduler.BaseCompScheduler._start_scheduling", @@ -358,7 +356,7 @@ async def test_empty_pipeline_is_not_scheduled( run_metadata=run_metadata, use_on_demand_clusters=False, ) - assert len(scheduler.scheduled_pipelines) == 0 + assert len(scheduler._scheduled_pipelines) == 0 # noqa: SLF001 # check the database is empty async with aiopg_engine.acquire() as conn: result = await conn.scalar( @@ -397,8 +395,12 @@ async def test_misconfigured_pipeline_is_not_scheduled( run_metadata=run_metadata, use_on_demand_clusters=False, ) - assert len(scheduler.scheduled_pipelines) == 1 - for (u_id, p_id, it), params in scheduler.scheduled_pipelines.items(): + assert len(scheduler._scheduled_pipelines) == 1 # noqa: SLF001 + for ( + u_id, + p_id, + it, + ), params in scheduler._scheduled_pipelines.items(): # noqa: SLF001 assert u_id == user["id"] assert p_id == sleepers_project.uuid assert it > 0 @@ -415,7 +417,7 @@ async def test_misconfigured_pipeline_is_not_scheduled( # let the scheduler kick in await schedule_all_pipelines(scheduler) # check the scheduled pipelines is again empty since it's misconfigured - assert len(scheduler.scheduled_pipelines) == 0 + assert len(scheduler._scheduled_pipelines) == 0 # noqa: SLF001 # check the database entry is correctly updated async with aiopg_engine.acquire() as conn: result = await conn.execute( @@ -444,12 +446,17 @@ async def _assert_start_pipeline( run_metadata=run_metadata, use_on_demand_clusters=False, ) - assert len(scheduler.scheduled_pipelines) == 1, "the pipeline is not scheduled!" - for (u_id, p_id, it), params in scheduler.scheduled_pipelines.items(): + assert ( + len(scheduler._scheduled_pipelines) == 1 # noqa: SLF001 + ), "the pipeline is not scheduled!" + for ( + u_id, + p_id, + it, + ) in scheduler._scheduled_pipelines: # noqa: SLF001 assert u_id == published_project.project.prj_owner assert p_id == published_project.project.uuid assert it > 0 - assert params.run_metadata == run_metadata # check the database is correctly updated, the run is published await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PUBLISHED) @@ -1017,7 +1024,7 @@ async def _return_3rd_task_success(job_ids: list[str]) -> list[DaskClientTaskSta assert isinstance(messages[1], RabbitResourceTrackingStoppedMessage) # the scheduled pipeline shall be removed - assert scheduler.scheduled_pipelines == {} + assert scheduler._scheduled_pipelines == {} # noqa: SLF001 async def test_task_progress_triggers( From b824819605d790e891399d367f7cd76327a1efda Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:12:30 +0100 Subject: [PATCH 7/9] cleanup --- .../unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py b/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py index a54bb3fbf68..4f644fcaf6f 100644 --- a/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py +++ b/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py @@ -400,7 +400,7 @@ async def test_misconfigured_pipeline_is_not_scheduled( u_id, p_id, it, - ), params in scheduler._scheduled_pipelines.items(): # noqa: SLF001 + ) in scheduler._scheduled_pipelines: # noqa: SLF001 assert u_id == user["id"] assert p_id == sleepers_project.uuid assert it > 0 From bf896c977c6c6c92bf292f3e6ccf01d682f7f555 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:27:43 +0100 Subject: [PATCH 8/9] remove flaky flag --- .../unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py b/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py index 4f644fcaf6f..927476e851d 100644 --- a/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py +++ b/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py @@ -1165,7 +1165,6 @@ class RebootState: expected_run_state: RunningState -@pytest.mark.flaky(max_runs=3) @pytest.mark.parametrize( "reboot_state", [ From 13af0393d0a361a10ff1cf36bc9fb055a2b568a7 Mon Sep 17 00:00:00 2001 From: sanderegg <35365065+sanderegg@users.noreply.github.com> Date: Thu, 14 Nov 2024 17:18:56 +0100 Subject: [PATCH 9/9] only stop if there are tasks to stop --- .../modules/comp_scheduler/_base_scheduler.py | 42 +++++++++---------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py index 91db4f4fcca..e6d8e6da491 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py @@ -84,9 +84,9 @@ _Current = CompTaskAtDB _MAX_WAITING_FOR_CLUSTER_TIMEOUT_IN_MIN: Final[int] = 10 _SCHEDULER_INTERVAL: Final[datetime.timedelta] = datetime.timedelta(seconds=5) -_TASK_NAME_TEMPLATE: Final[ - str -] = "computational-scheduler-{user_id}:{project_id}:{iteration}" +_TASK_NAME_TEMPLATE: Final[str] = ( + "computational-scheduler-{user_id}:{project_id}:{iteration}" +) PipelineSchedulingTask: TypeAlias = asyncio.Task PipelineSchedulingWakeUpEvent: TypeAlias = asyncio.Event @@ -219,9 +219,9 @@ async def run_new_pipeline( task, wake_up_event = self._start_scheduling( user_id, project_id, new_run.iteration ) - self._scheduled_pipelines[ - (user_id, project_id, new_run.iteration) - ] = ScheduledPipelineParams(scheduler_task=task, scheduler_waker=wake_up_event) + self._scheduled_pipelines[(user_id, project_id, new_run.iteration)] = ( + ScheduledPipelineParams(scheduler_task=task, scheduler_waker=wake_up_event) + ) await publish_project_log( self.rabbitmq_client, user_id, @@ -653,20 +653,17 @@ async def _start_tasks( scheduled_tasks: dict[NodeID, CompTaskAtDB], comp_run: CompRunsAtDB, wake_up_callback: Callable[[], None], - ) -> None: - ... + ) -> None: ... @abstractmethod async def _get_tasks_status( self, user_id: UserID, tasks: list[CompTaskAtDB], comp_run: CompRunsAtDB - ) -> list[RunningState]: - ... + ) -> list[RunningState]: ... @abstractmethod async def _stop_tasks( self, user_id: UserID, tasks: list[CompTaskAtDB], comp_run: CompRunsAtDB - ) -> None: - ... + ) -> None: ... @abstractmethod async def _process_completed_tasks( @@ -675,8 +672,7 @@ async def _process_completed_tasks( tasks: list[CompTaskAtDB], iteration: Iteration, comp_run: CompRunsAtDB, - ) -> None: - ... + ) -> None: ... @staticmethod def _build_exclusive_lock_key(*args, **kwargs) -> str: @@ -816,8 +812,10 @@ async def _schedule_tasks_to_stop( project_id ) # stop any remaining running task, these are already submitted - tasks_to_stop = [t for t in comp_tasks.values() if t.state in PROCESSING_STATES] - await self._stop_tasks(user_id, tasks_to_stop, comp_run) + if tasks_to_stop := [ + t for t in comp_tasks.values() if t.state in PROCESSING_STATES + ]: + await self._stop_tasks(user_id, tasks_to_stop, comp_run) async def _schedule_tasks_to_start( # noqa: C901 self, @@ -877,9 +875,9 @@ async def _schedule_tasks_to_start( # noqa: C901 RunningState.WAITING_FOR_CLUSTER, ) for task in tasks_ready_to_start: - comp_tasks[ - NodeIDStr(f"{task}") - ].state = RunningState.WAITING_FOR_CLUSTER + comp_tasks[NodeIDStr(f"{task}")].state = ( + RunningState.WAITING_FOR_CLUSTER + ) except ComputationalBackendOnDemandNotReadyError as exc: _logger.info( @@ -901,9 +899,9 @@ async def _schedule_tasks_to_start( # noqa: C901 RunningState.WAITING_FOR_CLUSTER, ) for task in tasks_ready_to_start: - comp_tasks[ - NodeIDStr(f"{task}") - ].state = RunningState.WAITING_FOR_CLUSTER + comp_tasks[NodeIDStr(f"{task}")].state = ( + RunningState.WAITING_FOR_CLUSTER + ) except ClustersKeeperNotAvailableError: _logger.exception("Unexpected error while starting tasks:") await publish_project_log(