Skip to content

Commit

Permalink
Raise RuntimeError if using closed ComputeNode
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Apr 23, 2024
1 parent 84987a8 commit 533fbb2
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 33 deletions.
4 changes: 1 addition & 3 deletions milatools/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,7 @@ def make_process(
) -> multiprocessing.Process:
# Tiny wrapper around the `multiprocessing.Process` init to detect if the args and
# kwargs don't match the target signature using typing instead of at runtime.
return multiprocessing.Process(
target=target, daemon=False, args=args, kwargs=kwargs
)
return multiprocessing.Process(target=target, daemon=True, args=args, kwargs=kwargs)


def currently_in_a_test() -> bool:
Expand Down
110 changes: 87 additions & 23 deletions milatools/utils/compute_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import inspect
import re
import shlex
import signal
import subprocess
import sys
import warnings
Expand All @@ -19,6 +18,20 @@
from milatools.utils.remote_v2 import RemoteV2, logger, ssh_command
from milatools.utils.runner import Runner

JOB_NOT_RUNNING_MESSAGE = (
"ComputeNode for job {job_id} has been closed and is unusable, since the job has "
"already ended!"
)


class JobNotRunningError(RuntimeError):
"""Raised when trying to call `run` or `run_async` on a ComputeNode whose job has
already been closed."""

def __init__(self, job_id: int, *args: object) -> None:
super().__init__(JOB_NOT_RUNNING_MESSAGE.format(job_id=job_id), *args)
self.job_id = job_id


@dataclasses.dataclass
class ComputeNode(Runner):
Expand All @@ -34,6 +47,7 @@ class ComputeNode(Runner):
login_node: RemoteV2
job_id: int
salloc_subprocess: asyncio.subprocess.Process | None = None
_closed: bool = dataclasses.field(default=False, init=False, repr=False)

def __post_init__(self):
# The hostname will be of the compute node, not the login node.
Expand All @@ -51,6 +65,8 @@ def __post_init__(self):
def run(
self, command: str, display: bool = True, warn: bool = False, hide: Hide = False
):
if self._closed:
raise JobNotRunningError(self.job_id)
if display:
# Show the compute node hostname instead of the login node.
console.log(f"({self.hostname}) $ {command}", style="green")
Expand All @@ -77,6 +93,8 @@ async def run_async(
warn: bool = False,
hide: Hide = False,
) -> subprocess.CompletedProcess[str]:
if self._closed:
raise JobNotRunningError(self.job_id)
if display:
# Show the compute node hostname instead of the login node.
console.log(f"({self.hostname}) $ {command}", style="green")
Expand All @@ -96,32 +114,68 @@ async def run_async(
hide=hide,
)

def __del__(self):
if not self._closed and self.salloc_subprocess:
try:
self.salloc_subprocess.terminate()
except ProcessLookupError:
pass # salloc subprocess has already been terminated.
else:
# NOTE: We only get here if the job is being deleted without having been
# closed.
logger.warning(
f"Compute node is being deleted without having been closed!\n"
f"Terminating job {self.job_id} on {self.hostname}."
)

def __enter__(self):
return self

def __exit__(self, *excinfo):
self.close()

async def __aenter__(self):
return self

async def __aexit__(self, *excinfo):
await self.close()
await self.close_async()

async def close(self):
def close(self):
"""Ends the job.
The ComputeNode becomes unusable.
"""
if self._closed:
logger.warning(f"Job {self.job_id} has already been closed.")
return
logger.info(f"Stopping job {self.job_id}.")
if self.salloc_subprocess:
self.salloc_subprocess.terminate()
else:
self.login_node.run(f"scancel {self.job_id}")
self._closed = True

