diff --git a/packages/models-library/src/models_library/clusters.py b/packages/models-library/src/models_library/clusters.py index 1856dc5c287..5fa14c2daa1 100644 --- a/packages/models-library/src/models_library/clusters.py +++ b/packages/models-library/src/models_library/clusters.py @@ -154,7 +154,7 @@ class Config: ClusterID: TypeAlias = NonNegativeInt -DEFAULT_CLUSTER_ID: Final[NonNegativeInt] = 0 +DEFAULT_CLUSTER_ID: Final[ClusterID] = 0 class Cluster(BaseCluster): diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/__init__.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/__init__.py index d06c37457b7..2b29acf16c9 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/__init__.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/__init__.py @@ -16,10 +16,7 @@ async def start_scheduler() -> None: with log_context( _logger, level=logging.INFO, msg="starting computational scheduler" ): - app.state.scheduler = scheduler = await _scheduler_factory.create_from_db( - app - ) - scheduler.recover_scheduling() + app.state.scheduler = await _scheduler_factory.create_from_db(app) return start_scheduler diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py index 097afd95288..e6d8e6da491 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py @@ -17,8 +17,9 @@ import functools import logging from abc import ABC, abstractmethod +from collections.abc import Callable from dataclasses import dataclass, field -from typing import Final +from typing import Final, TypeAlias import arrow import networkx as nx @@ -52,12 +53,13 @@ ) from ...core.settings import ComputationalBackendSettings from ...models.comp_pipelines import CompPipelineAtDB -from ...models.comp_runs import RunMetadataDict +from ...models.comp_runs import CompRunsAtDB, RunMetadataDict from ...models.comp_tasks import CompTaskAtDB from ...utils.comp_scheduler import ( COMPLETED_STATES, PROCESSING_STATES, RUNNING_STATES, + SCHEDULED_STATES, TASK_TO_START_STATES, WAITING_FOR_START_STATES, Iteration, @@ -82,9 +84,12 @@ _Current = CompTaskAtDB _MAX_WAITING_FOR_CLUSTER_TIMEOUT_IN_MIN: Final[int] = 10 _SCHEDULER_INTERVAL: Final[datetime.timedelta] = datetime.timedelta(seconds=5) -_TASK_NAME_TEMPLATE: Final[ - str -] = "computational-scheduler-{user_id}:{project_id}:{iteration}" +_TASK_NAME_TEMPLATE: Final[str] = ( + "computational-scheduler-{user_id}:{project_id}:{iteration}" +) + +PipelineSchedulingTask: TypeAlias = asyncio.Task +PipelineSchedulingWakeUpEvent: TypeAlias = asyncio.Event @dataclass(frozen=True, slots=True) @@ -138,13 +143,8 @@ async def _triage_changed_tasks( @dataclass(kw_only=True) class ScheduledPipelineParams: - cluster_id: ClusterID - run_metadata: RunMetadataDict - mark_for_cancellation: datetime.datetime | None - use_on_demand_clusters: bool - - scheduler_task: asyncio.Task | None = None - scheduler_waker: asyncio.Event = field(default_factory=asyncio.Event) + scheduler_task: asyncio.Task + scheduler_waker: asyncio.Event def wake_up(self) -> None: self.scheduler_waker.set() @@ -152,17 +152,41 @@ def wake_up(self) -> None: @dataclass class BaseCompScheduler(ABC): - scheduled_pipelines: dict[ - tuple[UserID, ProjectID, Iteration], ScheduledPipelineParams - ] db_engine: Engine - wake_up_event: asyncio.Event = field(default_factory=asyncio.Event, init=False) rabbitmq_client: RabbitMQClient rabbitmq_rpc_client: RabbitMQRPCClient settings: ComputationalBackendSettings service_runtime_heartbeat_interval: datetime.timedelta redis_client: RedisClientSDK + # NOTE: this is a trick to be able to inheritate from the class + _scheduled_pipelines: dict[ + tuple[UserID, ProjectID, Iteration], ScheduledPipelineParams + ] = field(default_factory=dict, init=False) + + def __post_init__(self) -> None: + self._scheduled_pipelines = {} + + async def restore_scheduling_from_db(self) -> None: + # get currently scheduled runs + comp_runs = await CompRunsRepository.instance(self.db_engine).list( + filter_by_state=SCHEDULED_STATES + ) + + for run in comp_runs: + task, wake_up_event = self._start_scheduling( + run.user_id, run.project_uuid, run.iteration + ) + self._scheduled_pipelines |= { + ( + run.user_id, + run.project_uuid, + run.iteration, + ): ScheduledPipelineParams( + scheduler_task=task, scheduler_waker=wake_up_event + ) + } + async def run_new_pipeline( self, user_id: UserID, @@ -192,13 +216,11 @@ async def run_new_pipeline( metadata=run_metadata, use_on_demand_clusters=use_on_demand_clusters, ) - self.scheduled_pipelines[ - (user_id, project_id, new_run.iteration) - ] = pipeline_params = ScheduledPipelineParams( - cluster_id=cluster_id, - run_metadata=new_run.metadata, - use_on_demand_clusters=use_on_demand_clusters, - mark_for_cancellation=None, + task, wake_up_event = self._start_scheduling( + user_id, project_id, new_run.iteration + ) + self._scheduled_pipelines[(user_id, project_id, new_run.iteration)] = ( + ScheduledPipelineParams(scheduler_task=task, scheduler_waker=wake_up_event) ) await publish_project_log( self.rabbitmq_client, @@ -208,8 +230,6 @@ async def run_new_pipeline( log_level=logging.INFO, ) - self._start_scheduling(pipeline_params, user_id, project_id, new_run.iteration) - async def stop_pipeline( self, user_id: UserID, project_id: ProjectID, iteration: int | None = None ) -> None: @@ -217,7 +237,7 @@ async def stop_pipeline( # if no iteration given find the latest one in the list possible_iterations = { it - for u_id, p_id, it in self.scheduled_pipelines + for u_id, p_id, it in self._scheduled_pipelines if u_id == user_id and p_id == project_id } if not possible_iterations: @@ -236,28 +256,17 @@ async def stop_pipeline( ) if updated_comp_run: assert updated_comp_run.cancelled is not None # nosec - self.scheduled_pipelines[ - (user_id, project_id, selected_iteration) - ].mark_for_cancellation = updated_comp_run.cancelled # ensure the scheduler starts right away - self.scheduled_pipelines[ + self._scheduled_pipelines[ (user_id, project_id, selected_iteration) ].wake_up() - def recover_scheduling(self) -> None: - for ( - user_id, - project_id, - iteration, - ), params in self.scheduled_pipelines.items(): - self._start_scheduling(params, user_id, project_id, iteration) - async def shutdown(self) -> None: # cancel all current scheduling processes await asyncio.gather( *( stop_periodic_task(p.scheduler_task, timeout=3) - for p in self.scheduled_pipelines.values() + for p in self._scheduled_pipelines.values() if p.scheduler_task ), return_exceptions=True, @@ -267,7 +276,7 @@ def _get_last_iteration(self, user_id: UserID, project_id: ProjectID) -> Iterati # if no iteration given find the latest one in the list possible_iterations = { it - for u_id, p_id, it in self.scheduled_pipelines + for u_id, p_id, it in self._scheduled_pipelines if u_id == user_id and p_id == project_id } if not possible_iterations: @@ -277,40 +286,41 @@ def _get_last_iteration(self, user_id: UserID, project_id: ProjectID) -> Iterati def _start_scheduling( self, - pipeline_params: ScheduledPipelineParams, user_id: UserID, project_id: ProjectID, iteration: Iteration, - ) -> None: + ) -> tuple[PipelineSchedulingTask, PipelineSchedulingWakeUpEvent]: async def _exclusive_safe_schedule_pipeline( *, user_id: UserID, project_id: ProjectID, iteration: Iteration, - pipeline_params: ScheduledPipelineParams, + wake_up_callback: Callable[[], None], ) -> None: with contextlib.suppress(CouldNotAcquireLockError): await self._schedule_pipeline( user_id=user_id, project_id=project_id, iteration=iteration, - pipeline_params=pipeline_params, + wake_up_callback=wake_up_callback, ) - pipeline_params.scheduler_task = start_periodic_task( + pipeline_wake_up_event = asyncio.Event() + pipeline_task = start_periodic_task( functools.partial( _exclusive_safe_schedule_pipeline, user_id=user_id, project_id=project_id, iteration=iteration, - pipeline_params=pipeline_params, + wake_up_callback=pipeline_wake_up_event.set, ), interval=_SCHEDULER_INTERVAL, task_name=_TASK_NAME_TEMPLATE.format( user_id=user_id, project_id=project_id, iteration=iteration ), - early_wake_up_event=pipeline_params.scheduler_waker, + early_wake_up_event=pipeline_wake_up_event, ) + return pipeline_task, pipeline_wake_up_event async def _get_pipeline_dag(self, project_id: ProjectID) -> nx.DiGraph: comp_pipeline_repo = CompPipelinesRepository.instance(self.db_engine) @@ -451,10 +461,10 @@ async def _get_changed_tasks_from_backend( self, user_id: UserID, processing_tasks: list[CompTaskAtDB], - pipeline_params: ScheduledPipelineParams, + comp_run: CompRunsAtDB, ) -> list[tuple[_Previous, _Current]]: tasks_backend_status = await self._get_tasks_status( - user_id, processing_tasks, pipeline_params + user_id, processing_tasks, comp_run ) return [ @@ -592,7 +602,7 @@ async def _update_states_from_comp_backend( project_id: ProjectID, iteration: Iteration, pipeline_dag: nx.DiGraph, - pipeline_params: ScheduledPipelineParams, + comp_run: CompRunsAtDB, ) -> None: tasks = await self._get_pipeline_tasks(project_id, pipeline_dag) tasks_inprocess = [t for t in tasks.values() if t.state in PROCESSING_STATES] @@ -601,7 +611,7 @@ async def _update_states_from_comp_backend( # get the tasks which state actually changed since last check tasks_with_changed_states = await self._get_changed_tasks_from_backend( - user_id, tasks_inprocess, pipeline_params + user_id, tasks_inprocess, comp_run ) # NOTE: typical states a task goes through # NOT_STARTED (initial state) -> PUBLISHED (user press run/API call) -> PENDING -> WAITING_FOR_CLUSTER (cluster creation) -> @@ -620,7 +630,7 @@ async def _update_states_from_comp_backend( sorted_tasks.started, user_id=user_id, iteration=iteration, - run_metadata=pipeline_params.run_metadata, + run_metadata=comp_run.metadata, ) if sorted_tasks.completed or sorted_tasks.potentially_lost: @@ -628,7 +638,7 @@ async def _update_states_from_comp_backend( user_id, sorted_tasks.completed + sorted_tasks.potentially_lost, iteration, - pipeline_params=pipeline_params, + comp_run=comp_run, ) if sorted_tasks.waiting: @@ -641,27 +651,19 @@ async def _start_tasks( user_id: UserID, project_id: ProjectID, scheduled_tasks: dict[NodeID, CompTaskAtDB], - pipeline_params: ScheduledPipelineParams, - ) -> None: - ... + comp_run: CompRunsAtDB, + wake_up_callback: Callable[[], None], + ) -> None: ... @abstractmethod async def _get_tasks_status( - self, - user_id: UserID, - tasks: list[CompTaskAtDB], - pipeline_params: ScheduledPipelineParams, - ) -> list[RunningState]: - ... + self, user_id: UserID, tasks: list[CompTaskAtDB], comp_run: CompRunsAtDB + ) -> list[RunningState]: ... @abstractmethod async def _stop_tasks( - self, - user_id: UserID, - tasks: list[CompTaskAtDB], - pipeline_params: ScheduledPipelineParams, - ) -> None: - ... + self, user_id: UserID, tasks: list[CompTaskAtDB], comp_run: CompRunsAtDB + ) -> None: ... @abstractmethod async def _process_completed_tasks( @@ -669,9 +671,8 @@ async def _process_completed_tasks( user_id: UserID, tasks: list[CompTaskAtDB], iteration: Iteration, - pipeline_params: ScheduledPipelineParams, - ) -> None: - ... + comp_run: CompRunsAtDB, + ) -> None: ... @staticmethod def _build_exclusive_lock_key(*args, **kwargs) -> str: @@ -695,7 +696,7 @@ async def _schedule_pipeline( user_id: UserID, project_id: ProjectID, iteration: PositiveInt, - pipeline_params: ScheduledPipelineParams, + wake_up_callback: Callable[[], None], ) -> None: with log_context( _logger, @@ -704,19 +705,22 @@ async def _schedule_pipeline( ): dag: nx.DiGraph = nx.DiGraph() try: + comp_run = await CompRunsRepository.instance(self.db_engine).get( + user_id, project_id, iteration + ) dag = await self._get_pipeline_dag(project_id) # 1. Update our list of tasks with data from backend (state, results) await self._update_states_from_comp_backend( - user_id, project_id, iteration, dag, pipeline_params=pipeline_params + user_id, project_id, iteration, dag, comp_run ) # 2. Any task following a FAILED task shall be ABORTED comp_tasks = await self._set_states_following_failed_to_aborted( project_id, dag ) # 3. do we want to stop the pipeline now? - if pipeline_params.mark_for_cancellation: + if comp_run.cancelled: await self._schedule_tasks_to_stop( - user_id, project_id, comp_tasks, pipeline_params + user_id, project_id, comp_tasks, comp_run ) else: # let's get the tasks to schedule then @@ -725,7 +729,8 @@ async def _schedule_pipeline( project_id=project_id, comp_tasks=comp_tasks, dag=dag, - pipeline_params=pipeline_params, + comp_run=comp_run, + wake_up_callback=wake_up_callback, ) # 4. timeout if waiting for cluster has been there for more than X minutes comp_tasks = await self._timeout_if_waiting_for_cluster_too_long( @@ -744,14 +749,17 @@ async def _schedule_pipeline( # 7. Are we done scheduling that pipeline? if not dag.nodes() or pipeline_result in COMPLETED_STATES: # there is nothing left, the run is completed, we're done here - self.scheduled_pipelines.pop((user_id, project_id, iteration), None) + self._scheduled_pipelines.pop( + (user_id, project_id, iteration), None + ) _logger.info( "pipeline %s scheduling completed with result %s", f"{project_id=}", f"{pipeline_result=}", ) - assert pipeline_params.scheduler_task is not None # nosec - pipeline_params.scheduler_task.cancel() + current_task = asyncio.current_task() + assert current_task is not None # nosec + current_task.cancel() except PipelineNotFoundError: _logger.warning( "pipeline %s does not exist in comp_pipeline table, it will be removed from scheduler", @@ -760,7 +768,7 @@ async def _schedule_pipeline( await self._set_run_result( user_id, project_id, iteration, RunningState.ABORTED ) - self.scheduled_pipelines.pop((user_id, project_id, iteration), None) + self._scheduled_pipelines.pop((user_id, project_id, iteration), None) except InvalidPipelineError as exc: _logger.warning( "pipeline %s appears to be misconfigured, it will be removed from scheduler. Please check pipeline:\n%s", @@ -770,7 +778,7 @@ async def _schedule_pipeline( await self._set_run_result( user_id, project_id, iteration, RunningState.ABORTED ) - self.scheduled_pipelines.pop((user_id, project_id, iteration), None) + self._scheduled_pipelines.pop((user_id, project_id, iteration), None) except (DaskClientAcquisisitonError, ClustersKeeperNotAvailableError): _logger.exception( "Unexpected error while connecting with computational backend, aborting pipeline" @@ -787,7 +795,7 @@ async def _schedule_pipeline( await self._set_run_result( user_id, project_id, iteration, RunningState.FAILED ) - self.scheduled_pipelines.pop((user_id, project_id, iteration), None) + self._scheduled_pipelines.pop((user_id, project_id, iteration), None) except ComputationalBackendNotConnectedError: _logger.exception("Computational backend is not connected!") @@ -796,7 +804,7 @@ async def _schedule_tasks_to_stop( user_id: UserID, project_id: ProjectID, comp_tasks: dict[NodeIDStr, CompTaskAtDB], - pipeline_params: ScheduledPipelineParams, + comp_run: CompRunsAtDB, ) -> None: # get any running task and stop them comp_tasks_repo = CompTasksRepository.instance(self.db_engine) @@ -804,8 +812,10 @@ async def _schedule_tasks_to_stop( project_id ) # stop any remaining running task, these are already submitted - tasks_to_stop = [t for t in comp_tasks.values() if t.state in PROCESSING_STATES] - await self._stop_tasks(user_id, tasks_to_stop, pipeline_params) + if tasks_to_stop := [ + t for t in comp_tasks.values() if t.state in PROCESSING_STATES + ]: + await self._stop_tasks(user_id, tasks_to_stop, comp_run) async def _schedule_tasks_to_start( # noqa: C901 self, @@ -813,7 +823,8 @@ async def _schedule_tasks_to_start( # noqa: C901 project_id: ProjectID, comp_tasks: dict[NodeIDStr, CompTaskAtDB], dag: nx.DiGraph, - pipeline_params: ScheduledPipelineParams, + comp_run: CompRunsAtDB, + wake_up_callback: Callable[[], None], ) -> dict[NodeIDStr, CompTaskAtDB]: # filter out the successfully completed tasks dag.remove_nodes_from( @@ -845,7 +856,8 @@ async def _schedule_tasks_to_start( # noqa: C901 user_id=user_id, project_id=project_id, scheduled_tasks=tasks_ready_to_start, - pipeline_params=pipeline_params, + comp_run=comp_run, + wake_up_callback=wake_up_callback, ) except ( ComputationalBackendNotConnectedError, @@ -863,9 +875,9 @@ async def _schedule_tasks_to_start( # noqa: C901 RunningState.WAITING_FOR_CLUSTER, ) for task in tasks_ready_to_start: - comp_tasks[ - NodeIDStr(f"{task}") - ].state = RunningState.WAITING_FOR_CLUSTER + comp_tasks[NodeIDStr(f"{task}")].state = ( + RunningState.WAITING_FOR_CLUSTER + ) except ComputationalBackendOnDemandNotReadyError as exc: _logger.info( @@ -887,9 +899,9 @@ async def _schedule_tasks_to_start( # noqa: C901 RunningState.WAITING_FOR_CLUSTER, ) for task in tasks_ready_to_start: - comp_tasks[ - NodeIDStr(f"{task}") - ].state = RunningState.WAITING_FOR_CLUSTER + comp_tasks[NodeIDStr(f"{task}")].state = ( + RunningState.WAITING_FOR_CLUSTER + ) except ClustersKeeperNotAvailableError: _logger.exception("Unexpected error while starting tasks:") await publish_project_log( @@ -934,7 +946,7 @@ async def _schedule_tasks_to_start( # noqa: C901 "Unexpected error for %s with %s on %s happened when scheduling %s:", f"{user_id=}", f"{project_id=}", - f"{pipeline_params.cluster_id=}", + f"{comp_run.cluster_id=}", f"{tasks_ready_to_start.keys()=}", ) await CompTasksRepository.instance( diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_dask_scheduler.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_dask_scheduler.py index 512df1b1712..2f0c8c4eab4 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_dask_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_dask_scheduler.py @@ -4,6 +4,7 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from dataclasses import dataclass +from typing import Callable import arrow from dask_task_models_library.container_tasks.errors import TaskCancelledError @@ -12,7 +13,7 @@ TaskProgressEvent, ) from dask_task_models_library.container_tasks.io import TaskOutputData -from models_library.clusters import DEFAULT_CLUSTER_ID, BaseCluster +from models_library.clusters import DEFAULT_CLUSTER_ID, BaseCluster, ClusterID from models_library.errors import ErrorDict from models_library.projects import ProjectID from models_library.projects_nodes_io import NodeID @@ -27,7 +28,7 @@ ComputationalBackendOnDemandNotReadyError, TaskSchedulingError, ) -from ...models.comp_runs import RunMetadataDict +from ...models.comp_runs import CompRunsAtDB, RunMetadataDict from ...models.comp_tasks import CompTaskAtDB from ...models.dask_subsystem import DaskClientTaskState from ...utils.comp_scheduler import Iteration, get_resource_tracking_run_id @@ -49,7 +50,7 @@ from ..db.repositories.clusters import ClustersRepository from ..db.repositories.comp_runs import CompRunsRepository from ..db.repositories.comp_tasks import CompTasksRepository -from ._base_scheduler import BaseCompScheduler, ScheduledPipelineParams +from ._base_scheduler import BaseCompScheduler _logger = logging.getLogger(__name__) @@ -69,19 +70,22 @@ @asynccontextmanager async def _cluster_dask_client( user_id: UserID, - pipeline_params: ScheduledPipelineParams, scheduler: "DaskScheduler", + *, + use_on_demand_clusters: bool, + cluster_id: ClusterID, + run_metadata: RunMetadataDict, ) -> AsyncIterator[DaskClient]: cluster: BaseCluster = scheduler.settings.default_cluster - if pipeline_params.use_on_demand_clusters: + if use_on_demand_clusters: cluster = await get_or_create_on_demand_cluster( scheduler.rabbitmq_rpc_client, user_id=user_id, - wallet_id=pipeline_params.run_metadata.get("wallet_id"), + wallet_id=run_metadata.get("wallet_id"), ) - if pipeline_params.cluster_id != DEFAULT_CLUSTER_ID: + if cluster_id != DEFAULT_CLUSTER_ID: clusters_repo = ClustersRepository.instance(scheduler.db_engine) - cluster = await clusters_repo.get_cluster(user_id, pipeline_params.cluster_id) + cluster = await clusters_repo.get_cluster(user_id, cluster_id) async with scheduler.dask_clients_pool.acquire(cluster) as client: yield client @@ -104,10 +108,21 @@ async def _start_tasks( user_id: UserID, project_id: ProjectID, scheduled_tasks: dict[NodeID, CompTaskAtDB], - pipeline_params: ScheduledPipelineParams, + comp_run: CompRunsAtDB, + wake_up_callback: Callable[[], None], ) -> None: # now transfer the pipeline to the dask scheduler - async with _cluster_dask_client(user_id, pipeline_params, self) as client: + async with _cluster_dask_client( + user_id, + self, + use_on_demand_clusters=comp_run.use_on_demand_clusters, + cluster_id=( + comp_run.cluster_id + if comp_run.cluster_id is not None + else DEFAULT_CLUSTER_ID + ), + run_metadata=comp_run.metadata, + ) as client: # Change the tasks state to PENDING comp_tasks_repo = CompTasksRepository.instance(self.db_engine) await comp_tasks_repo.update_project_tasks_state( @@ -121,11 +136,15 @@ async def _start_tasks( client.send_computation_tasks( user_id=user_id, project_id=project_id, - cluster_id=pipeline_params.cluster_id, + cluster_id=( + comp_run.cluster_id + if comp_run.cluster_id is not None + else DEFAULT_CLUSTER_ID + ), tasks={node_id: task.image}, hardware_info=task.hardware_info, - callback=pipeline_params.wake_up, - metadata=pipeline_params.run_metadata, + callback=wake_up_callback, + metadata=comp_run.metadata, ) for node_id, task in scheduled_tasks.items() ), @@ -146,10 +165,20 @@ async def _get_tasks_status( self, user_id: UserID, tasks: list[CompTaskAtDB], - pipeline_params: ScheduledPipelineParams, + comp_run: CompRunsAtDB, ) -> list[RunningState]: try: - async with _cluster_dask_client(user_id, pipeline_params, self) as client: + async with _cluster_dask_client( + user_id, + self, + use_on_demand_clusters=comp_run.use_on_demand_clusters, + cluster_id=( + comp_run.cluster_id + if comp_run.cluster_id is not None + else DEFAULT_CLUSTER_ID + ), + run_metadata=comp_run.metadata, + ) as client: tasks_statuses = await client.get_tasks_status( [f"{t.job_id}" for t in tasks] ) @@ -177,14 +206,21 @@ async def _get_tasks_status( return [RunningState.WAITING_FOR_CLUSTER] * len(tasks) async def _stop_tasks( - self, - user_id: UserID, - tasks: list[CompTaskAtDB], - pipeline_params: ScheduledPipelineParams, + self, user_id: UserID, tasks: list[CompTaskAtDB], comp_run: CompRunsAtDB ) -> None: # NOTE: if this exception raises, it means the backend was anyway not up with contextlib.suppress(ComputationalBackendOnDemandNotReadyError): - async with _cluster_dask_client(user_id, pipeline_params, self) as client: + async with _cluster_dask_client( + user_id, + self, + use_on_demand_clusters=comp_run.use_on_demand_clusters, + cluster_id=( + comp_run.cluster_id + if comp_run.cluster_id is not None + else DEFAULT_CLUSTER_ID + ), + run_metadata=comp_run.metadata, + ) as client: await asyncio.gather( *[ client.abort_computation_task(t.job_id) @@ -209,10 +245,20 @@ async def _process_completed_tasks( user_id: UserID, tasks: list[CompTaskAtDB], iteration: Iteration, - pipeline_params: ScheduledPipelineParams, + comp_run: CompRunsAtDB, ) -> None: try: - async with _cluster_dask_client(user_id, pipeline_params, self) as client: + async with _cluster_dask_client( + user_id, + self, + use_on_demand_clusters=comp_run.use_on_demand_clusters, + cluster_id=( + comp_run.cluster_id + if comp_run.cluster_id is not None + else DEFAULT_CLUSTER_ID + ), + run_metadata=comp_run.metadata, + ) as client: tasks_results = await asyncio.gather( *[client.get_task_result(t.job_id or "undefined") for t in tasks], return_exceptions=True, @@ -220,13 +266,23 @@ async def _process_completed_tasks( await asyncio.gather( *[ self._process_task_result( - task, result, pipeline_params.run_metadata, iteration + task, result, comp_run.metadata, iteration ) for task, result in zip(tasks, tasks_results, strict=True) ] ) finally: - async with _cluster_dask_client(user_id, pipeline_params, self) as client: + async with _cluster_dask_client( + user_id, + self, + use_on_demand_clusters=comp_run.use_on_demand_clusters, + cluster_id=( + comp_run.cluster_id + if comp_run.cluster_id is not None + else DEFAULT_CLUSTER_ID + ), + run_metadata=comp_run.metadata, + ) as client: await asyncio.gather( *[client.release_task_result(t.job_id) for t in tasks if t.job_id] ) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py index 4f7812816cc..39b432b9492 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_factory.py @@ -1,19 +1,15 @@ import logging from fastapi import FastAPI -from models_library.clusters import DEFAULT_CLUSTER_ID from servicelib.logging_utils import log_context from settings_library.redis import RedisDatabase from ...core.errors import ConfigurationError from ...core.settings import AppSettings -from ...models.comp_runs import CompRunsAtDB -from ...utils.comp_scheduler import SCHEDULED_STATES from ..dask_clients_pool import DaskClientsPool -from ..db.repositories.comp_runs import CompRunsRepository from ..rabbitmq import get_rabbitmq_client, get_rabbitmq_rpc_client from ..redis import get_redis_client_manager -from ._base_scheduler import BaseCompScheduler, ScheduledPipelineParams +from ._base_scheduler import BaseCompScheduler from ._dask_scheduler import DaskScheduler _logger = logging.getLogger(__name__) @@ -24,39 +20,19 @@ async def create_from_db(app: FastAPI) -> BaseCompScheduler: msg = "Database connection is missing. Please check application configuration." raise ConfigurationError(msg) db_engine = app.state.engine - runs_repository = CompRunsRepository.instance(db_engine) - - # get currently scheduled runs - runs: list[CompRunsAtDB] = await runs_repository.list( - filter_by_state=SCHEDULED_STATES - ) - - _logger.debug( - "Following scheduled comp_runs found still to be scheduled: %s", - runs if runs else "NONE", - ) with log_context( _logger, logging.INFO, msg="Creating Dask-based computational scheduler" ): app_settings: AppSettings = app.state.settings - return DaskScheduler( + scheduler = DaskScheduler( settings=app_settings.DIRECTOR_V2_COMPUTATIONAL_BACKEND, dask_clients_pool=DaskClientsPool.instance(app), rabbitmq_client=get_rabbitmq_client(app), rabbitmq_rpc_client=get_rabbitmq_rpc_client(app), redis_client=get_redis_client_manager(app).client(RedisDatabase.LOCKS), db_engine=db_engine, - scheduled_pipelines={ - (r.user_id, r.project_uuid, r.iteration): ScheduledPipelineParams( - cluster_id=( - r.cluster_id if r.cluster_id is not None else DEFAULT_CLUSTER_ID - ), - run_metadata=r.metadata, - mark_for_cancellation=r.cancelled, - use_on_demand_clusters=r.use_on_demand_clusters, - ) - for r in runs - }, service_runtime_heartbeat_interval=app_settings.SERVICE_TRACKING_HEARTBEAT, ) + await scheduler.restore_scheduling_from_db() + return scheduler diff --git a/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py b/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py index 1df1ae09d39..927476e851d 100644 --- a/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py +++ b/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py @@ -71,9 +71,6 @@ BaseCompScheduler, get_scheduler, ) -from simcore_service_director_v2.modules.comp_scheduler._base_scheduler import ( - ScheduledPipelineParams, -) from simcore_service_director_v2.modules.comp_scheduler._dask_scheduler import ( DaskScheduler, ) @@ -165,14 +162,14 @@ async def schedule_all_pipelines(scheduler: BaseCompScheduler) -> None: # misconfigured pipelines that would be removed from the scheduler # NOTE: we simulate multiple dv-2 replicas by running several times # the same pipeline scheduling - local_pipelines = deepcopy(scheduler.scheduled_pipelines) + local_pipelines = deepcopy(scheduler._scheduled_pipelines) # noqa: SLF001 results = await asyncio.gather( *( scheduler._schedule_pipeline( # noqa: SLF001 user_id=user_id, project_id=project_id, iteration=iteration, - pipeline_params=params, + wake_up_callback=params.scheduler_waker.set, ) for _ in range(3) for ( @@ -256,15 +253,16 @@ def mocked_clean_task_output_fct(mocker: MockerFixture) -> mock.MagicMock: @pytest.fixture def with_disabled_auto_scheduling(mocker: MockerFixture) -> mock.MagicMock: - """disables the scheduler task, note that it needs to be triggered manually then""" + """disables the scheduler task, note that it needs to be triggered manu>ally then""" def _fake_starter( self: BaseCompScheduler, - pipeline_params: ScheduledPipelineParams, *args, **kwargs, - ) -> None: - pipeline_params.scheduler_task = mocker.MagicMock() + ): + scheduler_task = mocker.MagicMock() + scheduler_task_wake_up_event = mocker.MagicMock() + return scheduler_task, scheduler_task_wake_up_event return mocker.patch( "simcore_service_director_v2.modules.comp_scheduler._base_scheduler.BaseCompScheduler._start_scheduling", @@ -358,7 +356,7 @@ async def test_empty_pipeline_is_not_scheduled( run_metadata=run_metadata, use_on_demand_clusters=False, ) - assert len(scheduler.scheduled_pipelines) == 0 + assert len(scheduler._scheduled_pipelines) == 0 # noqa: SLF001 # check the database is empty async with aiopg_engine.acquire() as conn: result = await conn.scalar( @@ -397,12 +395,15 @@ async def test_misconfigured_pipeline_is_not_scheduled( run_metadata=run_metadata, use_on_demand_clusters=False, ) - assert len(scheduler.scheduled_pipelines) == 1 - for (u_id, p_id, it), params in scheduler.scheduled_pipelines.items(): + assert len(scheduler._scheduled_pipelines) == 1 # noqa: SLF001 + for ( + u_id, + p_id, + it, + ) in scheduler._scheduled_pipelines: # noqa: SLF001 assert u_id == user["id"] assert p_id == sleepers_project.uuid assert it > 0 - assert params.mark_for_cancellation is None # check the database was properly updated async with aiopg_engine.acquire() as conn: result = await conn.execute( @@ -416,7 +417,7 @@ async def test_misconfigured_pipeline_is_not_scheduled( # let the scheduler kick in await schedule_all_pipelines(scheduler) # check the scheduled pipelines is again empty since it's misconfigured - assert len(scheduler.scheduled_pipelines) == 0 + assert len(scheduler._scheduled_pipelines) == 0 # noqa: SLF001 # check the database entry is correctly updated async with aiopg_engine.acquire() as conn: result = await conn.execute( @@ -445,13 +446,17 @@ async def _assert_start_pipeline( run_metadata=run_metadata, use_on_demand_clusters=False, ) - assert len(scheduler.scheduled_pipelines) == 1, "the pipeline is not scheduled!" - for (u_id, p_id, it), params in scheduler.scheduled_pipelines.items(): + assert ( + len(scheduler._scheduled_pipelines) == 1 # noqa: SLF001 + ), "the pipeline is not scheduled!" + for ( + u_id, + p_id, + it, + ) in scheduler._scheduled_pipelines: # noqa: SLF001 assert u_id == published_project.project.prj_owner assert p_id == published_project.project.uuid assert it > 0 - assert params.mark_for_cancellation is None - assert params.run_metadata == run_metadata # check the database is correctly updated, the run is published await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PUBLISHED) @@ -1019,7 +1024,7 @@ async def _return_3rd_task_success(job_ids: list[str]) -> list[DaskClientTaskSta assert isinstance(messages[1], RabbitResourceTrackingStoppedMessage) # the scheduled pipeline shall be removed - assert scheduler.scheduled_pipelines == {} + assert scheduler._scheduled_pipelines == {} # noqa: SLF001 async def test_task_progress_triggers( @@ -1160,7 +1165,6 @@ class RebootState: expected_run_state: RunningState -@pytest.mark.flaky(max_runs=3) @pytest.mark.parametrize( "reboot_state", [