Skip to content

Commit

Permalink
Also setup Windows SSH config from WSL
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Nov 8, 2023
1 parent 9b5fd37 commit cbe4d74
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 45 deletions.
50 changes: 10 additions & 40 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,24 @@
from typing_extensions import TypedDict

from ..version import version as mversion
from .init_command import setup_ssh_config
from .init_command import (
setup_ssh_config,
setup_windows_ssh_config_from_wsl,
)
from .local import Local
from .profile import ensure_program, setup_profile
from .remote import Remote, SlurmRemote
from .utils import (
CommandNotFoundError,
MilatoolsUserError,
SSHConfig,
SSHConnectionError,
T,
get_fully_qualified_name,
qualified,
randname,
with_control_file,
yn,
running_inside_WSL,
)

logger = get_logger(__name__)
Expand Down Expand Up @@ -390,7 +393,7 @@ def init():

print("Checking ssh config")

setup_ssh_config()
ssh_config = setup_ssh_config()
# TODO: Move the rest of this command to functions in the init_command module,
# so they can more easily be tested.

Expand Down Expand Up @@ -478,7 +481,8 @@ def init():
# id_rsa.pub and the config to the Windows paths (taking care to remove the
# ControlMaster-related entries) so that the user doesn't need to install Python on
# the Windows side.
warn_if_using_WSL_and_mila_init_not_done_on_Windows()
if running_inside_WSL():
setup_windows_ssh_config_from_wsl(linux_ssh_config=ssh_config)

