Skip to content

Commit

Permalink
Add basic retry loop to account for max_submit functionality
Browse files Browse the repository at this point in the history
Use while retry to iterate from running to waiting states. It includes a
simple test to check if job has started 3 times.
  • Loading branch information
xjules committed Dec 12, 2023
1 parent 374c4b7 commit a9e4798
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 39 deletions.
80 changes: 50 additions & 30 deletions src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import logging
from enum import Enum
from typing import TYPE_CHECKING

Expand All @@ -16,6 +17,8 @@
from ert.ensemble_evaluator._builder._realization import Realization
from ert.scheduler.scheduler import Scheduler

logger = logging.getLogger(__name__)


class State(str, Enum):
WAITING = "WAITING"
Expand Down Expand Up @@ -66,36 +69,53 @@ async def __call__(
self, start: asyncio.Event, sem: asyncio.BoundedSemaphore
) -> None:
await start.wait()
await sem.acquire()

try:
await self._send(State.SUBMITTING)
await self.driver.submit(
self.real.iens, self.real.job_script, cwd=self.real.run_arg.runpath
)

await self._send(State.STARTING)
await self.started.wait()

await self._send(State.RUNNING)
returncode = await self.returncode
if (
returncode == 0
and forward_model_ok(self.real.run_arg).status
== LoadStatus.LOAD_SUCCESSFUL
):
await self._send(State.COMPLETED)
else:
await self._send(State.FAILED)

except asyncio.CancelledError:
await self._send(State.ABORTING)
await self.driver.kill(self.iens)

await self.aborted.wait()
await self._send(State.ABORTED)
finally:
sem.release()
retries = 0
retry: bool = True
while retry:
retry = False
await sem.acquire()
try:
await self._send(State.SUBMITTING)
await self.driver.submit(
self.real.iens, self.real.job_script, cwd=self.real.run_arg.runpath
)

await self._send(State.STARTING)
await self.started.wait()

await self._send(State.RUNNING)
returncode = await self.returncode
if (
returncode == 0
and forward_model_ok(self.real.run_arg).status
== LoadStatus.LOAD_SUCCESSFUL
):
await self._send(State.COMPLETED)
else:
await self._send(State.FAILED)
retries += 1
retry = retries < self._scheduler._max_submit
if retry:
message = f"Realization: {self.iens} failed, resubmitting"
logger.warning(message)
print(message)
else:
message = (
f"Realization: {self.iens} "
"failed after reaching max submit "
f"{self._scheduler._max_submit}!"
)
print(message)
logger.error(message)

except asyncio.CancelledError:
await self._send(State.ABORTING)
await self.driver.kill(self.iens)

await self.aborted.wait()
await self._send(State.ABORTED)
finally:
sem.release()

async def _send(self, state: State) -> None:
status = STATE_TO_LEGACY[state]
Expand Down
10 changes: 2 additions & 8 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,7 @@
import ssl
import threading
from dataclasses import asdict
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
MutableMapping,
Optional,
)
from typing import TYPE_CHECKING, Any, Callable, Iterable, MutableMapping, Optional

from pydantic.dataclasses import dataclass
from websockets import Headers
Expand Down Expand Up @@ -51,6 +44,7 @@ def __init__(self, driver: Optional[Driver] = None) -> None:
self._tasks: MutableMapping[int, asyncio.Task[None]] = {}

self._events: Optional[asyncio.Queue[Any]] = None
self._max_submit: int = 2

self._ee_uri = ""
self._ens_id = ""
Expand Down
18 changes: 17 additions & 1 deletion tests/unit_tests/scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import json
import os
import shutil
from dataclasses import asdict
from pathlib import Path
from textwrap import dedent
from typing import Sequence
from unittest.mock import patch

import pytest

Expand Down Expand Up @@ -108,3 +110,17 @@ async def test_cancel(tmp_path: Path, realization):

assert (tmp_path / "a").exists()
assert not (tmp_path / "b").exists()


async def test_that_max_submit_was_reached(tmp_path: Path, realization):
script = "[ -f cnt ] && echo $(( $(cat cnt) + 1 )) > cnt || echo 1 > cnt; exit 1"
step = create_bash_step(script)
realization.forward_models = [step]
sch = scheduler.Scheduler()
sch._max_submit = 3
sch.add_realization(realization, callback_timeout=lambda _: None)
create_jobs_json(tmp_path, [step])
sch.add_dispatch_information_to_jobs_file()
scheduler_task = asyncio.create_task(sch.execute())
await scheduler_task
assert (tmp_path / "cnt").read_text() == "3\n"

0 comments on commit a9e4798

Please sign in to comment.