async def close_async(self):
"""Cancels the running job using `scancel`."""
if self._closed:
logger.warning(f"Job {self.job_id} has already been closed.")
return
logger.info(f"Stopping job {self.job_id}.")
if self.salloc_subprocess is not None:
if self.salloc_subprocess.stdin is not None:
# NOTE: This will exit cleanly because we don't have nested terminals or
# job steps.
await self.salloc_subprocess.communicate("exit\n".encode()) # noqa: UP012
else:
# todo: Not sure what the best way to do this is..
self.salloc_subprocess.send_signal(signal=signal.SIGINT)
self.salloc_subprocess.send_signal(signal=signal.SIGKILL)
self.salloc_subprocess.kill()
# The scancel below is done even though it's redundant, just to be safe.
await self.login_node.run_async(
f"scancel {self.job_id}",
display=True,
hide=False,
warn=True,
)
assert self.salloc_subprocess.stdin is not None
# NOTE: This will exit cleanly because we don't have nested terminals or
# job steps.
logger.debug("Exiting the salloc subprocess gracefully.")
await self.salloc_subprocess.communicate("exit\n".encode()) # noqa: UP012
else:
# The scancel below is done even though it's redundant, just to be safe.
await self.login_node.run_async(
f"scancel {self.job_id}",
display=True,
hide=False,
warn=True,
)
self._closed = True