###################
# Welcome message #
Expand Down Expand Up @@ -590,12 +594,10 @@ def code(
# Try to detect if this is being run from within the Windows Subsystem for Linux.
# If so, then we run `code` through a powershell.exe command to open VSCode without
# issues.
running_inside_WSL = _running_inside_WSL()
warn_if_using_WSL_and_mila_init_not_done_on_Windows()

inside_WSL = running_inside_WSL()
try:
while True:
if running_inside_WSL:
if inside_WSL:
here.run(
"powershell.exe",
"code",
Expand Down Expand Up @@ -632,38 +634,6 @@ def code(
print(T.bold(f" ssh mila scancel {data['jobid']}"))


@functools.lru_cache()
def _running_inside_WSL() -> bool:
return sys.platform == "linux" and bool(shutil.which("powershell.exe"))


def _mila_init_also_done_on_windows() -> bool:
assert _running_inside_WSL()
windows_username = subprocess.getoutput("powershell.exe '$env:UserName'").strip()
windows_ssh_config_file_path = f"/mnt/c/Users/{windows_username}/.ssh/config"
if not os.path.exists(windows_ssh_config_file_path):
return False
ssh_config = SSHConfig(windows_ssh_config_file_path)
configured_hosts = ssh_config.hosts()
if any(host not in configured_hosts for host in ["mila", "mila-cpu"]):
return False
return True


def warn_if_using_WSL_and_mila_init_not_done_on_Windows():
if _running_inside_WSL() and not _mila_init_also_done_on_windows():
warnings.warn(
T.orange(
"It seems like you are using the Windows Subsystem for Linux, and "
"haven't yet set-up your SSH config file on the Windows side.\n"
"Make sure to also `pip install milatools` and run `mila init` "
"from a powershell window (assuming you also already installed Python "
"on Windows) so that you can use `mila code` from within WSL without "
"errors."
)
)


def connect(identifier: str, port: int | None):
"""Reconnect to a persistent server."""

Expand Down
98 changes: 93 additions & 5 deletions milatools/cli/init_command.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
from __future__ import annotations

import difflib
import os
import shutil
import subprocess
import sys
from logging import getLogger as get_logger
from pathlib import Path
from typing import Any
import warnings

import questionary as qn

from .utils import SSHConfig, T, yn
from .utils import SSHConfig, T, running_inside_WSL, yn

WINDOWS_UNSUPPORTED_KEYS = ["ControlMaster", "ControlPath", "ControlPersist"]

logger = get_logger(__name__)


def setup_ssh_config(
ssh_config_path: str | Path = "~/.ssh/config",
):
) -> SSHConfig:
"""Interactively sets up some useful entries in the ~/.ssh/config file on the local machine.
Exits if the User cancels any of the prompts or doesn't confirm the changes when asked.
Expand All @@ -28,6 +35,9 @@ def setup_ssh_config(
directly to compute nodes.
TODO: Also ask if we should add entries for the ComputeCanada/DRAC clusters.
Returns:
The resulting SSHConfig if the changes are approved.
"""

ssh_config_path = _setup_ssh_config_file(ssh_config_path)
Expand Down Expand Up @@ -116,6 +126,45 @@ def setup_ssh_config(
else:
ssh_config.save()
print(f"Wrote {ssh_config_path}")
return ssh_config


def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfig):
"""Setup the Windows SSH configuration and public key from within WSL.
This copies over the entries from the linux ssh configuration file, except for the
values that aren't supported on Windows (e.g. "ControlMaster").
This also copies the public key file from the linux SSH directory over to the
Windows SSH directory if it isn't already present.
This makes it so the user doesn't need to install Python/Anaconda on the Windows
side in order to use `mila code` from within WSL.
"""
assert running_inside_WSL()
windows_home = get_windows_home_path_in_wsl()
windows_ssh_config_path = windows_home / ".ssh/config"
if not windows_ssh_config_path.exists():
# SSHConfig needs an existing file.
windows_ssh_config_path.touch(mode=0b110_000_000)

_copy_valid_ssh_entries_to_windows_ssh_config_file(
linux_ssh_config, windows_ssh_config_path
)

# copy the public_key_to_windows_ssh_folder
# TODO: This might be different if the user selects a non-default location during
# `ssh-keygen`.
linux_pubkey_file = Path.home() / ".ssh/id_rsa.pub"
windows_pubkey_file = windows_home / ".ssh/id_rsa.pub"
if linux_pubkey_file.exists() and not windows_pubkey_file.exists():
shutil.copy2(src=linux_pubkey_file, dst=windows_pubkey_file)


def get_windows_home_path_in_wsl() -> Path:
assert running_inside_WSL()
windows_username = subprocess.getoutput("powershell.exe '$env:UserName'").strip()
return Path(f"/mnt/c/Users/{windows_username}")


def _setup_ssh_config_file(config_file_path: str | Path) -> Path:
Expand Down Expand Up @@ -150,7 +199,12 @@ def _setup_ssh_config_file(config_file_path: str | Path) -> Path:


def _confirm_changes(ssh_config: SSHConfig, previous: str) -> bool:
print(T.bold("The following modifications will be made to your ~/.ssh/config:\n"))
print(
T.bold(
f"The following modifications will be made to your SSH config file at "
f"{ssh_config.path}:\n"
)
)
diff_lines = list(
difflib.unified_diff(
(previous + "\n").splitlines(True),
Expand Down Expand Up @@ -227,12 +281,46 @@ def _add_ssh_entry(
existing_entry = ssh_config.host(host)
existing_entry.update(entry)
ssh_config.cfg.set(host, **existing_entry)
logger.debug(f"Updated {host} entry in ssh config.")
logger.debug(f"Updated {host} entry in ssh config at path {ssh_config.path}.")
else:
ssh_config.add(
host,
_space_before=_space_before,
_space_after=_space_after,
**entry,
)
logger.debug(f"Adding new {host} entry in ssh config.")
logger.debug(
f"Adding new {host} entry in ssh config at path {ssh_config.path}."
)


def _copy_valid_ssh_entries_to_windows_ssh_config_file(
linux_ssh_config: SSHConfig, windows_ssh_config_path: Path
):
windows_ssh_config = SSHConfig(windows_ssh_config_path)
initial_windows_config_contents = windows_ssh_config.cfg.config()

unsupported_keys_lowercase = set(k.lower() for k in WINDOWS_UNSUPPORTED_KEYS)

for host in linux_ssh_config.hosts():
linux_ssh_entry: dict[str, Any] = linux_ssh_config.host(host)
_add_ssh_entry(
windows_ssh_config,
host,
**{
key: value
for key, value in linux_ssh_entry.items()
if key.lower() not in unsupported_keys_lowercase
},
)

new_config_contents = windows_ssh_config.cfg.config()
if new_config_contents == initial_windows_config_contents:
print(f"Did not change ssh config at path {windows_ssh_config.path}")
return
if not _confirm_changes(
windows_ssh_config, previous=initial_windows_config_contents
):
exit()
# We made changes and they were accepted.
windows_ssh_config.save()
9 changes: 9 additions & 0 deletions milatools/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from __future__ import annotations

import contextvars
import functools
import itertools
import random
import shlex
import shutil
import socket
import subprocess
from contextlib import contextmanager
from pathlib import Path
import sys

import blessed
import paramiko
Expand Down Expand Up @@ -140,6 +143,7 @@ class SSHConfig:
"""Wrapper around sshconf with some extra niceties."""

def __init__(self, path: str | Path):
self.path = path
self.cfg = read_ssh_config(path)
# self.add = self.cfg.add
self.remove = self.cfg.remove
Expand Down Expand Up @@ -220,3 +224,8 @@ def get_fully_qualified_name() -> str:
except Exception:
# Fall back, e.g. on Windows.
return socket.getfqdn()


@functools.lru_cache()
def running_inside_WSL() -> bool:
return sys.platform == "linux" and bool(shutil.which("powershell.exe"))

0 comments on commit cbe4d74

Please sign in to comment.