diff --git a/milatools/cli/code.py b/milatools/cli/code.py index 3e8beaeb..b3236f23 100644 --- a/milatools/cli/code.py +++ b/milatools/cli/code.py @@ -6,6 +6,7 @@ import shutil import sys from logging import getLogger as get_logger +from typing import Awaitable from typing_extensions import deprecated @@ -105,15 +106,6 @@ def add_mila_code_arguments(subparsers: argparse._SubParsersAction): code_parser.set_defaults(function=code) -async def _check_disk_quota_task(remote: RemoteV2) -> None: - try: - await check_disk_quota(remote) - except MilatoolsUserError: - raise - except Exception as exc: - logger.warning(f"Unable to check the disk-quota on the cluster: {exc}") - - async def code( path: str, command: str, @@ -148,11 +140,17 @@ async def code( home = await login_node.get_output_async("echo $HOME", display=False, hide=True) path = home if path == "." else f"{home}/{path}" - check_disk_quota_task = asyncio.create_task(_check_disk_quota_task(login_node)) - # Raise errors that are meant to be shown to the user (disk quota is reached). + try: + await check_disk_quota(login_node) + except MilatoolsUserError: + # Raise errors meant to be shown to the user (disk quota exceeded). + raise + except Exception as exc: + logger.warning(f"Unable to check the disk-quota on the cluster: {exc}") # NOTE: Perhaps we could eventually do this check dynamically, if the cluster is an # unknown cluster? + sync_vscode_extensions_task: Awaitable | None = None if no_internet_on_compute_nodes(cluster): # Sync the VsCode extensions from the local machine over to the target cluster. console.log( @@ -163,18 +161,14 @@ async def code( # todo: use the mila or the local machine as the reference for vscode # extensions? # todo: could also perhaps make this function asynchronous instead of using a - # multiprocessing process for it. - sync_vscode_extensions_task = asyncio.create_task( - asyncio.to_thread( - sync_vscode_extensions, - LocalV2(), - [login_node], - ), - name="sync_vscode_extensions", + # different thread to run it (seems to not exit as gracefully now). + sync_vscode_extensions_task = asyncio.to_thread( + sync_vscode_extensions, + LocalV2(), + [login_node], ) - compute_node_task: asyncio.Task[ComputeNode] - + compute_node_task: Awaitable[ComputeNode] if job or node: if job and node: logger.warning( @@ -183,9 +177,7 @@ async def code( ) job_id_or_node = job or node assert job_id_or_node is not None - compute_node_task = asyncio.create_task( - connect_to_running_job(job_id_or_node, login_node) - ) + compute_node_task = connect_to_running_job(job_id_or_node, login_node) else: if cluster in DRAC_CLUSTERS and not any("--account" in flag for flag in alloc): logger.warning( @@ -212,77 +204,40 @@ async def code( alloc = alloc + [f"--job-name={job_name}"] if persist: - compute_node_task = asyncio.create_task( - sbatch(login_node, sbatch_flags=alloc, job_name=job_name) + compute_node_task = sbatch( + login_node, sbatch_flags=alloc, job_name=job_name ) - # compute_node = await sbatch( - # login_node, sbatch_flags=alloc, job_name=job_name - # ) else: # NOTE: Here we actually need the job name to be known, so that we can # scancel jobs if the call is interrupted. - compute_node_task = asyncio.create_task( - salloc(login_node, salloc_flags=alloc, job_name=job_name) + compute_node_task = salloc( + login_node, salloc_flags=alloc, job_name=job_name ) - # compute_node = await salloc( - # login_node, salloc_flags=alloc, job_name=job_name - # ) - try: - _, _, compute_node = await asyncio.gather( - check_disk_quota_task, + + if sync_vscode_extensions_task: + # Sync the extensions while waiting for the compute node to be allocated. + # Also waits for the extensions to be synced before launching VsCode. + _, compute_node = await asyncio.gather( sync_vscode_extensions_task, compute_node_task, - return_exceptions=True, ) - except: - # If any of the tasks failed, we want to raise the exception. - for task in ( - check_disk_quota_task, - sync_vscode_extensions_task, - compute_node_task, - ): - if not task.done(): - task.cancel() - for task in ( - check_disk_quota_task, - sync_vscode_extensions_task, - compute_node_task, - ): - if exception := task.exception(): - raise exception - raise - - if isinstance(compute_node, BaseException): - raise compute_node + else: + # Wait for the compute node to be allocated. + compute_node = await compute_node_task try: - while True: - code_command_to_run = ( - code_command, - "--new-window", - "--wait", - "--remote", - f"ssh-remote+{compute_node.hostname}", - path, - ) - console.log( - f"(local) {shlex.join(code_command_to_run)}", style="bold green" - ) - await run_async(code_command_to_run) - # TODO: BUG: This now requires two Ctrl+C's instead of one! - console.print( - "The editor was closed. Reopen it with or terminate the " - "process with (maybe twice)." - ) - if currently_in_a_test(): - # TODO: This early exit kills the job when it is not persistent. - break - input() + await launch_vscode( + compute_node=compute_node, + code_command=code_command, + path=path, + ) except KeyboardInterrupt: logger.info("Keyboard interrupt.") if not persist: # Cancel the job explicitly. + # TODO: Check what happens when calling `del(compute_node)`, perhaps that is + # enough? await compute_node.close() console.print(f"Ended session on '{compute_node.hostname}'") return compute_node.job_id @@ -300,6 +255,25 @@ async def code( return compute_node +async def launch_vscode(compute_node: ComputeNode, code_command: str, path: str): + while True: + code_command_to_run = ( + code_command, + "--new-window", + "--wait", + "--remote", + f"ssh-remote+{compute_node.hostname}", + path, + ) + console.log(f"(local) {shlex.join(code_command_to_run)}", style="bold green") + await run_async(code_command_to_run) + console.print( + "The editor was closed. Reopen it with or terminate the " + "process with (maybe twice)." + ) + input() + + async def connect_to_running_job( jobid_or_nodename: int | str, login_node: RemoteV2, diff --git a/milatools/cli/commands.py b/milatools/cli/commands.py index 171fe40c..a2f58f5b 100644 --- a/milatools/cli/commands.py +++ b/milatools/cli/commands.py @@ -362,11 +362,15 @@ def mila(): args_dict = _convert_uppercase_keys_to_lowercase(args_dict) if inspect.iscoroutinefunction(function): + # TODO: Need to let the function handle KeyboardInterrupt by itself, here it + # seems like it never gets there. try: - return asyncio.run(function(**args_dict)) + # NOTE: Not using `asyncio.run` here, because it doesn't exit as cleanly + # when interrupted (prints out some ignored exceptions in + # SubprocessTransport.__del__). Not sure why or what the difference is + # between them. + return asyncio.get_event_loop().run_until_complete(function(**args_dict)) except KeyboardInterrupt: - from milatools.cli import console - console.log("Terminated by user.") exit() diff --git a/milatools/cli/common.py b/milatools/cli/common.py index f239a269..b74ac796 100644 --- a/milatools/cli/common.py +++ b/milatools/cli/common.py @@ -81,8 +81,6 @@ def _parse_lfs_quota_output( async def check_disk_quota(remote: RemoteV2) -> None: - cluster = remote.hostname - # NOTE: This is what the output of the command looks like on the Mila cluster: # # Disk quotas for usr normandf (uid 1471600598): @@ -101,7 +99,7 @@ async def check_disk_quota(remote: RemoteV2) -> None: home_disk_quota_output = await remote.get_output_async( "lfs quota -u $USER $HOME", display=False, hide=True ) - _check_disk_quota_common_part(home_disk_quota_output, cluster) + _check_disk_quota_common_part(home_disk_quota_output, remote.hostname) def check_disk_quota_v1(remote: RemoteV1 | RemoteV2) -> None: diff --git a/milatools/utils/compute_node.py b/milatools/utils/compute_node.py index fe3eaf90..1987fced 100644 --- a/milatools/utils/compute_node.py +++ b/milatools/utils/compute_node.py @@ -100,15 +100,21 @@ async def close(self): """Cancels the running job using `scancel`.""" logger.info(f"Stopping job {self.job_id}.") if self.salloc_subprocess is not None: - # NOTE: This will exit cleanly because we don't have nested terminals or - # job steps. if self.salloc_subprocess.stdin is not None: + # NOTE: This will exit cleanly because we don't have nested terminals or + # job steps. await self.salloc_subprocess.communicate("exit\n".encode()) # noqa: UP012 else: + # todo: Not sure what the best way to do this is.. self.salloc_subprocess.send_signal(signal=signal.SIGINT) + self.salloc_subprocess.send_signal(signal=signal.SIGKILL) + self.salloc_subprocess.kill() # The scancel below is done even though it's redundant, just to be safe. await self.login_node.run_async( - f"scancel {self.job_id}", display=True, hide=False + f"scancel {self.job_id}", + display=True, + hide=False, + warn=True, ) def __repr__(self) -> str: @@ -255,8 +261,13 @@ async def salloc( # eventually be done at the same time as something else (while waiting for the # job to start) using things like `asyncio.gather` and `asyncio.wait_for`. logger.debug(f"(local) $ {shlex.join(command)}") - console.log(f"({login_node.hostname}) $ {salloc_command}", style="green") + console.log( + f"({login_node.hostname}) $ {salloc_command}", style="green", markup=False + ) async with cancel_new_jobs_on_interrupt(login_node, job_name): + # BUG: IF stdin is not set (or set to PIPE?) then writing `salloc`, then the + # terminal is actually 'live' and affects the compute node! For instance if + # you do `mila code .` and then write `salloc`, it spawns a second job!! salloc_subprocess = await asyncio.subprocess.create_subprocess_exec( *command, shell=False, @@ -281,13 +292,15 @@ async def salloc( try: console.log(f"Waiting for job {job_id} to start.", style="green") - _node, state = await wait_while_job_is_pending(login_node, job_id) + await wait_while_job_is_pending(login_node, job_id) except (KeyboardInterrupt, asyncio.CancelledError): if salloc_subprocess is not None: logger.debug("Killing the salloc subprocess following a KeyboardInterrupt.") - salloc_subprocess.send_signal(signal.SIGINT) - salloc_subprocess.terminate() - login_node.run(f"scancel {job_id}", display=True, hide=False) + await salloc_subprocess.communicate("exit\n".encode()) # noqa: UP012 + # salloc_subprocess.send_signal(signal.SIGINT) + # salloc_subprocess.send_signal(signal.SIGKILL) + # salloc_subprocess.terminate() + await login_node.run_async(f"scancel {job_id}", display=True, hide=False) raise # todo: Are there are states between `PENDING` and `RUNNING`? @@ -385,7 +398,7 @@ async def _wait_while_job_is_in_state(login_node: RemoteV2, job_id: int, state: f"Job {job_id} was allocated node(s) {nodes!r} and is in state " f"{current_state!r}." ) - return nodes, current_state + return current_state waiting_until = f"Waiting {wait_time_seconds} seconds until job {job_id} " condition: str | None = None @@ -418,11 +431,9 @@ async def _wait_while_job_is_in_state(login_node: RemoteV2, job_id: int, state: attempt += 1 -async def wait_while_job_is_pending( - login_node: RemoteV2, job_id: int -) -> tuple[str, str]: +async def wait_while_job_is_pending(login_node: RemoteV2, job_id: int) -> str: """Waits until a job show up in `sacct` then waits until its state is not PENDING. - Returns the `Node` and `State` from `sacct` after the job is no longer pending. + Returns the `State` from `sacct` after the job is no longer pending. """ - return await _wait_while_job_is_in_state(login_node, job_id, state="PENDING") + await _wait_while_job_is_in_state(login_node, job_id, state="PENDING") diff --git a/tests/conftest.py b/tests/conftest.py index 5d6f4f9c..58e6f3ff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,6 +16,7 @@ import questionary from fabric.connection import Connection +from milatools.cli import console from milatools.cli.init_command import setup_ssh_config from milatools.utils.compute_node import get_queued_milatools_job_ids from milatools.utils.remote_v1 import RemoteV1 @@ -212,10 +213,19 @@ def test_something(remote: Remote): @pytest.fixture(scope="session") def job_name(request: pytest.FixtureRequest) -> str | None: + # TODO: Make the job name different based on the runner that is launching tests, so + # that the `launches_job_fixture` doesn't scancel the test jobs launched from + # another runner (e.g. me on my dev machine or laptop) on a cluster return getattr(request, "param", JOB_NAME) -@pytest_asyncio.fixture(scope="session") +# TODO: This `launches_jobs` fixture has issues. +# - It can cancel jobs from the self-hosted runner in addition to its own +# - Because it is session-scoped, if one test uses it, then all tests have their jobs +# cancelled at the end of their run. Therefore, might as well make it auto-used, no? + + +@pytest_asyncio.fixture(scope="session", autouse=True) async def launches_job_fixture(login_node_v2: RemoteV2, job_name: str): jobs_before = await get_queued_milatools_job_ids(login_node_v2, job_name=job_name) if jobs_before: @@ -231,6 +241,7 @@ async def launches_job_fixture(login_node_v2: RemoteV2, job_name: str): new_jobs = jobs_after - jobs_before if new_jobs: + console.log(f"Cancelling jobs {new_jobs} after running tests...") login_node_v2.run( "scancel " + " ".join(str(job_id) for job_id in new_jobs), display=True ) @@ -238,6 +249,28 @@ async def launches_job_fixture(login_node_v2: RemoteV2, job_name: str): logger.debug("Test apparently didn't launch any new jobs.") +@pytest_asyncio.fixture(scope="session") +async def register_test_job(login_node_v2: RemoteV2): + """Fixture that gives a function to register a job that should be cancelled after + tests are done running. + + This is more targeted than `launches_jobs_fixture` that does an scancel based on the + job name. + """ + jobs: list[int] = [] + register_job = jobs.append + try: + yield register_job + finally: + if jobs: + console.log(f"Cancelling jobs {jobs} after running tests...") + login_node_v2.run( + "scancel " + " ".join(str(job_id) for job_id in jobs), display=True + ) + else: + logger.debug("Test apparently didn't launch any new jobs.") + + launches_jobs = pytest.mark.usefixtures(launches_job_fixture.__name__) diff --git a/tests/integration/test_code/test_code_mila__sbatch_.txt b/tests/integration/test_code/test_code_mila__sbatch_.txt index ec932601..0c394c70 100644 --- a/tests/integration/test_code/test_code_mila__sbatch_.txt +++ b/tests/integration/test_code/test_code_mila__sbatch_.txt @@ -1,6 +1,7 @@ Checking disk quota on $HOME... Disk usage: X / LIMIT GiB and X / LIMIT files -(mila) $ cd $SCRATCH && sbatch --parsable --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code --wrap 'srun sleep 7d' +(mila) $ cd $SCRATCH && sbatch --parsable --wckey=milatools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=mila-code --wrap 'srun sleep +7d' JOB_ID (local) echo --new-window --wait --remote ssh-remote+COMPUTE_NODE $HOME/bob diff --git a/tests/utils/test_compute_node.py b/tests/utils/test_compute_node.py index e41f867b..73b6ec04 100644 --- a/tests/utils/test_compute_node.py +++ b/tests/utils/test_compute_node.py @@ -4,6 +4,7 @@ import datetime import re from logging import getLogger as get_logger +from typing import Callable import pytest import pytest_asyncio @@ -22,13 +23,16 @@ from ..conftest import launches_jobs, unsupported_on_windows logger = get_logger(__name__) -pytestmark = [launches_jobs, unsupported_on_windows] +pytestmark = [unsupported_on_windows] @pytest.mark.slow @pytest.mark.asyncio async def test_salloc( - login_node_v2: RemoteV2, allocation_flags: list[str], job_name: str + login_node_v2: RemoteV2, + allocation_flags: list[str], + job_name: str, + register_test_job: Callable[[int], None], ): if login_node_v2.hostname == "localhost": # todo: Check why this (and other tests in this file) don't work on the mock @@ -38,6 +42,8 @@ async def test_salloc( pytest.skip(reason="Test doesn't currently work on the mock slurm cluster.") compute_node = await salloc(login_node_v2, allocation_flags, job_name=job_name) + register_test_job(compute_node.job_id) + assert isinstance(compute_node, ComputeNode) assert compute_node.hostname != login_node_v2.hostname @@ -60,13 +66,17 @@ async def test_salloc( @pytest.mark.slow @pytest.mark.asyncio async def test_sbatch( - login_node_v2: RemoteV2, allocation_flags: list[str], job_name: str + login_node_v2: RemoteV2, + allocation_flags: list[str], + job_name: str, + register_test_job: Callable[[int], None], ): if login_node_v2.hostname == "localhost": pytest.skip(reason="Test doesn't currently work on the mock slurm cluster.") compute_node = await sbatch(login_node_v2, allocation_flags, job_name=job_name) assert isinstance(compute_node, ComputeNode) + register_test_job(compute_node.job_id) assert compute_node.hostname != login_node_v2.hostname job_id = compute_node.get_output("echo $SLURM_JOB_ID") @@ -86,6 +96,7 @@ def persist(request: pytest.FixtureRequest): return request.param +@launches_jobs @pytest.mark.slow @pytest.mark.asyncio async def test_interrupt_allocation( @@ -154,6 +165,7 @@ async def get_new_job_ids() -> set[int]: assert jobs_after <= _jobs_before +@launches_jobs @pytest.mark.slow @launches_jobs class TestComputeNode(RunnerTests): @@ -246,3 +258,38 @@ async def test_close( else: # interactive jobs are exited cleanly by just exiting in the terminal. assert job_state == "COMPLETED" + + +@pytest.mark.asyncio +async def test_del_compute_node( + login_node_v2: RemoteV2, + persist: bool, + allocation_flags: list[str], + job_name: str, + max_job_duration: datetime.timedelta, +): + if persist: + compute_node = await sbatch( + login_node_v2, sbatch_flags=allocation_flags, job_name=job_name + ) + else: + compute_node = await salloc( + login_node_v2, salloc_flags=allocation_flags, job_name=job_name + ) + + job_id = compute_node.job_id + + del compute_node + + # Make sure that if we sleep 10 seconds, the job should still be alive. + assert max_job_duration > datetime.timedelta(seconds=10) + await asyncio.sleep(10) + + job_state = await login_node_v2.get_output_async( + f"sacct --noheader --allocations --jobs {job_id} --format=State%100", + display=True, + hide=False, + ) + assert ( + False + ), f"Job {job_id} is in state {job_state} after the ComputeNode was deleted."