Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the mila init and mila code commands on windows & WSL #65

Merged
merged 26 commits into from
Nov 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a5ecdcc
Dont add ControlMaster ssh entries on Windows
lebrice Nov 6, 2023
b5f6ce9
Remove the unused params of _add_ssh_entry
lebrice Nov 6, 2023
e43a4fb
Rebase on Master
lebrice Nov 6, 2023
05b2b9a
Fix isort issues
lebrice Nov 6, 2023
c6978af
Fix the `mila init` command on windows
lebrice Nov 2, 2023
fb30bb8
Fix black formatting errors
lebrice Nov 6, 2023
126a46b
Fix the equivalent of ssh-copy-id, now works!
lebrice Nov 6, 2023
557c9ac
Attempt to fix running `mila code` from WSL
lebrice Nov 7, 2023
ddd7ad0
Fix whitespace errors and run `code` directly
lebrice Nov 7, 2023
7f0b0c4
Set the remote.SSH settings in User Settings Json
lebrice Nov 7, 2023
440d1e3
Show that VsCode settings are found
lebrice Nov 7, 2023
1838323
Remove the VsCode settings thing for now
lebrice Nov 8, 2023
0b92c8f
Use `here.run` to run the command
lebrice Nov 7, 2023
a0ad53f
Skip the hostname checking on Windows machines
lebrice Nov 7, 2023
e881705
Fix black formatting issue
lebrice Nov 8, 2023
477997d
Fix missing condition in WSL check
lebrice Nov 8, 2023
9b5fd37
Warn WSL users if `mila init` not done on Windows
lebrice Nov 8, 2023
cbe4d74
Also setup Windows SSH config from WSL
lebrice Nov 8, 2023
611b2b6
Fix isort issues
lebrice Nov 8, 2023
f687847
Reuse existing fn to setup windows SSH file
lebrice Nov 9, 2023
62bd443
Preserve ordering of SSH entries when copying
lebrice Nov 9, 2023
a2f4b86
Copy both public and private key files
lebrice Nov 9, 2023
90b85b8
Add tests for the Windows SSH setup from WSL
lebrice Nov 9, 2023
dd87787
Remove need for pytest-mock dependency
lebrice Nov 9, 2023
531b89a
Fix a bug in the added tests
lebrice Nov 9, 2023
f073c9e
Remove accidentally-added .pre-commit-config.yaml
lebrice Nov 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 48 additions & 12 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import argparse
import functools
import operator
import os
import re
Expand All @@ -15,6 +16,7 @@
import time
import traceback
import typing
import warnings
import webbrowser
from argparse import ArgumentParser, _HelpAction
from contextlib import ExitStack
Expand All @@ -28,7 +30,7 @@
from typing_extensions import TypedDict

from ..version import version as mversion
from .init_command import setup_ssh_config
from .init_command import setup_ssh_config, setup_windows_ssh_config_from_wsl
from .local import Local
from .profile import ensure_program, setup_profile
from .remote import Remote, SlurmRemote
Expand All @@ -40,6 +42,7 @@
get_fully_qualified_name,
qualified,
randname,
running_inside_WSL,
with_control_file,
yn,
)
Expand All @@ -50,8 +53,9 @@


def main():
on_mila = get_fully_qualified_name().endswith(".server.mila.quebec")
if on_mila:
if sys.platform != "win32" and get_fully_qualified_name().endswith(
".server.mila.quebec"
):
exit(
"ERROR: 'mila ...' should be run on your local machine and not on the Mila cluster"
)
Expand Down Expand Up @@ -386,7 +390,7 @@ def init():

print("Checking ssh config")

setup_ssh_config()
ssh_config = setup_ssh_config()
# TODO: Move the rest of this command to functions in the init_command module,
# so they can more easily be tested.

