Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new mila run command #114

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@
from typing_extensions import TypedDict

from milatools.cli import console
from milatools.cli.login import login
from milatools.cli.run import run_command
from milatools.utils.local_v1 import LocalV1
from milatools.utils.remote_v1 import RemoteV1, SlurmRemote
from milatools.utils.remote_v2 import RemoteV2
from milatools.utils.remote_v2 import SSH_CONFIG_FILE, RemoteV2
from milatools.utils.vscode_utils import (
get_code_command,
# install_local_vscode_extensions_on_remote,
sync_vscode_extensions,
sync_vscode_extensions_with_hostnames,
)
Expand Down Expand Up @@ -170,6 +171,28 @@ def mila():

init_parser.set_defaults(function=init)

# ----- mila login ------
login_parser = subparsers.add_parser(
"login",
help="Sets up reusable SSH connections to the entries of the SSH config.",
formatter_class=SortingHelpFormatter,
)
login_parser.add_argument("--ssh_config_path", type=Path, default=SSH_CONFIG_FILE)
login_parser.set_defaults(function=login)

# ----- mila run ------
run_parser = subparsers.add_parser(
"run",
help="Runs a command over SSH on all the slurm clusters in the SSH config.",
formatter_class=SortingHelpFormatter,
)
run_parser.add_argument("--ssh_config_path", type=Path, default=SSH_CONFIG_FILE)
run_parser.add_argument("--show-table", action="store_true", default=False)
run_parser.add_argument(
"command", type=str, nargs=argparse.REMAINDER, help="The command to run."
)
run_parser.set_defaults(function=run_command)

# ----- mila forward ------

