diff --git a/milatools/cli/commands.py b/milatools/cli/commands.py index 89c70259..b94aa2d8 100644 --- a/milatools/cli/commands.py +++ b/milatools/cli/commands.py @@ -5,6 +5,7 @@ from __future__ import annotations import argparse +import functools import operator import os import re @@ -15,6 +16,7 @@ import time import traceback import typing +import warnings import webbrowser from argparse import ArgumentParser, _HelpAction from contextlib import ExitStack @@ -28,7 +30,7 @@ from typing_extensions import TypedDict from ..version import version as mversion -from .init_command import setup_ssh_config +from .init_command import setup_ssh_config, setup_windows_ssh_config_from_wsl from .local import Local from .profile import ensure_program, setup_profile from .remote import Remote, SlurmRemote @@ -40,6 +42,7 @@ get_fully_qualified_name, qualified, randname, + running_inside_WSL, with_control_file, yn, ) @@ -50,8 +53,9 @@ def main(): - on_mila = get_fully_qualified_name().endswith(".server.mila.quebec") - if on_mila: + if sys.platform != "win32" and get_fully_qualified_name().endswith( + ".server.mila.quebec" + ): exit( "ERROR: 'mila ...' should be run on your local machine and not on the Mila cluster" ) @@ -386,7 +390,7 @@ def init(): print("Checking ssh config") - setup_ssh_config() + ssh_config = setup_ssh_config() # TODO: Move the rest of this command to functions in the init_command module, # so they can more easily be tested. @@ -408,6 +412,7 @@ def init(): for entry in os.listdir(sshdir) ): if yn("You have no public keys. Generate one?"): + # TODO: need to get the location of the key as an output of this command! here.run("ssh-keygen") else: exit("No public keys.") @@ -418,7 +423,15 @@ def init(): if yn( "Your public key does not appear be registered on the cluster. Register it?" ): - here.run("ssh-copy-id", "mila") + # NOTE: If we're on a Windows machine, we do something different here: + if sys.platform == "win32": + command = ( + "powershell.exe type $env:USERPROFILE\\.ssh\\id_rsa.pub | ssh mila " + '"cat >> ~/.ssh/authorized_keys"' + ) + here.run(command) + else: + here.run("ssh-copy-id", "mila") if not here.check_passwordless("mila"): exit("ssh-copy-id appears to have failed") else: @@ -461,6 +474,13 @@ def init(): else: exit("You will not be able to SSH to a compute node") + # TODO: IF we're running on WSL, we could probably actually just copy the + # id_rsa.pub and the config to the Windows paths (taking care to remove the + # ControlMaster-related entries) so that the user doesn't need to install Python on + # the Windows side. + if running_inside_WSL(): + setup_windows_ssh_config_from_wsl(linux_ssh_config=ssh_config) + ################### # Welcome message # ################### @@ -566,15 +586,31 @@ def code( command_path = shutil.which(command) if not command_path: raise CommandNotFoundError(command) + qualified_node_name = qualified(node_name) + + # Try to detect if this is being run from within the Windows Subsystem for Linux. + # If so, then we run `code` through a powershell.exe command to open VSCode without + # issues. + inside_WSL = running_inside_WSL() try: while True: - here.run( - command_path, - "-nw", - "--remote", - f"ssh-remote+{qualified(node_name)}", - path, - ) + if inside_WSL: + here.run( + "powershell.exe", + "code", + "-nw", + "--remote", + f"ssh-remote+{qualified_node_name}", + path, + ) + else: + here.run( + command_path, + "-nw", + "--remote", + f"ssh-remote+{qualified_node_name}", + path, + ) print( "The editor was closed. Reopen it with " " or terminate the process with " diff --git a/milatools/cli/init_command.py b/milatools/cli/init_command.py index 56905c4d..3e55acfa 100644 --- a/milatools/cli/init_command.py +++ b/milatools/cli/init_command.py @@ -1,19 +1,29 @@ from __future__ import annotations import difflib +import shutil +import subprocess +import sys +import warnings from logging import getLogger as get_logger from pathlib import Path +from typing import Any import questionary as qn -from .utils import SSHConfig, T, yn +from .utils import SSHConfig, T, running_inside_WSL, yn + +WINDOWS_UNSUPPORTED_KEYS = ["ControlMaster", "ControlPath", "ControlPersist"] logger = get_logger(__name__) +HOSTS = ["mila", "mila-cpu", "*.server.mila.quebec !*login.server.mila.quebec"] +"""List of host entries that get added to the SSH configurtion by `mila init`.""" + def setup_ssh_config( ssh_config_path: str | Path = "~/.ssh/config", -): +) -> SSHConfig: """Interactively sets up some useful entries in the ~/.ssh/config file on the local machine. Exits if the User cancels any of the prompts or doesn't confirm the changes when asked. @@ -27,6 +37,9 @@ def setup_ssh_config( directly to compute nodes. TODO: Also ask if we should add entries for the ComputeCanada/DRAC clusters. + + Returns: + The resulting SSHConfig if the changes are approved. """ ssh_config_path = _setup_ssh_config_file(ssh_config_path) @@ -39,21 +52,28 @@ def setup_ssh_config( # sure that the directory actually exists. control_path_dir.expanduser().mkdir(exist_ok=True, parents=True) + if sys.platform == "win32": + ssh_multiplexing_config = {} + else: + ssh_multiplexing_config = { + # Tries to reuse an existing connection, but if it fails, it will create a new one. + "ControlMaster": "auto", + # This makes a file per connection, like normandf@login.server.mila.quebec:2222 + "ControlPath": str(control_path_dir / r"%r@%h:%p"), + # persist for 10 minutes after the last connection ends. + "ControlPersist": 600, + } + _add_ssh_entry( ssh_config, - "mila", + host="mila", HostName="login.server.mila.quebec", User=username, PreferredAuthentications="publickey,keyboard-interactive", Port=2222, ServerAliveInterval=120, ServerAliveCountMax=5, - # Tries to reuse an existing connection, but if it fails, it will create a new one. - ControlMaster="auto", - # This makes a file per connection, like normandf@login.server.mila.quebec:2222 - ControlPath=str(control_path_dir / r"%r@%h:%p"), - # persist for 10 minutes after the last connection ends. - ControlPersist=600, + **ssh_multiplexing_config, ) _add_ssh_entry( @@ -97,12 +117,7 @@ def setup_ssh_config( HostName="%h", User=username, ProxyJump="mila", - # Tries to reuse an existing connection, but if it fails, it will create a new one. - ControlMaster="auto", - # This makes a file per connection, like normandf@login.server.mila.quebec:2222 - ControlPath=str(control_path_dir / r"%r@%h:%p"), - # persist for 10 minutes after the last connection ends. - ControlPersist=600, + **ssh_multiplexing_config, ) new_config = ssh_config.cfg.config() @@ -113,6 +128,76 @@ def setup_ssh_config( else: ssh_config.save() print(f"Wrote {ssh_config_path}") + return ssh_config + + +def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfig): + """Setup the Windows SSH configuration and public key from within WSL. + + This copies over the entries from the linux ssh configuration file, except for the + values that aren't supported on Windows (e.g. "ControlMaster"). + + This also copies the public key file from the linux SSH directory over to the + Windows SSH directory if it isn't already present. + + This makes it so the user doesn't need to install Python/Anaconda on the Windows + side in order to use `mila code` from within WSL. + """ + assert running_inside_WSL() + # NOTE: This also assumes that a public/private key pair has already been generated + # at ~/.ssh/id_rsa.pub and ~/.ssh/id_rsa. + windows_home = get_windows_home_path_in_wsl() + windows_ssh_config_path = windows_home / ".ssh/config" + windows_ssh_config_path = _setup_ssh_config_file(windows_ssh_config_path) + + windows_ssh_config = SSHConfig(windows_ssh_config_path) + + initial_windows_config_contents = windows_ssh_config.cfg.config() + _copy_valid_ssh_entries_to_windows_ssh_config_file( + linux_ssh_config, windows_ssh_config + ) + new_windows_config_contents = windows_ssh_config.cfg.config() + + if ( + new_windows_config_contents != initial_windows_config_contents + and _confirm_changes(windows_ssh_config, initial_windows_config_contents) + ): + # We made changes and they were accepted. + windows_ssh_config.save() + else: + print(f"Did not change ssh config at path {windows_ssh_config.path}") + return # also skip copying the SSH keys. + + # Copy the SSH key to the windows folder so that passwordless SSH also works on + # Windows. + # TODO: This will need to change if we support using a non-default location at some + # point. + linux_private_key_file = Path.home() / ".ssh/id_rsa" + windows_private_key_file = windows_home / ".ssh/id_rsa" + + for linux_key_file, windows_key_file in [ + (linux_private_key_file, windows_private_key_file), + ( + linux_private_key_file.with_suffix(".pub"), + windows_private_key_file.with_suffix(".pub"), + ), + ]: + _copy_if_needed(linux_key_file, windows_key_file) + + +def _copy_if_needed(linux_key_file: Path, windows_key_file: Path): + if linux_key_file.exists() and not windows_key_file.exists(): + print( + f"Copying {linux_key_file} over to the Windows ssh folder at " + f"{windows_key_file}." + ) + shutil.copy2(src=linux_key_file, dst=windows_key_file) + + +def get_windows_home_path_in_wsl() -> Path: + assert running_inside_WSL() + windows_username = subprocess.getoutput("powershell.exe '$env:UserName'").strip() + return Path(f"/mnt/c/Users/{windows_username}") def _setup_ssh_config_file(config_file_path: str | Path) -> Path: @@ -147,7 +232,12 @@ def _setup_ssh_config_file(config_file_path: str | Path) -> Path: def _confirm_changes(ssh_config: SSHConfig, previous: str) -> bool: - print(T.bold("The following modifications will be made to your ~/.ssh/config:\n")) + print( + T.bold( + f"The following modifications will be made to your SSH config file at " + f"{ssh_config.path}:\n" + ) + ) diff_lines = list( difflib.unified_diff( (previous + "\n").splitlines(True), @@ -206,6 +296,7 @@ def _add_ssh_entry( ssh_config: SSHConfig, host: str, Host: str | None = None, + *, _space_before: bool = True, _space_after: bool = False, **entry, @@ -223,7 +314,7 @@ def _add_ssh_entry( existing_entry = ssh_config.host(host) existing_entry.update(entry) ssh_config.cfg.set(host, **existing_entry) - logger.debug(f"Updated {host} entry in ssh config.") + logger.debug(f"Updated {host} entry in ssh config at path {ssh_config.path}.") else: ssh_config.add( host, @@ -231,4 +322,34 @@ def _add_ssh_entry( _space_after=_space_after, **entry, ) - logger.debug(f"Adding new {host} entry in ssh config.") + logger.debug( + f"Adding new {host} entry in ssh config at path {ssh_config.path}." + ) + + +def _copy_valid_ssh_entries_to_windows_ssh_config_file( + linux_ssh_config: SSHConfig, windows_ssh_config: SSHConfig +): + unsupported_keys_lowercase = set(k.lower() for k in WINDOWS_UNSUPPORTED_KEYS) + + # NOTE: need to preserve the ordering of entries: + for host in HOSTS + [ + host for host in linux_ssh_config.hosts() if host not in HOSTS + ]: + if host not in linux_ssh_config.hosts(): + warnings.warn( + RuntimeWarning( + f"Weird, we expected to have a {host!r} entry in the SSH config..." + ) + ) + continue + linux_ssh_entry: dict[str, Any] = linux_ssh_config.host(host) + _add_ssh_entry( + windows_ssh_config, + host, + **{ + key: value + for key, value in linux_ssh_entry.items() + if key.lower() not in unsupported_keys_lowercase + }, + ) diff --git a/milatools/cli/utils.py b/milatools/cli/utils.py index 4c0bdc65..3efd2ad1 100644 --- a/milatools/cli/utils.py +++ b/milatools/cli/utils.py @@ -1,11 +1,14 @@ from __future__ import annotations import contextvars +import functools import itertools import random import shlex +import shutil import socket import subprocess +import sys from contextlib import contextmanager from pathlib import Path @@ -140,6 +143,7 @@ class SSHConfig: """Wrapper around sshconf with some extra niceties.""" def __init__(self, path: str | Path): + self.path = path self.cfg = read_ssh_config(path) # self.add = self.cfg.add self.remove = self.cfg.remove @@ -220,3 +224,8 @@ def get_fully_qualified_name() -> str: except Exception: # Fall back, e.g. on Windows. return socket.getfqdn() + + +@functools.lru_cache() +def running_inside_WSL() -> bool: + return sys.platform == "linux" and bool(shutil.which("powershell.exe")) diff --git a/tests/cli/test_init_command.py b/tests/cli/test_init_command.py index 91f309f6..1dab0b94 100644 --- a/tests/cli/test_init_command.py +++ b/tests/cli/test_init_command.py @@ -4,18 +4,22 @@ import textwrap from functools import partial from pathlib import Path +from unittest.mock import Mock import pytest import questionary from prompt_toolkit.input import PipeInput, create_pipe_input from pytest_regressions.file_regression import FileRegressionFixture +from milatools.cli import init_command from milatools.cli.init_command import ( _get_username, _setup_ssh_config_file, + get_windows_home_path_in_wsl, setup_ssh_config, + setup_windows_ssh_config_from_wsl, ) -from milatools.cli.utils import SSHConfig +from milatools.cli.utils import SSHConfig, running_inside_WSL @pytest.fixture @@ -499,3 +503,140 @@ def test_fixes_dir_permission_issues( assert file.parent.stat().st_mode & 0o777 == 0o700 assert file.exists() assert file.stat().st_mode & 0o777 == 0o600 + + +@pytest.fixture +def linux_ssh_config( + tmp_path: Path, input_pipe: PipeInput, monkeypatch: pytest.MonkeyPatch +) -> SSHConfig: + """Creates the SSH config that would be generated by `mila init`.""" + # Enter username, accept fixing that entry, then confirm. + ssh_config_path = tmp_path / "ssh_config" + + for prompt in ["y", "bob\r", "y"]: + input_pipe.send_text(prompt) + + monkeypatch.setattr("sys.platform", "linux") + # TODO: The config will be different if we run the tests on Windows, it won't + # contain the ControlMaster entries. + setup_ssh_config(ssh_config_path) + + return SSHConfig(ssh_config_path) + + +@pytest.mark.parametrize("accept_changes", [True, False], ids=["accept", "reject"]) +def test_setup_windows_ssh_config_from_wsl( + tmp_path: Path, + linux_ssh_config: SSHConfig, + input_pipe: PipeInput, + file_regression: FileRegressionFixture, + monkeypatch: pytest.MonkeyPatch, + accept_changes: bool, +): + initial_contents = linux_ssh_config.cfg.config() + windows_home = tmp_path / "fake_windows_home" + windows_home.mkdir(exist_ok=False) + windows_ssh_config_path = windows_home / ".ssh" / "config" + + monkeypatch.setattr( + init_command, + running_inside_WSL.__name__, + Mock(spec=running_inside_WSL, return_value=True), + ) + monkeypatch.setattr( + init_command, + get_windows_home_path_in_wsl.__name__, + Mock(spec=get_windows_home_path_in_wsl, return_value=windows_home), + ) + user_inputs: list[str] = [] + if not windows_ssh_config_path.exists(): + # We accept creating the Windows SSH config file for now. + user_inputs.append("y") + user_inputs.append("y" if accept_changes else "n") + + for prompt in user_inputs: + input_pipe.send_text(prompt) + + setup_windows_ssh_config_from_wsl(linux_ssh_config=linux_ssh_config) + + assert windows_ssh_config_path.exists() + assert windows_ssh_config_path.stat().st_mode & 0o777 == 0o600 + assert windows_ssh_config_path.parent.stat().st_mode & 0o777 == 0o700 + if not accept_changes: + assert windows_ssh_config_path.read_text() == "" + + expected_text = "\n".join( + [ + "When this SSH config is already present in the WSL environment with " + + ( + "\n".join( + [ + "these initial contents:", + "```", + initial_contents, + "```", + "", + ] + ) + if initial_contents.strip() + else "no initial ssh config file" + ), + "", + f"and these user inputs: {tuple(user_inputs)}", + "leads the following ssh config file on the Windows side:", + "", + "```", + windows_ssh_config_path.read_text(), + "```", + ] + ) + + file_regression.check(expected_text, extension=".md") + + +def test_setup_windows_ssh_config_from_wsl_copies_keys( + tmp_path: Path, + linux_ssh_config: SSHConfig, + input_pipe: PipeInput, + monkeypatch: pytest.MonkeyPatch, +): + linux_home = tmp_path / "fake_linux_home" + linux_home.mkdir(exist_ok=False) + windows_home = tmp_path / "fake_windows_home" + windows_home.mkdir(exist_ok=False) + monkeypatch.setattr(Path, "home", Mock(spec=Path.home, return_value=linux_home)) + + monkeypatch.setattr( + init_command, + running_inside_WSL.__name__, + Mock(spec=running_inside_WSL, return_value=True), + ) + monkeypatch.setattr( + init_command, + get_windows_home_path_in_wsl.__name__, + Mock(spec=get_windows_home_path_in_wsl, return_value=windows_home), + ) + + fake_linux_ssh_dir = linux_home / ".ssh" + fake_linux_ssh_dir.mkdir(mode=0o700) + + private_key_text = "THIS IS A PRIVATE KEY" + linux_private_key_path = fake_linux_ssh_dir / "id_rsa" + linux_private_key_path.write_text(private_key_text) + + public_key_text = "THIS IS A PUBLIC KEY" + linux_public_key_path = linux_private_key_path.with_suffix(".pub") + linux_public_key_path.write_text(public_key_text) + + input_pipe.send_text("y") # accept creating the Windows config file + input_pipe.send_text("y") # accept the changes + + setup_windows_ssh_config_from_wsl(linux_ssh_config=linux_ssh_config) + + windows_private_key_path = windows_home / ".ssh" / "id_rsa" + windows_public_key_path = windows_private_key_path.with_suffix(".pub") + + assert windows_private_key_path.exists() + assert windows_private_key_path.read_text() == private_key_text + assert windows_public_key_path.exists() + assert windows_public_key_path.read_text() == public_key_text diff --git a/tests/cli/test_init_command/test_setup_windows_ssh_config_from_wsl_accept_.md b/tests/cli/test_init_command/test_setup_windows_ssh_config_from_wsl_accept_.md new file mode 100644 index 00000000..730cab85 --- /dev/null +++ b/tests/cli/test_init_command/test_setup_windows_ssh_config_from_wsl_accept_.md @@ -0,0 +1,68 @@ +When this SSH config is already present in the WSL environment with these initial contents: +``` + +Host mila + HostName login.server.mila.quebec + User bob + PreferredAuthentications publickey,keyboard-interactive + Port 2222 + ServerAliveInterval 120 + ServerAliveCountMax 5 + ControlMaster auto + ControlPath ~/.cache/ssh/%r@%h:%p + ControlPersist 600 + +Host mila-cpu + User bob + Port 2222 + ForwardAgent yes + StrictHostKeyChecking no + LogLevel ERROR + UserKnownHostsFile /dev/null + RequestTTY force + ConnectTimeout 600 + ServerAliveInterval 120 + ProxyCommand ssh mila "/cvmfs/config.mila.quebec/scripts/milatools/slurm-proxy.sh mila-cpu --mem=8G" + RemoteCommand /cvmfs/config.mila.quebec/scripts/milatools/entrypoint.sh mila-cpu + +Host *.server.mila.quebec !*login.server.mila.quebec + HostName %h + User bob + ProxyJump mila + ControlMaster auto + ControlPath ~/.cache/ssh/%r@%h:%p + ControlPersist 600 +``` + + +and these user inputs: ('y', 'y') +leads the following ssh config file on the Windows side: + +``` + +Host mila + HostName login.server.mila.quebec + User bob + PreferredAuthentications publickey,keyboard-interactive + Port 2222 + ServerAliveInterval 120 + ServerAliveCountMax 5 + +Host mila-cpu + User bob + Port 2222 + ForwardAgent yes + StrictHostKeyChecking no + LogLevel ERROR + UserKnownHostsFile /dev/null + RequestTTY force + ConnectTimeout 600 + ServerAliveInterval 120 + ProxyCommand ssh mila "/cvmfs/config.mila.quebec/scripts/milatools/slurm-proxy.sh mila-cpu --mem=8G" + remotecommand /cvmfs/config.mila.quebec/scripts/milatools/entrypoint.sh mila-cpu + +Host *.server.mila.quebec !*login.server.mila.quebec + HostName %h + User bob + ProxyJump mila +``` \ No newline at end of file diff --git a/tests/cli/test_init_command/test_setup_windows_ssh_config_from_wsl_reject_.md b/tests/cli/test_init_command/test_setup_windows_ssh_config_from_wsl_reject_.md new file mode 100644 index 00000000..5940aa01 --- /dev/null +++ b/tests/cli/test_init_command/test_setup_windows_ssh_config_from_wsl_reject_.md @@ -0,0 +1,43 @@ +When this SSH config is already present in the WSL environment with these initial contents: +``` + +Host mila + HostName login.server.mila.quebec + User bob + PreferredAuthentications publickey,keyboard-interactive + Port 2222 + ServerAliveInterval 120 + ServerAliveCountMax 5 + ControlMaster auto + ControlPath ~/.cache/ssh/%r@%h:%p + ControlPersist 600 + +Host mila-cpu + User bob + Port 2222 + ForwardAgent yes + StrictHostKeyChecking no + LogLevel ERROR + UserKnownHostsFile /dev/null + RequestTTY force + ConnectTimeout 600 + ServerAliveInterval 120 + ProxyCommand ssh mila "/cvmfs/config.mila.quebec/scripts/milatools/slurm-proxy.sh mila-cpu --mem=8G" + RemoteCommand /cvmfs/config.mila.quebec/scripts/milatools/entrypoint.sh mila-cpu + +Host *.server.mila.quebec !*login.server.mila.quebec + HostName %h + User bob + ProxyJump mila + ControlMaster auto + ControlPath ~/.cache/ssh/%r@%h:%p + ControlPersist 600 +``` + + +and these user inputs: ('y', 'n') +leads the following ssh config file on the Windows side: + +``` + +``` \ No newline at end of file