From 6b0618cd291a0234a6f64f116b1bcbb67f4f0046 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 11 Apr 2024 11:02:58 -0400 Subject: [PATCH] Add new `mila login` command Signed-off-by: Fabrice Normandin --- milatools/cli/commands.py | 12 +++++ milatools/cli/login.py | 47 +++++++++++++++++++ tests/cli/test_commands/test_help_mila_.txt | 6 ++- .../test_invalid_command_output_mila_.txt | 2 +- ...alid_command_output_mila_search_conda_.txt | 4 +- tests/cli/test_login.py | 38 +++++++++++++++ 6 files changed, 104 insertions(+), 5 deletions(-) create mode 100644 milatools/cli/login.py create mode 100644 tests/cli/test_login.py diff --git a/milatools/cli/commands.py b/milatools/cli/commands.py index 52fd70f2..c1e191e7 100644 --- a/milatools/cli/commands.py +++ b/milatools/cli/commands.py @@ -16,6 +16,7 @@ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser from collections.abc import Sequence from logging import getLogger as get_logger +from pathlib import Path from typing import Any from urllib.parse import urlencode @@ -23,6 +24,8 @@ import rich.logging from typing_extensions import TypedDict +from milatools.cli.login import login +from milatools.utils.remote_v2 import SSH_CONFIG_FILE from milatools.utils.vscode_utils import ( sync_vscode_extensions_with_hostnames, ) @@ -150,6 +153,15 @@ def mila(): init_parser.set_defaults(function=init) + # ----- mila login ------ + login_parser = subparsers.add_parser( + "login", + help="Sets up reusable SSH connections to the entries of the SSH config.", + formatter_class=SortingHelpFormatter, + ) + login_parser.add_argument("--ssh_config_path", type=Path, default=SSH_CONFIG_FILE) + login_parser.set_defaults(function=login) + # ----- mila forward ------ forward_parser = subparsers.add_parser( diff --git a/milatools/cli/login.py b/milatools/cli/login.py new file mode 100644 index 00000000..9a5bfcc2 --- /dev/null +++ b/milatools/cli/login.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path + +from paramiko import SSHConfig + +from milatools.cli import console +from milatools.utils.remote_v2 import SSH_CONFIG_FILE, RemoteV2 + + +async def login( + ssh_config_path: Path = SSH_CONFIG_FILE, +) -> list[RemoteV2]: + """Logs in and sets up reusable SSH connections to all the hosts in the SSH config. + + Returns the list of remotes where the connection was successfully established. + """ + ssh_config = SSHConfig.from_path(str(ssh_config_path.expanduser())) + potential_clusters = [ + host + for host in ssh_config.get_hostnames() + if not any(c in host for c in ["*", "?", "!"]) + ] + # take out entries like `mila-cpu` that have a proxy and remote command. + potential_clusters = [ + hostname + for hostname in potential_clusters + if not ( + (config := ssh_config.lookup(hostname)).get("proxycommand") + and config.get("remotecommand") + ) + ] + remotes = await asyncio.gather( + *( + RemoteV2.connect(hostname, ssh_config_path=ssh_config_path) + for hostname in potential_clusters + ), + return_exceptions=True, + ) + remotes = [remote for remote in remotes if isinstance(remote, RemoteV2)] + console.log(f"Successfully connected to {[remote.hostname for remote in remotes]}") + return remotes + + +if __name__ == "__main__": + asyncio.run(login()) diff --git a/tests/cli/test_commands/test_help_mila_.txt b/tests/cli/test_commands/test_help_mila_.txt index 88b5da3e..9b09c790 100644 --- a/tests/cli/test_commands/test_help_mila_.txt +++ b/tests/cli/test_commands/test_help_mila_.txt @@ -1,14 +1,16 @@ usage: mila [-h] [--version] [-v] - {docs,intranet,init,forward,code,sync,serve} ... + {docs,intranet,init,login,forward,code,sync,serve} ... Tools to connect to and interact with the Mila cluster. Cluster documentation: https://docs.mila.quebec/ positional arguments: - {docs,intranet,init,forward,code,sync,serve} + {docs,intranet,init,login,forward,code,sync,serve} docs Open the Mila cluster documentation. intranet Open the Mila intranet in a browser. init Set up your configuration and credentials. + login Sets up reusable SSH connections to the entries of the + SSH config. forward Forward a port on a compute node to your local machine. code Open a remote VSCode session on a compute node. diff --git a/tests/cli/test_commands/test_invalid_command_output_mila_.txt b/tests/cli/test_commands/test_invalid_command_output_mila_.txt index fa5c4b84..3ebe3e53 100644 --- a/tests/cli/test_commands/test_invalid_command_output_mila_.txt +++ b/tests/cli/test_commands/test_invalid_command_output_mila_.txt @@ -1,3 +1,3 @@ usage: mila [-h] [--version] [-v] - {docs,intranet,init,forward,code,sync,serve} ... + {docs,intranet,init,login,forward,code,sync,serve} ... mila: error: the following arguments are required: diff --git a/tests/cli/test_commands/test_invalid_command_output_mila_search_conda_.txt b/tests/cli/test_commands/test_invalid_command_output_mila_search_conda_.txt index 725b21ac..fa96195d 100644 --- a/tests/cli/test_commands/test_invalid_command_output_mila_search_conda_.txt +++ b/tests/cli/test_commands/test_invalid_command_output_mila_search_conda_.txt @@ -1,3 +1,3 @@ usage: mila [-h] [--version] [-v] - {docs,intranet,init,forward,code,sync,serve} ... -mila: error: argument : invalid choice: 'search' (choose from 'docs', 'intranet', 'init', 'forward', 'code', 'sync', 'serve') + {docs,intranet,init,login,forward,code,sync,serve} ... +mila: error: argument : invalid choice: 'search' (choose from 'docs', 'intranet', 'init', 'login', 'forward', 'code', 'sync', 'serve') diff --git a/tests/cli/test_login.py b/tests/cli/test_login.py new file mode 100644 index 00000000..b0a3f5eb --- /dev/null +++ b/tests/cli/test_login.py @@ -0,0 +1,38 @@ +import textwrap +from logging import getLogger as get_logger +from pathlib import Path + +import pytest + +from milatools.cli.login import login +from milatools.utils.remote_v2 import SSH_CACHE_DIR, RemoteV2 + +from .common import requires_ssh_to_localhost + +logger = get_logger(__name__) + + +@requires_ssh_to_localhost +@pytest.mark.asyncio +async def test_login(tmp_path: Path): # ssh_config_file: Path): + assert SSH_CACHE_DIR.exists() + ssh_config_path = tmp_path / "ssh_config" + ssh_config_path.write_text( + textwrap.dedent( + """\ + Host foo + hostname localhost + Host bar + hostname localhost + """ + ) + + "\n" + ) + + # Should create a connection to every host in the ssh config file. + remotes = await login(ssh_config_path=ssh_config_path) + assert all(isinstance(remote, RemoteV2) for remote in remotes) + assert set(remote.hostname for remote in remotes) == {"foo", "bar"} + for remote in remotes: + logger.info(f"Removing control socket at {remote.control_path}") + remote.control_path.unlink()