diff --git a/tests/utils/test_compute_node.py b/tests/utils/test_compute_node.py index 28a0d988..05b29e5d 100644 --- a/tests/utils/test_compute_node.py +++ b/tests/utils/test_compute_node.py @@ -5,8 +5,9 @@ import re import subprocess from logging import getLogger as get_logger +from pathlib import Path from typing import Callable -from unittest.mock import ANY, AsyncMock, Mock +from unittest.mock import AsyncMock, Mock import pytest import pytest_asyncio @@ -298,16 +299,19 @@ async def test_del_computenode( await login_node_v2.run_async(f"scancel {job_id}") -@pytest.mark.asyncio -@pytest.mark.parametrize("close_async", [False, True], ids=["sync", "async"]) -async def test_using_closed_compute_node_raises_error( - login_node_v2: RemoteV2, close_async: bool +@pytest_asyncio.fixture(params=[False, True], ids=["sync", "async"]) +async def mock_closed_compute_node( + request: pytest.FixtureRequest, + ssh_config_file: Path, ): + """Cheaply constructs a *closed* ComputeNode, without launching any jobs.""" + close_async: bool = request.param + fake_job_id = 1234 def _mock_run(command: str, *args, input: str | None = None, **kwargs): if input == "echo $SLURMD_NODENAME\n": - return subprocess.CompletedProcess(command, 0, "bobobo", "") + return subprocess.CompletedProcess(command, 0, "cn-a001", "") if command == f"scancel {fake_job_id}": return subprocess.CompletedProcess(command, 0, "", "") # Unexpected command. @@ -319,14 +323,12 @@ async def _mock_run_async(command: str, *args, input: str | None = None, **kwarg # Unexpected command. assert False, (command, input) - mock_run = Mock(spec=login_node_v2.run, side_effect=_mock_run) - mock_run_async = AsyncMock( - spec=login_node_v2.run_async, side_effect=_mock_run_async - ) + mock_run = Mock(spec=RemoteV2.run, side_effect=_mock_run) + mock_run_async = AsyncMock(spec=RemoteV2.run_async, side_effect=_mock_run_async) mock_login_node = Mock( - wraps=login_node_v2, - hostname=login_node_v2.hostname, - ssh_config_path=login_node_v2.ssh_config_path, + spec=RemoteV2, + hostname="mila", + ssh_config_path=ssh_config_file, ) mock_login_node.configure_mock( run=mock_run, @@ -339,16 +341,27 @@ async def _mock_run_async(command: str, *args, input: str | None = None, **kwarg if close_async: await compute_node.close_async() - _other_kwargs = dict(display=ANY, hide=ANY, warn=ANY) - mock_run_async.assert_called_once_with( - command=f"scancel {fake_job_id}", **_other_kwargs - ) + mock_run_async.assert_called_once() + assert mock_run_async.mock_calls[0].args[0] == f"scancel {fake_job_id}" else: compute_node.close() - mock_run.assert_called_once_with(command=f"scancel {fake_job_id}") + # bug? this doesn't work but the output is identical? + # mock_run.assert_called_once_with(f"scancel {fake_job_id}") + mock_run.assert_called_once() + assert mock_run.mock_calls[0].args[0] == f"scancel {fake_job_id}" mock_run.reset_mock() mock_run_async.reset_mock() + return compute_node + +@pytest.mark.asyncio +async def test_using_closed_compute_node_raises_error( + mock_closed_compute_node: ComputeNode, +): + compute_node = mock_closed_compute_node + assert isinstance(compute_node.login_node, Mock) + mock_run: Mock = mock_closed_compute_node.login_node.run # type: ignore + mock_run_async: Mock = compute_node.login_node.run_async # type: ignore for method in ["run", "run_async", "get_output", "get_output_async"]: with pytest.raises(JobNotRunningError): output_or_coroutine = getattr(compute_node, method)("echo OK")