Skip to content

Commit

Permalink
Improve test_using_closed_compute_node test
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Apr 23, 2024
1 parent 021a434 commit ae46a98
Showing 1 changed file with 31 additions and 18 deletions.
49 changes: 31 additions & 18 deletions tests/utils/test_compute_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import re
import subprocess
from logging import getLogger as get_logger
from pathlib import Path
from typing import Callable
from unittest.mock import ANY, AsyncMock, Mock
from unittest.mock import AsyncMock, Mock

import pytest
import pytest_asyncio
Expand Down Expand Up @@ -298,16 +299,19 @@ async def test_del_computenode(
await login_node_v2.run_async(f"scancel {job_id}")


@pytest.mark.asyncio
@pytest.mark.parametrize("close_async", [False, True], ids=["sync", "async"])
async def test_using_closed_compute_node_raises_error(
login_node_v2: RemoteV2, close_async: bool
@pytest_asyncio.fixture(params=[False, True], ids=["sync", "async"])
async def mock_closed_compute_node(
request: pytest.FixtureRequest,
ssh_config_file: Path,
):
"""Cheaply constructs a *closed* ComputeNode, without launching any jobs."""
close_async: bool = request.param

fake_job_id = 1234

def _mock_run(command: str, *args, input: str | None = None, **kwargs):
if input == "echo $SLURMD_NODENAME\n":
return subprocess.CompletedProcess(command, 0, "bobobo", "")
return subprocess.CompletedProcess(command, 0, "cn-a001", "")
if command == f"scancel {fake_job_id}":
return subprocess.CompletedProcess(command, 0, "", "")
# Unexpected command.
Expand All @@ -319,14 +323,12 @@ async def _mock_run_async(command: str, *args, input: str | None = None, **kwarg
# Unexpected command.
assert False, (command, input)

mock_run = Mock(spec=login_node_v2.run, side_effect=_mock_run)
mock_run_async = AsyncMock(
spec=login_node_v2.run_async, side_effect=_mock_run_async
)
mock_run = Mock(spec=RemoteV2.run, side_effect=_mock_run)
mock_run_async = AsyncMock(spec=RemoteV2.run_async, side_effect=_mock_run_async)
mock_login_node = Mock(
wraps=login_node_v2,
hostname=login_node_v2.hostname,
ssh_config_path=login_node_v2.ssh_config_path,
spec=RemoteV2,
hostname="mila",
ssh_config_path=ssh_config_file,
)
mock_login_node.configure_mock(
run=mock_run,
Expand All @@ -339,16 +341,27 @@ async def _mock_run_async(command: str, *args, input: str | None = None, **kwarg

if close_async:
await compute_node.close_async()
_other_kwargs = dict(display=ANY, hide=ANY, warn=ANY)
mock_run_async.assert_called_once_with(
command=f"scancel {fake_job_id}", **_other_kwargs
)
mock_run_async.assert_called_once()
assert mock_run_async.mock_calls[0].args[0] == f"scancel {fake_job_id}"
else:
compute_node.close()
mock_run.assert_called_once_with(command=f"scancel {fake_job_id}")
# bug? this doesn't work but the output is identical?
# mock_run.assert_called_once_with(f"scancel {fake_job_id}")
mock_run.assert_called_once()
assert mock_run.mock_calls[0].args[0] == f"scancel {fake_job_id}"
mock_run.reset_mock()
mock_run_async.reset_mock()
return compute_node


@pytest.mark.asyncio
async def test_using_closed_compute_node_raises_error(
mock_closed_compute_node: ComputeNode,
):
compute_node = mock_closed_compute_node
assert isinstance(compute_node.login_node, Mock)
mock_run: Mock = mock_closed_compute_node.login_node.run # type: ignore
mock_run_async: Mock = compute_node.login_node.run_async # type: ignore
for method in ["run", "run_async", "get_output", "get_output_async"]:
with pytest.raises(JobNotRunningError):
output_or_coroutine = getattr(compute_node, method)("echo OK")
Expand Down

0 comments on commit ae46a98

Please sign in to comment.