diff --git a/src/ert/ensemble_evaluator/monitor.py b/src/ert/ensemble_evaluator/monitor.py index 55449bc620b..ab2ddb4c495 100644 --- a/src/ert/ensemble_evaluator/monitor.py +++ b/src/ert/ensemble_evaluator/monitor.py @@ -38,16 +38,14 @@ def __init__(self, ee_con_info: "EvaluatorConnectionInfo") -> None: self._event_queue: asyncio.Queue[Union[Event, EventSentinel]] = asyncio.Queue() self._connection: Optional[ClientConnection] = None self._receiver_task: Optional[asyncio.Task[None]] = None - self._connected: asyncio.Event = asyncio.Event() + self._connected: asyncio.Future[None] = asyncio.Future() self._connection_timeout: float = 120.0 self._receiver_timeout: float = 60.0 async def __aenter__(self) -> "Monitor": self._receiver_task = asyncio.create_task(self._receiver()) try: - await asyncio.wait_for( - self._connected.wait(), timeout=self._connection_timeout - ) + await asyncio.wait_for(self._connected, timeout=self._connection_timeout) except asyncio.TimeoutError as exc: msg = "Couldn't establish connection with the ensemble evaluator!" logger.error(msg) @@ -64,7 +62,6 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None self._receiver_task, return_exceptions=True, ) - if self._connection: await self._connection.close() @@ -127,13 +124,16 @@ async def _receiver(self) -> None: headers = Headers() if self._ee_con_info.token: headers["token"] = self._ee_con_info.token - - await wait_for_evaluator( - base_url=self._ee_con_info.url, - token=self._ee_con_info.token, - cert=self._ee_con_info.cert, - timeout=5, - ) + try: + await wait_for_evaluator( + base_url=self._ee_con_info.url, + token=self._ee_con_info.token, + cert=self._ee_con_info.cert, + timeout=5, + ) + except Exception as e: + self._connected.set_exception(e) + return async for conn in connect( self._ee_con_info.client_uri, ssl=tls, @@ -147,13 +147,13 @@ async def _receiver(self) -> None: ): try: self._connection = conn - self._connected.set() + self._connected.set_result(None) async for raw_msg in self._connection: event = event_from_json(raw_msg) await self._event_queue.put(event) except (ConnectionRefusedError, ConnectionClosed, ClientError) as exc: self._connection = None - self._connected.clear() + self._connected = asyncio.Future() logger.debug( f"Monitor connection to EnsembleEvaluator went down, reconnecting: {exc}" ) diff --git a/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py b/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py index 2a201a46c21..e4615649c72 100644 --- a/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py +++ b/tests/ert/unit_tests/ensemble_evaluator/test_monitor.py @@ -1,12 +1,15 @@ import asyncio import logging from http import HTTPStatus +from typing import NoReturn from urllib.parse import urlparse import pytest from websockets.asyncio import server from websockets.exceptions import ConnectionClosedOK +import ert +import ert.ensemble_evaluator from _ert.events import EEUserCancel, EEUserDone, event_from_json from ert.ensemble_evaluator import Monitor from ert.ensemble_evaluator.config import EvaluatorConnectionInfo @@ -135,3 +138,22 @@ async def test_that_monitor_can_emit_heartbeats(unused_tcp_port): set_when_done.set() # shuts down websocket server await websocket_server_task + + +@pytest.mark.timeout(10) +async def test_that_monitor_will_raise_exception_if_wait_for_evaluator_fails( + monkeypatch, +): + async def mock_failing_wait_for_evaluator(*args, **kwargs) -> NoReturn: + raise ValueError() + + monkeypatch.setattr( + ert.ensemble_evaluator.monitor, + "wait_for_evaluator", + mock_failing_wait_for_evaluator, + ) + ee_con_info = EvaluatorConnectionInfo("") + + with pytest.raises(ValueError): + async with Monitor(ee_con_info): + pass