Expand All @@ -408,6 +412,7 @@ def init():
for entry in os.listdir(sshdir)
):
if yn("You have no public keys. Generate one?"):
# TODO: need to get the location of the key as an output of this command!
here.run("ssh-keygen")
else:
exit("No public keys.")
Expand All @@ -418,7 +423,15 @@ def init():
if yn(
"Your public key does not appear be registered on the cluster. Register it?"
):
here.run("ssh-copy-id", "mila")
# NOTE: If we're on a Windows machine, we do something different here:
if sys.platform == "win32":
command = (
"powershell.exe type $env:USERPROFILE\\.ssh\\id_rsa.pub | ssh mila "
'"cat >> ~/.ssh/authorized_keys"'
)
here.run(command)
else:
here.run("ssh-copy-id", "mila")
if not here.check_passwordless("mila"):
exit("ssh-copy-id appears to have failed")
else:
Expand Down Expand Up @@ -461,6 +474,13 @@ def init():
else:
exit("You will not be able to SSH to a compute node")

# TODO: IF we're running on WSL, we could probably actually just copy the
# id_rsa.pub and the config to the Windows paths (taking care to remove the
# ControlMaster-related entries) so that the user doesn't need to install Python on
# the Windows side.
if running_inside_WSL():
setup_windows_ssh_config_from_wsl(linux_ssh_config=ssh_config)

