Skip to content

Commit

Permalink
Update mila code to use RemoteV2
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Apr 18, 2024
1 parent 51bc370 commit 07741bf
Show file tree
Hide file tree
Showing 3 changed files with 347 additions and 53 deletions.
213 changes: 210 additions & 3 deletions milatools/cli/code_command.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from __future__ import annotations

import argparse
import shlex
import shutil
import sys
from logging import getLogger as get_logger

from typing_extensions import deprecated

from milatools.cli import console
from milatools.cli.common import (
check_disk_quota,
find_allocation,
)
from milatools.cli.init_command import DRAC_CLUSTERS
from milatools.cli.utils import (
CLUSTERS,
Cluster,
Expand All @@ -22,9 +26,18 @@
no_internet_on_compute_nodes,
running_inside_WSL,
)
from milatools.utils.compute_node import (
ComputeNode,
salloc,
sbatch,
wait_while_job_is_pending,
)
from milatools.utils.local_v1 import LocalV1
from milatools.utils.local_v2 import LocalV2
from milatools.utils.local_v2 import LocalV2, run_async
from milatools.utils.remote_v1 import RemoteV1
from milatools.utils.remote_v2 import (
RemoteV2,
)
from milatools.utils.vscode_utils import (
get_code_command,
sync_vscode_extensions,
Expand Down Expand Up @@ -84,10 +97,204 @@ def add_mila_code_arguments(subparsers: argparse._SubParsersAction):
action="store_true",
help="Whether the server should persist or not",
)
code_parser.set_defaults(function=code)
if sys.platform == "win32":
code_parser.set_defaults(function=code_v1)
else:
code_parser.set_defaults(function=code)