def __repr__(self) -> str:
params = ", ".join(
Expand Down Expand Up @@ -170,6 +224,14 @@ async def cancel_new_jobs_on_interrupt(login_node: RemoteV2, job_name: str):
allocation on the cluster.
"""
jobs_before = await get_queued_milatools_job_ids(login_node, job_name=job_name)
if jobs_before:
logger.info(
f"Existing jobs on {login_node.hostname} with name {job_name}: {jobs_before}"
)
else:
logger.debug(
f"There are currently no jobs with name {job_name} on {login_node.hostname}."
)
try:
yield
except (KeyboardInterrupt, asyncio.CancelledError):
Expand All @@ -194,10 +256,12 @@ async def cancel_new_jobs_on_interrupt(login_node: RemoteV2, job_name: str):
"Cancelling all of them to be safe...",
style="yellow",
)
login_node.run(
"scancel " + " ".join(str(job_id) for job_id in new_jobs),
display=True,
hide=False,
await asyncio.shield(
login_node.run_async(
"scancel " + " ".join(str(job_id) for job_id in new_jobs),
display=True,
hide=False,
)
)
else:
warnings.warn(
Expand Down
3 changes: 1 addition & 2 deletions milatools/utils/local_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,10 @@ async def run_async(
*program_and_args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
stdin=asyncio.subprocess.PIPE if input else None,
stdin=asyncio.subprocess.PIPE if input else asyncio.subprocess.DEVNULL,
start_new_session=False,
)
if input:
# TODO: Getting a bug when sending 'echo $SCRATCH'!
logger.debug(f"Sending {input=!r} to the subprocess' stdin.")
stdout, stderr = await proc.communicate(input.encode() if input else None)

Expand Down
121 changes: 116 additions & 5 deletions tests/utils/test_compute_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,27 @@
import asyncio
import datetime
import re
import subprocess
from logging import getLogger as get_logger
from typing import Callable
from unittest.mock import ANY, AsyncMock, Mock

import pytest
import pytest_asyncio

from milatools.cli.utils import td_format
from milatools.utils.compute_node import (
ComputeNode,
JobNotRunningError,
get_queued_milatools_job_ids,
get_time_to_job,
salloc,
sbatch,
)
from milatools.utils.remote_v2 import RemoteV2
from tests.utils.runner_tests import RunnerTests

from ..conftest import launches_jobs, unsupported_on_windows
from .runner_tests import RunnerTests

logger = get_logger(__name__)
pytestmark = [unsupported_on_windows]
Expand Down Expand Up @@ -60,7 +63,7 @@ async def test_salloc(
# using `srun` with the job id on the login node to run our jobs.
assert all_slurm_env_vars["SLURM_JOB_ID"] == str(compute_node.job_id)
assert len(all_slurm_env_vars) > 1
await compute_node.close()
await compute_node.close_async()


@pytest.mark.slow
Expand Down Expand Up @@ -88,7 +91,7 @@ async def test_sbatch(
}
assert all_slurm_env_vars["SLURM_JOB_ID"] == str(compute_node.job_id)
assert len(all_slurm_env_vars) > 1
await compute_node.close()
await compute_node.close_async()


@pytest.fixture(scope="session", params=[True, False], ids=["sbatch", "salloc"])
Expand Down Expand Up @@ -199,7 +202,7 @@ async def runner(
login_node_v2, salloc_flags=allocation_flags, job_name="mila-code"
)
yield runner
await runner.close()
await runner.close_async()

@pytest.fixture(
scope="class",
Expand Down Expand Up @@ -244,7 +247,7 @@ async def test_close(
login_node_v2, salloc_flags=allocation_flags, job_name=job_name
)

await compute_node.close()
await compute_node.close_async()

job_state = await login_node_v2.get_output_async(
f"sacct --noheader --allocations --jobs {compute_node.job_id} --format=State%100",
Expand All @@ -257,3 +260,111 @@ async def test_close(
else:
# interactive jobs are exited cleanly by just exiting in the terminal.
assert job_state == "COMPLETED"


@launches_jobs
@pytest.mark.slow
@pytest.mark.asyncio
async def test_del_computenode(
login_node_v2: RemoteV2, persist: bool, allocation_flags: list[str], job_name: str
):
"""Test what happens when we delete a ComputeNode instance (persistent vs non-
persistent).
TODO: Perhaps we could use mocks here instead of allocating a job just to end it after.
"""
if persist:
compute_node = await sbatch(
login_node_v2, sbatch_flags=allocation_flags, job_name=job_name
)
else:
compute_node = await salloc(
login_node_v2, salloc_flags=allocation_flags, job_name=job_name
)

job_id = compute_node.job_id
del compute_node
# if deleting does anything, wait for its effect to propagate to sacct
await asyncio.sleep(5)
state_after = await login_node_v2.get_output_async(
f"sacct --jobs {job_id} --allocations --noheader --format=State",
)
try:
if persist:
assert state_after == "RUNNING"
else:
assert state_after == "COMPLETED"
finally:
await login_node_v2.run_async(f"scancel {job_id}")


@pytest.mark.asyncio
@pytest.mark.parametrize("close_async", [False, True], ids=["sync", "async"])
async def test_using_closed_compute_node_raises_error(
login_node_v2: RemoteV2, close_async: bool
):
fake_job_id = 1234

def _mock_run(command: str, *args, input: str | None = None, **kwargs):
if input == "echo $SLURMD_NODENAME\n":
return subprocess.CompletedProcess(command, 0, "bobobo", "")
if command == f"scancel {fake_job_id}":
return subprocess.CompletedProcess(command, 0, "", "")
# Unexpected command.
assert False, (command, input)

async def _mock_run_async(command: str, *args, input: str | None = None, **kwargs):
if command == f"scancel {fake_job_id}":
return subprocess.CompletedProcess(command, 0, "", "")
# Unexpected command.
assert False, (command, input)

mock_run = Mock(spec=login_node_v2.run, side_effect=_mock_run)
mock_run_async = AsyncMock(
spec=login_node_v2.run_async, side_effect=_mock_run_async
)
mock_login_node = Mock(
wraps=login_node_v2,
hostname=login_node_v2.hostname,
ssh_config_path=login_node_v2.ssh_config_path,
)
mock_login_node.configure_mock(
run=mock_run,
run_async=mock_run_async,
)
compute_node = ComputeNode(mock_login_node, job_id=1234)
mock_run.assert_called()
mock_run_async.assert_not_called()
mock_run.reset_mock()

if close_async:
await compute_node.close_async()
_other_kwargs = dict(display=ANY, hide=ANY, warn=ANY)
mock_run_async.assert_called_once_with(
command=f"scancel {fake_job_id}", **_other_kwargs
)
else:
compute_node.close()
mock_run.assert_called_once_with(command=f"scancel {fake_job_id}")
mock_run.reset_mock()
mock_run_async.reset_mock()

with pytest.raises(JobNotRunningError):
compute_node.run("echo OK")
mock_run.assert_not_called()
mock_run_async.assert_not_called()

with pytest.raises(JobNotRunningError):
compute_node.get_output("echo OK")
mock_run.assert_not_called()
mock_run_async.assert_not_called()

with pytest.raises(JobNotRunningError):
await compute_node.run_async("echo OK")
mock_run.assert_not_called()
mock_run_async.assert_not_called()

with pytest.raises(JobNotRunningError):
await compute_node.get_output_async("echo OK")
mock_run.assert_not_called()
mock_run_async.assert_not_called()

0 comments on commit 533fbb2

Please sign in to comment.