From 4bd3318e3ed78d62369cde8e994a0fc9f6c9f181 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 25 Apr 2024 11:10:31 -0400 Subject: [PATCH] Undo the change to commands.py to make PR simpler Signed-off-by: Fabrice Normandin --- milatools/cli/code_command.py | 246 --------- milatools/cli/commands.py | 693 +++++++++++++++++++++++-- milatools/cli/common.py | 426 --------------- milatools/utils/compute_node.py | 10 +- tests/cli/test_commands.py | 3 +- tests/integration/test_code_command.py | 3 +- 6 files changed, 665 insertions(+), 716 deletions(-) delete mode 100644 milatools/cli/code_command.py delete mode 100644 milatools/cli/common.py diff --git a/milatools/cli/code_command.py b/milatools/cli/code_command.py deleted file mode 100644 index 9795bb35..00000000 --- a/milatools/cli/code_command.py +++ /dev/null @@ -1,246 +0,0 @@ -from __future__ import annotations - -import argparse -import shutil -import sys -from logging import getLogger as get_logger - -from milatools.cli import console, currently_in_a_test -from milatools.cli.common import ( - check_disk_quota, - find_allocation, -) -from milatools.cli.utils import ( - CLUSTERS, - Cluster, - CommandNotFoundError, - MilatoolsUserError, - SortingHelpFormatter, - get_hostname_to_use_for_compute_node, - make_process, - no_internet_on_compute_nodes, - running_inside_WSL, -) -from milatools.utils.local_v1 import LocalV1 -from milatools.utils.local_v2 import LocalV2 -from milatools.utils.remote_v1 import RemoteV1 -from milatools.utils.vscode_utils import ( - get_code_command, - sync_vscode_extensions, - sync_vscode_extensions_with_hostnames, -) - -logger = get_logger(__name__) - - -def add_mila_code_arguments(subparsers: argparse._SubParsersAction): - code_parser: argparse.ArgumentParser = subparsers.add_parser( - "code", - help="Open a remote VSCode session on a compute node.", - formatter_class=SortingHelpFormatter, - ) - code_parser.add_argument( - "PATH", help="Path to open on the remote machine", type=str - ) - code_parser.add_argument( - "--cluster", - choices=CLUSTERS, # todo: widen based on the entries in ssh config? - default="mila", - help="Which cluster to connect to.", - ) - code_parser.add_argument( - "--alloc", - nargs=argparse.REMAINDER, - help="Extra options to pass to slurm", - metavar="VALUE", - default=[], - ) - code_parser.add_argument( - "--command", - default=get_code_command(), - help=( - "Command to use to start vscode\n" - '(defaults to "code" or the value of $MILATOOLS_CODE_COMMAND)' - ), - metavar="VALUE", - ) - code_parser.add_argument( - "--job", - type=int, - default=None, - help="Job ID to connect to", - metavar="VALUE", - ) - code_parser.add_argument( - "--node", - type=str, - default=None, - help="Node to connect to", - metavar="VALUE", - ) - code_parser.add_argument( - "--persist", - action="store_true", - help="Whether the server should persist or not", - ) - code_parser.set_defaults(function=code) - - -def code( - path: str, - command: str, - persist: bool, - job: int | None, - node: str | None, - alloc: list[str], - cluster: Cluster = "mila", -): - """Open a remote VSCode session on a compute node. - - Arguments: - path: Path to open on the remote machine - command: Command to use to start vscode - (defaults to "code" or the value of $MILATOOLS_CODE_COMMAND) - persist: Whether the server should persist or not - job: Job ID to connect to - node: Node to connect to - alloc: Extra options to pass to slurm - """ - here = LocalV1() - remote = RemoteV1(cluster) - - if cluster != "mila" and job is None and node is None: - if not any("--account" in flag for flag in alloc): - logger.warning( - "Warning: When using the DRAC clusters, you usually need to " - "specify the account to use when submitting a job. You can specify " - "this in the job resources with `--alloc`, like so: " - "`--alloc --account=`, for example:\n" - f"mila code {path} --cluster {cluster} --alloc " - f"--account=your-account-here" - ) - - if command is None: - command = get_code_command() - - try: - check_disk_quota(remote) - except MilatoolsUserError: - raise - except Exception as exc: - logger.warning(f"Unable to check the disk-quota on the cluster: {exc}") - - if sys.platform == "win32": - print( - "Syncing vscode extensions in the background isn't supported on " - "Windows. Skipping." - ) - elif no_internet_on_compute_nodes(cluster): - # Sync the VsCode extensions from the local machine over to the target cluster. - # TODO: Make this happen in the background (without overwriting the output). - run_in_the_background = False - print( - console.log( - f"[cyan]Installing VSCode extensions that are on the local machine on " - f"{cluster}" + (" in the background." if run_in_the_background else ".") - ) - ) - if run_in_the_background: - copy_vscode_extensions_process = make_process( - sync_vscode_extensions_with_hostnames, - # todo: use the mila cluster as the source for vscode extensions? Or - # `localhost`? - source="localhost", - destinations=[cluster], - ) - copy_vscode_extensions_process.start() - else: - sync_vscode_extensions( - LocalV2(), - [cluster], - ) - - if node is None: - cnode = find_allocation( - remote, - job_name="mila-code", - job=job, - node=node, - alloc=alloc, - cluster=cluster, - ) - if persist: - cnode = cnode.persist() - - data, proc = cnode.ensure_allocation() - - node_name = data["node_name"] - else: - node_name = node - proc = None - data = None - - if not path.startswith("/"): - # Get $HOME because we have to give the full path to code - home = remote.home() - path = home if path == "." else f"{home}/{path}" - - command_path = shutil.which(command) - if not command_path: - raise CommandNotFoundError(command) - - # NOTE: Since we have the config entries for the DRAC compute nodes, there is no - # need to use the fully qualified hostname here. - if cluster == "mila": - node_name = get_hostname_to_use_for_compute_node(node_name) - - # Try to detect if this is being run from within the Windows Subsystem for Linux. - # If so, then we run `code` through a powershell.exe command to open VSCode without - # issues. - inside_WSL = running_inside_WSL() - try: - while True: - if inside_WSL: - here.run( - "powershell.exe", - "code", - "-nw", - "--remote", - f"ssh-remote+{node_name}", - path, - ) - else: - here.run( - command_path, - "-nw", - "--remote", - f"ssh-remote+{node_name}", - path, - ) - print( - "The editor was closed. Reopen it with " - " or terminate the process with " - ) - if currently_in_a_test(): - break - input() - - except KeyboardInterrupt: - if not persist: - if proc is not None: - proc.kill() - print(f"Ended session on '{node_name}'") - - if persist: - print("This allocation is persistent and is still active.") - print("To reconnect to this node:") - console.print( - f" mila code {path} " - + (f"--cluster={cluster} " if cluster != "mila" else "") - + f"--node {node_name}", - style="bold", - ) - print("To kill this allocation:") - assert data is not None - assert "jobid" in data - console.print(f" ssh {cluster} scancel {data['jobid']}", style="bold") diff --git a/milatools/cli/commands.py b/milatools/cli/commands.py index 73a6c601..52e2bba9 100644 --- a/milatools/cli/commands.py +++ b/milatools/cli/commands.py @@ -9,13 +9,21 @@ import asyncio import inspect import logging +import operator +import re +import shutil +import socket +import subprocess import sys +import time import traceback import typing import webbrowser -from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser, _HelpAction from collections.abc import Sequence +from contextlib import ExitStack from logging import getLogger as get_logger +from pathlib import Path from typing import Any from urllib.parse import urlencode @@ -24,15 +32,18 @@ from typing_extensions import TypedDict from milatools.cli import console +from milatools.utils.local_v1 import LocalV1 +from milatools.utils.local_v2 import LocalV2 +from milatools.utils.remote_v1 import RemoteV1, SlurmRemote +from milatools.utils.remote_v2 import RemoteV2 from milatools.utils.vscode_utils import ( + get_code_command, + # install_local_vscode_extensions_on_remote, + sync_vscode_extensions, sync_vscode_extensions_with_hostnames, ) from ..__version__ import __version__ -from ..utils.local_v1 import LocalV1 -from ..utils.remote_v1 import RemoteV1 -from .code_command import add_mila_code_arguments -from .common import forward, standard_server from .init_command import ( print_welcome_message, setup_keys_on_login_node, @@ -41,14 +52,23 @@ setup_vscode_settings, setup_windows_ssh_config_from_wsl, ) +from .profile import ensure_program, setup_profile from .utils import ( CLUSTERS, + Cluster, + CommandNotFoundError, MilatoolsUserError, - SortingHelpFormatter, SSHConnectionError, T, + cluster_to_connect_kwargs, + currently_in_a_test, get_fully_qualified_name, + get_hostname_to_use_for_compute_node, + make_process, + no_internet_on_compute_nodes, + randname, running_inside_WSL, + with_control_file, ) if typing.TYPE_CHECKING: @@ -175,7 +195,57 @@ def mila(): forward_parser.set_defaults(function=forward) # ----- mila code ------ - add_mila_code_arguments(subparsers) + + code_parser = subparsers.add_parser( + "code", + help="Open a remote VSCode session on a compute node.", + formatter_class=SortingHelpFormatter, + ) + code_parser.add_argument( + "PATH", help="Path to open on the remote machine", type=str + ) + code_parser.add_argument( + "--cluster", + choices=CLUSTERS, + default="mila", + help="Which cluster to connect to.", + ) + code_parser.add_argument( + "--alloc", + nargs=argparse.REMAINDER, + help="Extra options to pass to slurm", + metavar="VALUE", + default=[], + ) + code_parser.add_argument( + "--command", + default=get_code_command(), + help=( + "Command to use to start vscode\n" + '(defaults to "code" or the value of $MILATOOLS_CODE_COMMAND)' + ), + metavar="VALUE", + ) + code_parser.add_argument( + "--job", + type=str, + default=None, + help="Job ID to connect to", + metavar="VALUE", + ) + code_parser.add_argument( + "--node", + type=str, + default=None, + help="Node to connect to", + metavar="VALUE", + ) + code_parser.add_argument( + "--persist", + action="store_true", + help="Whether the server should persist or not", + ) + code_parser.set_defaults(function=code) # ----- mila sync vscode-extensions ------ @@ -362,17 +432,11 @@ def mila(): args_dict = _convert_uppercase_keys_to_lowercase(args_dict) if inspect.iscoroutinefunction(function): - # TODO: Need to let the function handle KeyboardInterrupt by itself, here it - # seems like it never gets there. try: - # NOTE: Not using `asyncio.run` here, because it doesn't exit as cleanly - # when interrupted (prints out some ignored exceptions in - # SubprocessTransport.__del__). Not sure why or what the difference is - # between them. - return asyncio.get_event_loop().run_until_complete(function(**args_dict)) + return asyncio.run(function(**args_dict)) except KeyboardInterrupt: console.log("Terminated by user.") - exit() + return assert callable(function) return function(**args_dict) @@ -382,13 +446,11 @@ def setup_logging(verbose: int) -> None: global_loglevel = ( logging.CRITICAL if verbose == 0 - else ( - logging.WARNING - if verbose == 1 - else logging.INFO - if verbose == 2 - else logging.DEBUG - ) + else logging.WARNING + if verbose == 1 + else logging.INFO + if verbose == 2 + else logging.DEBUG ) package_loglevel = ( logging.WARNING @@ -457,7 +519,7 @@ def init(): print_welcome_message() -def forward_command( +def forward( remote: str, page: str | None, port: int | None, @@ -469,7 +531,7 @@ def forward_command( except ValueError: pass - local_proc, _ = forward( + local_proc, _ = _forward( local=LocalV1(), node=f"{node}.server.mila.quebec", to_forward=remote_port, @@ -485,12 +547,173 @@ def forward_command( local_proc.kill() +def code( + path: str, + command: str, + persist: bool, + job: str | None, + node: str | None, + alloc: list[str], + cluster: Cluster = "mila", +): + """Open a remote VSCode session on a compute node. + + Arguments: + path: Path to open on the remote machine + command: Command to use to start vscode + (defaults to "code" or the value of $MILATOOLS_CODE_COMMAND) + persist: Whether the server should persist or not + job: Job ID to connect to + node: Node to connect to + alloc: Extra options to pass to slurm + """ + here = LocalV1() + remote = RemoteV1(cluster) + + if cluster != "mila" and job is None and node is None: + if not any("--account" in flag for flag in alloc): + logger.warning( + "Warning: When using the DRAC clusters, you usually need to " + "specify the account to use when submitting a job. You can specify " + "this in the job resources with `--alloc`, like so: " + "`--alloc --account=`, for example:\n" + f"mila code {path} --cluster {cluster} --alloc " + f"--account=your-account-here" + ) + + if command is None: + command = get_code_command() + + try: + check_disk_quota(remote) + except MilatoolsUserError: + raise + except Exception as exc: + logger.warning(f"Unable to check the disk-quota on the cluster: {exc}") + + if sys.platform == "win32": + print( + "Syncing vscode extensions in the background isn't supported on " + "Windows. Skipping." + ) + elif no_internet_on_compute_nodes(cluster): + # Sync the VsCode extensions from the local machine over to the target cluster. + # TODO: Make this happen in the background (without overwriting the output). + run_in_the_background = False + print( + console.log( + f"[cyan]Installing VSCode extensions that are on the local machine on " + f"{cluster}" + (" in the background." if run_in_the_background else ".") + ) + ) + if run_in_the_background: + copy_vscode_extensions_process = make_process( + sync_vscode_extensions_with_hostnames, + # todo: use the mila cluster as the source for vscode extensions? Or + # `localhost`? + source="localhost", + destinations=[cluster], + ) + copy_vscode_extensions_process.start() + else: + sync_vscode_extensions( + LocalV2(), + [cluster], + ) + + if node is None: + cnode = _find_allocation( + remote, + job_name="mila-code", + job=job, + node=node, + alloc=alloc, + cluster=cluster, + ) + if persist: + cnode = cnode.persist() + + data, proc = cnode.ensure_allocation() + + node_name = data["node_name"] + else: + node_name = node + proc = None + data = None + + if not path.startswith("/"): + # Get $HOME because we have to give the full path to code + home = remote.home() + path = home if path == "." else f"{home}/{path}" + + command_path = shutil.which(command) + if not command_path: + raise CommandNotFoundError(command) + + # NOTE: Since we have the config entries for the DRAC compute nodes, there is no + # need to use the fully qualified hostname here. + if cluster == "mila": + node_name = get_hostname_to_use_for_compute_node(node_name) + + # Try to detect if this is being run from within the Windows Subsystem for Linux. + # If so, then we run `code` through a powershell.exe command to open VSCode without + # issues. + inside_WSL = running_inside_WSL() + try: + while True: + if inside_WSL: + here.run( + "powershell.exe", + "code", + "-nw", + "--remote", + f"ssh-remote+{node_name}", + path, + ) + else: + here.run( + command_path, + "-nw", + "--remote", + f"ssh-remote+{node_name}", + path, + ) + print( + "The editor was closed. Reopen it with " + " or terminate the process with " + ) + if currently_in_a_test(): + break + input() + + except KeyboardInterrupt: + if not persist: + if proc is not None: + proc.kill() + print(f"Ended session on '{node_name}'") + + if persist: + print("This allocation is persistent and is still active.") + print("To reconnect to this node:") + print( + T.bold( + f" mila code {path} " + + (f"--cluster={cluster} " if cluster != "mila" else "") + + f"--node {node_name}" + ) + ) + print("To kill this allocation:") + assert data is not None + assert "jobid" in data + print(T.bold(f" ssh {cluster} scancel {data['jobid']}")) + + def connect(identifier: str, port: int | None): """Reconnect to a persistent server.""" remote = RemoteV1("mila") info = _get_server_info(remote, identifier) - local_proc, _ = forward( + local_proc, _ = _forward( local=LocalV1(), node=f"{info['node_name']}.server.mila.quebec", to_forward=info["to_forward"], @@ -566,7 +789,7 @@ class StandardServerArgs(TypedDict): alloc: list[str] """Extra options to pass to slurm.""" - job: int | None + job: str | None """Job ID to connect to.""" name: str | None @@ -595,7 +818,7 @@ def lab(path: str | None, **kwargs: Unpack[StandardServerArgs]): if path and path.endswith(".ipynb"): exit("Only directories can be given to the mila serve lab command") - standard_server( + _standard_server( path, program="jupyter-lab", installers={ @@ -618,7 +841,7 @@ def notebook(path: str | None, **kwargs: Unpack[StandardServerArgs]): if path and path.endswith(".ipynb"): exit("Only directories can be given to the mila serve notebook command") - standard_server( + _standard_server( path, program="jupyter-notebook", installers={ @@ -639,7 +862,7 @@ def tensorboard(logdir: str, **kwargs: Unpack[StandardServerArgs]): logdir: Path to the experiment logs """ - standard_server( + _standard_server( logdir, program="tensorboard", installers={ @@ -659,7 +882,7 @@ def mlflow(logdir: str, **kwargs: Unpack[StandardServerArgs]): logdir: Path to the experiment logs """ - standard_server( + _standard_server( logdir, program="mlflow", installers={ @@ -677,7 +900,7 @@ def aim(logdir: str, **kwargs: Unpack[StandardServerArgs]): Arguments: logdir: Path to the experiment logs """ - standard_server( + _standard_server( logdir, program="aim", installers={ @@ -697,6 +920,18 @@ def _get_server_info( return info +class SortingHelpFormatter(argparse.HelpFormatter): + """Taken and adapted from https://stackoverflow.com/a/12269143/6388696.""" + + def add_arguments(self, actions): + actions = sorted(actions, key=operator.attrgetter("option_strings")) + # put help actions first. + actions = sorted( + actions, key=lambda action: not isinstance(action, _HelpAction) + ) + super().add_arguments(actions) + + def _add_standard_server_args(parser: ArgumentParser): parser.add_argument( "--alloc", @@ -707,7 +942,7 @@ def _add_standard_server_args(parser: ArgumentParser): ) parser.add_argument( "--job", - type=int, + type=str, default=None, help="Job ID to connect to", metavar="VALUE", @@ -747,5 +982,399 @@ def _add_standard_server_args(parser: ArgumentParser): ) +def _standard_server( + path: str | None, + *, + program: str, + installers: dict[str, str], + command: str, + profile: str | None, + persist: bool, + port: int | None, + name: str | None, + node: str | None, + job: str | None, + alloc: list[str], + port_pattern=None, + token_pattern=None, +): + # Make the server visible from the login node (other users will be able to connect) + # Temporarily disabled + share = False + + if name is not None: + persist = True + elif persist: + name = program + + remote = RemoteV1("mila") + + path = path or "~" + if path == "~" or path.startswith("~/"): + path = remote.home() + path[1:] + + results: dict | None = None + node_name: str | None = None + to_forward: int | str | None = None + cf: str | None = None + proc = None + with ExitStack() as stack: + if persist: + cf = stack.enter_context(with_control_file(remote, name=name)) + else: + cf = None + + if profile: + prof = f"~/.milatools/profiles/{profile}.bash" + else: + prof = setup_profile(remote, path) + + qn.print(f"Using profile: {prof}") + cat_result = remote.run(f"cat {prof}", hide=True, warn=True) + if cat_result.ok: + qn.print("=" * 50) + qn.print(cat_result.stdout.rstrip()) + qn.print("=" * 50) + else: + exit(f"Could not find or load profile: {prof}") + + premote = remote.with_profile(prof) + + if not ensure_program( + remote=premote, + program=program, + installers=installers, + ): + exit(f"Exit: {program} is not installed.") + + cnode = _find_allocation( + remote, + job_name=f"mila-serve-{program}", + node=node, + job=job, + alloc=alloc, + cluster="mila", + ) + + patterns = { + "node_name": "#### ([A-Za-z0-9_-]+)", + } + + if port_pattern: + patterns["port"] = port_pattern + elif share: + exit( + "Server cannot be shared because it is serving over a Unix domain " + "socket" + ) + else: + remote.run("mkdir -p ~/.milatools/sockets", hide=True) + + if share: + host = "0.0.0.0" + else: + host = "localhost" + + sock_name = name or randname() + command = command.format( + path=path, + sock=f"~/.milatools/sockets/{sock_name}.sock", + host=host, + ) + + if token_pattern: + patterns["token"] = token_pattern + + if persist: + cnode = cnode.persist() + + proc, results = ( + cnode.with_profile(prof) + .with_precommand("echo '####' $(hostname)") + .extract( + command, + patterns=patterns, + ) + ) + node_name = results["node_name"] + + if port_pattern: + to_forward = int(results["port"]) + else: + to_forward = f"{remote.home()}/.milatools/sockets/{sock_name}.sock" + + if cf is not None: + remote.simple_run(f"echo program = {program} >> {cf}") + remote.simple_run(f"echo node_name = {results['node_name']} >> {cf}") + remote.simple_run(f"echo host = {host} >> {cf}") + remote.simple_run(f"echo to_forward = {to_forward} >> {cf}") + if token_pattern: + remote.simple_run(f"echo token = {results['token']} >> {cf}") + + assert results is not None + assert node_name is not None + assert to_forward is not None + assert proc is not None + if token_pattern: + options = {"token": results["token"]} + else: + options = {} + + local_proc, local_port = _forward( + local=LocalV1(), + node=get_hostname_to_use_for_compute_node(node_name, cluster="mila"), + to_forward=to_forward, + options=options, + port=port, + ) + + if cf is not None: + remote.simple_run(f"echo local_port = {local_port} >> {cf}") + + try: + local_proc.wait() + except KeyboardInterrupt: + qn.print("Terminated by user.") + if cf is not None: + name = Path(cf).name + qn.print("To reconnect to this server, use the command:") + qn.print(f" mila serve connect {name}", style="bold yellow") + qn.print("To kill this server, use the command:") + qn.print(f" mila serve kill {name}", style="bold red") + finally: + local_proc.kill() + proc.kill() + + +def _parse_lfs_quota_output( + lfs_quota_output: str, +) -> tuple[tuple[float, float], tuple[int, int]]: + """Parses space and # of files (usage, limit) from the output of `lfs quota`.""" + lines = lfs_quota_output.splitlines() + + header_line: str | None = None + header_line_index: int | None = None + for index, line in enumerate(lines): + if ( + len(line_parts := line.strip().split()) == 9 + and line_parts[0].lower() == "filesystem" + ): + header_line = line + header_line_index = index + break + assert header_line + assert header_line_index is not None + + values_line_parts: list[str] = [] + # The next line may overflow to two (or maybe even more?) lines if the name of the + # $HOME dir is too long. + for content_line in lines[header_line_index + 1 :]: + additional_values = content_line.strip().split() + assert len(values_line_parts) < 9 + values_line_parts.extend(additional_values) + if len(values_line_parts) == 9: + break + + assert len(values_line_parts) == 9, values_line_parts + ( + _filesystem, + used_kbytes, + _quota_kbytes, + limit_kbytes, + _grace_kbytes, + files, + _quota_files, + limit_files, + _grace_files, + ) = values_line_parts + + used_gb = int(used_kbytes.strip()) / (1024**2) + max_gb = int(limit_kbytes.strip()) / (1024**2) + used_files = int(files.strip()) + max_files = int(limit_files.strip()) + return (used_gb, max_gb), (used_files, max_files) + + +def check_disk_quota(remote: RemoteV1 | RemoteV2) -> None: + cluster = remote.hostname + + # NOTE: This is what the output of the command looks like on the Mila cluster: + # + # Disk quotas for usr normandf (uid 1471600598): + # Filesystem kbytes quota limit grace files quota limit grace + # /home/mila/n/normandf + # 95747836 0 104857600 - 908722 0 1048576 - + # uid 1471600598 is using default block quota setting + # uid 1471600598 is using default file quota setting + + # Need to assert this, otherwise .get_output calls .run which would spawn a job! + assert not isinstance(remote, SlurmRemote) + if not remote.get_output("which lfs", hide=True): + logger.debug("Cluster doesn't have the lfs command. Skipping check.") + return + + console.log("Checking disk quota on $HOME...") + + home_disk_quota_output = remote.get_output("lfs quota -u $USER $HOME", hide=True) + if "not on a mounted Lustre filesystem" in home_disk_quota_output: + logger.debug("Cluster doesn't use lustre on $HOME filesystem. Skipping check.") + return + + (used_gb, max_gb), (used_files, max_files) = _parse_lfs_quota_output( + home_disk_quota_output + ) + + def get_colour(used: float, max: float) -> str: + return "red" if used >= max else "orange" if used / max > 0.7 else "green" + + disk_usage_style = get_colour(used_gb, max_gb) + num_files_style = get_colour(used_files, max_files) + from rich.text import Text + + console.log( + "Disk usage:", + Text(f"{used_gb:.2f} / {max_gb:.2f} GiB", style=disk_usage_style), + "and", + Text(f"{used_files} / {max_files} files", style=num_files_style), + markup=False, + ) + size_ratio = used_gb / max_gb + files_ratio = used_files / max_files + reason = ( + f"{used_gb:.1f} / {max_gb} GiB" + if size_ratio > files_ratio + else f"{used_files} / {max_files} files" + ) + + freeing_up_space_instructions = ( + "For example, temporary files (logs, checkpoints, etc.) can be moved to " + "$SCRATCH, while files that need to be stored for longer periods can be moved " + "to $ARCHIVE or to a shared project folder under /network/projects.\n" + "Visit https://docs.mila.quebec/Information.html#storage to learn more about " + "how to best make use of the different filesystems available on the cluster." + ) + + if used_gb >= max_gb or used_files >= max_files: + raise MilatoolsUserError( + T.red( + f"ERROR: Your disk quota on the $HOME filesystem is exceeded! " + f"({reason}).\n" + f"To fix this, login to the cluster with `ssh {cluster}` and free up " + f"some space, either by deleting files, or by moving them to a " + f"suitable filesystem.\n" + freeing_up_space_instructions + ) + ) + if max(size_ratio, files_ratio) > 0.9: + warning_message = ( + f"You are getting pretty close to your disk quota on the $HOME " + f"filesystem: ({reason})\n" + "Please consider freeing up some space in your $HOME folder, either by " + "deleting files, or by moving them to a more suitable filesystem.\n" + + freeing_up_space_instructions + ) + logger.warning(UserWarning(warning_message)) + + +def _find_allocation( + remote: RemoteV1, + node: str | None, + job: str | None, + alloc: list[str], + cluster: Cluster = "mila", + job_name: str = "mila-tools", +): + if (node is not None) + (job is not None) + bool(alloc) > 1: + exit("ERROR: --node, --job and --alloc are mutually exclusive") + + if node is not None: + node_name = get_hostname_to_use_for_compute_node(node, cluster=cluster) + return RemoteV1( + node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster) + ) + + elif job is not None: + node_name = remote.get_output(f"squeue --jobs {job} -ho %N") + return RemoteV1( + node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster) + ) + + else: + alloc = ["-J", job_name, *alloc] + return SlurmRemote( + connection=remote.connection, + alloc=alloc, + hostname=remote.hostname, + ) + + +def _forward( + local: LocalV1, + node: str, + to_forward: int | str, + port: int | None, + page: str | None = None, + options: dict[str, str | None] = {}, + through_login: bool = False, +): + if port is None: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Find a free local port by binding to port 0 + sock.bind(("localhost", 0)) + _, port = sock.getsockname() + # Close it for ssh -L. It is *unlikely* it will not be available. + sock.close() + + if isinstance(to_forward, int) or re.match("[0-9]+", to_forward): + if through_login: + to_forward = f"{node}:{to_forward}" + args = [f"localhost:{port}:{to_forward}", "mila"] + else: + to_forward = f"localhost:{to_forward}" + args = [f"localhost:{port}:{to_forward}", node] + else: + args = [f"localhost:{port}:{to_forward}", node] + + proc = local.popen( + "ssh", + "-o", + "UserKnownHostsFile=/dev/null", + "-o", + "StrictHostKeyChecking=no", + "-nNL", + *args, + ) + + url = f"http://localhost:{port}" + if page is not None: + if not page.startswith("/"): + page = f"/{page}" + url += page + + options = {k: v for k, v in options.items() if v is not None} + if options: + url += f"?{urlencode(options)}" + + qn.print("Waiting for connection to be active...") + nsecs = 10 + period = 0.2 + for _ in range(int(nsecs / period)): + time.sleep(period) + try: + # This feels stupid, there's probably a better way + local.silent_get("nc", "-z", "localhost", str(port)) + except subprocess.CalledProcessError: + continue + except Exception: + break + break + + qn.print( + "Starting browser. You might need to refresh the page.", + style="bold", + ) + webbrowser.open(url) + return proc, port + + if __name__ == "__main__": main() diff --git a/milatools/cli/common.py b/milatools/cli/common.py deleted file mode 100644 index 39818891..00000000 --- a/milatools/cli/common.py +++ /dev/null @@ -1,426 +0,0 @@ -from __future__ import annotations - -import re -import socket -import subprocess -import time -import webbrowser -from contextlib import ExitStack -from logging import getLogger as get_logger -from pathlib import Path -from urllib.parse import urlencode - -import questionary as qn -from rich.text import Text - -from milatools.cli import console -from milatools.cli.profile import ensure_program, setup_profile -from milatools.cli.utils import ( - Cluster, - MilatoolsUserError, - T, - cluster_to_connect_kwargs, - get_hostname_to_use_for_compute_node, - randname, - with_control_file, -) -from milatools.utils.local_v1 import LocalV1 -from milatools.utils.remote_v1 import RemoteV1, SlurmRemote -from milatools.utils.remote_v2 import RemoteV2 - -logger = get_logger(__name__) - - -def _parse_lfs_quota_output( - lfs_quota_output: str, -) -> tuple[tuple[float, float], tuple[int, int]]: - """Parses space and # of files (usage, limit) from the output of `lfs quota`.""" - lines = lfs_quota_output.splitlines() - - header_line: str | None = None - header_line_index: int | None = None - for index, line in enumerate(lines): - if ( - len(line_parts := line.strip().split()) == 9 - and line_parts[0].lower() == "filesystem" - ): - header_line = line - header_line_index = index - break - assert header_line - assert header_line_index is not None - - values_line_parts: list[str] = [] - # The next line may overflow to two (or maybe even more?) lines if the name of the - # $HOME dir is too long. - for content_line in lines[header_line_index + 1 :]: - additional_values = content_line.strip().split() - assert len(values_line_parts) < 9 - values_line_parts.extend(additional_values) - if len(values_line_parts) == 9: - break - - assert len(values_line_parts) == 9, values_line_parts - ( - _filesystem, - used_kbytes, - _quota_kbytes, - limit_kbytes, - _grace_kbytes, - files, - _quota_files, - limit_files, - _grace_files, - ) = values_line_parts - - used_gb = int(used_kbytes.strip()) / (1024**2) - max_gb = int(limit_kbytes.strip()) / (1024**2) - used_files = int(files.strip()) - max_files = int(limit_files.strip()) - return (used_gb, max_gb), (used_files, max_files) - - -def check_disk_quota(remote: RemoteV1 | RemoteV2) -> None: - cluster = remote.hostname - - # NOTE: This is what the output of the command looks like on the Mila cluster: - # - # Disk quotas for usr normandf (uid 1471600598): - # Filesystem kbytes quota limit grace files quota limit grace - # /home/mila/n/normandf - # 95747836 0 104857600 - 908722 0 1048576 - - # uid 1471600598 is using default block quota setting - # uid 1471600598 is using default file quota setting - - # Need to assert this, otherwise .get_output calls .run which would spawn a job! - assert not isinstance(remote, SlurmRemote) - if not remote.get_output("which lfs", display=False, hide=True): - logger.debug("Cluster doesn't have the lfs command. Skipping check.") - return - - console.log("Checking disk quota on $HOME...") - - home_disk_quota_output = remote.get_output( - "lfs quota -u $USER $HOME", display=False, hide=True - ) - if "not on a mounted Lustre filesystem" in home_disk_quota_output: - logger.debug("Cluster doesn't use lustre on $HOME filesystem. Skipping check.") - return - - (used_gb, max_gb), (used_files, max_files) = _parse_lfs_quota_output( - home_disk_quota_output - ) - - def get_colour(used: float, max: float) -> str: - return "red" if used >= max else "orange" if used / max > 0.7 else "green" - - disk_usage_style = get_colour(used_gb, max_gb) - num_files_style = get_colour(used_files, max_files) - - console.log( - "Disk usage:", - Text(f"{used_gb:.2f} / {max_gb:.2f} GiB", style=disk_usage_style), - "and", - Text(f"{used_files} / {max_files} files", style=num_files_style), - markup=False, - ) - size_ratio = used_gb / max_gb - files_ratio = used_files / max_files - reason = ( - f"{used_gb:.1f} / {max_gb} GiB" - if size_ratio > files_ratio - else f"{used_files} / {max_files} files" - ) - - freeing_up_space_instructions = ( - "For example, temporary files (logs, checkpoints, etc.) can be moved to " - "$SCRATCH, while files that need to be stored for longer periods can be moved " - "to $ARCHIVE or to a shared project folder under /network/projects.\n" - "Visit https://docs.mila.quebec/Information.html#storage to learn more about " - "how to best make use of the different filesystems available on the cluster." - ) - - if used_gb >= max_gb or used_files >= max_files: - raise MilatoolsUserError( - T.red( - f"ERROR: Your disk quota on the $HOME filesystem is exceeded! " - f"({reason}).\n" - f"To fix this, login to the cluster with `ssh {cluster}` and free up " - f"some space, either by deleting files, or by moving them to a " - f"suitable filesystem.\n" + freeing_up_space_instructions - ) - ) - if max(size_ratio, files_ratio) > 0.9: - warning_message = ( - f"You are getting pretty close to your disk quota on the $HOME " - f"filesystem: ({reason})\n" - "Please consider freeing up some space in your $HOME folder, either by " - "deleting files, or by moving them to a more suitable filesystem.\n" - + freeing_up_space_instructions - ) - logger.warning(UserWarning(warning_message)) - - -def find_allocation( - remote: RemoteV1, - node: str | None, - job: int | None, - alloc: list[str], - cluster: Cluster = "mila", - job_name: str = "mila-tools", -): - if (node is not None) + (job is not None) + bool(alloc) > 1: - exit("ERROR: --node, --job and --alloc are mutually exclusive") - - if node is not None: - node_name = get_hostname_to_use_for_compute_node(node, cluster=cluster) - return RemoteV1( - node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster) - ) - - elif job is not None: - node_name = remote.get_output(f"squeue --jobs {job} -ho %N") - return RemoteV1( - node_name, connect_kwargs=cluster_to_connect_kwargs.get(cluster) - ) - - else: - alloc = ["-J", job_name, *alloc] - return SlurmRemote( - connection=remote.connection, - alloc=alloc, - hostname=remote.hostname, - ) - - -def forward( - local: LocalV1, - node: str, - to_forward: int | str, - port: int | None, - page: str | None = None, - options: dict[str, str | None] = {}, - through_login: bool = False, -): - if port is None: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - # Find a free local port by binding to port 0 - sock.bind(("localhost", 0)) - _, port = sock.getsockname() - # Close it for ssh -L. It is *unlikely* it will not be available. - sock.close() - - if isinstance(to_forward, int) or re.match("[0-9]+", to_forward): - if through_login: - to_forward = f"{node}:{to_forward}" - args = [f"localhost:{port}:{to_forward}", "mila"] - else: - to_forward = f"localhost:{to_forward}" - args = [f"localhost:{port}:{to_forward}", node] - else: - args = [f"localhost:{port}:{to_forward}", node] - - proc = local.popen( - "ssh", - "-o", - "UserKnownHostsFile=/dev/null", - "-o", - "StrictHostKeyChecking=no", - "-nNL", - *args, - ) - - url = f"http://localhost:{port}" - if page is not None: - if not page.startswith("/"): - page = f"/{page}" - url += page - - options = {k: v for k, v in options.items() if v is not None} - if options: - url += f"?{urlencode(options)}" - - qn.print("Waiting for connection to be active...") - nsecs = 10 - period = 0.2 - for _ in range(int(nsecs / period)): - time.sleep(period) - try: - # This feels stupid, there's probably a better way - local.silent_get("nc", "-z", "localhost", str(port)) - except subprocess.CalledProcessError: - continue - except Exception: - break - break - - qn.print( - "Starting browser. You might need to refresh the page.", - style="bold", - ) - webbrowser.open(url) - return proc, port - - -def standard_server( - path: str | None, - *, - program: str, - installers: dict[str, str], - command: str, - profile: str | None, - persist: bool, - port: int | None, - name: str | None, - node: str | None, - job: int | None, - alloc: list[str], - port_pattern=None, - token_pattern=None, -): - # Make the server visible from the login node (other users will be able to connect) - # Temporarily disabled - share = False - - if name is not None: - persist = True - elif persist: - name = program - - remote = RemoteV1("mila") - - path = path or "~" - if path == "~" or path.startswith("~/"): - path = remote.home() + path[1:] - - results: dict | None = None - node_name: str | None = None - to_forward: int | str | None = None - cf: str | None = None - proc = None - with ExitStack() as stack: - if persist: - cf = stack.enter_context(with_control_file(remote, name=name)) - else: - cf = None - - if profile: - prof = f"~/.milatools/profiles/{profile}.bash" - else: - prof = setup_profile(remote, path) - - qn.print(f"Using profile: {prof}") - cat_result = remote.run(f"cat {prof}", hide=True, warn=True) - if cat_result.ok: - qn.print("=" * 50) - qn.print(cat_result.stdout.rstrip()) - qn.print("=" * 50) - else: - exit(f"Could not find or load profile: {prof}") - - premote = remote.with_profile(prof) - - if not ensure_program( - remote=premote, - program=program, - installers=installers, - ): - exit(f"Exit: {program} is not installed.") - - cnode = find_allocation( - remote, - job_name=f"mila-serve-{program}", - node=node, - job=job, - alloc=alloc, - cluster="mila", - ) - - patterns = { - "node_name": "#### ([A-Za-z0-9_-]+)", - } - - if port_pattern: - patterns["port"] = port_pattern - elif share: - exit( - "Server cannot be shared because it is serving over a Unix domain " - "socket" - ) - else: - remote.run("mkdir -p ~/.milatools/sockets", hide=True) - - if share: - host = "0.0.0.0" - else: - host = "localhost" - - sock_name = name or randname() - command = command.format( - path=path, - sock=f"~/.milatools/sockets/{sock_name}.sock", - host=host, - ) - - if token_pattern: - patterns["token"] = token_pattern - - if persist: - cnode = cnode.persist() - - proc, results = ( - cnode.with_profile(prof) - .with_precommand("echo '####' $(hostname)") - .extract( - command, - patterns=patterns, - ) - ) - node_name = results["node_name"] - - if port_pattern: - to_forward = int(results["port"]) - else: - to_forward = f"{remote.home()}/.milatools/sockets/{sock_name}.sock" - - if cf is not None: - remote.simple_run(f"echo program = {program} >> {cf}") - remote.simple_run(f"echo node_name = {results['node_name']} >> {cf}") - remote.simple_run(f"echo host = {host} >> {cf}") - remote.simple_run(f"echo to_forward = {to_forward} >> {cf}") - if token_pattern: - remote.simple_run(f"echo token = {results['token']} >> {cf}") - - assert results is not None - assert node_name is not None - assert to_forward is not None - assert proc is not None - if token_pattern: - options = {"token": results["token"]} - else: - options = {} - - local_proc, local_port = forward( - local=LocalV1(), - node=get_hostname_to_use_for_compute_node(node_name, cluster="mila"), - to_forward=to_forward, - options=options, - port=port, - ) - - if cf is not None: - remote.simple_run(f"echo local_port = {local_port} >> {cf}") - - try: - local_proc.wait() - except KeyboardInterrupt: - qn.print("Terminated by user.") - if cf is not None: - name = Path(cf).name - qn.print("To reconnect to this server, use the command:") - qn.print(f" mila serve connect {name}", style="bold yellow") - qn.print("To kill this server, use the command:") - qn.print(f" mila serve kill {name}", style="bold red") - finally: - local_proc.kill() - proc.kill() diff --git a/milatools/utils/compute_node.py b/milatools/utils/compute_node.py index 36b88cc8..116b4416 100644 --- a/milatools/utils/compute_node.py +++ b/milatools/utils/compute_node.py @@ -4,7 +4,6 @@ import contextlib import dataclasses import datetime -import inspect import re import shlex import subprocess @@ -53,6 +52,8 @@ class ComputeNode(Runner): salloc_subprocess: asyncio.subprocess.Process | None = None """A handle to the subprocess that is running the `salloc` command.""" + hostname: str = dataclasses.field(init=False) + _closed: bool = dataclasses.field(default=False, init=False, repr=False) def __post_init__(self): @@ -181,13 +182,6 @@ async def close_async(self): ) self._closed = True - def __repr__(self) -> str: - params = ", ".join( - f"{k}={repr(getattr(self, k))}" - for k in inspect.signature(type(self)).parameters - ) - return f"{type(self).__name__}({params})" - async def get_queued_milatools_job_ids( login_node: RemoteV2, job_name="mila-code" diff --git a/tests/cli/test_commands.py b/tests/cli/test_commands.py index 20e2e48f..98e3a40a 100644 --- a/tests/cli/test_commands.py +++ b/tests/cli/test_commands.py @@ -8,8 +8,7 @@ import pytest from pytest_regressions.file_regression import FileRegressionFixture -from milatools.cli.commands import main -from milatools.cli.common import _parse_lfs_quota_output +from milatools.cli.commands import _parse_lfs_quota_output, main from .common import requires_no_s_flag diff --git a/tests/integration/test_code_command.py b/tests/integration/test_code_command.py index d3727cb1..35caae01 100644 --- a/tests/integration/test_code_command.py +++ b/tests/integration/test_code_command.py @@ -9,8 +9,7 @@ import pytest -from milatools.cli.code_command import code -from milatools.cli.common import check_disk_quota +from milatools.cli.commands import check_disk_quota, code from milatools.cli.utils import get_hostname_to_use_for_compute_node from milatools.utils.remote_v1 import RemoteV1 from milatools.utils.remote_v2 import RemoteV2