Skip to content

Commit

Permalink
renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
sanderegg committed Nov 26, 2024
1 parent 862173e commit 84c8d8f
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ async def _process_completed_tasks(
) -> None:
"""process tasks from the 3rd party backend"""

async def schedule_pipeline(
async def apply(
self,
*,
user_id: UserID,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ def _unique_key_builder(
async def _exclusively_schedule_pipeline(
app: FastAPI, *, user_id: UserID, project_id: ProjectID, iteration: Iteration
) -> None:
await _get_scheduler_worker(app).schedule_pipeline(
await _get_scheduler_worker(app).apply(
user_id=user_id,
project_id=project_id,
iteration=iteration,
)


async def _handle_distributed_pipeline(app: FastAPI, data: bytes) -> bool:
async def _handle_apply_distributed_schedule(app: FastAPI, data: bytes) -> bool:

with log_context(_logger, logging.DEBUG, msg="handling scheduling"):
to_schedule_pipeline = SchedulePipelineRabbitMessage.model_validate_json(data)
Expand All @@ -65,7 +65,7 @@ async def setup_worker(app: FastAPI) -> None:
rabbitmq_client = get_rabbitmq_client(app)
await rabbitmq_client.subscribe(
SchedulePipelineRabbitMessage.get_channel_name(),
functools.partial(_handle_distributed_pipeline, app),
functools.partial(_handle_apply_distributed_schedule, app),
exclusive_queue=False,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ async def _return_tasks_pending(job_ids: list[str]) -> list[DaskClientTaskState]

mocked_dask_client.get_tasks_status.side_effect = _return_tasks_pending
assert published_project.project.prj_owner
await scheduler.schedule_pipeline(
await scheduler.apply(
user_id=published_project.project.prj_owner,
project_id=published_project.project.uuid,
iteration=1,
Expand Down Expand Up @@ -273,7 +273,7 @@ async def _return_tasks_pending(job_ids: list[str]) -> list[DaskClientTaskState]
mocked_dask_client.get_tasks_status.assert_not_called()
mocked_dask_client.get_task_result.assert_not_called()
# there is a second run of the scheduler to move comp_runs to pending, the rest does not change
await scheduler.schedule_pipeline(
await scheduler.apply(
user_id=published_project.project.prj_owner,
project_id=published_project.project.uuid,
iteration=1,
Expand Down Expand Up @@ -464,7 +464,7 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta
]

mocked_dask_client.get_tasks_status.side_effect = _return_1st_task_running
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
Expand Down Expand Up @@ -514,7 +514,7 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta
node_id=exp_started_task.node_id,
)

await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
Expand Down Expand Up @@ -591,7 +591,7 @@ async def _return_random_task_result(job_id) -> TaskOutputData:
return TaskOutputData.model_validate({"out_1": None, "out_2": 45})

mocked_dask_client.get_task_result.side_effect = _return_random_task_result
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
Expand Down Expand Up @@ -702,7 +702,7 @@ async def _return_2nd_task_running(job_ids: list[str]) -> list[DaskClientTaskSta
project_id=exp_started_task.project_id,
node_id=exp_started_task.node_id,
)
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
Expand Down Expand Up @@ -758,7 +758,7 @@ async def _return_2nd_task_failed(job_ids: list[str]) -> list[DaskClientTaskStat

mocked_dask_client.get_tasks_status.side_effect = _return_2nd_task_failed
mocked_dask_client.get_task_result.side_effect = None
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
Expand Down Expand Up @@ -822,7 +822,7 @@ async def _return_3rd_task_success(job_ids: list[str]) -> list[DaskClientTaskSta
mocked_dask_client.get_task_result.side_effect = _return_random_task_result

# trigger the scheduler, it should switch to FAILED, as we are done
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
Expand Down Expand Up @@ -934,7 +934,7 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta
assert isinstance(mocked_dask_client.send_computation_tasks, mock.Mock)
assert isinstance(mocked_dask_client.get_task_result, mock.Mock)
mocked_dask_client.get_tasks_status.side_effect = _return_1st_task_running
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
Expand Down Expand Up @@ -984,7 +984,7 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta
node_id=exp_started_task.node_id,
)

await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
Expand Down Expand Up @@ -1133,7 +1133,7 @@ async def test_broken_pipeline_configuration_is_not_scheduled_and_aborted(
#
# Trigger scheduling manually. since the pipeline is broken, it shall be aborted
#
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_entry.user_id,
project_id=run_entry.project_uuid,
iteration=run_entry.iteration,
Expand Down Expand Up @@ -1273,7 +1273,7 @@ async def test_handling_of_disconnected_scheduler_dask(
project_id=published_project.project.uuid,
)
# we ensure the scheduler was run
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
Expand All @@ -1291,7 +1291,7 @@ async def test_handling_of_disconnected_scheduler_dask(
expected_progress=1,
)
# then we have another scheduler run
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
Expand Down Expand Up @@ -1413,7 +1413,7 @@ async def mocked_get_task_result(_job_id: str) -> TaskOutputData:

mocked_dask_client.get_task_result.side_effect = mocked_get_task_result
assert running_project.project.prj_owner
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=running_project.project.prj_owner,
project_id=running_project.project.uuid,
iteration=1,
Expand Down Expand Up @@ -1522,7 +1522,7 @@ async def mocked_get_tasks_status(job_ids: list[str]) -> list[DaskClientTaskStat

mocked_dask_client.get_tasks_status.side_effect = mocked_get_tasks_status
# Running the scheduler, should actually cancel the run now
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
Expand Down Expand Up @@ -1571,7 +1571,7 @@ async def _return_random_task_result(job_id) -> TaskOutputData:
raise TaskCancelledError

mocked_dask_client.get_task_result.side_effect = _return_random_task_result
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
Expand Down Expand Up @@ -1662,7 +1662,7 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta
project_id=exp_started_task.project_id,
node_id=exp_started_task.node_id,
)
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
Expand All @@ -1678,12 +1678,12 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta
# -------------------------------------------------------------------------------
# 3. wait a bit and run again we should get another heartbeat, but only one!
await asyncio.sleep(with_fast_service_heartbeat_s + 1)
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
)
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
Expand All @@ -1698,12 +1698,12 @@ async def _return_1st_task_running(job_ids: list[str]) -> list[DaskClientTaskSta
# -------------------------------------------------------------------------------
# 4. wait a bit and run again we should get another heartbeat, but only one!
await asyncio.sleep(with_fast_service_heartbeat_s + 1)
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
)
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
Expand Down Expand Up @@ -1776,7 +1776,7 @@ async def test_pipeline_with_on_demand_cluster_with_not_ready_backend_waits(
published_project.tasks[1],
published_project.tasks[3],
]
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
Expand All @@ -1801,7 +1801,7 @@ async def test_pipeline_with_on_demand_cluster_with_not_ready_backend_waits(
expected_progress=None,
)
# again will trigger the same response
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
Expand Down Expand Up @@ -1880,7 +1880,7 @@ async def test_pipeline_with_on_demand_cluster_with_no_clusters_keeper_fails(
published_project.tasks[1],
published_project.tasks[3],
]
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
Expand All @@ -1905,7 +1905,7 @@ async def test_pipeline_with_on_demand_cluster_with_no_clusters_keeper_fails(
expected_progress=1.0,
)
# again will not re-trigger the call to clusters-keeper
await scheduler_api.schedule_pipeline(
await scheduler_api.apply(
user_id=run_in_db.user_id,
project_id=run_in_db.project_uuid,
iteration=run_in_db.iteration,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def test_worker_properly_autocalls_scheduler_api(
@pytest.fixture
async def mocked_scheduler_api(mocker: MockerFixture) -> mock.Mock:
return mocker.patch(
"simcore_service_director_v2.modules.comp_scheduler._scheduler_base.BaseCompScheduler.schedule_pipeline"
"simcore_service_director_v2.modules.comp_scheduler._scheduler_base.BaseCompScheduler.apply"
)


Expand All @@ -89,16 +89,26 @@ async def test_worker_scheduling_parallelism(
):
with_disabled_auto_scheduling.assert_called_once()

mocked_scheduler_api.side_effect = asyncio.sleep(10)

published_project = await publish_project()
assert published_project.project.prj_owner
await run_new_pipeline(
initialized_app,
user_id=published_project.project.prj_owner,
project_id=published_project.project.uuid,
cluster_id=DEFAULT_CLUSTER_ID,
run_metadata=run_metadata,
use_on_demand_clusters=False,
async def _side_effect(*args, **kwargs):
await asyncio.sleep(10)

mocked_scheduler_api.side_effect = _side_effect

async def _project_pipeline_creation_workflow():
published_project = await publish_project()
assert published_project.project.prj_owner
await run_new_pipeline(
initialized_app,
user_id=published_project.project.prj_owner,
project_id=published_project.project.uuid,
cluster_id=DEFAULT_CLUSTER_ID,
run_metadata=run_metadata,
use_on_demand_clusters=False,
)

num_concurrent_calls = 10
await asyncio.gather(
*(_project_pipeline_creation_workflow() for _ in range(num_concurrent_calls))
)
mocked_scheduler_api.assert_called_once()
mocked_scheduler_api.assert_called()
assert mocked_scheduler_api.call_count == num_concurrent_calls

0 comments on commit 84c8d8f

Please sign in to comment.