def code(
async def code(
path: str,
command: str,
persist: bool,
job: int | None,
node: str | None,
alloc: list[str],
cluster: Cluster = "mila",
) -> ComputeNode | int:
"""Open a remote VSCode session on a compute node.
Arguments:
path: Path to open on the remote machine
command: Command to use to start vscode (defaults to "code" or the value of \
$MILATOOLS_CODE_COMMAND)
persist: Whether the server should persist or not after exiting the terminal.
job: ID of the job to connect to
node: Name of the node to connect to
alloc: Extra options to pass to slurm
"""
# Check that the `code` command is in the $PATH so that we can use just `code` as
# the command.
code_command = command
if not shutil.which(code_command):
raise CommandNotFoundError(code_command)

# Connect to the cluster's login node.
login_node = await RemoteV2.connect(cluster)

if not path.startswith("/"):
# Get $HOME because we have to give the full path to code
home = login_node.get_output("echo $HOME", display=False, hide=True)
path = home if path == "." else f"{home}/{path}"

try:
check_disk_quota(login_node)
except MilatoolsUserError:
# Raise errors that are meant to be shown to the user (disk quota is reached).
raise
except Exception as exc:
logger.warning(f"Unable to check the disk-quota on the cluster: {exc}")

# NOTE: Perhaps we could eventually do this check dynamically, if the cluster is an
# unknown cluster?
if no_internet_on_compute_nodes(cluster):
# Sync the VsCode extensions from the local machine over to the target cluster.
console.log(
f"Installing VSCode extensions that are on the local machine on "
f"{cluster} in the background.",
style="cyan",
)
# todo: use the mila or the local machine as the reference for vscode
# extensions?
# todo: could also perhaps make this function asynchronous instead of using a
# multiprocessing process for it.
copy_vscode_extensions_process = make_process(
sync_vscode_extensions,
LocalV2(),
[login_node],
)
copy_vscode_extensions_process.start()
# todo: could potentially do this at the same time as the blocks above and just wait
# for the result here, instead of running each block in sequence.
if currently_in_a_test():
copy_vscode_extensions_process.join()

if job or node:
if job and node:
logger.warning(
"Both job ID and node name were specified. Ignoring the node name and "
"only using the job id."
)
job_id_or_node = job or node
assert job_id_or_node is not None
compute_node = await connect_to_running_job(job_id_or_node, login_node)
else:
if cluster in DRAC_CLUSTERS and not any("--account" in flag for flag in alloc):
logger.warning(
"Warning: When using the DRAC clusters, you usually need to "
"specify the account to use when submitting a job. You can specify "
"this in the job resources with `--alloc`, like so: "
"`--alloc --account=<account_to_use>`, for example:\n"
f"mila code some_path --cluster {cluster} --alloc "
f"--account=your-account-here"
)
# Set the job name to `mila-code`. This should not be changed by the user
# ideally, so we can collect some simple stats about the use of `milatools` on
# the clusters.
if any(flag == "-J" or "-J=" in flag or "--job-name" in flag for flag in alloc):
# todo: Get the job name from the flags instead?
raise MilatoolsUserError(
"The job name flag (--job-name or -J) should be left unset for now "
"because we use the job name to gage how many people use `mila code` "
"on the various clusters. We also make use of the job name when the "
"call to `salloc` is interrupted before we have a chance to know the "
"job id."
)
job_name = "mila-code"
alloc = alloc + [f"--job-name={job_name}"]

if persist:
compute_node = await sbatch(
login_node, sbatch_flags=alloc, job_name=job_name
)
else:
# NOTE: Here we actually need the job name to be known, so that we can
# scancel jobs if the call is interrupted.
compute_node = await salloc(
login_node, salloc_flags=alloc, job_name=job_name
)

try:
while True:
code_command_to_run = (
code_command,
"--new-window",
"--wait",
"--remote",
f"ssh-remote+{compute_node.hostname}",
path,
)
console.log(
f"(local) {shlex.join(code_command_to_run)}", style="bold green"
)
await run_async(code_command_to_run)
# TODO: BUG: This now requires two Ctrl+C's instead of one!
console.print(
"The editor was closed. Reopen it with <Enter> or terminate the "
"process with <Ctrl+C> (maybe twice)."
)
if currently_in_a_test():
# TODO: This early exit kills the job when it is not persistent.
break
input()
except KeyboardInterrupt:
logger.info("Keyboard interrupt.")

if not persist:
# Cancel the job explicitly.
await compute_node.close()
console.print(f"Ended session on '{compute_node.hostname}'")
return compute_node.job_id

console.print("This allocation is persistent and is still active.")
console.print("To reconnect to this job, run the following:")
console.print(
f" mila code {path} "
+ (f"--cluster {cluster} " if cluster != "mila" else "")
+ f"--job {compute_node.job_id}",
style="bold",
)
console.print("To kill this allocation:")
console.print(f" ssh {cluster} scancel {compute_node.job_id}", style="bold")
return compute_node


async def connect_to_running_job(
jobid_or_nodename: int | str,
login_node: RemoteV2,
) -> ComputeNode:
if isinstance(jobid_or_nodename, int):
job_id = jobid_or_nodename
await wait_while_job_is_pending(login_node, job_id=job_id)
return ComputeNode(login_node, job_id=job_id)

node_name = jobid_or_nodename
# we have to find the job id to use on the given node.
jobs_on_node = await login_node.get_output_async(
f"squeue --me --node {node_name} --noheader --format=%A"
)
jobs_on_node = [int(job_id.strip()) for job_id in jobs_on_node.splitlines()]
if len(jobs_on_node) == 0:
raise MilatoolsUserError(
f"You don't appear to have any jobs currently running on node {node_name}. "
"Please check again or specify the job id to connect to."
)
if len(jobs_on_node) > 1:
raise MilatoolsUserError(
f"You have more than one job running on node {node_name}: {jobs_on_node}.\n"
"please use the `--job` flag to specify which job to connect to."
)
assert len(jobs_on_node) == 1
return ComputeNode(login_node=login_node, job_id=jobs_on_node[0])


@deprecated(
"Support for the `mila code` command is now deprecated on Windows machines, as it "
"does not support ssh keys with passphrases or clusters where 2FA is enabled. "
"Please consider switching to the Windows Subsystem for Linux (WSL) to run "
"`mila code`."
)
def code_v1(
path: str,
command: str,
persist: bool,
Expand Down
Loading

0 comments on commit 07741bf

Please sign in to comment.