Skip to content

Commit

Permalink
Fix the mila init and mila code commands on windows & WSL (#65)
Browse files Browse the repository at this point in the history
* Dont add ControlMaster ssh entries on Windows

Signed-off-by: Fabrice Normandin <[email protected]>

* Remove the unused params of _add_ssh_entry

Signed-off-by: Fabrice Normandin <[email protected]>

* Rebase on Master

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix isort issues

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix the `mila init` command on windows

Fixes #63

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix black formatting errors

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix the equivalent of ssh-copy-id, now works!

Signed-off-by: Fabrice Normandin <[email protected]>

* Attempt to fix running `mila code` from WSL

Signed-off-by: Fabrice <[email protected]>

* Fix whitespace errors and run `code` directly

Signed-off-by: Fabrice Normandin <[email protected]>

* Set the remote.SSH settings in User Settings Json

Signed-off-by: Fabrice Normandin <[email protected]>

* Show that VsCode settings are found

Signed-off-by: Fabrice Normandin <[email protected]>

* Remove the VsCode settings thing for now

Signed-off-by: Fabrice Normandin <[email protected]>

* Use `here.run` to run the command

Signed-off-by: Fabrice <[email protected]>

* Skip the hostname checking on Windows machines

Signed-off-by: Fabrice <[email protected]>

* Fix black formatting issue

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix missing condition in WSL check

Signed-off-by: Fabrice Normandin <[email protected]>

* Warn WSL users if `mila init` not done on Windows

Signed-off-by: Fabrice Normandin <[email protected]>

* Also setup Windows SSH config from WSL

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix isort issues

Signed-off-by: Fabrice Normandin <[email protected]>

* Reuse existing fn to setup windows SSH file

Signed-off-by: Fabrice Normandin <[email protected]>

* Preserve ordering of SSH entries when copying

Signed-off-by: Fabrice Normandin <[email protected]>

* Copy both public and private key files

Signed-off-by: Fabrice Normandin <[email protected]>

* Add tests for the Windows SSH setup from WSL

Signed-off-by: Fabrice Normandin <[email protected]>

* Remove need for pytest-mock dependency

Signed-off-by: Fabrice Normandin <[email protected]>

* Fix a bug in the added tests

Signed-off-by: Fabrice Normandin <[email protected]>

* Remove accidentally-added .pre-commit-config.yaml

Signed-off-by: Fabrice Normandin <[email protected]>

---------

Signed-off-by: Fabrice Normandin <[email protected]>
Signed-off-by: Fabrice <[email protected]>
  • Loading branch information
lebrice authored Nov 11, 2023
1 parent f238cf5 commit 270e3c4
Show file tree
Hide file tree
Showing 6 changed files with 449 additions and 31 deletions.
60 changes: 48 additions & 12 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import argparse
import functools
import operator
import os
import re
Expand All @@ -15,6 +16,7 @@
import time
import traceback
import typing
import warnings
import webbrowser
from argparse import ArgumentParser, _HelpAction
from contextlib import ExitStack
Expand All @@ -28,7 +30,7 @@
from typing_extensions import TypedDict

from ..version import version as mversion
from .init_command import setup_ssh_config
from .init_command import setup_ssh_config, setup_windows_ssh_config_from_wsl
from .local import Local
from .profile import ensure_program, setup_profile
from .remote import Remote, SlurmRemote
Expand All @@ -40,6 +42,7 @@
get_fully_qualified_name,
qualified,
randname,
running_inside_WSL,
with_control_file,
yn,
)
Expand All @@ -50,8 +53,9 @@


def main():
on_mila = get_fully_qualified_name().endswith(".server.mila.quebec")
if on_mila:
if sys.platform != "win32" and get_fully_qualified_name().endswith(
".server.mila.quebec"
):
exit(
"ERROR: 'mila ...' should be run on your local machine and not on the Mila cluster"
)
Expand Down Expand Up @@ -386,7 +390,7 @@ def init():

