diff --git a/milatools/cli/code.py b/milatools/cli/code.py index ead8391b..27fbff6e 100644 --- a/milatools/cli/code.py +++ b/milatools/cli/code.py @@ -7,10 +7,7 @@ from logging import getLogger as get_logger from typing import Awaitable -from typing_extensions import deprecated - from milatools.cli import console -from milatools.cli.commands import find_allocation from milatools.cli.init_command import DRAC_CLUSTERS from milatools.cli.utils import ( CLUSTERS, @@ -18,18 +15,15 @@ CommandNotFoundError, MilatoolsUserError, currently_in_a_test, - get_hostname_to_use_for_compute_node, - no_internet_on_compute_nodes, - running_inside_WSL, + internet_on_compute_nodes, ) from milatools.utils.compute_node import ( ComputeNode, salloc, sbatch, ) -from milatools.utils.disk_quota import check_disk_quota, check_disk_quota_v1 +from milatools.utils.disk_quota import check_disk_quota from milatools.utils.local_v2 import LocalV2 -from milatools.utils.remote_v1 import RemoteV1 from milatools.utils.remote_v2 import ( RemoteV2, ) @@ -99,6 +93,8 @@ def add_mila_code_arguments(subparsers: argparse._SubParsersAction): help="When set, skips the syncing of vscode extensions.", ) if sys.platform == "win32": + from milatools.cli.commands import code_v1 + code_parser.set_defaults(function=code_v1) else: code_parser.set_defaults(function=code) @@ -152,7 +148,7 @@ async def code( # NOTE: Perhaps we could eventually do this check dynamically, if the cluster is an # unknown cluster? sync_vscode_extensions_task = None - if no_internet_on_compute_nodes(cluster) and not no_sync: + if not internet_on_compute_nodes(cluster) and not no_sync: # Sync the VsCode extensions from the local machine over to the target cluster. console.log( f"Installing VSCode extensions that are on the local machine on " @@ -276,166 +272,3 @@ async def launch_vscode_loop(code_command: str, compute_node: ComputeNode, path: except asyncio.CancelledError: raise # return - - -@deprecated( - "Support for the `mila code` command is now deprecated on Windows machines, as it " - "does not support ssh keys with passphrases or clusters where 2FA is enabled. " - "Please consider switching to the Windows Subsystem for Linux (WSL) to run " - "`mila code`." -) -def code_v1( - path: str, - command: str, - persist: bool, - job: int | None, - node: str | None, - alloc: list[str], - cluster: Cluster = "mila", - no_sync: bool = False, -): - """Open a remote VSCode session on a compute node. - - Arguments: - path: Path to open on the remote machine - command: Command to use to start vscode - (defaults to "code" or the value of $MILATOOLS_CODE_COMMAND) - persist: Whether the server should persist or not - job: Job ID to connect to - node: Node to connect to - alloc: Extra options to pass to slurm - """ - here = LocalV2() - remote = RemoteV1(cluster) - - if cluster != "mila" and job is None and node is None: - if not any("--account" in flag for flag in alloc): - logger.warning( - "Warning: When using the DRAC clusters, you usually need to " - "specify the account to use when submitting a job. You can specify " - "this in the job resources with `--alloc`, like so: " - "`--alloc --account=`, for example:\n" - f"mila code {path} --cluster {cluster} --alloc " - f"--account=your-account-here" - ) - - if command is None: - command = get_code_command() - - try: - check_disk_quota_v1(remote) - except MilatoolsUserError: - raise - except Exception as exc: - logger.warning(f"Unable to check the disk-quota on the cluster: {exc}") - - if sys.platform == "win32": - print( - "Syncing vscode extensions in the background isn't supported on " - "Windows. Skipping." - ) - elif no_internet_on_compute_nodes(cluster): - # Sync the VsCode extensions from the local machine over to the target cluster. - # TODO: Make this happen in the background (without overwriting the output). - run_in_the_background = False - print( - console.log( - f"[cyan]Installing VSCode extensions that are on the local machine on " - f"{cluster}" + (" in the background." if run_in_the_background else ".") - ) - ) - asyncio.run( - sync_vscode_extensions( - LocalV2(), - [cluster], - ) - ) - - if node is None: - cnode = find_allocation( - remote, - job_name="mila-code", - job=job, - node=node, - alloc=alloc, - cluster=cluster, - ) - if persist: - cnode = cnode.persist() - - data, proc = cnode.ensure_allocation() - - node_name = data["node_name"] - else: - node_name = node - proc = None - data = None - - if not path.startswith("/"): - # Get $HOME because we have to give the full path to code - home = remote.home() - path = home if path == "." else f"{home}/{path}" - - command_path = shutil.which(command) - if not command_path: - raise CommandNotFoundError(command) - - # NOTE: Since we have the config entries for the DRAC compute nodes, there is no - # need to use the fully qualified hostname here. - if cluster == "mila": - node_name = get_hostname_to_use_for_compute_node(node_name) - - # Try to detect if this is being run from within the Windows Subsystem for Linux. - # If so, then we run `code` through a powershell.exe command to open VSCode without - # issues. - inside_WSL = running_inside_WSL() - try: - while True: - if inside_WSL: - here.run( - ( - "powershell.exe", - "code", - "-nw", - "--remote", - f"ssh-remote+{node_name}", - path, - ), - ) - else: - here.run( - ( - command_path, - "-nw", - "--remote", - f"ssh-remote+{node_name}", - path, - ), - ) - print( - "The editor was closed. Reopen it with " - " or terminate the process with " - ) - if currently_in_a_test(): - break - input() - - except KeyboardInterrupt: - if not persist: - if proc is not None: - proc.kill() - print(f"Ended session on '{node_name}'") - - if persist: - print("This allocation is persistent and is still active.") - print("To reconnect to this node:") - console.print( - f" mila code {path} " - + (f"--cluster={cluster} " if cluster != "mila" else "") - + f"--node {node_name}", - style="bold", - ) - print("To kill this allocation:") - assert data is not None - assert "jobid" in data - console.print(f" ssh {cluster} scancel {data['jobid']}", style="bold") diff --git a/milatools/cli/commands.py b/milatools/cli/commands.py index b9218ecd..a0472279 100644 --- a/milatools/cli/commands.py +++ b/milatools/cli/commands.py @@ -11,6 +11,7 @@ import logging import operator import re +import shutil import socket import subprocess import sys @@ -28,16 +29,23 @@ import questionary as qn import rich.logging -from typing_extensions import TypedDict +from typing_extensions import TypedDict, deprecated from milatools.cli import console from milatools.cli.utils import ( + CommandNotFoundError, + MilatoolsUserError, cluster_to_connect_kwargs, + currently_in_a_test, get_hostname_to_use_for_compute_node, + internet_on_compute_nodes, + running_inside_WSL, ) +from milatools.utils.disk_quota import check_disk_quota_v1 from milatools.utils.local_v1 import LocalV1 +from milatools.utils.local_v2 import LocalV2 from milatools.utils.remote_v1 import RemoteV1, SlurmRemote -from milatools.utils.vscode_utils import sync_vscode_extensions +from milatools.utils.vscode_utils import get_code_command, sync_vscode_extensions from ..__version__ import __version__ from .code import add_mila_code_arguments @@ -52,12 +60,10 @@ from .profile import ensure_program, setup_profile from .utils import ( CLUSTERS, - MilatoolsUserError, SSHConnectionError, T, get_fully_qualified_name, randname, - running_inside_WSL, with_control_file, ) @@ -489,6 +495,171 @@ def forward( local_proc.kill() +@deprecated( + "Support for the `mila code` command is now deprecated on Windows machines, as it " + "does not support ssh keys with passphrases or clusters where 2FA is enabled. " + "Please consider switching to the Windows Subsystem for Linux (WSL) to run " + "`mila code`." +) +def code_v1( + path: str, + command: str, + persist: bool, + job: int | None, + node: str | None, + alloc: list[str], + cluster: str = "mila", + no_sync: bool = False, +): + """Open a remote VSCode session on a compute node. + + Arguments: + path: Path to open on the remote machine + command: Command to use to start vscode + (defaults to "code" or the value of $MILATOOLS_CODE_COMMAND) + persist: Whether the server should persist or not + job: Job ID to connect to + node: Node to connect to + alloc: Extra options to pass to slurm + """ + here = LocalV2() + remote = RemoteV1(cluster) + + if cluster != "mila" and job is None and node is None: + if not any("--account" in flag for flag in alloc): + logger.warning( + "Warning: When using the DRAC clusters, you usually need to " + "specify the account to use when submitting a job. You can specify " + "this in the job resources with `--alloc`, like so: " + "`--alloc --account=`, for example:\n" + f"mila code {path} --cluster {cluster} --alloc " + f"--account=your-account-here" + ) + + if command is None: + command = get_code_command() + + try: + check_disk_quota_v1(remote) + except MilatoolsUserError: + raise + except Exception as exc: + logger.warning(f"Unable to check the disk-quota on the cluster: {exc}") + + if sys.platform == "win32": + print( + "Syncing vscode extensions in the background isn't supported on " + "Windows. Skipping." + ) + elif not internet_on_compute_nodes(cluster) and not no_sync: + # Sync the VsCode extensions from the local machine over to the target cluster. + # TODO: Make this happen in the background (without overwriting the output). + run_in_the_background = False + print( + console.log( + f"[cyan]Installing VSCode extensions that are on the local machine on " + f"{cluster}" + (" in the background." if run_in_the_background else ".") + ) + ) + asyncio.run( + sync_vscode_extensions( + LocalV2(), + [cluster], + ) + ) + + if node is None: + from milatools.cli.commands import find_allocation + + cnode = find_allocation( + remote, + job_name="mila-code", + job=job, + node=node, + alloc=alloc, + cluster=cluster, + ) + if persist: + cnode = cnode.persist() + + data, proc = cnode.ensure_allocation() + + node_name = data["node_name"] + else: + node_name = node + proc = None + data = None + + if not path.startswith("/"): + # Get $HOME because we have to give the full path to code + home = remote.home() + path = home if path == "." else f"{home}/{path}" + + command_path = shutil.which(command) + if not command_path: + raise CommandNotFoundError(command) + + # NOTE: Since we have the config entries for the DRAC compute nodes, there is no + # need to use the fully qualified hostname here. + if cluster == "mila": + node_name = get_hostname_to_use_for_compute_node(node_name) + + # Try to detect if this is being run from within the Windows Subsystem for Linux. + # If so, then we run `code` through a powershell.exe command to open VSCode without + # issues. + inside_WSL = running_inside_WSL() + try: + while True: + if inside_WSL: + here.run( + ( + "powershell.exe", + "code", + "-nw", + "--remote", + f"ssh-remote+{node_name}", + path, + ), + ) + else: + here.run( + ( + command_path, + "-nw", + "--remote", + f"ssh-remote+{node_name}", + path, + ), + ) + print( + "The editor was closed. Reopen it with " + " or terminate the process with " + ) + if currently_in_a_test(): + break + input() + + except KeyboardInterrupt: + if not persist: + if proc is not None: + proc.kill() + print(f"Ended session on '{node_name}'") + + if persist: + print("This allocation is persistent and is still active.") + print("To reconnect to this node:") + console.print( + f" mila code {path} " + + (f"--cluster={cluster} " if cluster != "mila" else "") + + f"--node {node_name}", + style="bold", + ) + print("To kill this allocation:") + assert data is not None + assert "jobid" in data + console.print(f" ssh {cluster} scancel {data['jobid']}", style="bold") + + def connect(identifier: str, port: int | None): """Reconnect to a persistent server.""" diff --git a/milatools/cli/utils.py b/milatools/cli/utils.py index c527f71c..7a33c4fc 100644 --- a/milatools/cli/utils.py +++ b/milatools/cli/utils.py @@ -83,17 +83,15 @@ def currently_in_a_test() -> bool: return "pytest" in sys.modules -def no_internet_on_compute_nodes( - cluster: Cluster, -) -> TypeGuard[ClusterWithoutInternetOnCNodes]: +def internet_on_compute_nodes(cluster: str) -> TypeGuard[ClusterWithInternetOnCNodes]: if cluster not in CLUSTERS: warnings.warn( UserWarning( - f"Unknown cluster {cluster}. Assuming that compute nodes do not have " - f"internet access on this cluster for now." + f"Unknown cluster {cluster}. Assuming that compute nodes of this " + f"cluster do NOT have access to the internet for now." ) ) - return cluster not in get_args(ClusterWithInternetOnCNodes) + return cluster in get_args(ClusterWithInternetOnCNodes) def randname(): diff --git a/tests/conftest.py b/tests/conftest.py index 4ec4e675..f83730cc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,6 @@ import milatools.cli.code import milatools.cli.commands -import milatools.cli.common import milatools.utils.compute_node import milatools.utils.local_v2 import milatools.utils.parallel_progress @@ -69,7 +68,6 @@ def use_wider_console_during_tests(monkeypatch: pytest.MonkeyPatch): milatools.utils.parallel_progress, milatools.utils.remote_v2, test_parallel_progress, - milatools.cli.common, milatools.cli.code, ]: # These modules import the console from milatools.cli before this runs, so we