diff --git a/services/director-v2/src/simcore_service_director_v2/api/routes/computations.py b/services/director-v2/src/simcore_service_director_v2/api/routes/computations.py index e624ed0785cb..4dd160ed6f40 100644 --- a/services/director-v2/src/simcore_service_director_v2/api/routes/computations.py +++ b/services/director-v2/src/simcore_service_director_v2/api/routes/computations.py @@ -53,10 +53,10 @@ ClusterNotFoundError, ClustersKeeperNotAvailableError, ComputationalRunNotFoundError, + ComputationalSchedulerError, ConfigurationError, PricingPlanUnitNotFoundError, ProjectNotFoundError, - SchedulerError, WalletNotEnoughCreditsError, ) from ...models.comp_pipelines import CompPipelineAtDB @@ -510,7 +510,9 @@ async def get_computation( pipeline_details=pipeline_details, url=TypeAdapter(AnyHttpUrl).validate_python(f"{request.url}"), stop_url=( - TypeAdapter(AnyHttpUrl).validate_python(f"{self_url}:stop?user_id={user_id}") + TypeAdapter(AnyHttpUrl).validate_python( + f"{self_url}:stop?user_id={user_id}" + ) if pipeline_state.is_running() else None ), @@ -598,7 +600,7 @@ async def stop_computation( except ProjectNotFoundError as e: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"{e}") from e - except SchedulerError as e: + except ComputationalSchedulerError as e: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"{e}") from e @@ -639,7 +641,7 @@ async def delete_computation( # abort the pipeline first try: await scheduler.stop_pipeline(computation_stop.user_id, project_id) - except SchedulerError as e: + except ComputationalSchedulerError as e: _logger.warning( "Project %s could not be stopped properly.\n reason: %s", project_id, diff --git a/services/director-v2/src/simcore_service_director_v2/core/errors.py b/services/director-v2/src/simcore_service_director_v2/core/errors.py index 285e4d13a3c7..eb605c307eab 100644 --- a/services/director-v2/src/simcore_service_director_v2/core/errors.py +++ b/services/director-v2/src/simcore_service_director_v2/core/errors.py @@ -49,71 +49,47 @@ class PipelineNotFoundError(DirectorError): msg_template: str = "pipeline {pipeline_id} not found" -class ComputationalRunNotFoundError(OsparcErrorMixin, DirectorError): +class ComputationalRunNotFoundError(DirectorError): msg_template = "Computational run not found" -class ComputationalTaskNotFoundError(OsparcErrorMixin, DirectorError): +class ComputationalTaskNotFoundError(DirectorError): msg_template = "Computational task {node_id} not found" -class WalletNotEnoughCreditsError(OsparcErrorMixin, DirectorError): +class WalletNotEnoughCreditsError(DirectorError): msg_template = "Wallet '{wallet_name}' has {wallet_credit_amount} credits." # # SCHEDULER ERRORS # +class ComputationalSchedulerError(DirectorError): + msg_template = "Computational scheduler unexpected error" -class SchedulerError(DirectorError): - def __init__(self, msg: str | None = None): - super().__init__(msg or "Unexpected error in the scheduler") +class InvalidPipelineError(ComputationalSchedulerError): + msg_template = "Computational scheduler: Invalid configuration of pipeline {pipeline_id}: {msg}" -class InvalidPipelineError(SchedulerError): - """A pipeline is misconfigured""" - - def __init__(self, pipeline_id: str, msg: str | None = None): - super().__init__(msg or f"Invalid configuration of pipeline {pipeline_id}") - - -class TaskSchedulingError(SchedulerError): - """A task cannot be scheduled""" - - code: str = "task scheduler error" - - def __init__(self, project_id: ProjectID, node_id: NodeID, msg: str | None = None): - super().__init__(msg=msg) - self.project_id = project_id - self.node_id = node_id - - def get_errors(self) -> list[ErrorDict]: - # default implementation - return [ - { - "loc": ( - f"{self.project_id}", - f"{self.node_id}", - ), - "msg": f"{self.args[0]}", - "type": self.code, - }, - ] +class TaskSchedulingError(ComputationalSchedulerError): + msg_template = "Computational scheduler: Task {node_id} in project {project_id} could not be scheduled {msg}" class MissingComputationalResourcesError(TaskSchedulingError): - """A task cannot be scheduled because the cluster does not have the required resources""" - - def __init__(self, project_id: ProjectID, node_id: NodeID, msg: str | None = None): - super().__init__(project_id, node_id, msg=msg) + msg_template = ( + "Service {service_name}:{service_version} cannot be scheduled " + "on cluster {cluster_id}: task needs '{task_resources}', " + "cluster has {cluster_resources}", + ) class InsuficientComputationalResourcesError(TaskSchedulingError): - """A task cannot be scheduled because the cluster does not have *enough* of the required resources""" - - def __init__(self, project_id: ProjectID, node_id: NodeID, msg: str | None = None): - super().__init__(project_id, node_id, msg=msg) + msg_template: str = ( + "Insufficient computational resources to run {service_name}:{service_version} with {service_requested_resources} on cluster {cluster_id}." + "Cluster available workers: {cluster_available_resources}" + "TIP: Reduce service required resources or contact oSparc support" + ) class PortsValidationError(TaskSchedulingError): @@ -158,33 +134,33 @@ def get_errors(self) -> list[ErrorDict]: return value_errors -class ComputationalSchedulerChangedError(OsparcErrorMixin, SchedulerError): +class ComputationalSchedulerChangedError(ComputationalSchedulerError): msg_template = "The dask scheduler ID changed from '{original_scheduler_id}' to '{current_scheduler_id}'" -class ComputationalBackendNotConnectedError(OsparcErrorMixin, SchedulerError): +class ComputationalBackendNotConnectedError(ComputationalSchedulerError): msg_template = "The dask computational backend is not connected" -class ComputationalBackendNoS3AccessError(OsparcErrorMixin, SchedulerError): +class ComputationalBackendNoS3AccessError(ComputationalSchedulerError): msg_template = "The S3 backend is not ready, please try again later" -class ComputationalBackendTaskNotFoundError(OsparcErrorMixin, SchedulerError): +class ComputationalBackendTaskNotFoundError(ComputationalSchedulerError): msg_template = ( "The dask computational backend does not know about the task '{job_id}'" ) -class ComputationalBackendTaskResultsNotReadyError(OsparcErrorMixin, SchedulerError): +class ComputationalBackendTaskResultsNotReadyError(ComputationalSchedulerError): msg_template = "The task result is not ready yet for job '{job_id}'" -class ClustersKeeperNotAvailableError(OsparcErrorMixin, SchedulerError): +class ClustersKeeperNotAvailableError(ComputationalSchedulerError): msg_template = "clusters-keeper service is not available!" -class ComputationalBackendOnDemandNotReadyError(OsparcErrorMixin, SchedulerError): +class ComputationalBackendOnDemandNotReadyError(ComputationalSchedulerError): msg_template = ( "The on demand computational cluster is not ready 'est. remaining time: {eta}'" ) @@ -193,15 +169,15 @@ class ComputationalBackendOnDemandNotReadyError(OsparcErrorMixin, SchedulerError # # SCHEDULER/CLUSTER ERRORS # -class ClusterNotFoundError(OsparcErrorMixin, SchedulerError): +class ClusterNotFoundError(ComputationalSchedulerError): msg_template = "The cluster '{cluster_id}' not found" -class ClusterAccessForbiddenError(OsparcErrorMixin, SchedulerError): +class ClusterAccessForbiddenError(ComputationalSchedulerError): msg_template = "Insufficient rights to access cluster '{cluster_id}'" -class ClusterInvalidOperationError(OsparcErrorMixin, SchedulerError): +class ClusterInvalidOperationError(ComputationalSchedulerError): msg_template = "Invalid operation on cluster '{cluster_id}'" @@ -210,21 +186,21 @@ class ClusterInvalidOperationError(OsparcErrorMixin, SchedulerError): # -class DaskClientRequestError(OsparcErrorMixin, SchedulerError): +class DaskClientRequestError(ComputationalSchedulerError): msg_template = ( "The dask client to cluster on '{endpoint}' did an invalid request '{error}'" ) -class DaskClusterError(OsparcErrorMixin, SchedulerError): +class DaskClusterError(ComputationalSchedulerError): msg_template = "The dask cluster on '{endpoint}' encountered an error: '{error}'" -class DaskGatewayServerError(OsparcErrorMixin, SchedulerError): +class DaskGatewayServerError(ComputationalSchedulerError): msg_template = "The dask gateway on '{endpoint}' encountered an error: '{error}'" -class DaskClientAcquisisitonError(OsparcErrorMixin, SchedulerError): +class DaskClientAcquisisitonError(ComputationalSchedulerError): msg_template = ( "The dask client to cluster '{cluster}' encountered an error '{error}'" ) diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py index 5fd39dda4bc2..9dfae4bc6660 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_base_scheduler.py @@ -45,10 +45,10 @@ ComputationalBackendNotConnectedError, ComputationalBackendOnDemandNotReadyError, ComputationalSchedulerChangedError, + ComputationalSchedulerError, DaskClientAcquisisitonError, InvalidPipelineError, PipelineNotFoundError, - SchedulerError, TaskSchedulingError, ) from ...core.settings import ComputationalBackendSettings @@ -242,7 +242,7 @@ async def stop_pipeline( } if not possible_iterations: msg = f"There are no pipeline scheduled for {user_id}:{project_id}" - raise SchedulerError(msg) + raise ComputationalSchedulerError(msg) current_max_iteration = max(possible_iterations) selected_iteration = current_max_iteration else: @@ -281,7 +281,7 @@ def _get_last_iteration(self, user_id: UserID, project_id: ProjectID) -> Iterati } if not possible_iterations: msg = f"There are no pipeline scheduled for {user_id}:{project_id}" - raise SchedulerError(msg) + raise ComputationalSchedulerError(msg) return max(possible_iterations) def _start_scheduling( @@ -342,10 +342,10 @@ async def _get_pipeline_tasks( } if len(pipeline_comp_tasks) != len(pipeline_dag.nodes()): # type: ignore[arg-type] msg = ( - f"{project_id}The tasks defined for {project_id} do not contain all" + f"The tasks defined for {project_id} do not contain all" f" the tasks defined in the pipeline [{list(pipeline_dag.nodes)}]! Please check." ) - raise InvalidPipelineError(msg) + raise InvalidPipelineError(pipeline_id=project_id, msg=msg) return pipeline_comp_tasks async def _update_run_result_from_tasks( @@ -929,10 +929,11 @@ async def _schedule_tasks_to_start( # noqa: C901 comp_tasks[NodeIDStr(f"{task}")].state = RunningState.FAILED raise except TaskSchedulingError as exc: + exc.error_context()["project_id"] _logger.exception( "Project '%s''s task '%s' could not be scheduled", - exc.project_id, - exc.node_id, + exc.error_context()["project_id"], + exc.error_context()["node_id"], ) await CompTasksRepository.instance( self.db_engine diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_dask_scheduler.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_dask_scheduler.py index 2fdf7acd2e9e..e578715eb799 100644 --- a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_dask_scheduler.py +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_dask_scheduler.py @@ -345,11 +345,10 @@ async def _process_task_result( except TaskSchedulingError as err: task_final_state = RunningState.FAILED simcore_platform_status = SimcorePlatformStatus.BAD - errors = err.get_errors() _logger.debug( "Unexpected failure while processing results of %s: %s", f"{task=}", - f"{errors=}", + f"{err=}", ) # resource tracking diff --git a/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py new file mode 100644 index 000000000000..ce15311f9949 --- /dev/null +++ b/services/director-v2/src/simcore_service_director_v2/modules/comp_scheduler/_scheduler_base.py @@ -0,0 +1,792 @@ +"""The scheduler shall be run as a background task. +Based on oSparc pipelines, it monitors when to start the next worker task(s), either one at a time or as a group of tasks. + +In principle the Scheduler maintains the comp_runs table in the database. +It contains how the pipeline was run and by whom. +It also contains the final result of the pipeline run. + +When a pipeline is scheduled first all the tasks contained in the DAG are set to PUBLISHED state. +Once the scheduler determines a task shall run, its state is set to PENDING, so that the sidecar can pick up the task. +The sidecar will then change the state to STARTED, then to SUCCESS or FAILED. + +""" + +import asyncio +import datetime +import logging +from abc import ABC, abstractmethod +from collections.abc import Callable +from dataclasses import dataclass +from typing import Final + +import arrow +import networkx as nx +from aiopg.sa.engine import Engine +from models_library.projects import ProjectID +from models_library.projects_nodes_io import NodeID, NodeIDStr +from models_library.projects_state import RunningState +from models_library.services import ServiceKey, ServiceType, ServiceVersion +from models_library.users import UserID +from networkx.classes.reportviews import InDegreeView +from servicelib.common_headers import UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE +from servicelib.logging_utils import log_context +from servicelib.rabbitmq import RabbitMQClient, RabbitMQRPCClient +from servicelib.redis import RedisClientSDK + +from ...constants import UNDEFINED_STR_METADATA +from ...core.errors import ( + ClustersKeeperNotAvailableError, + ComputationalBackendNotConnectedError, + ComputationalBackendOnDemandNotReadyError, + ComputationalSchedulerChangedError, + DaskClientAcquisisitonError, + InvalidPipelineError, + PipelineNotFoundError, + TaskSchedulingError, +) +from ...core.settings import ComputationalBackendSettings +from ...models.comp_pipelines import CompPipelineAtDB +from ...models.comp_runs import CompRunsAtDB, Iteration, RunMetadataDict +from ...models.comp_tasks import CompTaskAtDB +from ...utils.computations import get_pipeline_state_from_task_states +from ...utils.rabbitmq import ( + publish_project_log, + publish_service_resource_tracking_heartbeat, + publish_service_resource_tracking_started, + publish_service_started_metrics, +) +from ..db.repositories.comp_pipelines import CompPipelinesRepository +from ..db.repositories.comp_runs import CompRunsRepository +from ..db.repositories.comp_tasks import CompTasksRepository +from ._utils import ( + COMPLETED_STATES, + PROCESSING_STATES, + RUNNING_STATES, + TASK_TO_START_STATES, + WAITING_FOR_START_STATES, + create_service_resources_from_task, + get_resource_tracking_run_id, +) + +_logger = logging.getLogger(__name__) + + +_Previous = CompTaskAtDB +_Current = CompTaskAtDB +_MAX_WAITING_FOR_CLUSTER_TIMEOUT_IN_MIN: Final[int] = 10 + + +@dataclass(frozen=True, slots=True) +class SortedTasks: + started: list[CompTaskAtDB] + completed: list[CompTaskAtDB] + waiting: list[CompTaskAtDB] + potentially_lost: list[CompTaskAtDB] + + +async def _triage_changed_tasks( + changed_tasks: list[tuple[_Previous, _Current]] +) -> SortedTasks: + started_tasks = [ + current + for previous, current in changed_tasks + if current.state in RUNNING_STATES + or ( + previous.state in WAITING_FOR_START_STATES + and current.state in COMPLETED_STATES + ) + ] + + # NOTE: some tasks can be both started and completed since we might have the time they were running + completed_tasks = [ + current for _, current in changed_tasks if current.state in COMPLETED_STATES + ] + + waiting_for_resources_tasks = [ + current + for previous, current in changed_tasks + if current.state in WAITING_FOR_START_STATES + ] + + lost_or_momentarily_lost_tasks = [ + current for _, current in changed_tasks if current.state is RunningState.UNKNOWN + ] + if lost_or_momentarily_lost_tasks: + _logger.warning( + "%s are currently in unknown state. TIP: If they are running in an external cluster and it is not yet ready, that might explain it. But inform @sanderegg nevertheless!", + [t.node_id for t in lost_or_momentarily_lost_tasks], + ) + + return SortedTasks( + started_tasks, + completed_tasks, + waiting_for_resources_tasks, + lost_or_momentarily_lost_tasks, + ) + + +@dataclass +class BaseCompScheduler(ABC): + db_engine: Engine + rabbitmq_client: RabbitMQClient + rabbitmq_rpc_client: RabbitMQRPCClient + settings: ComputationalBackendSettings + service_runtime_heartbeat_interval: datetime.timedelta + redis_client: RedisClientSDK + + async def _get_pipeline_dag(self, project_id: ProjectID) -> nx.DiGraph: + comp_pipeline_repo = CompPipelinesRepository.instance(self.db_engine) + pipeline_at_db: CompPipelineAtDB = await comp_pipeline_repo.get_pipeline( + project_id + ) + dag = pipeline_at_db.get_graph() + _logger.debug("%s: current %s", f"{project_id=}", f"{dag=}") + return dag + + async def _get_pipeline_tasks( + self, project_id: ProjectID, pipeline_dag: nx.DiGraph + ) -> dict[NodeIDStr, CompTaskAtDB]: + comp_tasks_repo = CompTasksRepository.instance(self.db_engine) + pipeline_comp_tasks: dict[NodeIDStr, CompTaskAtDB] = { + NodeIDStr(f"{t.node_id}"): t + for t in await comp_tasks_repo.list_computational_tasks(project_id) + if (f"{t.node_id}" in list(pipeline_dag.nodes())) + } + if len(pipeline_comp_tasks) != len(pipeline_dag.nodes()): # type: ignore[arg-type] + msg = ( + f"{project_id}The tasks defined for {project_id} do not contain all" + f" the tasks defined in the pipeline [{list(pipeline_dag.nodes)}]! Please check." + ) + raise InvalidPipelineError(msg) + return pipeline_comp_tasks + + async def _update_run_result_from_tasks( + self, + user_id: UserID, + project_id: ProjectID, + iteration: Iteration, + pipeline_tasks: dict[NodeIDStr, CompTaskAtDB], + ) -> RunningState: + pipeline_state_from_tasks: RunningState = get_pipeline_state_from_task_states( + list(pipeline_tasks.values()), + ) + _logger.debug( + "pipeline %s is currently in %s", + f"{user_id=}_{project_id=}_{iteration=}", + f"{pipeline_state_from_tasks}", + ) + await self._set_run_result( + user_id, project_id, iteration, pipeline_state_from_tasks + ) + return pipeline_state_from_tasks + + async def _set_run_result( + self, + user_id: UserID, + project_id: ProjectID, + iteration: Iteration, + run_result: RunningState, + ) -> None: + comp_runs_repo = CompRunsRepository.instance(self.db_engine) + await comp_runs_repo.set_run_result( + user_id=user_id, + project_id=project_id, + iteration=iteration, + result_state=run_result, + final_state=(run_result in COMPLETED_STATES), + ) + + async def _set_states_following_failed_to_aborted( + self, project_id: ProjectID, dag: nx.DiGraph + ) -> dict[NodeIDStr, CompTaskAtDB]: + tasks: dict[NodeIDStr, CompTaskAtDB] = await self._get_pipeline_tasks( + project_id, dag + ) + node_ids_to_set_as_aborted: set[NodeIDStr] = set() + for task in tasks.values(): + if task.state == RunningState.FAILED: + node_ids_to_set_as_aborted.update(nx.bfs_tree(dag, f"{task.node_id}")) + node_ids_to_set_as_aborted.remove(NodeIDStr(f"{task.node_id}")) + for node_id in node_ids_to_set_as_aborted: + tasks[NodeIDStr(f"{node_id}")].state = RunningState.ABORTED + if node_ids_to_set_as_aborted: + # update the current states back in DB + comp_tasks_repo = CompTasksRepository.instance(self.db_engine) + await comp_tasks_repo.update_project_tasks_state( + project_id, + [NodeID(n) for n in node_ids_to_set_as_aborted], + RunningState.ABORTED, + optional_progress=1.0, + optional_stopped=arrow.utcnow().datetime, + ) + return tasks + + async def _send_running_tasks_heartbeat( + self, + user_id: UserID, + project_id: ProjectID, + iteration: Iteration, + dag: nx.DiGraph, + ) -> None: + utc_now = arrow.utcnow().datetime + + def _need_heartbeat(task: CompTaskAtDB) -> bool: + if task.state not in RUNNING_STATES: + return False + if task.last_heartbeat is None: + assert task.start # nosec + return bool( + (utc_now - task.start.replace(tzinfo=datetime.UTC)) + > self.service_runtime_heartbeat_interval + ) + return bool( + (utc_now - task.last_heartbeat) + > self.service_runtime_heartbeat_interval + ) + + tasks: dict[NodeIDStr, CompTaskAtDB] = await self._get_pipeline_tasks( + project_id, dag + ) + if running_tasks := [t for t in tasks.values() if _need_heartbeat(t)]: + await asyncio.gather( + *( + publish_service_resource_tracking_heartbeat( + self.rabbitmq_client, + get_resource_tracking_run_id( + user_id, t.project_id, t.node_id, iteration + ), + ) + for t in running_tasks + ) + ) + comp_tasks_repo = CompTasksRepository(self.db_engine) + await asyncio.gather( + *( + comp_tasks_repo.update_project_task_last_heartbeat( + t.project_id, t.node_id, utc_now + ) + for t in running_tasks + ) + ) + + async def _get_changed_tasks_from_backend( + self, + user_id: UserID, + processing_tasks: list[CompTaskAtDB], + comp_run: CompRunsAtDB, + ) -> list[tuple[_Previous, _Current]]: + tasks_backend_status = await self._get_tasks_status( + user_id, processing_tasks, comp_run + ) + + return [ + ( + task, + task.copy(update={"state": backend_state}), + ) + for task, backend_state in zip( + processing_tasks, tasks_backend_status, strict=True + ) + if task.state is not backend_state + ] + + async def _process_started_tasks( + self, + tasks: list[CompTaskAtDB], + *, + user_id: UserID, + iteration: Iteration, + run_metadata: RunMetadataDict, + ) -> None: + utc_now = arrow.utcnow().datetime + + # resource tracking + await asyncio.gather( + *( + publish_service_resource_tracking_started( + self.rabbitmq_client, + service_run_id=get_resource_tracking_run_id( + user_id, t.project_id, t.node_id, iteration + ), + wallet_id=run_metadata.get("wallet_id"), + wallet_name=run_metadata.get("wallet_name"), + pricing_plan_id=( + t.pricing_info.get("pricing_plan_id") + if t.pricing_info + else None + ), + pricing_unit_id=( + t.pricing_info.get("pricing_unit_id") + if t.pricing_info + else None + ), + pricing_unit_cost_id=( + t.pricing_info.get("pricing_unit_cost_id") + if t.pricing_info + else None + ), + product_name=run_metadata.get( + "product_name", UNDEFINED_STR_METADATA + ), + simcore_user_agent=run_metadata.get( + "simcore_user_agent", UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE + ), + user_id=user_id, + user_email=run_metadata.get("user_email", UNDEFINED_STR_METADATA), + project_id=t.project_id, + project_name=run_metadata.get("project_metadata", {}).get( # type: ignore[arg-type] + "project_name", UNDEFINED_STR_METADATA + ), + node_id=t.node_id, + node_name=run_metadata.get("node_id_names_map", {}).get( + t.node_id, UNDEFINED_STR_METADATA + ), + parent_project_id=run_metadata.get("project_metadata", {}).get( + "parent_project_id" + ), + parent_node_id=run_metadata.get("project_metadata", {}).get( + "parent_node_id" + ), + root_parent_project_id=run_metadata.get("project_metadata", {}).get( + "root_parent_project_id" + ), + root_parent_project_name=run_metadata.get( + "project_metadata", {} + ).get("root_parent_project_name"), + root_parent_node_id=run_metadata.get("project_metadata", {}).get( + "root_parent_node_id" + ), + service_key=ServiceKey(t.image.name), + service_version=ServiceVersion(t.image.tag), + service_type=ServiceType.COMPUTATIONAL, + service_resources=create_service_resources_from_task(t), + service_additional_metadata={}, + ) + for t in tasks + ) + ) + # instrumentation + await asyncio.gather( + *( + publish_service_started_metrics( + self.rabbitmq_client, + user_id=user_id, + simcore_user_agent=run_metadata.get( + "simcore_user_agent", UNDEFINED_DEFAULT_SIMCORE_USER_AGENT_VALUE + ), + task=t, + ) + for t in tasks + ) + ) + + # update DB + comp_tasks_repo = CompTasksRepository(self.db_engine) + await asyncio.gather( + *( + comp_tasks_repo.update_project_tasks_state( + t.project_id, + [t.node_id], + t.state, + optional_started=utc_now, + optional_progress=t.progress, + ) + for t in tasks + ) + ) + + async def _process_waiting_tasks(self, tasks: list[CompTaskAtDB]) -> None: + comp_tasks_repo = CompTasksRepository(self.db_engine) + await asyncio.gather( + *( + comp_tasks_repo.update_project_tasks_state( + t.project_id, + [t.node_id], + t.state, + ) + for t in tasks + ) + ) + + async def _update_states_from_comp_backend( + self, + user_id: UserID, + project_id: ProjectID, + iteration: Iteration, + pipeline_dag: nx.DiGraph, + comp_run: CompRunsAtDB, + ) -> None: + tasks = await self._get_pipeline_tasks(project_id, pipeline_dag) + tasks_inprocess = [t for t in tasks.values() if t.state in PROCESSING_STATES] + if not tasks_inprocess: + return + + # get the tasks which state actually changed since last check + tasks_with_changed_states = await self._get_changed_tasks_from_backend( + user_id, tasks_inprocess, comp_run + ) + # NOTE: typical states a task goes through + # NOT_STARTED (initial state) -> PUBLISHED (user press run/API call) -> PENDING -> WAITING_FOR_CLUSTER (cluster creation) -> + # PENDING -> WAITING_FOR_RESOURCES (workers creation or missing) -> PENDING -> STARTED (worker started processing the task) -> SUCCESS/FAILED + # or ABORTED (user cancelled) or UNKNOWN (lost task - it might be transient, be careful with this one) + sorted_tasks = await _triage_changed_tasks(tasks_with_changed_states) + + # now process the tasks + if sorted_tasks.started: + # NOTE: the dask-scheduler cannot differentiate between tasks that are effectively computing and + # tasks that are only queued and accepted by a dask-worker. + # tasks_started should therefore be mostly empty but for cases where + # - dask Pub/Sub mechanism failed, the tasks goes from PENDING -> SUCCESS/FAILED/ABORTED without STARTED + # - the task finished so fast that the STARTED state was skipped between 2 runs of the dv-2 comp scheduler + await self._process_started_tasks( + sorted_tasks.started, + user_id=user_id, + iteration=iteration, + run_metadata=comp_run.metadata, + ) + + if sorted_tasks.completed or sorted_tasks.potentially_lost: + await self._process_completed_tasks( + user_id, + sorted_tasks.completed + sorted_tasks.potentially_lost, + iteration, + comp_run=comp_run, + ) + + if sorted_tasks.waiting: + await self._process_waiting_tasks(sorted_tasks.waiting) + + @abstractmethod + async def _start_tasks( + self, + *, + user_id: UserID, + project_id: ProjectID, + scheduled_tasks: dict[NodeID, CompTaskAtDB], + comp_run: CompRunsAtDB, + wake_up_callback: Callable[[], None], + ) -> None: + ... + + @abstractmethod + async def _get_tasks_status( + self, user_id: UserID, tasks: list[CompTaskAtDB], comp_run: CompRunsAtDB + ) -> list[RunningState]: + ... + + @abstractmethod + async def _stop_tasks( + self, user_id: UserID, tasks: list[CompTaskAtDB], comp_run: CompRunsAtDB + ) -> None: + ... + + @abstractmethod + async def _process_completed_tasks( + self, + user_id: UserID, + tasks: list[CompTaskAtDB], + iteration: Iteration, + comp_run: CompRunsAtDB, + ) -> None: + ... + + async def schedule_pipeline( + self, + *, + user_id: UserID, + project_id: ProjectID, + iteration: Iteration, + wake_up_callback: Callable[[], None], + ) -> None: + with log_context( + _logger, + level=logging.INFO, + msg=f"scheduling pipeline {user_id=}:{project_id=}:{iteration=}", + ): + dag: nx.DiGraph = nx.DiGraph() + try: + comp_run = await CompRunsRepository.instance(self.db_engine).get( + user_id, project_id, iteration + ) + dag = await self._get_pipeline_dag(project_id) + # 1. Update our list of tasks with data from backend (state, results) + await self._update_states_from_comp_backend( + user_id, project_id, iteration, dag, comp_run + ) + # 2. Any task following a FAILED task shall be ABORTED + comp_tasks = await self._set_states_following_failed_to_aborted( + project_id, dag + ) + # 3. do we want to stop the pipeline now? + if comp_run.cancelled: + await self._schedule_tasks_to_stop( + user_id, project_id, comp_tasks, comp_run + ) + else: + # let's get the tasks to schedule then + comp_tasks = await self._schedule_tasks_to_start( + user_id=user_id, + project_id=project_id, + comp_tasks=comp_tasks, + dag=dag, + comp_run=comp_run, + wake_up_callback=wake_up_callback, + ) + # 4. timeout if waiting for cluster has been there for more than X minutes + comp_tasks = await self._timeout_if_waiting_for_cluster_too_long( + user_id, project_id, comp_tasks + ) + # 5. send a heartbeat + await self._send_running_tasks_heartbeat( + user_id, project_id, iteration, dag + ) + + # 6. Update the run result + pipeline_result = await self._update_run_result_from_tasks( + user_id, project_id, iteration, comp_tasks + ) + + # 7. Are we done scheduling that pipeline? + if not dag.nodes() or pipeline_result in COMPLETED_STATES: + # there is nothing left, the run is completed, we're done here + _logger.info( + "pipeline %s scheduling completed with result %s", + f"{project_id=}", + f"{pipeline_result=}", + ) + except PipelineNotFoundError: + _logger.warning( + "pipeline %s does not exist in comp_pipeline table, it will be removed from scheduler", + f"{project_id=}", + ) + await self._set_run_result( + user_id, project_id, iteration, RunningState.ABORTED + ) + except InvalidPipelineError as exc: + _logger.warning( + "pipeline %s appears to be misconfigured, it will be removed from scheduler. Please check pipeline:\n%s", + f"{project_id=}", + exc, + ) + await self._set_run_result( + user_id, project_id, iteration, RunningState.ABORTED + ) + except (DaskClientAcquisisitonError, ClustersKeeperNotAvailableError): + _logger.exception( + "Unexpected error while connecting with computational backend, aborting pipeline" + ) + tasks: dict[NodeIDStr, CompTaskAtDB] = await self._get_pipeline_tasks( + project_id, dag + ) + comp_tasks_repo = CompTasksRepository(self.db_engine) + await comp_tasks_repo.update_project_tasks_state( + project_id, + [t.node_id for t in tasks.values()], + RunningState.FAILED, + ) + await self._set_run_result( + user_id, project_id, iteration, RunningState.FAILED + ) + except ComputationalBackendNotConnectedError: + _logger.exception("Computational backend is not connected!") + + async def _schedule_tasks_to_stop( + self, + user_id: UserID, + project_id: ProjectID, + comp_tasks: dict[NodeIDStr, CompTaskAtDB], + comp_run: CompRunsAtDB, + ) -> None: + # get any running task and stop them + comp_tasks_repo = CompTasksRepository.instance(self.db_engine) + await comp_tasks_repo.mark_project_published_waiting_for_cluster_tasks_as_aborted( + project_id + ) + # stop any remaining running task, these are already submitted + if tasks_to_stop := [ + t for t in comp_tasks.values() if t.state in PROCESSING_STATES + ]: + await self._stop_tasks(user_id, tasks_to_stop, comp_run) + + async def _schedule_tasks_to_start( # noqa: C901 + self, + user_id: UserID, + project_id: ProjectID, + comp_tasks: dict[NodeIDStr, CompTaskAtDB], + dag: nx.DiGraph, + comp_run: CompRunsAtDB, + wake_up_callback: Callable[[], None], + ) -> dict[NodeIDStr, CompTaskAtDB]: + # filter out the successfully completed tasks + dag.remove_nodes_from( + { + node_id + for node_id, t in comp_tasks.items() + if t.state == RunningState.SUCCESS + } + ) + dag_in_degree = dag.in_degree() + assert isinstance(dag_in_degree, InDegreeView) # nosec + next_task_node_ids = [ + node_id for node_id, degree in dag_in_degree if degree == 0 + ] + + # get the tasks to start + tasks_ready_to_start: dict[NodeID, CompTaskAtDB] = { + node_id: comp_tasks[NodeIDStr(f"{node_id}")] + for node_id in next_task_node_ids + if comp_tasks[NodeIDStr(f"{node_id}")].state in TASK_TO_START_STATES + } + + if not tasks_ready_to_start: + # nothing to do + return comp_tasks + + try: + await self._start_tasks( + user_id=user_id, + project_id=project_id, + scheduled_tasks=tasks_ready_to_start, + comp_run=comp_run, + wake_up_callback=wake_up_callback, + ) + except ( + ComputationalBackendNotConnectedError, + ComputationalSchedulerChangedError, + ): + _logger.exception( + "Issue with computational backend. Tasks are set back " + "to WAITING_FOR_CLUSTER state until scheduler comes back!", + ) + await CompTasksRepository.instance( + self.db_engine + ).update_project_tasks_state( + project_id, + list(tasks_ready_to_start.keys()), + RunningState.WAITING_FOR_CLUSTER, + ) + for task in tasks_ready_to_start: + comp_tasks[ + NodeIDStr(f"{task}") + ].state = RunningState.WAITING_FOR_CLUSTER + + except ComputationalBackendOnDemandNotReadyError as exc: + _logger.info( + "The on demand computational backend is not ready yet: %s", exc + ) + await publish_project_log( + self.rabbitmq_client, + user_id, + project_id, + log=f"{exc}", + log_level=logging.INFO, + ) + + await CompTasksRepository.instance( + self.db_engine + ).update_project_tasks_state( + project_id, + list(tasks_ready_to_start.keys()), + RunningState.WAITING_FOR_CLUSTER, + ) + for task in tasks_ready_to_start: + comp_tasks[ + NodeIDStr(f"{task}") + ].state = RunningState.WAITING_FOR_CLUSTER + except ClustersKeeperNotAvailableError: + _logger.exception("Unexpected error while starting tasks:") + await publish_project_log( + self.rabbitmq_client, + user_id, + project_id, + log="Unexpected error while scheduling computational tasks! TIP: contact osparc support.", + log_level=logging.ERROR, + ) + + await CompTasksRepository.instance( + self.db_engine + ).update_project_tasks_state( + project_id, + list(tasks_ready_to_start.keys()), + RunningState.FAILED, + optional_progress=1.0, + optional_stopped=arrow.utcnow().datetime, + ) + for task in tasks_ready_to_start: + comp_tasks[NodeIDStr(f"{task}")].state = RunningState.FAILED + raise + except TaskSchedulingError as exc: + _logger.exception("Project '%s' task could not be scheduled", project_id) + node_id = NodeID(exc.error_context()["node_id"]) + await CompTasksRepository.instance( + self.db_engine + ).update_project_tasks_state( + project_id, + [node_id], + RunningState.FAILED, + [{exc.error_code(): f"{exc}"}], + optional_progress=1.0, + optional_stopped=arrow.utcnow().datetime, + ) + comp_tasks[NodeIDStr(f"{node_id}")].state = RunningState.FAILED + except Exception: + _logger.exception( + "Unexpected error for %s with %s on %s happened when scheduling %s:", + f"{user_id=}", + f"{project_id=}", + f"{comp_run.cluster_id=}", + f"{tasks_ready_to_start.keys()=}", + ) + await CompTasksRepository.instance( + self.db_engine + ).update_project_tasks_state( + project_id, + list(tasks_ready_to_start.keys()), + RunningState.FAILED, + optional_progress=1.0, + optional_stopped=arrow.utcnow().datetime, + ) + for task in tasks_ready_to_start: + comp_tasks[NodeIDStr(f"{task}")].state = RunningState.FAILED + raise + + return comp_tasks + + async def _timeout_if_waiting_for_cluster_too_long( + self, + user_id: UserID, + project_id: ProjectID, + comp_tasks: dict[NodeIDStr, CompTaskAtDB], + ) -> dict[NodeIDStr, CompTaskAtDB]: + if all( + c.state is RunningState.WAITING_FOR_CLUSTER for c in comp_tasks.values() + ): + # get latest modified task + latest_modified_of_all_tasks = max( + comp_tasks.values(), key=lambda task: task.modified + ).modified + + if ( + arrow.utcnow().datetime - latest_modified_of_all_tasks + ) > datetime.timedelta(minutes=_MAX_WAITING_FOR_CLUSTER_TIMEOUT_IN_MIN): + await CompTasksRepository.instance( + self.db_engine + ).update_project_tasks_state( + project_id, + [NodeID(idstr) for idstr in comp_tasks], + RunningState.FAILED, + optional_progress=1.0, + optional_stopped=arrow.utcnow().datetime, + ) + for task in comp_tasks.values(): + task.state = RunningState.FAILED + msg = "Timed-out waiting for computational cluster! Please try again and/or contact Osparc support." + _logger.error(msg) + await publish_project_log( + self.rabbitmq_client, + user_id, + project_id, + log=msg, + log_level=logging.ERROR, + ) + return comp_tasks diff --git a/services/director-v2/src/simcore_service_director_v2/utils/dask.py b/services/director-v2/src/simcore_service_director_v2/utils/dask.py index d76596b5bf19..a42a28e392c1 100644 --- a/services/director-v2/src/simcore_service_director_v2/utils/dask.py +++ b/services/director-v2/src/simcore_service_director_v2/utils/dask.py @@ -608,18 +608,25 @@ def check_if_cluster_is_able_to_run_pipeline( raise MissingComputationalResourcesError( project_id=project_id, node_id=node_id, - msg=f"Service {node_image.name}:{node_image.tag} cannot be scheduled " - f"on cluster {cluster_id}: task needs '{task_resources}', " - f"cluster has {cluster_resources}", + service_name=node_image.name, + service_version=node_image.tag, + cluster_id=cluster_id, + task_resources=task_resources, + cluster_resources=cluster_resources, ) # well then our workers are not powerful enough raise InsuficientComputationalResourcesError( project_id=project_id, node_id=node_id, - msg=f"Insufficient computational resources to run {node_image.name}:{node_image.tag} with {_to_human_readable_resource_values( task_resources)} on cluster {cluster_id}." - f"Cluster available workers: {[_to_human_readable_resource_values( worker.get('resources', None)) for worker in workers.values()]}" - "TIP: Reduce service required resources or contact oSparc support", + service_name=node_image.name, + service_version=node_image.tag, + service_requested_resources=_to_human_readable_resource_values(task_resources), + cluster_id=cluster_id, + cluster_available_resources=[ + _to_human_readable_resource_values(worker.get("resources", None)) + for worker in workers.values() + ], ) diff --git a/services/director-v2/src/simcore_service_director_v2/utils/dask_client_utils.py b/services/director-v2/src/simcore_service_director_v2/utils/dask_client_utils.py index fca890c6128e..e1367e71c842 100644 --- a/services/director-v2/src/simcore_service_director_v2/utils/dask_client_utils.py +++ b/services/director-v2/src/simcore_service_director_v2/utils/dask_client_utils.py @@ -26,11 +26,11 @@ from pydantic import AnyUrl from ..core.errors import ( + ComputationalSchedulerError, ConfigurationError, DaskClientRequestError, DaskClusterError, DaskGatewayServerError, - SchedulerError, ) from .dask import check_maximize_workers, wrap_client_async_routine @@ -220,7 +220,7 @@ async def test_scheduler_endpoint( ) as dask_client: if dask_client.status != _DASK_SCHEDULER_RUNNING_STATE: msg = "internal scheduler is not running!" - raise SchedulerError(msg) + raise ComputationalSchedulerError(msg) else: gateway_auth = await get_gateway_auth_from_params(authentication) @@ -247,7 +247,7 @@ async def test_scheduler_endpoint( ClientConnectionError, ClientResponseError, httpx.HTTPError, - SchedulerError, + ComputationalSchedulerError, ) as exc: logger.debug("Pinging %s, failed: %s", f"{endpoint=}", f"{exc=!r}") msg = f"Could not connect to cluster in {endpoint}: error: {exc}" diff --git a/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_modules_comp_scheduler_dask_scheduler.py b/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_modules_comp_scheduler_dask_scheduler.py new file mode 100644 index 000000000000..6984fcea12c4 --- /dev/null +++ b/services/director-v2/tests/unit/with_dbs/comp_scheduler/test_modules_comp_scheduler_dask_scheduler.py @@ -0,0 +1,1630 @@ +# pylint:disable=unused-variable +# pylint:disable=unused-argument +# pylint:disable=redefined-outer-name +# pylint:disable=no-value-for-parameter +# pylint:disable=protected-access +# pylint:disable=too-many-arguments +# pylint:disable=no-name-in-module +# pylint: disable=too-many-statements + + +import asyncio +import datetime +from collections.abc import AsyncIterator, Awaitable, Callable +from copy import deepcopy +from dataclasses import dataclass +from typing import Any, cast +from unittest import mock + +import aiopg +import aiopg.sa +import httpx +import pytest +from _helpers import PublishedProject, RunningProject +from dask.distributed import SpecCluster +from dask_task_models_library.container_tasks.errors import TaskCancelledError +from dask_task_models_library.container_tasks.events import TaskProgressEvent +from dask_task_models_library.container_tasks.io import TaskOutputData +from dask_task_models_library.container_tasks.protocol import TaskOwner +from faker import Faker +from fastapi.applications import FastAPI +from models_library.clusters import DEFAULT_CLUSTER_ID +from models_library.projects import ProjectAtDB, ProjectID +from models_library.projects_nodes_io import NodeID +from models_library.projects_state import RunningState +from models_library.rabbitmq_messages import ( + InstrumentationRabbitMessage, + RabbitResourceTrackingBaseMessage, + RabbitResourceTrackingHeartbeatMessage, + RabbitResourceTrackingMessages, + RabbitResourceTrackingStartedMessage, + RabbitResourceTrackingStoppedMessage, +) +from models_library.users import UserID +from pydantic import parse_obj_as, parse_raw_as +from pytest_mock.plugin import MockerFixture +from pytest_simcore.helpers.typing_env import EnvVarsDict +from servicelib.rabbitmq import RabbitMQClient +from servicelib.redis import CouldNotAcquireLockError +from settings_library.rabbit import RabbitSettings +from settings_library.redis import RedisSettings +from simcore_postgres_database.models.comp_runs import comp_runs +from simcore_postgres_database.models.comp_tasks import NodeClass, comp_tasks +from simcore_service_director_v2.core.application import init_app +from simcore_service_director_v2.core.errors import ( + ClustersKeeperNotAvailableError, + ComputationalBackendNotConnectedError, + ComputationalBackendOnDemandNotReadyError, + ComputationalBackendTaskNotFoundError, + ComputationalBackendTaskResultsNotReadyError, + ComputationalSchedulerChangedError, + ComputationalSchedulerError, + ConfigurationError, + PipelineNotFoundError, +) +from simcore_service_director_v2.core.settings import AppSettings +from simcore_service_director_v2.models.comp_pipelines import CompPipelineAtDB +from simcore_service_director_v2.models.comp_runs import CompRunsAtDB, RunMetadataDict +from simcore_service_director_v2.models.comp_tasks import CompTaskAtDB, Image +from simcore_service_director_v2.models.dask_subsystem import DaskClientTaskState +from simcore_service_director_v2.modules.comp_scheduler._scheduler_base import ( + BaseCompScheduler, +) +from simcore_service_director_v2.modules.comp_scheduler._scheduler_dask import ( + DaskScheduler, +) +from simcore_service_director_v2.modules.comp_scheduler._utils import COMPLETED_STATES +from simcore_service_director_v2.modules.dask_client import ( + DaskJobID, + PublishedComputationTask, +) +from simcore_service_director_v2.utils.dask_client_utils import TaskHandlers +from starlette.testclient import TestClient +from tenacity.asyncio import AsyncRetrying +from tenacity.retry import retry_if_exception_type +from tenacity.stop import stop_after_delay +from tenacity.wait import wait_fixed + +pytest_simcore_core_services_selection = ["postgres", "rabbit", "redis"] +pytest_simcore_ops_services_selection = [ + "adminer", +] + + +def _assert_dask_client_correctly_initialized( + mocked_dask_client: mock.MagicMock, scheduler: BaseCompScheduler +) -> None: + mocked_dask_client.create.assert_called_once_with( + app=mock.ANY, + settings=mock.ANY, + endpoint=mock.ANY, + authentication=mock.ANY, + tasks_file_link_type=mock.ANY, + cluster_type=mock.ANY, + ) + mocked_dask_client.register_handlers.assert_called_once_with( + TaskHandlers( + cast( # noqa: SLF001 + DaskScheduler, scheduler + )._task_progress_change_handler, + cast(DaskScheduler, scheduler)._task_log_change_handler, # noqa: SLF001 + ) + ) + + +async def _assert_comp_run_db( + aiopg_engine: aiopg.sa.engine.Engine, + pub_project: PublishedProject, + expected_state: RunningState, +) -> None: + # check the database is correctly updated, the run is published + async with aiopg_engine.acquire() as conn: + result = await conn.execute( + comp_runs.select().where( + (comp_runs.c.user_id == pub_project.project.prj_owner) + & (comp_runs.c.project_uuid == f"{pub_project.project.uuid}") + ) # there is only one entry + ) + run_entry = CompRunsAtDB.parse_obj(await result.first()) + assert ( + run_entry.result == expected_state + ), f"comp_runs: expected state '{expected_state}, found '{run_entry.result}'" + + +async def _assert_comp_tasks_db( + aiopg_engine: aiopg.sa.engine.Engine, + project_uuid: ProjectID, + task_ids: list[NodeID], + *, + expected_state: RunningState, + expected_progress: float | None, +) -> None: + # check the database is correctly updated, the run is published + async with aiopg_engine.acquire() as conn: + result = await conn.execute( + comp_tasks.select().where( + (comp_tasks.c.project_id == f"{project_uuid}") + & (comp_tasks.c.node_id.in_([f"{n}" for n in task_ids])) + ) # there is only one entry + ) + tasks = parse_obj_as(list[CompTaskAtDB], await result.fetchall()) + assert all( + t.state == expected_state for t in tasks + ), f"expected state: {expected_state}, found: {[t.state for t in tasks]}" + assert all( + t.progress == expected_progress for t in tasks + ), f"{expected_progress=}, found: {[t.progress for t in tasks]}" + + +async def schedule_all_pipelines(scheduler: BaseCompScheduler) -> None: + # NOTE: we take a copy of the pipelines, as this could change quickly if there are + # misconfigured pipelines that would be removed from the scheduler + # NOTE: we simulate multiple dv-2 replicas by running several times + # the same pipeline scheduling + local_pipelines = deepcopy(scheduler._scheduled_pipelines) # noqa: SLF001 + results = await asyncio.gather( + *( + scheduler.schedule_pipeline( + user_id=user_id, + project_id=project_id, + iteration=iteration, + wake_up_callback=params.scheduler_waker.set, + ) + for _ in range(3) + for ( + user_id, + project_id, + iteration, + ), params in local_pipelines.items() + ), + return_exceptions=True, + ) + # we should have exceptions 2/3 of the time + could_not_acquire_lock_count = sum( + isinstance(r, CouldNotAcquireLockError) for r in results + ) + total_results_count = len(results) + + # Check if 2/3 of the results are CouldNotAcquireLockError + # checks that scheduling is done exclusively + assert could_not_acquire_lock_count == (2 / 3) * total_results_count + + +@pytest.fixture +def minimal_scheduler_dask_config( + mock_env: EnvVarsDict, + postgres_host_config: dict[str, str], + monkeypatch: pytest.MonkeyPatch, + rabbit_service: RabbitSettings, + redis_service: RedisSettings, + faker: Faker, +) -> None: + """set a minimal configuration for testing the dask connection only""" + monkeypatch.setenv("DIRECTOR_V2_DYNAMIC_SIDECAR_ENABLED", "false") + monkeypatch.setenv("DIRECTOR_V0_ENABLED", "0") + monkeypatch.setenv("COMPUTATIONAL_BACKEND_DASK_CLIENT_ENABLED", "1") + monkeypatch.setenv("COMPUTATIONAL_BACKEND_ENABLED", "1") + monkeypatch.setenv("R_CLONE_PROVIDER", "MINIO") + monkeypatch.setenv("S3_ENDPOINT", faker.url()) + monkeypatch.setenv("S3_ACCESS_KEY", faker.pystr()) + monkeypatch.setenv("S3_REGION", faker.pystr()) + monkeypatch.setenv("S3_SECRET_KEY", faker.pystr()) + monkeypatch.setenv("S3_BUCKET_NAME", faker.pystr()) + + +@pytest.fixture +def scheduler( + minimal_scheduler_dask_config: None, + aiopg_engine: aiopg.sa.engine.Engine, + minimal_app: FastAPI, +) -> BaseCompScheduler: + scheduler = _get_scheduler_worker(minimal_app) + assert scheduler is not None + return scheduler + + +@pytest.fixture +def mocked_dask_client(mocker: MockerFixture) -> mock.MagicMock: + mocked_dask_client = mocker.patch( + "simcore_service_director_v2.modules.dask_clients_pool.DaskClient", + autospec=True, + ) + mocked_dask_client.create.return_value = mocked_dask_client + return mocked_dask_client + + +@pytest.fixture +def mocked_parse_output_data_fct(mocker: MockerFixture) -> mock.Mock: + return mocker.patch( + "simcore_service_director_v2.modules.comp_scheduler._scheduler_dask.parse_output_data", + autospec=True, + ) + + +@pytest.fixture +def mocked_clean_task_output_fct(mocker: MockerFixture) -> mock.MagicMock: + return mocker.patch( + "simcore_service_director_v2.modules.comp_scheduler._scheduler_dask.clean_task_output_and_log_files_if_invalid", + return_value=None, + autospec=True, + ) + + +@pytest.fixture +def with_disabled_auto_scheduling(mocker: MockerFixture) -> mock.MagicMock: + """disables the scheduler task, note that it needs to be triggered manu>ally then""" + + def _fake_starter( + self: BaseCompScheduler, + *args, + **kwargs, + ): + scheduler_task = mocker.MagicMock() + scheduler_task_wake_up_event = mocker.MagicMock() + return scheduler_task, scheduler_task_wake_up_event + + return mocker.patch( + "simcore_service_director_v2.modules.comp_scheduler._scheduler_base.BaseCompScheduler._start_scheduling", + autospec=True, + side_effect=_fake_starter, + ) + + +@pytest.fixture +async def minimal_app(async_client: httpx.AsyncClient) -> FastAPI: + # must use the minimal app from from the `async_client`` + # the`client` uses starlette's TestClient which spawns + # a new thread on which it creates a new loop + # causing issues downstream with coroutines not + # being created on the same loop + return async_client._transport.app # type: ignore # noqa: SLF001 + + +@pytest.fixture +def mocked_clean_task_output_and_log_files_if_invalid(mocker: MockerFixture) -> None: + mocker.patch( + "simcore_service_director_v2.modules.comp_scheduler._scheduler_dask.clean_task_output_and_log_files_if_invalid", + autospec=True, + ) + + +async def test_scheduler_gracefully_starts_and_stops( + minimal_scheduler_dask_config: None, + aiopg_engine: aiopg.sa.engine.Engine, + dask_spec_local_cluster: SpecCluster, + minimal_app: FastAPI, +): + # check it started correctly + assert _get_scheduler_worker(minimal_app) is not None + + +@pytest.mark.parametrize( + "missing_dependency", + [ + "COMPUTATIONAL_BACKEND_DASK_CLIENT_ENABLED", + ], +) +def test_scheduler_raises_exception_for_missing_dependencies( + minimal_scheduler_dask_config: None, + aiopg_engine: aiopg.sa.engine.Engine, + dask_spec_local_cluster: SpecCluster, + monkeypatch: pytest.MonkeyPatch, + missing_dependency: str, +): + # disable the dependency + monkeypatch.setenv(missing_dependency, "0") + # create the client + settings = AppSettings.create_from_envs() + app = init_app(settings) + + with pytest.raises(ConfigurationError), TestClient( + app, raise_server_exceptions=True + ) as _: + pass + + +async def test_empty_pipeline_is_not_scheduled( + with_disabled_auto_scheduling: None, + scheduler: BaseCompScheduler, + registered_user: Callable[..., dict[str, Any]], + project: Callable[..., Awaitable[ProjectAtDB]], + pipeline: Callable[..., CompPipelineAtDB], + aiopg_engine: aiopg.sa.engine.Engine, + run_metadata: RunMetadataDict, +): + user = registered_user() + empty_project = await project(user) + + # the project is not in the comp_pipeline, therefore scheduling it should fail + with pytest.raises(PipelineNotFoundError): + await scheduler.run_new_pipeline( + user_id=user["id"], + project_id=empty_project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + run_metadata=run_metadata, + use_on_demand_clusters=False, + ) + # create the empty pipeline now + pipeline(project_id=f"{empty_project.uuid}") + + # creating a run with an empty pipeline is useless, check the scheduler is not kicking in + await scheduler.run_new_pipeline( + user_id=user["id"], + project_id=empty_project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + run_metadata=run_metadata, + use_on_demand_clusters=False, + ) + assert len(scheduler._scheduled_pipelines) == 0 # noqa: SLF001 + # check the database is empty + async with aiopg_engine.acquire() as conn: + result = await conn.scalar( + comp_runs.select().where( + (comp_runs.c.user_id == user["id"]) + & (comp_runs.c.project_uuid == f"{empty_project.uuid}") + ) # there is only one entry + ) + assert result is None + + +async def test_misconfigured_pipeline_is_not_scheduled( + with_disabled_auto_scheduling: None, + scheduler: BaseCompScheduler, + registered_user: Callable[..., dict[str, Any]], + project: Callable[..., Awaitable[ProjectAtDB]], + pipeline: Callable[..., CompPipelineAtDB], + fake_workbench_without_outputs: dict[str, Any], + fake_workbench_adjacency: dict[str, Any], + aiopg_engine: aiopg.sa.engine.Engine, + run_metadata: RunMetadataDict, +): + """A pipeline which comp_tasks are missing should not be scheduled. + It shall be aborted and shown as such in the comp_runs db""" + user = registered_user() + sleepers_project = await project(user, workbench=fake_workbench_without_outputs) + pipeline( + project_id=f"{sleepers_project.uuid}", + dag_adjacency_list=fake_workbench_adjacency, + ) + # check the pipeline is correctly added to the scheduled pipelines + await scheduler.run_new_pipeline( + user_id=user["id"], + project_id=sleepers_project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + run_metadata=run_metadata, + use_on_demand_clusters=False, + ) + assert len(scheduler._scheduled_pipelines) == 1 # noqa: SLF001 + for ( + u_id, + p_id, + it, + ) in scheduler._scheduled_pipelines: # noqa: SLF001 + assert u_id == user["id"] + assert p_id == sleepers_project.uuid + assert it > 0 + # check the database was properly updated + async with aiopg_engine.acquire() as conn: + result = await conn.execute( + comp_runs.select().where( + (comp_runs.c.user_id == user["id"]) + & (comp_runs.c.project_uuid == f"{sleepers_project.uuid}") + ) # there is only one entry + ) + run_entry = CompRunsAtDB.parse_obj(await result.first()) + assert run_entry.result == RunningState.PUBLISHED + # let the scheduler kick in + await schedule_all_pipelines(scheduler) + # check the scheduled pipelines is again empty since it's misconfigured + assert len(scheduler._scheduled_pipelines) == 0 # noqa: SLF001 + # check the database entry is correctly updated + async with aiopg_engine.acquire() as conn: + result = await conn.execute( + comp_runs.select().where( + (comp_runs.c.user_id == user["id"]) + & (comp_runs.c.project_uuid == f"{sleepers_project.uuid}") + ) # there is only one entry + ) + run_entry = CompRunsAtDB.parse_obj(await result.first()) + assert run_entry.result == RunningState.ABORTED + assert run_entry.metadata == run_metadata + + +async def _assert_start_pipeline( + aiopg_engine, + published_project: PublishedProject, + scheduler: BaseCompScheduler, + run_metadata: RunMetadataDict, +) -> list[CompTaskAtDB]: + exp_published_tasks = deepcopy(published_project.tasks) + assert published_project.project.prj_owner + await scheduler.run_new_pipeline( + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + run_metadata=run_metadata, + use_on_demand_clusters=False, + ) + assert ( + len(scheduler._scheduled_pipelines) == 1 # noqa: SLF001 + ), "the pipeline is not scheduled!" + for ( + u_id, + p_id, + it, + ) in scheduler._scheduled_pipelines: # noqa: SLF001 + assert u_id == published_project.project.prj_owner + assert p_id == published_project.project.uuid + assert it > 0 + + # check the database is correctly updated, the run is published + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PUBLISHED) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [p.node_id for p in exp_published_tasks], + expected_state=RunningState.PUBLISHED, + expected_progress=None, + ) + return exp_published_tasks + + +async def _assert_schedule_pipeline_PENDING( # noqa: N802 + aiopg_engine, + published_project: PublishedProject, + published_tasks: list[CompTaskAtDB], + mocked_dask_client: mock.MagicMock, + scheduler: BaseCompScheduler, +) -> list[CompTaskAtDB]: + expected_pending_tasks = [ + published_tasks[1], + published_tasks[3], + ] + for p in expected_pending_tasks: + published_tasks.remove(p) + + async def _return_tasks_pending(job_ids: list[str]) -> list[DaskClientTaskState]: + return [DaskClientTaskState.PENDING for job_id in job_ids] + + mocked_dask_client.get_tasks_status.side_effect = _return_tasks_pending + await schedule_all_pipelines(scheduler) + _assert_dask_client_correctly_initialized(mocked_dask_client, scheduler) + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PUBLISHED) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [p.node_id for p in expected_pending_tasks], + expected_state=RunningState.PENDING, + expected_progress=None, + ) + # the other tasks are still waiting in published state + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [p.node_id for p in published_tasks], + expected_state=RunningState.PUBLISHED, + expected_progress=None, # since we bypass the API entrypoint this is correct + ) + # tasks were send to the backend + assert published_project.project.prj_owner is not None + mocked_dask_client.send_computation_tasks.assert_has_calls( + calls=[ + mock.call( + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + tasks={f"{p.node_id}": p.image}, + callback=mock.ANY, + metadata=mock.ANY, + hardware_info=mock.ANY, + ) + for p in expected_pending_tasks + ], + any_order=True, + ) + mocked_dask_client.send_computation_tasks.reset_mock() + mocked_dask_client.get_tasks_status.assert_not_called() + mocked_dask_client.get_task_result.assert_not_called() + # there is a second run of the scheduler to move comp_runs to pending, the rest does not change + await schedule_all_pipelines(scheduler) + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PENDING) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [p.node_id for p in expected_pending_tasks], + expected_state=RunningState.PENDING, + expected_progress=None, + ) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [p.node_id for p in published_tasks], + expected_state=RunningState.PUBLISHED, + expected_progress=None, + ) + mocked_dask_client.send_computation_tasks.assert_not_called() + mocked_dask_client.get_tasks_status.assert_has_calls( + calls=[mock.call([p.job_id for p in expected_pending_tasks])], any_order=True + ) + mocked_dask_client.get_tasks_status.reset_mock() + mocked_dask_client.get_task_result.assert_not_called() + return expected_pending_tasks + + +@pytest.fixture +async def instrumentation_rabbit_client_parser( + create_rabbitmq_client: Callable[[str], RabbitMQClient], mocker: MockerFixture +) -> AsyncIterator[mock.AsyncMock]: + client = create_rabbitmq_client("instrumentation_pytest_consumer") + mock = mocker.AsyncMock(return_value=True) + queue_name = await client.subscribe( + InstrumentationRabbitMessage.get_channel_name(), mock + ) + yield mock + await client.unsubscribe(queue_name) + + +@pytest.fixture +async def resource_tracking_rabbit_client_parser( + create_rabbitmq_client: Callable[[str], RabbitMQClient], mocker: MockerFixture +) -> AsyncIterator[mock.AsyncMock]: + client = create_rabbitmq_client("resource_tracking_pytest_consumer") + mock = mocker.AsyncMock(return_value=True) + queue_name = await client.subscribe( + RabbitResourceTrackingBaseMessage.get_channel_name(), mock + ) + yield mock + await client.unsubscribe(queue_name) + + +async def _assert_message_received( + mocked_message_parser: mock.AsyncMock, + expected_call_count: int, + message_parser: Callable, +) -> list: + async for attempt in AsyncRetrying( + wait=wait_fixed(0.1), + stop=stop_after_delay(5), + retry=retry_if_exception_type(AssertionError), + reraise=True, + ): + with attempt: + print( + f"--> waiting for rabbitmq message [{attempt.retry_state.attempt_number}, {attempt.retry_state.idle_for}]" + ) + assert mocked_message_parser.call_count == expected_call_count + print( + f"<-- rabbitmq message received after [{attempt.retry_state.attempt_number}, {attempt.retry_state.idle_for}]" + ) + parsed_messages = [ + message_parser(mocked_message_parser.call_args_list[c].args[0]) + for c in range(expected_call_count) + ] + + mocked_message_parser.reset_mock() + return parsed_messages + + +def _mock_send_computation_tasks( + tasks: list[CompTaskAtDB], mocked_dask_client: mock.MagicMock +) -> None: + node_id_to_job_id_map = {task.node_id: task.job_id for task in tasks} + + async def _send_computation_tasks( + *args, tasks: dict[NodeID, Image], **kwargs + ) -> list[PublishedComputationTask]: + for node_id in tasks: + assert NodeID(f"{node_id}") in node_id_to_job_id_map + return [ + PublishedComputationTask( + node_id=NodeID(f"{node_id}"), + job_id=DaskJobID(node_id_to_job_id_map[NodeID(f"{node_id}")]), + ) + for node_id in tasks + ] # type: ignore + + mocked_dask_client.send_computation_tasks.side_effect = _send_computation_tasks + + +async def _trigger_progress_event( + scheduler: BaseCompScheduler, + *, + job_id: str, + user_id: UserID, + project_id: ProjectID, + node_id: NodeID, +) -> None: + event = TaskProgressEvent( + job_id=job_id, + progress=0, + task_owner=TaskOwner( + user_id=user_id, + project_id=project_id, + node_id=node_id, + parent_project_id=None, + parent_node_id=None, + ), + ) + await cast(DaskScheduler, scheduler)._task_progress_change_handler( # noqa: SLF001 + event.json() + ) + + +@pytest.mark.acceptance_test() +async def test_proper_pipeline_is_scheduled( # noqa: PLR0915 + with_disabled_auto_scheduling: None, + mocked_dask_client: mock.MagicMock, + scheduler: BaseCompScheduler, + aiopg_engine: aiopg.sa.engine.Engine, + published_project: PublishedProject, + mocked_parse_output_data_fct: mock.Mock, + mocked_clean_task_output_and_log_files_if_invalid: None, + instrumentation_rabbit_client_parser: mock.AsyncMock, + resource_tracking_rabbit_client_parser: mock.AsyncMock, + run_metadata: RunMetadataDict, +): + _mock_send_computation_tasks(published_project.tasks, mocked_dask_client) + + expected_published_tasks = await _assert_start_pipeline( + aiopg_engine, published_project, scheduler, run_metadata + ) + + # ------------------------------------------------------------------------------- + # 1. first run will move comp_tasks to PENDING so the worker can take them + expected_pending_tasks = await _assert_schedule_pipeline_PENDING( + aiopg_engine, + published_project, + expected_published_tasks, + mocked_dask_client, + scheduler, + ) + + # ------------------------------------------------------------------------------- + # 2.1. the worker might be taking the task, until we get a progress we do not know + # whether it effectively started or it is still queued in the worker process + exp_started_task = expected_pending_tasks[0] + expected_pending_tasks.remove(exp_started_task) + + async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskState]: + return [ + ( + DaskClientTaskState.PENDING_OR_STARTED + if job_id == exp_started_task.job_id + else DaskClientTaskState.PENDING + ) + for job_id in job_ids + ] + + mocked_dask_client.get_tasks_status.side_effect = _return_1st_task_running + + await schedule_all_pipelines(scheduler) + + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PENDING) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [exp_started_task.node_id], + expected_state=RunningState.PENDING, + expected_progress=None, + ) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [p.node_id for p in expected_pending_tasks], + expected_state=RunningState.PENDING, + expected_progress=None, + ) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [p.node_id for p in expected_published_tasks], + expected_state=RunningState.PUBLISHED, + expected_progress=None, # since we bypass the API entrypoint this is correct + ) + mocked_dask_client.send_computation_tasks.assert_not_called() + mocked_dask_client.get_tasks_status.assert_called_once_with( + [p.job_id for p in (exp_started_task, *expected_pending_tasks)], + ) + mocked_dask_client.get_tasks_status.reset_mock() + mocked_dask_client.get_task_result.assert_not_called() + + # ------------------------------------------------------------------------------- + # 3. the "worker" starts processing a task + # here we trigger a progress from the worker + assert exp_started_task.job_id + assert exp_started_task.project_id + assert exp_started_task.node_id + assert published_project.project.prj_owner + await _trigger_progress_event( + scheduler, + job_id=exp_started_task.job_id, + user_id=published_project.project.prj_owner, + project_id=exp_started_task.project_id, + node_id=exp_started_task.node_id, + ) + + await schedule_all_pipelines(scheduler) + # comp_run, the comp_task switch to STARTED + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.STARTED) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [exp_started_task.node_id], + expected_state=RunningState.STARTED, + expected_progress=0, + ) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [p.node_id for p in expected_pending_tasks], + expected_state=RunningState.PENDING, + expected_progress=None, + ) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [p.node_id for p in expected_published_tasks], + expected_state=RunningState.PUBLISHED, + expected_progress=None, + ) + mocked_dask_client.send_computation_tasks.assert_not_called() + mocked_dask_client.get_tasks_status.assert_called_once_with( + [p.job_id for p in (exp_started_task, *expected_pending_tasks)], + ) + mocked_dask_client.get_tasks_status.reset_mock() + mocked_dask_client.get_task_result.assert_not_called() + messages = await _assert_message_received( + instrumentation_rabbit_client_parser, 1, InstrumentationRabbitMessage.parse_raw + ) + assert messages[0].metrics == "service_started" + assert messages[0].service_uuid == exp_started_task.node_id + + def _parser(x) -> RabbitResourceTrackingMessages: + return parse_raw_as(RabbitResourceTrackingMessages, x) + + messages = await _assert_message_received( + resource_tracking_rabbit_client_parser, + 1, + RabbitResourceTrackingStartedMessage.parse_raw, + ) + assert messages[0].node_id == exp_started_task.node_id + + # ------------------------------------------------------------------------------- + # 4. the "worker" completed the task successfully + async def _return_1st_task_success(job_ids: list[str]) -> list[DaskClientTaskState]: + return [ + ( + DaskClientTaskState.SUCCESS + if job_id == exp_started_task.job_id + else DaskClientTaskState.PENDING + ) + for job_id in job_ids + ] + + mocked_dask_client.get_tasks_status.side_effect = _return_1st_task_success + + async def _return_random_task_result(job_id) -> TaskOutputData: + return TaskOutputData.parse_obj({"out_1": None, "out_2": 45}) + + mocked_dask_client.get_task_result.side_effect = _return_random_task_result + await schedule_all_pipelines(scheduler) + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.STARTED) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [exp_started_task.node_id], + expected_state=RunningState.SUCCESS, + expected_progress=1, + ) + messages = await _assert_message_received( + instrumentation_rabbit_client_parser, 1, InstrumentationRabbitMessage.parse_raw + ) + assert messages[0].metrics == "service_stopped" + assert messages[0].service_uuid == exp_started_task.node_id + messages = await _assert_message_received( + resource_tracking_rabbit_client_parser, + 1, + RabbitResourceTrackingStoppedMessage.parse_raw, + ) + + completed_tasks = [exp_started_task] + next_pending_task = published_project.tasks[2] + expected_pending_tasks.append(next_pending_task) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [p.node_id for p in expected_pending_tasks], + expected_state=RunningState.PENDING, + expected_progress=None, + ) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [ + p.node_id + for p in published_project.tasks + if p not in expected_pending_tasks + completed_tasks + ], + expected_state=RunningState.PUBLISHED, + expected_progress=None, # since we bypass the API entrypoint this is correct + ) + mocked_dask_client.send_computation_tasks.assert_called_once_with( + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + tasks={ + f"{next_pending_task.node_id}": next_pending_task.image, + }, + callback=mock.ANY, + metadata=mock.ANY, + hardware_info=mock.ANY, + ) + mocked_dask_client.send_computation_tasks.reset_mock() + mocked_dask_client.get_tasks_status.assert_has_calls( + calls=[ + mock.call([p.job_id for p in completed_tasks + expected_pending_tasks[:1]]) + ], + any_order=True, + ) + mocked_dask_client.get_tasks_status.reset_mock() + mocked_dask_client.get_task_result.assert_called_once_with( + completed_tasks[0].job_id + ) + mocked_dask_client.get_task_result.reset_mock() + mocked_parse_output_data_fct.assert_called_once_with( + mock.ANY, + completed_tasks[0].job_id, + await _return_random_task_result(completed_tasks[0].job_id), + ) + mocked_parse_output_data_fct.reset_mock() + + # ------------------------------------------------------------------------------- + # 6. the "worker" starts processing a task + exp_started_task = next_pending_task + + async def _return_2nd_task_running(job_ids: list[str]) -> list[DaskClientTaskState]: + return [ + ( + DaskClientTaskState.PENDING_OR_STARTED + if job_id == exp_started_task.job_id + else DaskClientTaskState.PENDING + ) + for job_id in job_ids + ] + + mocked_dask_client.get_tasks_status.side_effect = _return_2nd_task_running + # trigger the scheduler, run state should keep to STARTED, task should be as well + assert exp_started_task.job_id + await _trigger_progress_event( + scheduler, + job_id=exp_started_task.job_id, + user_id=published_project.project.prj_owner, + project_id=exp_started_task.project_id, + node_id=exp_started_task.node_id, + ) + await schedule_all_pipelines(scheduler) + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.STARTED) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [exp_started_task.node_id], + expected_state=RunningState.STARTED, + expected_progress=0, + ) + mocked_dask_client.send_computation_tasks.assert_not_called() + expected_pending_tasks.reverse() + mocked_dask_client.get_tasks_status.assert_called_once_with( + [p.job_id for p in expected_pending_tasks] + ) + mocked_dask_client.get_tasks_status.reset_mock() + mocked_dask_client.get_task_result.assert_not_called() + messages = await _assert_message_received( + instrumentation_rabbit_client_parser, 1, InstrumentationRabbitMessage.parse_raw + ) + assert messages[0].metrics == "service_started" + assert messages[0].service_uuid == exp_started_task.node_id + messages = await _assert_message_received( + resource_tracking_rabbit_client_parser, + 1, + RabbitResourceTrackingStartedMessage.parse_raw, + ) + assert messages[0].node_id == exp_started_task.node_id + + # ------------------------------------------------------------------------------- + # 7. the task fails + async def _return_2nd_task_failed(job_ids: list[str]) -> list[DaskClientTaskState]: + return [ + ( + DaskClientTaskState.ERRED + if job_id == exp_started_task.job_id + else DaskClientTaskState.PENDING + ) + for job_id in job_ids + ] + + mocked_dask_client.get_tasks_status.side_effect = _return_2nd_task_failed + mocked_dask_client.get_task_result.side_effect = None + await schedule_all_pipelines(scheduler) + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.STARTED) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [exp_started_task.node_id], + expected_state=RunningState.FAILED, + expected_progress=1, + ) + mocked_dask_client.send_computation_tasks.assert_not_called() + mocked_dask_client.get_tasks_status.assert_called_once_with( + [p.job_id for p in expected_pending_tasks] + ) + mocked_dask_client.get_tasks_status.reset_mock() + mocked_dask_client.get_task_result.assert_called_once_with(exp_started_task.job_id) + mocked_dask_client.get_task_result.reset_mock() + mocked_parse_output_data_fct.assert_not_called() + expected_pending_tasks.remove(exp_started_task) + messages = await _assert_message_received( + instrumentation_rabbit_client_parser, 1, InstrumentationRabbitMessage.parse_raw + ) + assert messages[0].metrics == "service_stopped" + assert messages[0].service_uuid == exp_started_task.node_id + messages = await _assert_message_received( + resource_tracking_rabbit_client_parser, + 1, + RabbitResourceTrackingStoppedMessage.parse_raw, + ) + + # ------------------------------------------------------------------------------- + # 8. the last task shall succeed + exp_started_task = expected_pending_tasks[0] + + async def _return_3rd_task_success(job_ids: list[str]) -> list[DaskClientTaskState]: + return [ + ( + DaskClientTaskState.SUCCESS + if job_id == exp_started_task.job_id + else DaskClientTaskState.PENDING + ) + for job_id in job_ids + ] + + mocked_dask_client.get_tasks_status.side_effect = _return_3rd_task_success + mocked_dask_client.get_task_result.side_effect = _return_random_task_result + + # trigger the scheduler, it should switch to FAILED, as we are done + await schedule_all_pipelines(scheduler) + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.FAILED) + + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [exp_started_task.node_id], + expected_state=RunningState.SUCCESS, + expected_progress=1, + ) + mocked_dask_client.send_computation_tasks.assert_not_called() + mocked_dask_client.get_tasks_status.assert_called_once_with( + [p.job_id for p in expected_pending_tasks] + ) + mocked_dask_client.get_task_result.assert_called_once_with(exp_started_task.job_id) + messages = await _assert_message_received( + instrumentation_rabbit_client_parser, 2, InstrumentationRabbitMessage.parse_raw + ) + # NOTE: the service was fast and went directly to success + assert messages[0].metrics == "service_started" + assert messages[0].service_uuid == exp_started_task.node_id + assert messages[1].metrics == "service_stopped" + assert messages[1].service_uuid == exp_started_task.node_id + messages = await _assert_message_received( + resource_tracking_rabbit_client_parser, + 2, + _parser, + ) + assert isinstance(messages[0], RabbitResourceTrackingStartedMessage) + assert isinstance(messages[1], RabbitResourceTrackingStoppedMessage) + + # the scheduled pipeline shall be removed + assert scheduler._scheduled_pipelines == {} # noqa: SLF001 + + +async def test_task_progress_triggers( + with_disabled_auto_scheduling: None, + mocked_dask_client: mock.MagicMock, + scheduler: BaseCompScheduler, + aiopg_engine: aiopg.sa.engine.Engine, + published_project: PublishedProject, + mocked_parse_output_data_fct: None, + mocked_clean_task_output_and_log_files_if_invalid: None, + run_metadata: RunMetadataDict, +): + _mock_send_computation_tasks(published_project.tasks, mocked_dask_client) + expected_published_tasks = await _assert_start_pipeline( + aiopg_engine, published_project, scheduler, run_metadata + ) + # ------------------------------------------------------------------------------- + # 1. first run will move comp_tasks to PENDING so the worker can take them + expected_pending_tasks = await _assert_schedule_pipeline_PENDING( + aiopg_engine, + published_project, + expected_published_tasks, + mocked_dask_client, + scheduler, + ) + + # send some progress + started_task = expected_pending_tasks[0] + assert started_task.job_id + assert published_project.project.prj_owner + for progress in [-1, 0, 0.3, 0.5, 1, 1.5, 0.7, 0, 20]: + progress_event = TaskProgressEvent( + job_id=started_task.job_id, + progress=progress, + task_owner=TaskOwner( + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + node_id=started_task.node_id, + parent_node_id=None, + parent_project_id=None, + ), + ) + await cast( # noqa: SLF001 + DaskScheduler, scheduler + )._task_progress_change_handler(progress_event.json()) + # NOTE: not sure whether it should switch to STARTED.. it would make sense + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [started_task.node_id], + expected_state=RunningState.STARTED, + expected_progress=min(max(0, progress), 1), + ) + + +@pytest.mark.parametrize( + "backend_error", + [ + ComputationalBackendNotConnectedError(msg="faked disconnected backend"), + ComputationalSchedulerChangedError( + original_scheduler_id="some_old_scheduler_id", + current_scheduler_id="some_new_scheduler_id", + ), + ], +) +async def test_handling_of_disconnected_scheduler_dask( + with_disabled_auto_scheduling: None, + mocked_dask_client: mock.MagicMock, + scheduler: BaseCompScheduler, + aiopg_engine: aiopg.sa.engine.Engine, + mocker: MockerFixture, + published_project: PublishedProject, + backend_error: ComputationalSchedulerError, + run_metadata: RunMetadataDict, +): + # this will create a non connected backend issue that will trigger re-connection + mocked_dask_client_send_task = mocker.patch( + "simcore_service_director_v2.modules.comp_scheduler._scheduler_dask.DaskClient.send_computation_tasks", + side_effect=backend_error, + ) + assert mocked_dask_client_send_task + + # running the pipeline will now raise and the tasks are set back to PUBLISHED + assert published_project.project.prj_owner + await scheduler.run_new_pipeline( + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + run_metadata=run_metadata, + use_on_demand_clusters=False, + ) + + # since there is no cluster, there is no dask-scheduler, + # the tasks shall all still be in PUBLISHED state now + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PUBLISHED) + + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [t.node_id for t in published_project.tasks], + expected_state=RunningState.PUBLISHED, + expected_progress=None, + ) + # on the next iteration of the pipeline it will try to re-connect + # now try to abort the tasks since we are wondering what is happening, this should auto-trigger the scheduler + await scheduler.stop_pipeline( + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + ) + # we ensure the scheduler was run + await schedule_all_pipelines(scheduler) + # after this step the tasks are marked as ABORTED + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [ + t.node_id + for t in published_project.tasks + if t.node_class == NodeClass.COMPUTATIONAL + ], + expected_state=RunningState.ABORTED, + expected_progress=1, + ) + # then we have another scheduler run + await schedule_all_pipelines(scheduler) + # now the run should be ABORTED + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.ABORTED) + + +@dataclass(frozen=True, kw_only=True) +class RebootState: + dask_task_status: DaskClientTaskState + task_result: Exception | TaskOutputData + expected_task_state_group1: RunningState + expected_task_progress_group1: float + expected_task_state_group2: RunningState + expected_task_progress_group2: float + expected_run_state: RunningState + + +@pytest.mark.parametrize( + "reboot_state", + [ + pytest.param( + RebootState( + dask_task_status=DaskClientTaskState.LOST, + task_result=ComputationalBackendTaskNotFoundError(job_id="fake_job_id"), + expected_task_state_group1=RunningState.FAILED, + expected_task_progress_group1=1, + expected_task_state_group2=RunningState.ABORTED, + expected_task_progress_group2=1, + expected_run_state=RunningState.FAILED, + ), + id="reboot with lost tasks", + ), + pytest.param( + RebootState( + dask_task_status=DaskClientTaskState.ABORTED, + task_result=TaskCancelledError(job_id="fake_job_id"), + expected_task_state_group1=RunningState.ABORTED, + expected_task_progress_group1=1, + expected_task_state_group2=RunningState.ABORTED, + expected_task_progress_group2=1, + expected_run_state=RunningState.ABORTED, + ), + id="reboot with aborted tasks", + ), + pytest.param( + RebootState( + dask_task_status=DaskClientTaskState.ERRED, + task_result=ValueError("some error during the call"), + expected_task_state_group1=RunningState.FAILED, + expected_task_progress_group1=1, + expected_task_state_group2=RunningState.ABORTED, + expected_task_progress_group2=1, + expected_run_state=RunningState.FAILED, + ), + id="reboot with failed tasks", + ), + pytest.param( + RebootState( + dask_task_status=DaskClientTaskState.PENDING_OR_STARTED, + task_result=ComputationalBackendTaskResultsNotReadyError( + job_id="fake_job_id" + ), + expected_task_state_group1=RunningState.STARTED, + expected_task_progress_group1=0, + expected_task_state_group2=RunningState.STARTED, + expected_task_progress_group2=0, + expected_run_state=RunningState.STARTED, + ), + id="reboot with running tasks", + ), + pytest.param( + RebootState( + dask_task_status=DaskClientTaskState.SUCCESS, + task_result=TaskOutputData.parse_obj({"whatever_output": 123}), + expected_task_state_group1=RunningState.SUCCESS, + expected_task_progress_group1=1, + expected_task_state_group2=RunningState.SUCCESS, + expected_task_progress_group2=1, + expected_run_state=RunningState.SUCCESS, + ), + id="reboot with completed tasks", + ), + ], +) +async def test_handling_scheduling_after_reboot( + with_disabled_auto_scheduling: None, + mocked_dask_client: mock.MagicMock, + aiopg_engine: aiopg.sa.engine.Engine, + running_project: RunningProject, + scheduler: BaseCompScheduler, + mocked_parse_output_data_fct: mock.MagicMock, + mocked_clean_task_output_fct: mock.MagicMock, + reboot_state: RebootState, +): + """After the dask client is rebooted, or that the director-v2 reboots the dv-2 internal scheduler + shall continue scheduling correctly. Even though the task might have continued to run + in the dask-scheduler.""" + + async def mocked_get_tasks_status(job_ids: list[str]) -> list[DaskClientTaskState]: + return [reboot_state.dask_task_status for j in job_ids] + + mocked_dask_client.get_tasks_status.side_effect = mocked_get_tasks_status + + async def mocked_get_task_result(_job_id: str) -> TaskOutputData: + if isinstance(reboot_state.task_result, Exception): + raise reboot_state.task_result + return reboot_state.task_result + + mocked_dask_client.get_task_result.side_effect = mocked_get_task_result + + await schedule_all_pipelines(scheduler) + # the status will be called once for all RUNNING tasks + mocked_dask_client.get_tasks_status.assert_called_once() + if reboot_state.expected_run_state in COMPLETED_STATES: + mocked_dask_client.get_task_result.assert_has_calls( + [ + mock.call(t.job_id) + for t in running_project.tasks + if t.node_class == NodeClass.COMPUTATIONAL + ], + any_order=True, + ) + else: + mocked_dask_client.get_task_result.assert_not_called() + if reboot_state.expected_run_state in [RunningState.ABORTED, RunningState.FAILED]: + # the clean up of the outputs should be done + mocked_clean_task_output_fct.assert_has_calls( + [ + mock.call( + mock.ANY, + running_project.project.prj_owner, + running_project.project.uuid, + t.node_id, + ) + for t in running_project.tasks + if t.node_class == NodeClass.COMPUTATIONAL + ], + any_order=True, + ) + else: + mocked_clean_task_output_fct.assert_not_called() + + await _assert_comp_tasks_db( + aiopg_engine, + running_project.project.uuid, + [ + running_project.tasks[1].node_id, + running_project.tasks[2].node_id, + running_project.tasks[3].node_id, + ], + expected_state=reboot_state.expected_task_state_group1, + expected_progress=reboot_state.expected_task_progress_group1, + ) + await _assert_comp_tasks_db( + aiopg_engine, + running_project.project.uuid, + [running_project.tasks[4].node_id], + expected_state=reboot_state.expected_task_state_group2, + expected_progress=reboot_state.expected_task_progress_group2, + ) + assert running_project.project.prj_owner + await _assert_comp_run_db( + aiopg_engine, running_project, reboot_state.expected_run_state + ) + + +async def test_handling_cancellation_of_jobs_after_reboot( + with_disabled_auto_scheduling: None, + mocked_dask_client: mock.MagicMock, + aiopg_engine: aiopg.sa.engine.Engine, + running_project_mark_for_cancellation: RunningProject, + scheduler: BaseCompScheduler, + mocked_parse_output_data_fct: mock.MagicMock, + mocked_clean_task_output_fct: mock.MagicMock, +): + """A running pipeline was cancelled by a user and the DV-2 was restarted BEFORE + It could actually cancel the task. On reboot the DV-2 shall recover + and actually cancel the pipeline properly""" + + # check initial status + await _assert_comp_run_db( + aiopg_engine, running_project_mark_for_cancellation, RunningState.STARTED + ) + await _assert_comp_tasks_db( + aiopg_engine, + running_project_mark_for_cancellation.project.uuid, + [t.node_id for t in running_project_mark_for_cancellation.tasks], + expected_state=RunningState.STARTED, + expected_progress=0, + ) + + # the backend shall report the tasks as running + async def mocked_get_tasks_status(job_ids: list[str]) -> list[DaskClientTaskState]: + return [DaskClientTaskState.PENDING_OR_STARTED for j in job_ids] + + mocked_dask_client.get_tasks_status.side_effect = mocked_get_tasks_status + # Running the scheduler, should actually cancel the run now + await schedule_all_pipelines(scheduler) + mocked_dask_client.abort_computation_task.assert_called() + assert mocked_dask_client.abort_computation_task.call_count == len( + [ + t.node_id + for t in running_project_mark_for_cancellation.tasks + if t.node_class == NodeClass.COMPUTATIONAL + ] + ) + # in the DB they are still running, they will be stopped in the next iteration + await _assert_comp_tasks_db( + aiopg_engine, + running_project_mark_for_cancellation.project.uuid, + [ + t.node_id + for t in running_project_mark_for_cancellation.tasks + if t.node_class == NodeClass.COMPUTATIONAL + ], + expected_state=RunningState.STARTED, + expected_progress=0, + ) + await _assert_comp_run_db( + aiopg_engine, running_project_mark_for_cancellation, RunningState.STARTED + ) + + # the backend shall now report the tasks as aborted + async def mocked_get_tasks_status_aborted( + job_ids: list[str], + ) -> list[DaskClientTaskState]: + return [DaskClientTaskState.ABORTED for j in job_ids] + + mocked_dask_client.get_tasks_status.side_effect = mocked_get_tasks_status_aborted + + async def _return_random_task_result(job_id) -> TaskOutputData: + raise TaskCancelledError + + mocked_dask_client.get_task_result.side_effect = _return_random_task_result + await schedule_all_pipelines(scheduler) + # now should be stopped + await _assert_comp_tasks_db( + aiopg_engine, + running_project_mark_for_cancellation.project.uuid, + [ + t.node_id + for t in running_project_mark_for_cancellation.tasks + if t.node_class == NodeClass.COMPUTATIONAL + ], + expected_state=RunningState.ABORTED, + expected_progress=1, + ) + await _assert_comp_run_db( + aiopg_engine, running_project_mark_for_cancellation, RunningState.ABORTED + ) + mocked_clean_task_output_fct.assert_called() + + +@pytest.fixture +def with_fast_service_heartbeat_s(monkeypatch: pytest.MonkeyPatch) -> int: + seconds = 1 + monkeypatch.setenv("SERVICE_TRACKING_HEARTBEAT", f"{seconds}") + return seconds + + +async def test_running_pipeline_triggers_heartbeat( + with_disabled_auto_scheduling: None, + with_fast_service_heartbeat_s: int, + mocked_dask_client: mock.MagicMock, + scheduler: BaseCompScheduler, + aiopg_engine: aiopg.sa.engine.Engine, + published_project: PublishedProject, + resource_tracking_rabbit_client_parser: mock.AsyncMock, + run_metadata: RunMetadataDict, +): + _mock_send_computation_tasks(published_project.tasks, mocked_dask_client) + expected_published_tasks = await _assert_start_pipeline( + aiopg_engine, published_project, scheduler, run_metadata + ) + # ------------------------------------------------------------------------------- + # 1. first run will move comp_tasks to PENDING so the worker can take them + expected_pending_tasks = await _assert_schedule_pipeline_PENDING( + aiopg_engine, + published_project, + expected_published_tasks, + mocked_dask_client, + scheduler, + ) + # ------------------------------------------------------------------------------- + # 2. the "worker" starts processing a task + exp_started_task = expected_pending_tasks[0] + expected_pending_tasks.remove(exp_started_task) + + async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskState]: + return [ + ( + DaskClientTaskState.PENDING_OR_STARTED + if job_id == exp_started_task.job_id + else DaskClientTaskState.PENDING + ) + for job_id in job_ids + ] + + mocked_dask_client.get_tasks_status.side_effect = _return_1st_task_running + assert exp_started_task.job_id + assert published_project.project.prj_owner + await _trigger_progress_event( + scheduler, + job_id=exp_started_task.job_id, + user_id=published_project.project.prj_owner, + project_id=exp_started_task.project_id, + node_id=exp_started_task.node_id, + ) + await schedule_all_pipelines(scheduler) + + messages = await _assert_message_received( + resource_tracking_rabbit_client_parser, + 1, + RabbitResourceTrackingStartedMessage.parse_raw, + ) + assert messages[0].node_id == exp_started_task.node_id + + # ------------------------------------------------------------------------------- + # 3. wait a bit and run again we should get another heartbeat, but only one! + await asyncio.sleep(with_fast_service_heartbeat_s + 1) + await schedule_all_pipelines(scheduler) + await schedule_all_pipelines(scheduler) + messages = await _assert_message_received( + resource_tracking_rabbit_client_parser, + 1, + RabbitResourceTrackingHeartbeatMessage.parse_raw, + ) + assert isinstance(messages[0], RabbitResourceTrackingHeartbeatMessage) + + # ------------------------------------------------------------------------------- + # 4. wait a bit and run again we should get another heartbeat, but only one! + await asyncio.sleep(with_fast_service_heartbeat_s + 1) + await schedule_all_pipelines(scheduler) + await schedule_all_pipelines(scheduler) + messages = await _assert_message_received( + resource_tracking_rabbit_client_parser, + 1, + RabbitResourceTrackingHeartbeatMessage.parse_raw, + ) + assert isinstance(messages[0], RabbitResourceTrackingHeartbeatMessage) + + +@pytest.fixture +async def mocked_get_or_create_cluster(mocker: MockerFixture) -> mock.Mock: + return mocker.patch( + "simcore_service_director_v2.modules.comp_scheduler._scheduler_dask.get_or_create_on_demand_cluster", + autospec=True, + ) + + +async def test_pipeline_with_on_demand_cluster_with_not_ready_backend_waits( + with_disabled_auto_scheduling: None, + scheduler: BaseCompScheduler, + aiopg_engine: aiopg.sa.engine.Engine, + published_project: PublishedProject, + run_metadata: RunMetadataDict, + mocked_get_or_create_cluster: mock.Mock, + faker: Faker, +): + mocked_get_or_create_cluster.side_effect = ( + ComputationalBackendOnDemandNotReadyError( + eta=faker.time_delta(datetime.timedelta(hours=1)) + ) + ) + # running the pipeline will trigger a call to the clusters-keeper + assert published_project.project.prj_owner + await scheduler.run_new_pipeline( + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + run_metadata=run_metadata, + use_on_demand_clusters=True, + ) + + # we ask to use an on-demand cluster, therefore the tasks are published first + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PUBLISHED) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [t.node_id for t in published_project.tasks], + expected_state=RunningState.PUBLISHED, + expected_progress=None, + ) + mocked_get_or_create_cluster.assert_not_called() + # now it should switch to waiting + expected_waiting_tasks = [ + published_project.tasks[1], + published_project.tasks[3], + ] + await schedule_all_pipelines(scheduler) + mocked_get_or_create_cluster.assert_called() + assert mocked_get_or_create_cluster.call_count == 1 + mocked_get_or_create_cluster.reset_mock() + await _assert_comp_run_db( + aiopg_engine, published_project, RunningState.WAITING_FOR_CLUSTER + ) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [t.node_id for t in expected_waiting_tasks], + expected_state=RunningState.WAITING_FOR_CLUSTER, + expected_progress=None, + ) + # again will trigger the same response + await schedule_all_pipelines(scheduler) + mocked_get_or_create_cluster.assert_called() + assert mocked_get_or_create_cluster.call_count == 1 + mocked_get_or_create_cluster.reset_mock() + await _assert_comp_run_db( + aiopg_engine, published_project, RunningState.WAITING_FOR_CLUSTER + ) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [t.node_id for t in expected_waiting_tasks], + expected_state=RunningState.WAITING_FOR_CLUSTER, + expected_progress=None, + ) + + +@pytest.mark.parametrize( + "get_or_create_exception", + [ClustersKeeperNotAvailableError], +) +async def test_pipeline_with_on_demand_cluster_with_no_clusters_keeper_fails( + with_disabled_auto_scheduling: None, + scheduler: BaseCompScheduler, + aiopg_engine: aiopg.sa.engine.Engine, + published_project: PublishedProject, + run_metadata: RunMetadataDict, + mocked_get_or_create_cluster: mock.Mock, + get_or_create_exception: Exception, +): + mocked_get_or_create_cluster.side_effect = get_or_create_exception + # running the pipeline will trigger a call to the clusters-keeper + assert published_project.project.prj_owner + await scheduler.run_new_pipeline( + user_id=published_project.project.prj_owner, + project_id=published_project.project.uuid, + cluster_id=DEFAULT_CLUSTER_ID, + run_metadata=run_metadata, + use_on_demand_clusters=True, + ) + + # we ask to use an on-demand cluster, therefore the tasks are published first + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.PUBLISHED) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [t.node_id for t in published_project.tasks], + expected_state=RunningState.PUBLISHED, + expected_progress=None, + ) + # now it should switch to failed, the run still runs until the next iteration + expected_failed_tasks = [ + published_project.tasks[1], + published_project.tasks[3], + ] + await schedule_all_pipelines(scheduler) + mocked_get_or_create_cluster.assert_called() + assert mocked_get_or_create_cluster.call_count == 1 + mocked_get_or_create_cluster.reset_mock() + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.FAILED) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [t.node_id for t in expected_failed_tasks], + expected_state=RunningState.FAILED, + expected_progress=1.0, + ) + # again will not re-trigger the call to clusters-keeper + await schedule_all_pipelines(scheduler) + mocked_get_or_create_cluster.assert_not_called() + await _assert_comp_run_db(aiopg_engine, published_project, RunningState.FAILED) + await _assert_comp_tasks_db( + aiopg_engine, + published_project.project.uuid, + [t.node_id for t in expected_failed_tasks], + expected_state=RunningState.FAILED, + expected_progress=1.0, + ) diff --git a/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py b/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py index 495023dbda2a..8fbc2d9006b7 100644 --- a/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py +++ b/services/director-v2/tests/unit/with_dbs/test_modules_comp_scheduler_dask_scheduler.py @@ -58,9 +58,9 @@ ComputationalBackendTaskNotFoundError, ComputationalBackendTaskResultsNotReadyError, ComputationalSchedulerChangedError, + ComputationalSchedulerError, ConfigurationError, PipelineNotFoundError, - SchedulerError, ) from simcore_service_director_v2.core.settings import AppSettings from simcore_service_director_v2.models.comp_pipelines import CompPipelineAtDB @@ -1107,7 +1107,7 @@ async def test_handling_of_disconnected_dask_scheduler( aiopg_engine: aiopg.sa.engine.Engine, mocker: MockerFixture, published_project: PublishedProject, - backend_error: SchedulerError, + backend_error: ComputationalSchedulerError, run_metadata: RunMetadataDict, ): # this will create a non connected backend issue that will trigger re-connection