print("Checking ssh config")

setup_ssh_config()
ssh_config = setup_ssh_config()
# TODO: Move the rest of this command to functions in the init_command module,
# so they can more easily be tested.

Expand All @@ -408,6 +412,7 @@ def init():
for entry in os.listdir(sshdir)
):
if yn("You have no public keys. Generate one?"):
# TODO: need to get the location of the key as an output of this command!
here.run("ssh-keygen")
else:
exit("No public keys.")
Expand All @@ -418,7 +423,15 @@ def init():
if yn(
"Your public key does not appear be registered on the cluster. Register it?"
):
here.run("ssh-copy-id", "mila")
# NOTE: If we're on a Windows machine, we do something different here:
if sys.platform == "win32":
command = (
"powershell.exe type $env:USERPROFILE\\.ssh\\id_rsa.pub | ssh mila "
'"cat >> ~/.ssh/authorized_keys"'
)
here.run(command)
else:
here.run("ssh-copy-id", "mila")
if not here.check_passwordless("mila"):
exit("ssh-copy-id appears to have failed")
else:
Expand Down Expand Up @@ -461,6 +474,13 @@ def init():
else:
exit("You will not be able to SSH to a compute node")

# TODO: IF we're running on WSL, we could probably actually just copy the
# id_rsa.pub and the config to the Windows paths (taking care to remove the
# ControlMaster-related entries) so that the user doesn't need to install Python on
# the Windows side.
if running_inside_WSL():
setup_windows_ssh_config_from_wsl(linux_ssh_config=ssh_config)

