diff --git a/milatools/cli/code_command.py b/milatools/cli/code_command.py index bf24666c..bf3f66dc 100644 --- a/milatools/cli/code_command.py +++ b/milatools/cli/code_command.py @@ -1,15 +1,19 @@ from __future__ import annotations import argparse +import shlex import shutil import sys from logging import getLogger as get_logger +from typing_extensions import deprecated + from milatools.cli import console from milatools.cli.common import ( check_disk_quota, find_allocation, ) +from milatools.cli.init_command import DRAC_CLUSTERS from milatools.cli.utils import ( CLUSTERS, Cluster, @@ -22,9 +26,18 @@ no_internet_on_compute_nodes, running_inside_WSL, ) +from milatools.utils.compute_node import ( + ComputeNode, + salloc, + sbatch, + wait_while_job_is_pending, +) from milatools.utils.local_v1 import LocalV1 -from milatools.utils.local_v2 import LocalV2 +from milatools.utils.local_v2 import LocalV2, run_async from milatools.utils.remote_v1 import RemoteV1 +from milatools.utils.remote_v2 import ( + RemoteV2, +) from milatools.utils.vscode_utils import ( get_code_command, sync_vscode_extensions, @@ -84,10 +97,204 @@ def add_mila_code_arguments(subparsers: argparse._SubParsersAction): action="store_true", help="Whether the server should persist or not", ) - code_parser.set_defaults(function=code) + if sys.platform == "win32": + code_parser.set_defaults(function=code_v1) + else: + code_parser.set_defaults(function=code) -def code( +async def code( + path: str, + command: str, + persist: bool, + job: int | None, + node: str | None, + alloc: list[str], + cluster: Cluster = "mila", +) -> ComputeNode | int: + """Open a remote VSCode session on a compute node. + + Arguments: + path: Path to open on the remote machine + command: Command to use to start vscode (defaults to "code" or the value of \ + $MILATOOLS_CODE_COMMAND) + persist: Whether the server should persist or not after exiting the terminal. + job: ID of the job to connect to + node: Name of the node to connect to + alloc: Extra options to pass to slurm + """ + # Check that the `code` command is in the $PATH so that we can use just `code` as + # the command. + code_command = command + if not shutil.which(code_command): + raise CommandNotFoundError(code_command) + + # Connect to the cluster's login node. + login_node = await RemoteV2.connect(cluster) + + if not path.startswith("/"): + # Get $HOME because we have to give the full path to code + home = login_node.get_output("echo $HOME", display=False, hide=True) + path = home if path == "." else f"{home}/{path}" + + try: + check_disk_quota(login_node) + except MilatoolsUserError: + # Raise errors that are meant to be shown to the user (disk quota is reached). + raise + except Exception as exc: + logger.warning(f"Unable to check the disk-quota on the cluster: {exc}") + + # NOTE: Perhaps we could eventually do this check dynamically, if the cluster is an + # unknown cluster? + if no_internet_on_compute_nodes(cluster): + # Sync the VsCode extensions from the local machine over to the target cluster. + console.log( + f"Installing VSCode extensions that are on the local machine on " + f"{cluster} in the background.", + style="cyan", + ) + # todo: use the mila or the local machine as the reference for vscode + # extensions? + # todo: could also perhaps make this function asynchronous instead of using a + # multiprocessing process for it. + copy_vscode_extensions_process = make_process( + sync_vscode_extensions, + LocalV2(), + [login_node], + ) + copy_vscode_extensions_process.start() + # todo: could potentially do this at the same time as the blocks above and just wait + # for the result here, instead of running each block in sequence. + if currently_in_a_test(): + copy_vscode_extensions_process.join() + + if job or node: + if job and node: + logger.warning( + "Both job ID and node name were specified. Ignoring the node name and " + "only using the job id." + ) + job_id_or_node = job or node + assert job_id_or_node is not None + compute_node = await connect_to_running_job(job_id_or_node, login_node) + else: + if cluster in DRAC_CLUSTERS and not any("--account" in flag for flag in alloc): + logger.warning( + "Warning: When using the DRAC clusters, you usually need to " + "specify the account to use when submitting a job. You can specify " + "this in the job resources with `--alloc`, like so: " + "`--alloc --account=`, for example:\n" + f"mila code some_path --cluster {cluster} --alloc " + f"--account=your-account-here" + ) + # Set the job name to `mila-code`. This should not be changed by the user + # ideally, so we can collect some simple stats about the use of `milatools` on + # the clusters. + if any(flag == "-J" or "-J=" in flag or "--job-name" in flag for flag in alloc): + # todo: Get the job name from the flags instead? + raise MilatoolsUserError( + "The job name flag (--job-name or -J) should be left unset for now " + "because we use the job name to gage how many people use `mila code` " + "on the various clusters. We also make use of the job name when the " + "call to `salloc` is interrupted before we have a chance to know the " + "job id." + ) + job_name = "mila-code" + alloc = alloc + [f"--job-name={job_name}"] + + if persist: + compute_node = await sbatch( + login_node, sbatch_flags=alloc, job_name=job_name + ) + else: + # NOTE: Here we actually need the job name to be known, so that we can + # scancel jobs if the call is interrupted. + compute_node = await salloc( + login_node, salloc_flags=alloc, job_name=job_name + ) + + try: + while True: + code_command_to_run = ( + code_command, + "--new-window", + "--wait", + "--remote", + f"ssh-remote+{compute_node.hostname}", + path, + ) + console.log( + f"(local) {shlex.join(code_command_to_run)}", style="bold green" + ) + await run_async(code_command_to_run) + # TODO: BUG: This now requires two Ctrl+C's instead of one! + console.print( + "The editor was closed. Reopen it with or terminate the " + "process with (maybe twice)." + ) + if currently_in_a_test(): + # TODO: This early exit kills the job when it is not persistent. + break + input() + except KeyboardInterrupt: + logger.info("Keyboard interrupt.") + + if not persist: + # Cancel the job explicitly. + await compute_node.close() + console.print(f"Ended session on '{compute_node.hostname}'") + return compute_node.job_id + + console.print("This allocation is persistent and is still active.") + console.print("To reconnect to this job, run the following:") + console.print( + f" mila code {path} " + + (f"--cluster {cluster} " if cluster != "mila" else "") + + f"--job {compute_node.job_id}", + style="bold", + ) + console.print("To kill this allocation:") + console.print(f" ssh {cluster} scancel {compute_node.job_id}", style="bold") + return compute_node + + +async def connect_to_running_job( + jobid_or_nodename: int | str, + login_node: RemoteV2, +) -> ComputeNode: + if isinstance(jobid_or_nodename, int): + job_id = jobid_or_nodename + await wait_while_job_is_pending(login_node, job_id=job_id) + return ComputeNode(login_node, job_id=job_id) + + node_name = jobid_or_nodename + # we have to find the job id to use on the given node. + jobs_on_node = await login_node.get_output_async( + f"squeue --me --node {node_name} --noheader --format=%A" + ) + jobs_on_node = [int(job_id.strip()) for job_id in jobs_on_node.splitlines()] + if len(jobs_on_node) == 0: + raise MilatoolsUserError( + f"You don't appear to have any jobs currently running on node {node_name}. " + "Please check again or specify the job id to connect to." + ) + if len(jobs_on_node) > 1: + raise MilatoolsUserError( + f"You have more than one job running on node {node_name}: {jobs_on_node}.\n" + "please use the `--job` flag to specify which job to connect to." + ) + assert len(jobs_on_node) == 1 + return ComputeNode(login_node=login_node, job_id=jobs_on_node[0]) + + +@deprecated( + "Support for the `mila code` command is now deprecated on Windows machines, as it " + "does not support ssh keys with passphrases or clusters where 2FA is enabled. " + "Please consider switching to the Windows Subsystem for Linux (WSL) to run " + "`mila code`." +) +def code_v1( path: str, command: str, persist: bool, diff --git a/tests/integration/test_code_command.py b/tests/integration/test_code_command.py index 572b4238..751b6893 100644 --- a/tests/integration/test_code_command.py +++ b/tests/integration/test_code_command.py @@ -1,26 +1,32 @@ from __future__ import annotations +import asyncio +import contextlib +import datetime import logging import re import subprocess -import time +import sys from datetime import timedelta from logging import getLogger as get_logger import pytest +from pytest_regressions.file_regression import FileRegressionFixture from milatools.cli.code_command import code from milatools.cli.common import check_disk_quota from milatools.cli.utils import get_hostname_to_use_for_compute_node +from milatools.utils.compute_node import ComputeNode from milatools.utils.remote_v1 import RemoteV1 from milatools.utils.remote_v2 import RemoteV2 from ..cli.common import skip_param_if_on_github_ci +from ..conftest import job_name, launches_jobs from .conftest import ( skip_if_not_already_logged_in, skip_param_if_not_already_logged_in, ) -from .test_slurm_remote import PARAMIKO_SSH_BANNER_BUG, get_recent_jobs_info_dicts +from .test_slurm_remote import get_recent_jobs_info_dicts logger = get_logger(__name__) @@ -62,67 +68,108 @@ def test_check_disk_quota( # IF the quota is met, then a `MilatoolsUserError` is logged. +async def get_job_info( + job_id: int, + login_node: RemoteV2, + fields: tuple[str, ...] = ("JobID", "JobName", "Node", "WorkDir", "State"), +) -> dict: + return dict( + zip( + fields, + ( + await login_node.get_output_async( + f"sacct --noheader --allocations --user=$USER --jobs {job_id} " + "--format=" + ",".join(f"{field}%40" for field in fields), + display=False, + hide=True, + ) + ) + .strip() + .split(), + ) + ) + + +@launches_jobs @pytest.mark.slow -@PARAMIKO_SSH_BANNER_BUG -@pytest.mark.parametrize("persist", [True, False]) -def test_code( - login_node: RemoteV1 | RemoteV2, +@pytest.mark.asyncio +@pytest.mark.parametrize("persist", [True, False], ids=["sbatch", "salloc"]) +@pytest.mark.parametrize( + job_name.__name__, + [ + None, + ], + ids=[""], + indirect=True, +) +async def test_code( + login_node: RemoteV2, persist: bool, capsys: pytest.CaptureFixture, allocation_flags: list[str], + file_regression: FileRegressionFixture, + slurm_account_on_cluster: str, ): if login_node.hostname == "localhost": pytest.skip( "TODO: This test doesn't yet work with the slurm cluster spun up in the GitHub CI." ) - home = login_node.run("echo $HOME", display=False, hide=True).stdout.strip() - scratch = login_node.get_output("echo $SCRATCH") - relative_path = "bob" - code( - path=relative_path, - command="echo", # replace the usual `code` with `echo` for testing. - persist=persist, - job=None, - node=None, - alloc=allocation_flags, - cluster=login_node.hostname, # type: ignore + + home = await login_node.get_output_async("echo $HOME") + scratch = await login_node.get_output_async("echo $SCRATCH") + + start = datetime.datetime.now() - timedelta(minutes=5) + jobs_before = get_recent_jobs_info_dicts( + login_node, since=datetime.datetime.now() - start ) + jobs_before = { + int(job_info["JobID"]): job_info + for job_info in jobs_before + if job_info["JobName"] == "mila-code" + } + + relative_path = "bob" + + with contextlib.redirect_stderr(sys.stdout): + logger.info(f"{'sbatch' if persist else 'salloc'} flags: {allocation_flags}") + compute_node_or_job_id = await code( + path=relative_path, + command="echo", # replace the usual `code` with `echo` for testing. + # idea: Could probably also return the process ID of the `code` editor? + persist=persist, + job=None, + node=None, + alloc=allocation_flags, + cluster=login_node.hostname, # type: ignore + ) # Get the output that was printed while running that command. - # We expect our fake vscode command (with 'code' replaced with 'echo') to have been - # executed. - captured_output: str = capsys.readouterr().out + captured_output = capsys.readouterr().out - # Get the job id from the output just so we can more easily check the command output - # with sacct below. + node_hostname: str | None = None if persist: - m = re.search(r"Submitted batch job ([0-9]+)", captured_output) - assert m - job_id = int(m.groups()[0]) + assert isinstance(compute_node_or_job_id, ComputeNode) + compute_node = compute_node_or_job_id + assert compute_node is not None + job_id = compute_node.job_id + node_hostname = compute_node.hostname + else: - m = re.search(r"salloc: Granted job allocation ([0-9]+)", captured_output) - assert m - job_id = int(m.groups()[0]) + assert isinstance(compute_node_or_job_id, int) + job_id = compute_node_or_job_id + + await asyncio.sleep(5) # give a chance to sacct to update. - time.sleep(5) # give a chance to sacct to update. - recent_jobs = get_recent_jobs_info_dicts( - since=timedelta(minutes=5), + job_info = await get_job_info( + job_id=job_id, login_node=login_node, fields=("JobID", "JobName", "Node", "WorkDir", "State"), ) - job_id_to_job_info = {int(job_info["JobID"]): job_info for job_info in recent_jobs} - assert job_id in job_id_to_job_info, (job_id, job_id_to_job_info) - job_info = job_id_to_job_info[job_id] - - node = job_info["Node"] - node_hostname = get_hostname_to_use_for_compute_node( - node, cluster=login_node.hostname - ) - expected_line = f"(local) $ /usr/bin/echo -nw --remote ssh-remote+{node_hostname} {home}/{relative_path}" - assert any((expected_line in line) for line in captured_output.splitlines()), ( - captured_output, - expected_line, - ) + if node_hostname is None: + node_hostname = get_hostname_to_use_for_compute_node( + job_info["Node"], cluster=login_node.hostname + ) + assert node_hostname and node_hostname != "None" # Check that the workdir is the scratch directory (because we cd'ed to $SCRATCH # before submitting the job) @@ -132,11 +179,38 @@ def test_code( if persist: # Job should still be running since we're using `persist` (that's the whole # point.) + # NOTE: There's a fixture that scancel's all our jobs spawned during unit tests + # so there's no issue of lingering jobs on the cluster after the tests run/fail. assert job_info["State"] == "RUNNING" + await compute_node.close() else: - # Job should have been cancelled by us after the `echo` process finished. - # NOTE: This check is a bit flaky, perhaps our `scancel` command hasn't - # completed yet, or sacct doesn't show the change in status quick enough. - # Relaxing it a bit for now. - # assert "CANCELLED" in job_info["State"] - assert "CANCELLED" in job_info["State"] or job_info["State"] == "RUNNING" + # NOTE: Job is actually in the `COMPLETED` state because we exited cleanly (by + # passing `exit\n` to the salloc subprocess.) + assert job_info["State"] == "COMPLETED" + + def filter_captured_output(captured_output: str) -> str: + # Remove information that may vary between runs from the regression test files. + def filter_line(line: str) -> str: + if ( + regex := re.compile( + r"Disk usage: \d+\.\d+ / \d+\.\d+ GiB and \d+ / \d+ files" + ) + ).match(line): + # IDEA: Use regex to go from this: + # Disk usage: 66.56 / 100.00 GiB and 789192 / 1048576 files + # to this: + # Disk usage: X / LIMIT GiB and X / LIMIT files + line = regex.sub("Disk usage: X / LIMIT GiB and X / LIMIT files", line) + return ( + line.rstrip() + .replace(str(job_id), "JOB_ID") + .replace(node_hostname, "COMPUTE_NODE") + .replace(home, "$HOME") + .replace( + f"--account={slurm_account_on_cluster}", "--account=SLURM_ACCOUNT" + ) + ) + + return "\n".join(filter_line(line) for line in captured_output.splitlines()) + + file_regression.check(filter_captured_output(captured_output)) diff --git a/tests/integration/test_code_command/test_code_mila0_None_True_.txt b/tests/integration/test_code_command/test_code_mila0_None_True_.txt new file mode 100644 index 00000000..240c33f7 --- /dev/null +++ b/tests/integration/test_code_command/test_code_mila0_None_True_.txt @@ -0,0 +1,14 @@ +Checking disk quota on $HOME... +Disk usage: X / LIMIT GiB and X / LIMIT files +(SLURM_ACCOUNT) $ sbatch --parsable --wckey=SLURM_ACCOUNTtools_test --account=SLURM_ACCOUNT --nodes=1 --ntasks=1 --cpus-per-task=1 --mem=1G --time=0:05:00 --oversubscribe --job-name=SLURM_ACCOUNT-code --wrap 'srun sleep 7d' +JOB_ID + +(local) echo --new-window --wait --remote ssh-remote+cn-f002.server.SLURM_ACCOUNT.quebec /home/SLURM_ACCOUNT/n/normandf/bob +--new-window --wait --remote ssh-remote+cn-f002.server.SLURM_ACCOUNT.quebec /home/SLURM_ACCOUNT/n/normandf/bob + +The editor was closed. Reopen it with or terminate the process with (maybe twice). +This allocation is persistent and is still active. +To reconnect to this job, run the following: + SLURM_ACCOUNT code /home/SLURM_ACCOUNT/n/normandf/bob --job JOB_ID +To kill this allocation: + ssh SLURM_ACCOUNT scancel JOB_ID \ No newline at end of file