From 021a434eb870ef0ce0249bbb45f0a6f0b00047b1 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Tue, 23 Apr 2024 18:00:22 -0400 Subject: [PATCH] Improve tests for LocalV2, RemoteV2, ComputeNode Signed-off-by: Fabrice Normandin --- tests/utils/runner_tests.py | 337 ++++++++++++++++++++----------- tests/utils/test_compute_node.py | 63 +++--- 2 files changed, 253 insertions(+), 147 deletions(-) diff --git a/tests/utils/runner_tests.py b/tests/utils/runner_tests.py index da2dba89..e08cf1c1 100644 --- a/tests/utils/runner_tests.py +++ b/tests/utils/runner_tests.py @@ -6,9 +6,11 @@ import re import subprocess import time +from unittest.mock import AsyncMock, Mock import pytest +from milatools.utils.remote_v1 import Hide from milatools.utils.remote_v2 import RemoteV2 from milatools.utils.runner import Runner @@ -18,11 +20,11 @@ class RunnerTests(abc.ABC): Subclasses have to implement these methods: - the `runner` fixture which should ideally be class or session-scoped; - - the `command_output_err` fixture should return a tuple containing 3 items: + - the `command_with_result` fixture should return a tuple containing 3 items: - The command to run successfully - the expected stdout or a `re.Pattern` to match against the stdout - the expected stderr or a `re.Pattern` to match against the stderr - - the `command_exception_err` fixture should return a tuple containing 3 items: + - the `command_with_exception_and_stderr` fixture should return a tuple containing 3 items: (The command to run uncessfully, the expected exception, the expected stderr). """ @@ -35,13 +37,21 @@ def runner(self) -> Runner: scope="class", params=[ ("echo OK", "OK", ""), - # TODO: Test the proper escaping of variables. - # ("echo $USER", "todo", ""), ], ) - def command_output_err(self, request: pytest.FixtureRequest): + def command_with_result(self, request: pytest.FixtureRequest): + """Parametrized fixture for commands that are expected to raise an exception. + + These should be a tuple of: + - the command to run + - the expected stdout or a regular expression that matches stdout; + - the expected stderr or a regular expression that matches stderr + + Subclasses should override this fixture to provide more commands to run. + """ return request.param + # @abc.abstractmethod @pytest.fixture( scope="class", params=[ @@ -52,111 +62,160 @@ def command_output_err(self, request: pytest.FixtureRequest): ), ], ) - def command_exception_err(self, request: pytest.FixtureRequest): + def command_with_exception_and_stderr(self, request: pytest.FixtureRequest): + """Parametrized fixture for commands that are expected to raise an exception. + + These should be a tuple of: + - the command to run + - The type of exception that is expected to be raised + - a string or regular expression that matches stderr. + + Subclasses should override this fixture to provide more commands to run. + """ return request.param + @pytest.mark.parametrize("use_async", [False, True], ids=["sync", "async"]) @pytest.mark.parametrize("display", [True, False]) @pytest.mark.parametrize("hide", [True, False, "out", "err", "stdout", "stderr"]) - def test_run( + @pytest.mark.asyncio + async def test_run( self, runner: Runner, - command_output_err: tuple[str, str | re.Pattern, str | re.Pattern], - hide: bool, + command_with_result: tuple[str, str | re.Pattern, str | re.Pattern], + hide: Hide, display: bool, capsys: pytest.CaptureFixture, + caplog: pytest.LogCaptureFixture, + use_async: bool, ): - command, expected_output, expected_err = command_output_err - result = runner.run(command, display=display, hide=hide) + command, expected_output, expected_err = command_with_result - if isinstance(expected_output, re.Pattern): - assert expected_output.search(result.stdout) + if use_async: + result = await runner.run_async(command, display=display, hide=hide) else: - assert result.stdout.strip() == expected_output - - if isinstance(expected_err, re.Pattern): - assert expected_err.search(result.stderr) - else: - assert result.stderr.strip() == expected_err - - printed_output, printed_err = capsys.readouterr() - assert isinstance(printed_output, str) - assert isinstance(printed_err, str) - - assert (f"({runner.hostname}) $ {command}" in printed_output) == display - - if result.stdout: - stdout_should_be_printed = hide not in [ - True, - "out", - "stdout", - ] - stdout_was_printed = result.stdout in printed_output - assert stdout_was_printed == stdout_should_be_printed + result = runner.run(command, display=display, hide=hide) + + self._shared_run_checks( + runner=runner, + hide=hide, + display=display, + capsys=capsys, + caplog=caplog, + command=command, + expected_output=expected_output, + expected_err=expected_err, + result=result, + warn=False, + ) - if result.stderr: - error_should_be_printed = hide not in [ - True, - "err", - "stderr", - ] - error_was_printed = result.stderr in printed_err - assert error_was_printed == error_should_be_printed, ( - result.stderr, - printed_err, - ) + @pytest.mark.parametrize("use_async", [False, True], ids=["sync", "async"]) + @pytest.mark.parametrize("display", [True, False]) + @pytest.mark.parametrize("hide", [True, False, "out", "err", "stdout", "stderr"]) + @pytest.mark.asyncio + async def test_run_with_error( + self, + runner: Runner, + command_with_exception_and_stderr: tuple[ + str, type[Exception], str | re.Pattern + ], + hide: Hide, + display: bool, + use_async: bool, + ): + command, expected_exception, expected_err = command_with_exception_and_stderr - @pytest.mark.parametrize("warn", [True, False]) + assert isinstance(expected_exception, type) and issubclass( + expected_exception, Exception + ) + # Should raise an exception of this type. + with pytest.raises(expected_exception=expected_exception): + if use_async: + _ = await runner.run_async(command, display=display, hide=hide) + else: + _ = runner.run(command, display=display, hide=hide) + + @pytest.mark.parametrize("use_async", [False, True], ids=["sync", "async"]) @pytest.mark.parametrize("display", [True, False]) @pytest.mark.parametrize("hide", [True, False, "out", "err", "stdout", "stderr"]) - def test_run_with_error( + @pytest.mark.asyncio + async def test_run_with_error_warn( self, runner: Runner, - command_exception_err: tuple[str, type[Exception], str | re.Pattern], - hide: bool, - warn: bool, + command_with_exception_and_stderr: tuple[ + str, type[Exception], str | re.Pattern + ], + hide: Hide, display: bool, capsys: pytest.CaptureFixture, caplog: pytest.LogCaptureFixture, + use_async: bool, ): - command, expected_exception, expected_err = command_exception_err + command, expected_exception, expected_err = command_with_exception_and_stderr assert isinstance(expected_exception, type) and issubclass( expected_exception, Exception ) - - if not warn: - # Should raise an exception of this type. - with pytest.raises(expected_exception=expected_exception): - _ = runner.run(command, display=display, hide=hide, warn=warn) - # unreachable code here, so just pretend like it returns directly. - return - with caplog.at_level(logging.WARNING): - result = runner.run(command, display=display, hide=hide, warn=warn) + if use_async: + result = await runner.run_async( + command, display=display, hide=hide, warn=True + ) + else: + result = runner.run(command, display=display, hide=hide, warn=True) assert result.stdout == "" - if isinstance(expected_err, re.Pattern): - assert expected_err.search(result.stderr) - else: - assert result.stderr.strip() == expected_err + self._shared_run_checks( + runner=runner, + hide=hide, + display=display, + capsys=capsys, + caplog=caplog, + command=command, + expected_output="", + expected_err=expected_err, + result=result, + warn=True, + ) - if hide is True: - # Warnings not logged at all (because `warn=True` and `hide=True`). - assert caplog.records == [] - elif isinstance(expected_err, str): - assert len(caplog.records) == 1 - assert ( - caplog.records[0].message.strip() - == f"Command {command!r} returned non-zero exit code 1: {expected_err}" - ) - elif isinstance(expected_err, re.Pattern): - assert len(caplog.records) == 1 - message = caplog.records[0].message.strip() - # assert message.startswith( - # f"Command {command!r} returned non-zero exit code 1:" - # ) - assert expected_err.search(message) + def _shared_run_checks( + self, + runner: Runner, + hide: Hide, + display: bool, + capsys: pytest.CaptureFixture, + caplog: pytest.LogCaptureFixture, + command: str, + expected_output: str | re.Pattern, + expected_err: str | re.Pattern, + result: subprocess.CompletedProcess, + warn: bool, + ): + self._check_result(expected_output, expected_err, result) + self._check_printed_stdout_stderr( + runner=runner, + command=command, + display=display, + hide=hide, + result=result, + capsys=capsys, + ) + self._check_warning_logs( + hide=hide, + caplog=caplog, + command=command, + expected_err=expected_err, + warn=warn, + ) + def _check_printed_stdout_stderr( + self, + runner: Runner, + command: str, + display: bool, + hide: Hide, + result: subprocess.CompletedProcess, + capsys: pytest.CaptureFixture, + ): printed_output, printed_err = capsys.readouterr() assert isinstance(printed_output, str) assert isinstance(printed_err, str) @@ -184,56 +243,98 @@ def test_run_with_error( printed_err, ) - @pytest.mark.parametrize("display", [True, False]) - @pytest.mark.parametrize("hide", [True, False, "out", "err", "stdout", "stderr"]) - @pytest.mark.asyncio - async def test_run_async( + def _check_result( self, - runner: Runner, - command_output_err: tuple[str, str | re.Pattern, str | re.Pattern], - hide: bool, - display: bool, - capsys: pytest.CaptureFixture, + expected_output: str | re.Pattern, + expected_err: str | re.Pattern, + result: subprocess.CompletedProcess, ): - command, expected_output, expected_err = command_output_err - result = await runner.run_async(command, display=display, hide=hide) - if isinstance(expected_output, re.Pattern): - assert expected_output.match(result.stdout) + assert expected_output.search(result.stdout) else: assert result.stdout.strip() == expected_output if isinstance(expected_err, re.Pattern): - assert expected_err.match(result.stderr) + assert expected_err.search(result.stderr) else: assert result.stderr.strip() == expected_err - printed_output, printed_err = capsys.readouterr() - assert isinstance(printed_output, str) - assert isinstance(printed_err, str) - - assert (f"({runner.hostname}) $ {command}" in printed_output) == display - - if result.stdout: - stdout_should_be_printed = hide not in [ - True, - "out", - "stdout", - ] - stdout_was_printed = result.stdout in printed_output - assert stdout_was_printed == stdout_should_be_printed + def _check_warning_logs( + self, + command: str, + warn: bool, + hide: Hide, + expected_err: str | re.Pattern, + caplog: pytest.LogCaptureFixture, + ): + if not warn: + # No warnings should have been logged. + assert caplog.records == [] + return - if result.stderr: - error_should_be_printed = hide not in [ - True, - "err", - "stderr", - ] - error_was_printed = result.stderr in printed_err - assert error_was_printed == error_should_be_printed, ( - result.stderr, - printed_err, + if hide is True: + # Warnings not logged at all (because `warn=True` and `hide=True`). + assert caplog.records == [] + elif isinstance(expected_err, str): + assert len(caplog.records) == 1 + assert ( + caplog.records[0].message.strip() + == f"Command {command!r} returned non-zero exit code 1: {expected_err}" ) + elif isinstance(expected_err, re.Pattern): + assert len(caplog.records) == 1 + message = caplog.records[0].message.strip() + assert expected_err.search(message) + + @pytest.mark.parametrize("use_async", [False, True], ids=["sync", "async"]) + @pytest.mark.asyncio + async def test_get_output_calls_run( + self, + runner: Runner, + use_async: bool, + monkeypatch: pytest.MonkeyPatch, + ): + mock = Mock(spec=subprocess.CompletedProcess, stdout=Mock()) + command = "echo OK" + if use_async: + if ( + runner.run_async is type(runner).run_async + and runner.get_output_async is type(runner).get_output_async + ): + # It's a static method! Path the class instead of the "instance". + monkeypatch.setattr( + type(runner), + type(runner).run_async.__name__, + AsyncMock(spec=runner.run_async, spec_set=True, return_value=mock), + ) + else: + monkeypatch.setattr( + runner, + runner.run_async.__name__, + AsyncMock(spec=runner.run_async, spec_set=True, return_value=mock), + ) + output = await runner.get_output_async(command) + else: + if ( + runner.run is type(runner).run + and runner.get_output is type(runner).get_output + ): + # It's a static method! Path the class instead. + monkeypatch.setattr( + type(runner), + type(runner).run.__name__, + Mock(spec=runner.run, spec_set=True, return_value=mock), + ) + else: + # It's a regular method: + monkeypatch.setattr( + runner, + runner.run.__name__, + Mock(spec=runner.run, spec_set=True, return_value=mock), + ) + output = runner.get_output(command) + assert isinstance(output, Mock) + assert output is mock.stdout.strip() @pytest.mark.asyncio async def test_run_async_runs_in_parallel(self, runner: RemoteV2): diff --git a/tests/utils/test_compute_node.py b/tests/utils/test_compute_node.py index 0206ced6..28a0d988 100644 --- a/tests/utils/test_compute_node.py +++ b/tests/utils/test_compute_node.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import inspect import re import subprocess from logging import getLogger as get_logger @@ -197,22 +198,31 @@ async def runner( ("echo $SLURM_PROCID", "0", ""), ], ) - def command_output_err(self, request: pytest.FixtureRequest): + def command_and_expected_result(self, request: pytest.FixtureRequest): return request.param - def test_run_gets_executed_in_job_step(self, runner: ComputeNode): - job_step_a = int(runner.get_output("echo $SLURM_STEP_ID")) - job_step_b = int(runner.get_output("echo $SLURM_STEP_ID")) - assert job_step_a >= 0 - assert job_step_b == job_step_a + 1 - + @pytest.mark.parametrize("use_async", [False, True], ids=["sync", "async"]) @pytest.mark.asyncio - async def test_run_async_gets_executed_in_job_step(self, runner: ComputeNode): - job_step_a = int(await runner.get_output_async("echo $SLURM_STEP_ID")) - job_step_b = int(await runner.get_output_async("echo $SLURM_STEP_ID")) + async def test_run_gets_executed_in_job_step( + self, runner: ComputeNode, use_async: bool + ): + command = "echo $SLURM_STEP_ID" + output_a = ( + await runner.get_output_async(command) + if use_async + else runner.get_output(command) + ) + output_b = ( + await runner.get_output_async(command) + if use_async + else runner.get_output(command) + ) + job_step_a = int(output_a) + job_step_b = int(output_b) assert job_step_a >= 0 assert job_step_b == job_step_a + 1 + @pytest.mark.parametrize("use_async", [False, True], ids=["sync", "async"]) @pytest.mark.asyncio async def test_close( self, @@ -220,6 +230,7 @@ async def test_close( persist: bool, allocation_flags: list[str], job_name: str, + use_async: bool, ): if login_node_v2.hostname == "localhost": pytest.skip(reason="Test doesn't currently work on the mock slurm cluster.") @@ -233,7 +244,10 @@ async def test_close( login_node_v2, salloc_flags=allocation_flags, job_name=job_name ) - await compute_node.close_async() + if use_async: + await compute_node.close_async() + else: + compute_node.close() job_state = await login_node_v2.get_output_async( f"sacct --noheader --allocations --jobs {compute_node.job_id} --format=State%100", @@ -335,22 +349,13 @@ async def _mock_run_async(command: str, *args, input: str | None = None, **kwarg mock_run.reset_mock() mock_run_async.reset_mock() - with pytest.raises(JobNotRunningError): - compute_node.run("echo OK") - mock_run.assert_not_called() - mock_run_async.assert_not_called() - - with pytest.raises(JobNotRunningError): - compute_node.get_output("echo OK") - mock_run.assert_not_called() - mock_run_async.assert_not_called() + for method in ["run", "run_async", "get_output", "get_output_async"]: + with pytest.raises(JobNotRunningError): + output_or_coroutine = getattr(compute_node, method)("echo OK") + # if we get here it means it's a coroutine, since we'd otherwise have raised + # the error. + assert inspect.iscoroutine(output_or_coroutine) + await output_or_coroutine - with pytest.raises(JobNotRunningError): - await compute_node.run_async("echo OK") - mock_run.assert_not_called() - mock_run_async.assert_not_called() - - with pytest.raises(JobNotRunningError): - await compute_node.get_output_async("echo OK") - mock_run.assert_not_called() - mock_run_async.assert_not_called() + mock_run.assert_not_called() + mock_run_async.assert_not_called()