###################
# Welcome message #
###################
Expand Down Expand Up @@ -566,15 +586,31 @@ def code(
command_path = shutil.which(command)
if not command_path:
raise CommandNotFoundError(command)
qualified_node_name = qualified(node_name)

# Try to detect if this is being run from within the Windows Subsystem for Linux.
# If so, then we run `code` through a powershell.exe command to open VSCode without
# issues.
inside_WSL = running_inside_WSL()
try:
while True:
here.run(
command_path,
"-nw",
"--remote",
f"ssh-remote+{qualified(node_name)}",
path,
)
if inside_WSL:
here.run(
"powershell.exe",
"code",
"-nw",
"--remote",
f"ssh-remote+{qualified_node_name}",
path,
)
else:
here.run(
command_path,
"-nw",
"--remote",
f"ssh-remote+{qualified_node_name}",
path,
)
print(
"The editor was closed. Reopen it with <Enter>"
" or terminate the process with <Ctrl+C>"
Expand Down
157 changes: 139 additions & 18 deletions milatools/cli/init_command.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
from __future__ import annotations

import difflib
import shutil
import subprocess
import sys
import warnings
from logging import getLogger as get_logger
from pathlib import Path
from typing import Any

import questionary as qn

from .utils import SSHConfig, T, yn
from .utils import SSHConfig, T, running_inside_WSL, yn

WINDOWS_UNSUPPORTED_KEYS = ["ControlMaster", "ControlPath", "ControlPersist"]

logger = get_logger(__name__)

HOSTS = ["mila", "mila-cpu", "*.server.mila.quebec !*login.server.mila.quebec"]
"""List of host entries that get added to the SSH configurtion by `mila init`."""


def setup_ssh_config(
ssh_config_path: str | Path = "~/.ssh/config",
):
) -> SSHConfig:
"""Interactively sets up some useful entries in the ~/.ssh/config file on the local machine.
Exits if the User cancels any of the prompts or doesn't confirm the changes when asked.
Expand All @@ -27,6 +37,9 @@ def setup_ssh_config(
directly to compute nodes.
TODO: Also ask if we should add entries for the ComputeCanada/DRAC clusters.
Returns:
The resulting SSHConfig if the changes are approved.
"""

ssh_config_path = _setup_ssh_config_file(ssh_config_path)
Expand All @@ -39,21 +52,28 @@ def setup_ssh_config(
# sure that the directory actually exists.
control_path_dir.expanduser().mkdir(exist_ok=True, parents=True)

if sys.platform == "win32":
ssh_multiplexing_config = {}
else:
ssh_multiplexing_config = {
# Tries to reuse an existing connection, but if it fails, it will create a new one.
"ControlMaster": "auto",
# This makes a file per connection, like [email protected]:2222
"ControlPath": str(control_path_dir / r"%r@%h:%p"),
# persist for 10 minutes after the last connection ends.
"ControlPersist": 600,
}

_add_ssh_entry(
ssh_config,
"mila",
host="mila",
HostName="login.server.mila.quebec",
User=username,
PreferredAuthentications="publickey,keyboard-interactive",
Port=2222,
ServerAliveInterval=120,
ServerAliveCountMax=5,
# Tries to reuse an existing connection, but if it fails, it will create a new one.
ControlMaster="auto",
# This makes a file per connection, like [email protected]:2222
ControlPath=str(control_path_dir / r"%r@%h:%p"),
# persist for 10 minutes after the last connection ends.
ControlPersist=600,
**ssh_multiplexing_config,
)

_add_ssh_entry(
Expand Down Expand Up @@ -97,12 +117,7 @@ def setup_ssh_config(
HostName="%h",
User=username,
ProxyJump="mila",
# Tries to reuse an existing connection, but if it fails, it will create a new one.
ControlMaster="auto",
# This makes a file per connection, like [email protected]:2222
ControlPath=str(control_path_dir / r"%r@%h:%p"),
# persist for 10 minutes after the last connection ends.
ControlPersist=600,
**ssh_multiplexing_config,
)

new_config = ssh_config.cfg.config()
Expand All @@ -113,6 +128,76 @@ def setup_ssh_config(
else:
ssh_config.save()
print(f"Wrote {ssh_config_path}")
return ssh_config


def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfig):
"""Setup the Windows SSH configuration and public key from within WSL.
This copies over the entries from the linux ssh configuration file, except for the
values that aren't supported on Windows (e.g. "ControlMaster").
This also copies the public key file from the linux SSH directory over to the
Windows SSH directory if it isn't already present.
This makes it so the user doesn't need to install Python/Anaconda on the Windows
side in order to use `mila code` from within WSL.
"""
assert running_inside_WSL()
# NOTE: This also assumes that a public/private key pair has already been generated
# at ~/.ssh/id_rsa.pub and ~/.ssh/id_rsa.
windows_home = get_windows_home_path_in_wsl()
windows_ssh_config_path = windows_home / ".ssh/config"
windows_ssh_config_path = _setup_ssh_config_file(windows_ssh_config_path)

windows_ssh_config = SSHConfig(windows_ssh_config_path)

initial_windows_config_contents = windows_ssh_config.cfg.config()
_copy_valid_ssh_entries_to_windows_ssh_config_file(
linux_ssh_config, windows_ssh_config
)
new_windows_config_contents = windows_ssh_config.cfg.config()

if (
new_windows_config_contents != initial_windows_config_contents
and _confirm_changes(windows_ssh_config, initial_windows_config_contents)
):
# We made changes and they were accepted.
windows_ssh_config.save()
else:
print(f"Did not change ssh config at path {windows_ssh_config.path}")
return # also skip copying the SSH keys.

# Copy the SSH key to the windows folder so that passwordless SSH also works on
# Windows.
# TODO: This will need to change if we support using a non-default location at some
# point.
linux_private_key_file = Path.home() / ".ssh/id_rsa"
windows_private_key_file = windows_home / ".ssh/id_rsa"

for linux_key_file, windows_key_file in [
(linux_private_key_file, windows_private_key_file),
(
linux_private_key_file.with_suffix(".pub"),
windows_private_key_file.with_suffix(".pub"),
),
]:
_copy_if_needed(linux_key_file, windows_key_file)


def _copy_if_needed(linux_key_file: Path, windows_key_file: Path):
if linux_key_file.exists() and not windows_key_file.exists():
print(
f"Copying {linux_key_file} over to the Windows ssh folder at "
f"{windows_key_file}."
)
shutil.copy2(src=linux_key_file, dst=windows_key_file)


def get_windows_home_path_in_wsl() -> Path:
assert running_inside_WSL()
windows_username = subprocess.getoutput("powershell.exe '$env:UserName'").strip()
return Path(f"/mnt/c/Users/{windows_username}")


def _setup_ssh_config_file(config_file_path: str | Path) -> Path:
Expand Down Expand Up @@ -147,7 +232,12 @@ def _setup_ssh_config_file(config_file_path: str | Path) -> Path:


def _confirm_changes(ssh_config: SSHConfig, previous: str) -> bool:
print(T.bold("The following modifications will be made to your ~/.ssh/config:\n"))
print(
T.bold(
f"The following modifications will be made to your SSH config file at "
f"{ssh_config.path}:\n"
)
)
diff_lines = list(
difflib.unified_diff(
(previous + "\n").splitlines(True),
Expand Down Expand Up @@ -206,6 +296,7 @@ def _add_ssh_entry(
ssh_config: SSHConfig,
host: str,
Host: str | None = None,
*,
_space_before: bool = True,
_space_after: bool = False,
**entry,
Expand All @@ -223,12 +314,42 @@ def _add_ssh_entry(
existing_entry = ssh_config.host(host)
existing_entry.update(entry)
ssh_config.cfg.set(host, **existing_entry)
logger.debug(f"Updated {host} entry in ssh config.")
logger.debug(f"Updated {host} entry in ssh config at path {ssh_config.path}.")
else:
ssh_config.add(
host,
_space_before=_space_before,
_space_after=_space_after,
**entry,
)
logger.debug(f"Adding new {host} entry in ssh config.")
logger.debug(
f"Adding new {host} entry in ssh config at path {ssh_config.path}."
)


def _copy_valid_ssh_entries_to_windows_ssh_config_file(
linux_ssh_config: SSHConfig, windows_ssh_config: SSHConfig
):
unsupported_keys_lowercase = set(k.lower() for k in WINDOWS_UNSUPPORTED_KEYS)

# NOTE: need to preserve the ordering of entries:
for host in HOSTS + [
host for host in linux_ssh_config.hosts() if host not in HOSTS
]:
if host not in linux_ssh_config.hosts():
warnings.warn(
RuntimeWarning(
f"Weird, we expected to have a {host!r} entry in the SSH config..."
)
)
continue
linux_ssh_entry: dict[str, Any] = linux_ssh_config.host(host)
_add_ssh_entry(
windows_ssh_config,
host,
**{
key: value
for key, value in linux_ssh_entry.items()
if key.lower() not in unsupported_keys_lowercase
},
)
9 changes: 9 additions & 0 deletions milatools/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import contextvars
import functools
import itertools
import random
import shlex
import shutil
import socket
import subprocess
import sys
from contextlib import contextmanager
from pathlib import Path

Expand Down Expand Up @@ -140,6 +143,7 @@ class SSHConfig:
"""Wrapper around sshconf with some extra niceties."""

def __init__(self, path: str | Path):
self.path = path
self.cfg = read_ssh_config(path)
# self.add = self.cfg.add
self.remove = self.cfg.remove
Expand Down Expand Up @@ -220,3 +224,8 @@ def get_fully_qualified_name() -> str:
except Exception:
# Fall back, e.g. on Windows.
return socket.getfqdn()


@functools.lru_cache()
def running_inside_WSL() -> bool:
return sys.platform == "linux" and bool(shutil.which("powershell.exe"))
Loading

0 comments on commit 270e3c4

Please sign in to comment.