forward_parser = subparsers.add_parser(
Expand Down
50 changes: 50 additions & 0 deletions milatools/cli/login.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations

import asyncio
from pathlib import Path

from paramiko import SSHConfig

from milatools.cli import console
from milatools.cli.utils import CLUSTERS
from milatools.utils.remote_v2 import SSH_CONFIG_FILE, RemoteV2


async def login(
ssh_config_path: Path = SSH_CONFIG_FILE,
) -> list[RemoteV2]:
"""Logs in and sets up reusable SSH connections to all the hosts in the SSH config.

Returns the list of remotes where the connection was successfully established.
"""
ssh_config = SSHConfig.from_path(str(ssh_config_path.expanduser()))
potential_clusters = [
host
for host in ssh_config.get_hostnames()
if not any(c in host for c in ["*", "?", "!"])
]
potential_clusters = [
hostname
for hostname in potential_clusters
if hostname in CLUSTERS
# TODO: make this more generic with something like this:
# take out entries like `mila-cpu` that have a proxy and remote command.
if not (
(config := ssh_config.lookup(hostname)).get("proxycommand")
and config.get("remotecommand")
)
]
remotes = await asyncio.gather(
*(
RemoteV2.connect(hostname, ssh_config_path=ssh_config_path)
for hostname in potential_clusters
),
return_exceptions=True,
)
remotes = [remote for remote in remotes if isinstance(remote, RemoteV2)]
console.log(f"Successfully connected to {[remote.hostname for remote in remotes]}")
return remotes


if __name__ == "__main__":
asyncio.run(login())
117 changes: 117 additions & 0 deletions milatools/cli/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from __future__ import annotations

import asyncio
import shlex
import subprocess
import sys
from pathlib import Path

import rich
import rich.columns
import rich.live
import rich.table
import rich.text

from milatools.cli import console
from milatools.cli.login import login
from milatools.cli.utils import SSH_CONFIG_FILE
from milatools.utils.remote_v2 import RemoteV2


async def run_command(
command: str | list[str],
ssh_config_path: Path = SSH_CONFIG_FILE,
show_table: bool = False,
):
command = shlex.join(command) if isinstance(command, list) else command
if command.startswith("'") and command.endswith("'"):
# NOTE: Need to remove leading and trailing quotes so the ssh subprocess doesn't
# give an error. For example, with `mila run 'echo $SCRATCH'`, we would
# otherwise get the error: bash: echo: command not found
command = command[1:-1]

remotes = await login(ssh_config_path=ssh_config_path)

async def _is_slurm_cluster(remote: RemoteV2) -> bool:
sbatch_path = await remote.get_output_async(
"which sbatch", warn=True, hide=True, display=False
)
return bool(sbatch_path)

is_slurm_cluster = await asyncio.gather(
*(_is_slurm_cluster(remote) for remote in remotes),
)
cluster_login_nodes = [
remote
for remote, is_slurm_cluster in zip(remotes, is_slurm_cluster)
if is_slurm_cluster
]

results = await asyncio.gather(
*(
login_node.run_async(command=command, warn=True, display=True, hide=True)
for login_node in cluster_login_nodes
)
)
if show_table:
_print_with_table(command, cluster_login_nodes, results)
else:
_print_with_prefix(command, cluster_login_nodes, results)
return results


def _print_with_prefix(
command: str,
cluster_login_nodes: list[RemoteV2],
results: list[subprocess.CompletedProcess[str]],
):
for remote, result in zip(cluster_login_nodes, results):
for line in result.stdout.splitlines():
console.print(f"[bold]({remote.hostname})[/bold] {line}", markup=True)
for line in result.stderr.splitlines():
print(f"({remote.hostname}) {line}", file=sys.stderr)

# return results


def _print_with_table(
command: str,
cluster_login_nodes: list[RemoteV2],
results: list[subprocess.CompletedProcess[str]],
):
table = rich.table.Table(title=command)
table.add_column("Cluster")

# need an stdout column.
need_stdout_column = any(result.stdout for result in results)
need_stderr_column = any(result.stderr for result in results)

if not need_stderr_column and not need_stdout_column:
return results

if need_stdout_column:
table.add_column("stdout")
if need_stderr_column:
table.add_column("stderr")

for remote, result in zip(cluster_login_nodes, results):
row = [remote.hostname]
if need_stdout_column:
row.append(result.stdout)
if need_stderr_column:
row.append(result.stderr)
table.add_row(*row, end_section=True)

console.print(table)
# table = rich.table.Table(title=command)
# with rich.live.Live(table, refresh_per_second=1):

# async with asyncio.TaskGroup() as group:
# for remote in remotes:
# table.add_column(remote.hostname, no_wrap=True)
# task = group.create_task(remote.run_async(command))
# task.add_done_callback(lambda _: table.add_row())


if __name__ == "main":
asyncio.run(run_command("hostname"))
6 changes: 4 additions & 2 deletions tests/cli/test_commands/test_help_mila_.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
usage: mila [-h] [--version] [-v]
{docs,intranet,init,forward,code,sync,serve} ...
{docs,intranet,init,login,forward,code,sync,serve} ...

Tools to connect to and interact with the Mila cluster. Cluster documentation:
https://docs.mila.quebec/

positional arguments:
{docs,intranet,init,forward,code,sync,serve}
{docs,intranet,init,login,forward,code,sync,serve}
docs Open the Mila cluster documentation.
intranet Open the Mila intranet in a browser.
init Set up your configuration and credentials.
login Sets up reusable SSH connections to the entries of the
SSH config.
forward Forward a port on a compute node to your local
machine.
code Open a remote VSCode session on a compute node.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
usage: mila [-h] [--version] [-v]
{docs,intranet,init,forward,code,sync,serve} ...
{docs,intranet,init,login,forward,code,sync,serve} ...
mila: error: the following arguments are required: <command>
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
usage: mila [-h] [--version] [-v]
{docs,intranet,init,forward,code,sync,serve} ...
mila: error: argument <command>: invalid choice: 'search' (choose from 'docs', 'intranet', 'init', 'forward', 'code', 'sync', 'serve')
{docs,intranet,init,login,forward,code,sync,serve} ...
mila: error: argument <command>: invalid choice: 'search' (choose from 'docs', 'intranet', 'init', 'login', 'forward', 'code', 'sync', 'serve')
38 changes: 38 additions & 0 deletions tests/cli/test_login.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import textwrap
from logging import getLogger as get_logger
from pathlib import Path

import pytest

from milatools.cli.login import login
from milatools.utils.remote_v2 import SSH_CACHE_DIR, RemoteV2

from .common import requires_ssh_to_localhost

logger = get_logger(__name__)


@requires_ssh_to_localhost
@pytest.mark.asyncio
async def test_login(tmp_path: Path): # ssh_config_file: Path):
assert SSH_CACHE_DIR.exists()
ssh_config_path = tmp_path / "ssh_config"
ssh_config_path.write_text(
textwrap.dedent(
"""\
Host foo
hostname localhost
Host bar
hostname localhost
"""
)
+ "\n"
)

# Should create a connection to every host in the ssh config file.
remotes = await login(ssh_config_path=ssh_config_path)
assert all(isinstance(remote, RemoteV2) for remote in remotes)
assert set(remote.hostname for remote in remotes) == {"foo", "bar"}
for remote in remotes:
logger.info(f"Removing control socket at {remote.control_path}")
remote.control_path.unlink()
Loading