###################
# Welcome message #
###################
Expand Down Expand Up @@ -566,15 +586,31 @@ def code(
command_path = shutil.which(command)
if not command_path:
raise CommandNotFoundError(command)
qualified_node_name = qualified(node_name)

# Try to detect if this is being run from within the Windows Subsystem for Linux.
# If so, then we run `code` through a powershell.exe command to open VSCode without
# issues.
inside_WSL = running_inside_WSL()
try:
while True:
here.run(
command_path,
"-nw",
"--remote",
f"ssh-remote+{qualified(node_name)}",
path,
)
if inside_WSL:
here.run(
"powershell.exe",
"code",
"-nw",
"--remote",
f"ssh-remote+{qualified_node_name}",
path,
)
else:
here.run(
command_path,
"-nw",
"--remote",
f"ssh-remote+{qualified_node_name}",
path,
)
print(
"The editor was closed. Reopen it with <Enter>"
" or terminate the process with <Ctrl+C>"
Expand Down
157 changes: 139 additions & 18 deletions milatools/cli/init_command.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
from __future__ import annotations

import difflib
import shutil
import subprocess
import sys
import warnings
from logging import getLogger as get_logger
from pathlib import Path
from typing import Any

import questionary as qn

from .utils import SSHConfig, T, yn
from .utils import SSHConfig, T, running_inside_WSL, yn

WINDOWS_UNSUPPORTED_KEYS = ["ControlMaster", "ControlPath", "ControlPersist"]

logger = get_logger(__name__)

HOSTS = ["mila", "mila-cpu", "*.server.mila.quebec !*login.server.mila.quebec"]
"""List of host entries that get added to the SSH configurtion by `mila init`."""


def setup_ssh_config(
ssh_config_path: str | Path = "~/.ssh/config",
):
) -> SSHConfig:
"""Interactively sets up some useful entries in the ~/.ssh/config file on the local machine.

Exits if the User cancels any of the prompts or doesn't confirm the changes when asked.
Expand All @@ -27,6 +37,9 @@ def setup_ssh_config(
directly to compute nodes.

TODO: Also ask if we should add entries for the ComputeCanada/DRAC clusters.

Returns:
The resulting SSHConfig if the changes are approved.
"""

ssh_config_path = _setup_ssh_config_file(ssh_config_path)
Expand All @@ -39,21 +52,28 @@ def setup_ssh_config(
# sure that the directory actually exists.
control_path_dir.expanduser().mkdir(exist_ok=True, parents=True)

if sys.platform == "win32":
ssh_multiplexing_config = {}
else:
ssh_multiplexing_config = {
# Tries to reuse an existing connection, but if it fails, it will create a new one.
"ControlMaster": "auto",
# This makes a file per connection, like [email protected]:2222
"ControlPath": str(control_path_dir / r"%r@%h:%p"),
# persist for 10 minutes after the last connection ends.
"ControlPersist": 600,
}

_add_ssh_entry(
ssh_config,
"mila",
host="mila",
HostName="login.server.mila.quebec",
User=username,
PreferredAuthentications="publickey,keyboard-interactive",
Port=2222,
ServerAliveInterval=120,
ServerAliveCountMax=5,
# Tries to reuse an existing connection, but if it fails, it will create a new one.
ControlMaster="auto",
# This makes a file per connection, like [email protected]:2222
ControlPath=str(control_path_dir / r"%r@%h:%p"),
# persist for 10 minutes after the last connection ends.
ControlPersist=600,
**ssh_multiplexing_config,
)

_add_ssh_entry(
Expand Down Expand Up @@ -97,12 +117,7 @@ def setup_ssh_config(
HostName="%h",
User=username,
ProxyJump="mila",
# Tries to reuse an existing connection, but if it fails, it will create a new one.
ControlMaster="auto",
# This makes a file per connection, like [email protected]:2222
ControlPath=str(control_path_dir / r"%r@%h:%p"),
# persist for 10 minutes after the last connection ends.
ControlPersist=600,
**ssh_multiplexing_config,
)

new_config = ssh_config.cfg.config()
Expand All @@ -113,6 +128,76 @@ def setup_ssh_config(
else:
ssh_config.save()
print(f"Wrote {ssh_config_path}")
return ssh_config


def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfig):
"""Setup the Windows SSH configuration and public key from within WSL.

This copies over the entries from the linux ssh configuration file, except for the
values that aren't supported on Windows (e.g. "ControlMaster").

This also copies the public key file from the linux SSH directory over to the
Windows SSH directory if it isn't already present.

This makes it so the user doesn't need to install Python/Anaconda on the Windows
side in order to use `mila code` from within WSL.
"""
assert running_inside_WSL()
# NOTE: This also assumes that a public/private key pair has already been generated
# at ~/.ssh/id_rsa.pub and ~/.ssh/id_rsa.
windows_home = get_windows_home_path_in_wsl()
windows_ssh_config_path = windows_home / ".ssh/config"
windows_ssh_config_path = _setup_ssh_config_file(windows_ssh_config_path)

windows_ssh_config = SSHConfig(windows_ssh_config_path)

initial_windows_config_contents = windows_ssh_config.cfg.config()
_copy_valid_ssh_entries_to_windows_ssh_config_file(
linux_ssh_config, windows_ssh_config
)
new_windows_config_contents = windows_ssh_config.cfg.config()

if (
new_windows_config_contents != initial_windows_config_contents
and _confirm_changes(windows_ssh_config, initial_windows_config_contents)
):
# We made changes and they were accepted.
windows_ssh_config.save()
else:
print(f"Did not change ssh config at path {windows_ssh_config.path}")
return # also skip copying the SSH keys.

# Copy the SSH key to the windows folder so that passwordless SSH also works on
# Windows.
# TODO: This will need to change if we support using a non-default location at some
# point.
linux_private_key_file = Path.home() / ".ssh/id_rsa"
windows_private_key_file = windows_home / ".ssh/id_rsa"

for linux_key_file, windows_key_file in [
(linux_private_key_file, windows_private_key_file),
(
linux_private_key_file.with_suffix(".pub"),
windows_private_key_file.with_suffix(".pub"),
),
]:
_copy_if_needed(linux_key_file, windows_key_file)


def _copy_if_needed(linux_key_file: Path, windows_key_file: Path):
if linux_key_file.exists() and not windows_key_file.exists():
print(
f"Copying {linux_key_file} over to the Windows ssh folder at "
f"{windows_key_file}."
)
shutil.copy2(src=linux_key_file, dst=windows_key_file)


def get_windows_home_path_in_wsl() -> Path:
assert running_inside_WSL()
windows_username = subprocess.getoutput("powershell.exe '$env:UserName'").strip()
return Path(f"/mnt/c/Users/{windows_username}")


def _setup_ssh_config_file(config_file_path: str | Path) -> Path:
Expand Down Expand Up @@ -147,7 +232,12 @@ def _setup_ssh_config_file(config_file_path: str | Path) -> Path:


def _confirm_changes(ssh_config: SSHConfig, previous: str) -> bool:
print(T.bold("The following modifications will be made to your ~/.ssh/config:\n"))
print(
T.bold(
f"The following modifications will be made to your SSH config file at "
f"{ssh_config.path}:\n"
)
)
diff_lines = list(
difflib.unified_diff(
(previous + "\n").splitlines(True),
Expand Down Expand Up @@ -206,6 +296,7 @@ def _add_ssh_entry(
ssh_config: SSHConfig,
host: str,
Host: str | None = None,
*,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the * is doing? Is it to grab all extra args? Should an exception be thrown instead?

Copy link
Collaborator Author

@lebrice lebrice Nov 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's making it so _space_before and _space_after cannot be passed as positional arguments. This helps remove a typing error when calls to _add_ssh_entry are made with different **kwargs (Windows vs non-Windows)

Copy link
Collaborator Author

@lebrice lebrice Nov 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for reference: image

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah get it thanks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we, should we assert that nothing is in this *?

def ... , *_args, ...:
    assert not _args

Only thinking out loud here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly! That's what using a * without a name does

_space_before: bool = True,
_space_after: bool = False,
**entry,
Expand All @@ -223,12 +314,42 @@ def _add_ssh_entry(
existing_entry = ssh_config.host(host)
existing_entry.update(entry)
ssh_config.cfg.set(host, **existing_entry)
logger.debug(f"Updated {host} entry in ssh config.")
logger.debug(f"Updated {host} entry in ssh config at path {ssh_config.path}.")
else:
ssh_config.add(
host,
_space_before=_space_before,
_space_after=_space_after,
**entry,
)
logger.debug(f"Adding new {host} entry in ssh config.")
logger.debug(
f"Adding new {host} entry in ssh config at path {ssh_config.path}."
)


def _copy_valid_ssh_entries_to_windows_ssh_config_file(
linux_ssh_config: SSHConfig, windows_ssh_config: SSHConfig
):
unsupported_keys_lowercase = set(k.lower() for k in WINDOWS_UNSUPPORTED_KEYS)

# NOTE: need to preserve the ordering of entries:
for host in HOSTS + [
host for host in linux_ssh_config.hosts() if host not in HOSTS
]:
if host not in linux_ssh_config.hosts():
warnings.warn(
RuntimeWarning(
f"Weird, we expected to have a {host!r} entry in the SSH config..."
)
)
continue
linux_ssh_entry: dict[str, Any] = linux_ssh_config.host(host)
_add_ssh_entry(
windows_ssh_config,
host,
**{
key: value
for key, value in linux_ssh_entry.items()
if key.lower() not in unsupported_keys_lowercase
},
)
9 changes: 9 additions & 0 deletions milatools/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import contextvars
import functools
import itertools
import random
import shlex
import shutil
import socket
import subprocess
import sys
from contextlib import contextmanager
from pathlib import Path

Expand Down Expand Up @@ -140,6 +143,7 @@ class SSHConfig:
"""Wrapper around sshconf with some extra niceties."""

def __init__(self, path: str | Path):
self.path = path
self.cfg = read_ssh_config(path)
# self.add = self.cfg.add
self.remove = self.cfg.remove
Expand Down Expand Up @@ -220,3 +224,8 @@ def get_fully_qualified_name() -> str:
except Exception:
# Fall back, e.g. on Windows.
return socket.getfqdn()


@functools.lru_cache()
def running_inside_WSL() -> bool:
return sys.platform == "linux" and bool(shutil.which("powershell.exe"))
Loading
Loading