diff --git a/milatools/cli/commands.py b/milatools/cli/commands.py index 953f2697..b8f31d48 100644 --- a/milatools/cli/commands.py +++ b/milatools/cli/commands.py @@ -30,14 +30,16 @@ from typing_extensions import TypedDict from ..version import version as mversion -from .init_command import setup_ssh_config +from .init_command import ( + setup_ssh_config, + setup_windows_ssh_config_from_wsl, +) from .local import Local from .profile import ensure_program, setup_profile from .remote import Remote, SlurmRemote from .utils import ( CommandNotFoundError, MilatoolsUserError, - SSHConfig, SSHConnectionError, T, get_fully_qualified_name, @@ -45,6 +47,7 @@ randname, with_control_file, yn, + running_inside_WSL, ) logger = get_logger(__name__) @@ -390,7 +393,7 @@ def init(): print("Checking ssh config") - setup_ssh_config() + ssh_config = setup_ssh_config() # TODO: Move the rest of this command to functions in the init_command module, # so they can more easily be tested. @@ -478,7 +481,8 @@ def init(): # id_rsa.pub and the config to the Windows paths (taking care to remove the # ControlMaster-related entries) so that the user doesn't need to install Python on # the Windows side. - warn_if_using_WSL_and_mila_init_not_done_on_Windows() + if running_inside_WSL(): + setup_windows_ssh_config_from_wsl(linux_ssh_config=ssh_config) ################### # Welcome message # @@ -590,12 +594,10 @@ def code( # Try to detect if this is being run from within the Windows Subsystem for Linux. # If so, then we run `code` through a powershell.exe command to open VSCode without # issues. - running_inside_WSL = _running_inside_WSL() - warn_if_using_WSL_and_mila_init_not_done_on_Windows() - + inside_WSL = running_inside_WSL() try: while True: - if running_inside_WSL: + if inside_WSL: here.run( "powershell.exe", "code", @@ -632,38 +634,6 @@ def code( print(T.bold(f" ssh mila scancel {data['jobid']}")) -@functools.lru_cache() -def _running_inside_WSL() -> bool: - return sys.platform == "linux" and bool(shutil.which("powershell.exe")) - - -def _mila_init_also_done_on_windows() -> bool: - assert _running_inside_WSL() - windows_username = subprocess.getoutput("powershell.exe '$env:UserName'").strip() - windows_ssh_config_file_path = f"/mnt/c/Users/{windows_username}/.ssh/config" - if not os.path.exists(windows_ssh_config_file_path): - return False - ssh_config = SSHConfig(windows_ssh_config_file_path) - configured_hosts = ssh_config.hosts() - if any(host not in configured_hosts for host in ["mila", "mila-cpu"]): - return False - return True - - -def warn_if_using_WSL_and_mila_init_not_done_on_Windows(): - if _running_inside_WSL() and not _mila_init_also_done_on_windows(): - warnings.warn( - T.orange( - "It seems like you are using the Windows Subsystem for Linux, and " - "haven't yet set-up your SSH config file on the Windows side.\n" - "Make sure to also `pip install milatools` and run `mila init` " - "from a powershell window (assuming you also already installed Python " - "on Windows) so that you can use `mila code` from within WSL without " - "errors." - ) - ) - - def connect(identifier: str, port: int | None): """Reconnect to a persistent server.""" diff --git a/milatools/cli/init_command.py b/milatools/cli/init_command.py index b2338cb7..5468bf36 100644 --- a/milatools/cli/init_command.py +++ b/milatools/cli/init_command.py @@ -1,20 +1,27 @@ from __future__ import annotations import difflib +import os +import shutil +import subprocess import sys from logging import getLogger as get_logger from pathlib import Path +from typing import Any +import warnings import questionary as qn -from .utils import SSHConfig, T, yn +from .utils import SSHConfig, T, running_inside_WSL, yn + +WINDOWS_UNSUPPORTED_KEYS = ["ControlMaster", "ControlPath", "ControlPersist"] logger = get_logger(__name__) def setup_ssh_config( ssh_config_path: str | Path = "~/.ssh/config", -): +) -> SSHConfig: """Interactively sets up some useful entries in the ~/.ssh/config file on the local machine. Exits if the User cancels any of the prompts or doesn't confirm the changes when asked. @@ -28,6 +35,9 @@ def setup_ssh_config( directly to compute nodes. TODO: Also ask if we should add entries for the ComputeCanada/DRAC clusters. + + Returns: + The resulting SSHConfig if the changes are approved. """ ssh_config_path = _setup_ssh_config_file(ssh_config_path) @@ -116,6 +126,45 @@ def setup_ssh_config( else: ssh_config.save() print(f"Wrote {ssh_config_path}") + return ssh_config + + +def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfig): + """Setup the Windows SSH configuration and public key from within WSL. + + This copies over the entries from the linux ssh configuration file, except for the + values that aren't supported on Windows (e.g. "ControlMaster"). + + This also copies the public key file from the linux SSH directory over to the + Windows SSH directory if it isn't already present. + + This makes it so the user doesn't need to install Python/Anaconda on the Windows + side in order to use `mila code` from within WSL. + """ + assert running_inside_WSL() + windows_home = get_windows_home_path_in_wsl() + windows_ssh_config_path = windows_home / ".ssh/config" + if not windows_ssh_config_path.exists(): + # SSHConfig needs an existing file. + windows_ssh_config_path.touch(mode=0b110_000_000) + + _copy_valid_ssh_entries_to_windows_ssh_config_file( + linux_ssh_config, windows_ssh_config_path + ) + + # copy the public_key_to_windows_ssh_folder + # TODO: This might be different if the user selects a non-default location during + # `ssh-keygen`. + linux_pubkey_file = Path.home() / ".ssh/id_rsa.pub" + windows_pubkey_file = windows_home / ".ssh/id_rsa.pub" + if linux_pubkey_file.exists() and not windows_pubkey_file.exists(): + shutil.copy2(src=linux_pubkey_file, dst=windows_pubkey_file) + + +def get_windows_home_path_in_wsl() -> Path: + assert running_inside_WSL() + windows_username = subprocess.getoutput("powershell.exe '$env:UserName'").strip() + return Path(f"/mnt/c/Users/{windows_username}") def _setup_ssh_config_file(config_file_path: str | Path) -> Path: @@ -150,7 +199,12 @@ def _setup_ssh_config_file(config_file_path: str | Path) -> Path: def _confirm_changes(ssh_config: SSHConfig, previous: str) -> bool: - print(T.bold("The following modifications will be made to your ~/.ssh/config:\n")) + print( + T.bold( + f"The following modifications will be made to your SSH config file at " + f"{ssh_config.path}:\n" + ) + ) diff_lines = list( difflib.unified_diff( (previous + "\n").splitlines(True), @@ -227,7 +281,7 @@ def _add_ssh_entry( existing_entry = ssh_config.host(host) existing_entry.update(entry) ssh_config.cfg.set(host, **existing_entry) - logger.debug(f"Updated {host} entry in ssh config.") + logger.debug(f"Updated {host} entry in ssh config at path {ssh_config.path}.") else: ssh_config.add( host, @@ -235,4 +289,38 @@ def _add_ssh_entry( _space_after=_space_after, **entry, ) - logger.debug(f"Adding new {host} entry in ssh config.") + logger.debug( + f"Adding new {host} entry in ssh config at path {ssh_config.path}." + ) + + +def _copy_valid_ssh_entries_to_windows_ssh_config_file( + linux_ssh_config: SSHConfig, windows_ssh_config_path: Path +): + windows_ssh_config = SSHConfig(windows_ssh_config_path) + initial_windows_config_contents = windows_ssh_config.cfg.config() + + unsupported_keys_lowercase = set(k.lower() for k in WINDOWS_UNSUPPORTED_KEYS) + + for host in linux_ssh_config.hosts(): + linux_ssh_entry: dict[str, Any] = linux_ssh_config.host(host) + _add_ssh_entry( + windows_ssh_config, + host, + **{ + key: value + for key, value in linux_ssh_entry.items() + if key.lower() not in unsupported_keys_lowercase + }, + ) + + new_config_contents = windows_ssh_config.cfg.config() + if new_config_contents == initial_windows_config_contents: + print(f"Did not change ssh config at path {windows_ssh_config.path}") + return + if not _confirm_changes( + windows_ssh_config, previous=initial_windows_config_contents + ): + exit() + # We made changes and they were accepted. + windows_ssh_config.save() diff --git a/milatools/cli/utils.py b/milatools/cli/utils.py index 4c0bdc65..00da8ff3 100644 --- a/milatools/cli/utils.py +++ b/milatools/cli/utils.py @@ -1,13 +1,16 @@ from __future__ import annotations import contextvars +import functools import itertools import random import shlex +import shutil import socket import subprocess from contextlib import contextmanager from pathlib import Path +import sys import blessed import paramiko @@ -140,6 +143,7 @@ class SSHConfig: """Wrapper around sshconf with some extra niceties.""" def __init__(self, path: str | Path): + self.path = path self.cfg = read_ssh_config(path) # self.add = self.cfg.add self.remove = self.cfg.remove @@ -220,3 +224,8 @@ def get_fully_qualified_name() -> str: except Exception: # Fall back, e.g. on Windows. return socket.getfqdn() + + +@functools.lru_cache() +def running_inside_WSL() -> bool: + return sys.platform == "linux" and bool(shutil.which("powershell.exe"))